diff --git a/.bazelrc b/.bazelrc index ad3ca15d1..1246d336b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,5 +1,45 @@ -build --cxxopt=-std=c++17 +common --enable_platform_specific_config + +build --enable_bzlmod +build --compilation_mode=fastbuild + +build:linux --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 +build:linux --cxxopt=-fsized-deallocation +build:linux --copt=-Wno-deprecated-declarations + +# you will typically need to spell out the compiler for local dev +# BAZEL_VC= +# BAZEL_VC_FULL_VERSION=14.44.3520 +# Some dependencies rely on bash so you will likely need msys2 +# BAZEL_SH=C:\msys64\usr\bin\bash.exe +build:msvc --cxxopt="-std:c++20" --cxxopt="-utf-8" --host_cxxopt="-std:c++20" +build:msvc --define=protobuf_allow_msvc=true +build:msvc --test_tag_filters=-benchmark,-notap,-no_test_msvc +build:msvc --build_tag_filters=-no_test_msvc + +build:macos --cxxopt=-faligned-allocation +build:macos --cxxopt=-mmacosx-version-min=10.13 +build:macos --linkopt=-mmacosx-version-min=10.13 + +# ANTLR tool requires Java 17+. +build --java_runtime_version=remotejdk_17 + +test --test_output=errors # Enable matchers in googletest build --define absl=1 +build:asan --linkopt -ldl +build:asan --linkopt -fsanitize=address +build:asan --copt -fsanitize=address +build:asan --copt -DADDRESS_SANITIZER=1 +build:asan --copt -D__SANITIZE_ADDRESS__ +build:asan --test_env=ASAN_OPTIONS=handle_abort=1:allow_addr2line=true:check_initialization_order=true:strict_init_order=true:detect_odr_violation=1 +build:asan --test_env=ASAN_SYMBOLIZER_PATH +build:asan --copt -O1 +build:asan --copt -fno-optimize-sibling-calls +build:asan --linkopt=-fuse-ld=lld + +try-import %workspace%/clang.bazelrc +try-import %workspace%/user.bazelrc +try-import %workspace%/local_tsan.bazelrc diff --git a/.bazelversion b/.bazelversion index 4a36342fc..df5119ec6 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -3.0.0 +8.7.0 diff --git a/.bcr/README.md b/.bcr/README.md new file mode 100644 index 000000000..5dc023f4e --- /dev/null +++ b/.bcr/README.md @@ -0,0 +1,35 @@ +# BCR Publishing Templates + +This directory contains templates used by the +[Publish to BCR](https://github.com/bazel-contrib/publish-to-bcr) GitHub Action +to automatically publish new versions of cel-cpp to the +[Bazel Central Registry (BCR)](https://github.com/bazelbuild/bazel-central-registry). + +## Files + +- **metadata.template.json**: Contains repository metadata including homepage, + maintainers, and repository location +- **source.template.json**: Template for generating the source.json file that + tells BCR where to download release archives +- **presubmit.yml**: Defines build and test tasks that BCR will run to verify + each published version + +## How it works + +When a new tag matching the pattern `v*.*.*` is created: 1. The GitHub Actions +workflow `.github/workflows/publish_to_bcr.yml` is triggered 2. The workflow +uses these templates to generate a BCR entry 3. A pull request is automatically +created against the Bazel Central Registry 4. Once merged, the new version +becomes available to Bazel users via bzlmod + +## Template Variables + +The following variables are automatically substituted: - `{OWNER}`: Repository +owner (google) - `{REPO}`: Repository name (cel-cpp) - `{VERSION}`: Version +number extracted from the tag (e.g., `0.14.0` from `v0.14.0`) - `{TAG}`: Full +tag name (e.g., `v0.14.0`) + +## More Information + +- [Publish to BCR documentation](https://github.com/bazel-contrib/publish-to-bcr) +- [BCR documentation](https://bazel.build/external/registry) diff --git a/.bcr/metadata.template.json b/.bcr/metadata.template.json new file mode 100644 index 000000000..00106b58f --- /dev/null +++ b/.bcr/metadata.template.json @@ -0,0 +1,34 @@ +{ + "homepage": "https://cel.dev", + "maintainers": [ + { + "email": "ferstl@intrinsic.ai", + "github": "ferstlf", + "github_user_id": 64520639, + "name": "Florian Ferstl" + }, + { + "email": "cel-lang-discuss@googlegroups.com", + "github": "cel-expr", + "github_user_id": 186625994, + "name": "CEL Team" + }, + { + "github": "jnthntatum", + "github_user_id": 733856 + }, + { + "github": "jcking", + "github_user_id": 997958 + }, + { + "github": "tristonianjones", + "github_user_id": 483300 + } + ], + "repository": [ + "github:google/cel-cpp" + ], + "versions": [], + "yanked_versions": {} +} diff --git a/.bcr/presubmit.yml b/.bcr/presubmit.yml new file mode 100644 index 000000000..b711847e0 --- /dev/null +++ b/.bcr/presubmit.yml @@ -0,0 +1,19 @@ +matrix: + platform: + - debian11 + - ubuntu2004 + bazel: + - 8.x + - 7.x +tasks: + verify_targets: + name: Verify build targets + platform: ${{ platform }} + bazel: ${{ bazel }} + build_flags: + - '--cxxopt=-std=c++17' + - '--host_cxxopt=-std=c++17' + - '--copt=-Wno-deprecated-declarations' + - '--define=absl=1' + build_targets: + - '@cel-cpp//...' diff --git a/.bcr/source.template.json b/.bcr/source.template.json new file mode 100644 index 000000000..df5af957c --- /dev/null +++ b/.bcr/source.template.json @@ -0,0 +1,5 @@ +{ + "integrity": "", + "strip_prefix": "cel-cpp-{VERSION}", + "url": "https://github.com/{OWNER}/{REPO}/archive/refs/tags/{TAG}.tar.gz" +} diff --git a/.github/workflows/publish_to_bcr.yml b/.github/workflows/publish_to_bcr.yml new file mode 100644 index 000000000..3ad6e91b8 --- /dev/null +++ b/.github/workflows/publish_to_bcr.yml @@ -0,0 +1,19 @@ +name: Publish to BCR + +on: + push: + tags: + - "v*.*.*" + +permissions: + id-token: write + attestations: write + contents: write + +jobs: + publish: + uses: bazel-contrib/publish-to-bcr/.github/workflows/publish.yaml@v1.0.0 + with: + tag_name: ${{ github.ref_name }} + secrets: + publish_token: ${{ secrets.BCR_PUBLISH_TOKEN }} diff --git a/.github/workflows/windows_bazel_test.yml b/.github/workflows/windows_bazel_test.yml new file mode 100644 index 000000000..6d12e6861 --- /dev/null +++ b/.github/workflows/windows_bazel_test.yml @@ -0,0 +1,28 @@ +name: Windows Bazel Test + +on: + workflow_call: + workflow_dispatch: + +jobs: + test: + name: Run Bazel Tests + runs-on: windows-latest + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Setup Bazel and Bazelisk + uses: bazel-contrib/setup-bazel@0.19.0 + with: + bazelisk-cache: true + disk-cache: ${{ github.workflow }} + repository-cache: true + + - name: Run Tests + # msys2 'bash' on Windows will try to 'fix' the label prefix to + # work as a directory. + # //... won't work. + shell: bash + run: | + bazelisk test --config=msvc conformance:all conformance/policy:all \ No newline at end of file diff --git a/.github/workflows/windows_bazel_test_post_merge.yml b/.github/workflows/windows_bazel_test_post_merge.yml new file mode 100644 index 000000000..569177fcc --- /dev/null +++ b/.github/workflows/windows_bazel_test_post_merge.yml @@ -0,0 +1,13 @@ +name: Windows Bazel Test (Post-Merge) + +on: + push: + branches: + - master + +jobs: + trigger-test: + # This prevents the workflow from running automatically when someone + # pushes to their fork. + if: github.repository == 'cel-expr/cel-cpp' + uses: ./.github/workflows/windows_bazel_test.yml \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6d3e1b8bb..8594eee37 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,11 @@ -# bazel produces these as symlinks, not directories bazel-bin -bazel-cel-cpp +bazel-eval bazel-genfiles bazel-out bazel-testlogs +bazel-cel-cpp +*~ +clang.bazelrc +user.bazelrc +local_tsan.bazelrc +MODULE.bazel.lock \ No newline at end of file diff --git a/BUILD.bazel b/BUILD.bazel new file mode 100644 index 000000000..ffd0fb0cd --- /dev/null +++ b/BUILD.bazel @@ -0,0 +1 @@ +package(default_visibility = ["//visibility:public"]) diff --git a/Dockerfile b/Dockerfile index 2561f3a82..97611fc75 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,69 @@ -FROM ubuntu:bionic +# This Dockerfile is used to create a container around gcc9 and bazel for +# building the CEL C++ library on GitHub. +# +# To update a new version of this container, use gcloud. You may need to run +# `gcloud auth login` and `gcloud auth configure-docker` first. +# +# Note, if you need to run docker using `sudo` use the following commands +# instead: +# +# sudo gcloud auth login --no-launch-browser +# sudo gcloud auth configure-docker +# +# Run the following command from the root of the CEL repository: +# +# gcloud builds submit --region=us -t gcr.io/cel-analysis/cel-cpp/ubuntu_floor . +# +# Once complete get the sha256 digest from the output using the following +# command: +# +# gcloud artifacts versions list --package=cel-cpp/ubuntu_floor --repository=gcr.io \ +# --location=us +# +# The cloudbuild.yaml file must be updated to use the new digest like so: +# +# - name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@' +FROM gcr.io/cloud-marketplace/google/ubuntu2204:latest -ENV DEBIAN_FRONTEND=noninteractive +# Install Bazel prerequesites and required tools. +# See https://docs.bazel.build/versions/master/install-ubuntu.html +RUN apt-get update && apt-get upgrade -y && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + bash \ + ca-certificates \ + git \ + libssl-dev \ + make \ + pkg-config \ + python3 \ + unzip \ + wget \ + zip \ + zlib1g-dev \ + default-jdk-headless \ + clang-11 \ + gcc-9 g++-9 \ + tzdata \ + && apt-get clean -RUN rm -rf /var/lib/apt/lists/* \ - && apt-get update --fix-missing -qq \ - && apt-get install -qqy --no-install-recommends ca-certificates tzdata wget git clang-10 patch \ - && apt-get clean && rm -rf /var/lib/apt/lists/* +# Install Bazelisk. +# https://github.com/bazelbuild/bazelisk/releases +ARG BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-amd64.deb" +ARG BAZELISK_CHKSUM="d8b00ea975c823e15263c80200ac42979e17368547fbff4ab177af035badfa83" +ADD ${BAZELISK_URL} /tmp/bazelisk.deb -RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.5.0/bazelisk-linux-amd64 && chmod +x bazelisk-linux-amd64 && mv bazelisk-linux-amd64 /bin/bazel +ENV BAZELISK_CHKSUM=${BAZELISK_CHKSUM} +RUN echo "${BAZELISK_CHKSUM} */tmp/bazelisk.deb" | sha256sum --check -ENV CC=clang-10 -ENV CXX=clang++-10 +RUN apt-get install /tmp/bazelisk.deb -ENTRYPOINT ["/bin/bazel"] +RUN mkdir -p /workspace +RUN mkdir -p /bazel + +RUN USE_BAZEL_VERSION=8.7.0 bazelisk help +RUN USE_BAZEL_VERSION=7.3.2 bazelisk help + +ENV CC=gcc-9 +ENV CXX=g++-9 + +ENTRYPOINT ["/usr/bin/bazelisk"] diff --git a/MODULE.bazel b/MODULE.bazel new file mode 100644 index 000000000..187d68164 --- /dev/null +++ b/MODULE.bazel @@ -0,0 +1,112 @@ +module( + name = "cel-cpp", +) + +bazel_dep( + name = "bazel_skylib", + version = "1.9.0", +) +bazel_dep( + name = "googleapis", + version = "0.0.0-20241220-5e258e33.bcr.1", + repo_name = "com_google_googleapis", +) +bazel_dep( + name = "googleapis-cc", + version = "1.0.0", +) +bazel_dep( + name = "rules_cc", + version = "0.2.14", +) +bazel_dep( + name = "rules_java", + version = "8.6.1", +) +bazel_dep( + name = "rules_proto", + version = "7.1.0", +) +bazel_dep( + name = "rules_python", + version = "1.6.3", +) +bazel_dep(name = "rules_license", version = "1.0.0") +bazel_dep( + name = "protobuf", + version = "34.1", + repo_name = "com_google_protobuf", +) +bazel_dep( + name = "abseil-cpp", + version = "20260107.0", + repo_name = "com_google_absl", +) +bazel_dep( + name = "googletest", + version = "1.17.0.bcr.2", + repo_name = "com_google_googletest", +) +bazel_dep( + name = "google_benchmark", + version = "1.9.2", + repo_name = "com_github_google_benchmark", +) +bazel_dep( + name = "re2", + version = "2025-11-05.bcr.1", + repo_name = "com_googlesource_code_re2", +) +bazel_dep( + name = "flatbuffers", + version = "25.9.23", + repo_name = "com_github_google_flatbuffers", +) +bazel_dep( + name = "cel-spec", + version = "0.25.1", + repo_name = "com_google_cel_spec", +) +bazel_dep( + name = "platforms", + version = "1.0.0", +) +bazel_dep( + name = "antlr4-cpp-runtime", + version = "4.13.2.bcr.2", +) + +python = use_extension("@rules_python//python/extensions:python.bzl", "python") +python.toolchain( + configure_coverage_tool = False, + ignore_root_user_error = True, + python_version = "3.11", +) + +http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") + +ANTLR4_VERSION = "4.13.2" + +http_jar( + name = "antlr4_jar", + sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", + urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], +) + +bazel_dep( + name = "yaml-cpp", + version = "0.9.0", +) + +_CEL_POLICY_TAG = "ebfb2361f47080af643c14cf4da4c2b551a68740" + +_CEL_POLICY_SHA = "ea69e9c6b7bd5bc37d358148aebd2fcca38bc7c45a23feb635de72338e0327c1" + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "cel_policy", + sha256 = _CEL_POLICY_SHA, + strip_prefix = "cel-policy-%s" % _CEL_POLICY_TAG, + url = "https://github.com/cel-expr/cel-policy/archive/%s.tar.gz" % _CEL_POLICY_TAG, +) diff --git a/README.md b/README.md index b70501dde..7c3c26be0 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,18 @@ # C++ Implementations of the Common Expression Language +> [!WARNING] +> **On June 16, 2026, this repository will move to +> github.com/cel-expr/cel-cpp!** +> +> Please update your links and dependencies. See the [pinned +> issue](https://github.com/google/cel-cpp/issues/2029) for details. + For background on the Common Expression Language see the [cel-spec][1] repo. -This is a C++ implementation of a [Common Expression Language][1] runtime. +This is a C++ implementation of a [Common Expression Language][1] runtime, +parser, and type checker. Released under the [Apache License](LICENSE). -Disclaimer: This is not an official Google product. - -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec diff --git a/WORKSPACE b/WORKSPACE deleted file mode 100644 index a910efe12..000000000 --- a/WORKSPACE +++ /dev/null @@ -1,143 +0,0 @@ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") - -CEL_SPEC_GIT_SHA = "02a12e7cffe452a611b0e6ef47872963bbd87028" # 4/17/2020 - -CEL_SPEC_SHA = "757cfdb00dc76fd0d12dadbae982c22a9218711d5e4cf30c94cfe6c05b1cdf2b" - -http_archive( - name = "com_google_cel_spec", - sha256 = CEL_SPEC_SHA, - strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, - urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], -) - -http_archive( - name = "com_google_absl", - strip_prefix = "abseil-cpp-master", - urls = ["https://github.com/abseil/abseil-cpp/archive/master.zip"], -) - -# Google RE2 (Regular Expression) C++ Library -http_archive( - name = "com_googlesource_code_re2", - strip_prefix = "re2-master", - urls = ["https://github.com/google/re2/archive/master.zip"], -) - -# gRPC dependencies: -http_archive( - name = "com_github_grpc_grpc", - sha256 = "1236514199d3deb111a6dd7f6092f67617cd2b147f7eda7adbafccea95de7381", - strip_prefix = "grpc-1.31.0", - urls = ["https://github.com/grpc/grpc/archive/v1.31.0.tar.gz"], -) - -load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") - -grpc_deps() - -GOOGLEAPIS_GIT_SHA = "be480e391cc88a75cf2a81960ef79c80d5012068" # Jul 24, 2019 - -GOOGLEAPIS_SHA = "c1969e5b72eab6d9b6cfcff748e45ba57294aeea1d96fd04cd081995de0605c2" - -http_archive( - name = "com_google_googleapis", - sha256 = GOOGLEAPIS_SHA, - strip_prefix = "googleapis-" + GOOGLEAPIS_GIT_SHA, - urls = ["https://github.com/googleapis/googleapis/archive/" + GOOGLEAPIS_GIT_SHA + ".tar.gz"], -) - -load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") - -switched_rules_by_language( - name = "com_google_googleapis_imports", - cc = True, - go = True, - grpc = True, -) - -http_archive( - name = "io_bazel_rules_go", - sha256 = "f04d2373bcaf8aa09bccb08a98a57e721306c8f6043a2a0ee610fd6853dcde3d", - urls = ["https://github.com/bazelbuild/rules_go/releases/download/0.18.6/rules_go-0.18.6.tar.gz"], -) - -load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") - -# cel-go dependencies: -http_archive( - name = "bazel_gazelle", - sha256 = "3c681998538231a2d24d0c07ed5a7658cb72bfb5fd4bf9911157c0e9ac6a2687", - urls = ["https://github.com/bazelbuild/bazel-gazelle/releases/download/0.17.0/bazel-gazelle-0.17.0.tar.gz"], -) - -load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") - -git_repository( - name = "com_google_cel_go", - remote = "https://github.com/google/cel-go.git", - tag = "v0.5.1", -) - -go_repository( - name = "org_golang_google_genproto", - build_file_proto_mode = "disable", - commit = "bd91e49a0898e27abb88c339b432fa53d7497ac0", - importpath = "google.golang.org/genproto", -) - -go_repository( - name = "com_github_antlr", - commit = "621b933c7a7f01c67ae9de15103151fa0f9d6d90", - importpath = "github.com/antlr/antlr4", -) - -go_rules_dependencies() - -go_register_toolchains() - -gazelle_dependencies() - -# Parser dependencies -http_archive( - name = "rules_antlr", - sha256 = "7249d1569293d9b239e23c65f6b4c81a07da921738bde0dfeb231ed98be40429", - strip_prefix = "rules_antlr-3cc2f9502a54ceb7b79b37383316b23c4da66f9a", - urls = ["https://github.com/marcohu/rules_antlr/archive/3cc2f9502a54ceb7b79b37383316b23c4da66f9a.tar.gz"], -) - -load("@rules_antlr//antlr:deps.bzl", "antlr_dependencies") - -antlr_dependencies(472) - -http_archive( - name = "antlr4_runtimes", - build_file_content = """ -package(default_visibility = ["//visibility:public"]) -cc_library( - name = "cpp", - srcs = glob(["runtime/Cpp/runtime/src/**/*.cpp"]), - hdrs = glob(["runtime/Cpp/runtime/src/**/*.h"]), - includes = ["runtime/Cpp/runtime/src"], -) -""", - sha256 = "46f5e1af5f4bd28ade55cb632f9a069656b31fc8c2408f9aa045f9b5f5caad64", - strip_prefix = "antlr4-4.7.2", - urls = ["https://github.com/antlr/antlr4/archive/4.7.2.tar.gz"], -) - -# tools/flatbuffers dependencies -FLAT_BUFFERS_SHA = "a83caf5910644ba1c421c002ef68e42f21c15f9f" - -http_archive( - name = "com_github_google_flatbuffers", - sha256 = "b8efbc25721e76780752bad775a97c3f77a0250271e2db37fc747b20e8b0f24a", - strip_prefix = "flatbuffers-" + FLAT_BUFFERS_SHA, - url = "https://github.com/google/flatbuffers/archive/" + FLAT_BUFFERS_SHA + ".tar.gz", -) - -# Needed by gRPC build rules (but not used). Should be after genproto. -load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") - -grpc_extra_deps() diff --git a/base/BUILD b/base/BUILD index 61633b38f..a239d4751 100644 --- a/base/BUILD +++ b/base/BUILD @@ -1,12 +1,156 @@ -licenses(["notice"]) # Apache v2.0 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") package( + # Under active development, not yet being released. default_visibility = ["//visibility:public"], ) +licenses(["notice"]) + +cc_library( + name = "attributes", + srcs = [ + "attribute.cc", + ], + hdrs = [ + "attribute.h", + "attribute_set.h", + ], + deps = [ + ":kind", + "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "kind", + hdrs = ["kind.h"], + deps = [ + "//common:kind", + "//common:type_kind", + "//common:value_kind", + ], +) + +cc_library( + name = "operators", + srcs = ["operators.cc"], + hdrs = ["operators.h"], + deps = [ + "//base/internal:operators", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "operators_test", + srcs = ["operators_test.cc"], + deps = [ + ":operators", + "//base/internal:operators", + "//internal:testing", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +# Build target encompassing cel::Type, cel::Value, and their related classes. +cc_library( + name = "data", + hdrs = [ + "type_provider.h", + ], + deps = [ + "//common:value", + ], +) + cc_library( - name = "status_macros", + name = "function", hdrs = [ - "status_macros.h", + "function.h", + ], + deps = [ + "//runtime:function", ], ) + +cc_library( + name = "function_descriptor", + hdrs = [ + "function_descriptor.h", + ], + deps = [ + "//common:function_descriptor", + ], +) + +cc_library( + name = "function_result", + hdrs = [ + "function_result.h", + ], + deps = [":function_descriptor"], +) + +cc_library( + name = "function_result_set", + srcs = [ + "function_result_set.cc", + ], + hdrs = [ + "function_result_set.h", + ], + deps = [ + ":function_result", + "@com_google_absl//absl/container:btree", + ], +) + +cc_library( + name = "ast", + hdrs = ["ast.h"], + deps = ["//common:ast"], +) + +cc_library( + name = "function_adapter", + hdrs = ["function_adapter.h"], + deps = [ + "//runtime:function_adapter", + ], +) + +cc_library( + name = "builtins", + hdrs = ["builtins.h"], +) diff --git a/base/README b/base/README deleted file mode 100644 index 26c974c82..000000000 --- a/base/README +++ /dev/null @@ -1,5 +0,0 @@ -This directory contains forked copies of google libraries not already available -in open source. Generally, these libraries should always be considered -'internal' and subject to change without notice. - -The original copy is located in https://github.com/google/zetasql/tree/master/zetasql/base diff --git a/base/ast.h b/base/ast.h new file mode 100644 index 000000000..9f5dfaaa7 --- /dev/null +++ b/base/ast.h @@ -0,0 +1,20 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_AST_H_ +#define THIRD_PARTY_CEL_CPP_BASE_AST_H_ + +#include "common/ast.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_BASE_AST_H_ diff --git a/base/attribute.cc b/base/attribute.cc new file mode 100644 index 000000000..f750a1850 --- /dev/null +++ b/base/attribute.cc @@ -0,0 +1,330 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/attribute.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/variant.h" +#include "base/kind.h" +#include "internal/status_macros.h" + +namespace cel { + +namespace { + +// Visitor for appending string representation for different qualifier kinds. +class AttributeStringPrinter { + public: + // String representation for the given qualifier is appended to output. + // output must be non-null. + explicit AttributeStringPrinter(std::string* output, Kind type) + : output_(*output), type_(type) {} + + absl::Status operator()(const Kind& ignored) const { + // Attributes are represented as a variant, with illegal attribute + // qualifiers represented with their type as the first alternative. + return absl::InvalidArgumentError( + absl::StrCat("Unsupported attribute qualifier ", KindToString(type_))); + } + + absl::Status operator()(int64_t index) { + absl::StrAppend(&output_, "[", index, "]"); + return absl::OkStatus(); + } + + absl::Status operator()(uint64_t index) { + absl::StrAppend(&output_, "[", index, "]"); + return absl::OkStatus(); + } + + absl::Status operator()(bool bool_key) { + absl::StrAppend(&output_, "[", (bool_key) ? "true" : "false", "]"); + return absl::OkStatus(); + } + + absl::Status operator()(const std::string& field) { + absl::StrAppend(&output_, ".", field); + return absl::OkStatus(); + } + + private: + std::string& output_; + Kind type_; +}; + +// Visitor for appending string representation for different qualifier kinds. +class AttributeQualifierStringPrinter { + public: + // String representation for the given qualifier is appended to output. + explicit AttributeQualifierStringPrinter(std::string* absl_nonnull output, + Kind type) + : output_(*output), type_(type) {} + + absl::Status operator()(const Kind& ignored) const { + // Attributes are represented as a variant, with illegal attribute + // qualifiers represented with their type as the first alternative. + return absl::InvalidArgumentError( + absl::StrCat("Unsupported attribute qualifier ", KindToString(type_))); + } + + absl::Status operator()(int64_t index) { + absl::StrAppend(&output_, index); + return absl::OkStatus(); + } + + absl::Status operator()(uint64_t index) { + absl::StrAppend(&output_, index); + return absl::OkStatus(); + } + + absl::Status operator()(bool bool_key) { + absl::StrAppend(&output_, (bool_key) ? "true" : "false"); + return absl::OkStatus(); + } + + absl::Status operator()(const std::string& field) { + absl::StrAppend(&output_, field); + return absl::OkStatus(); + } + + private: + std::string& output_; + Kind type_; +}; + +struct AttributeQualifierTypeVisitor final { + Kind operator()(const Kind& type) const { return type; } + + Kind operator()(int64_t ignored) const { + static_cast(ignored); + return Kind::kInt64; + } + + Kind operator()(uint64_t ignored) const { + static_cast(ignored); + return Kind::kUint64; + } + + Kind operator()(const std::string& ignored) const { + static_cast(ignored); + return Kind::kString; + } + + Kind operator()(bool ignored) const { + static_cast(ignored); + return Kind::kBool; + } +}; + +struct AttributeQualifierTypeComparator final { + const Kind lhs; + + bool operator()(const Kind& rhs) const { + return static_cast(lhs) < static_cast(rhs); + } + + bool operator()(int64_t) const { return false; } + + bool operator()(uint64_t other) const { return false; } + + bool operator()(const std::string&) const { return false; } + + bool operator()(bool other) const { return false; } +}; + +struct AttributeQualifierIntComparator final { + const int64_t lhs; + + bool operator()(const Kind&) const { return true; } + + bool operator()(int64_t rhs) const { return lhs < rhs; } + + bool operator()(uint64_t) const { return true; } + + bool operator()(const std::string&) const { return true; } + + bool operator()(bool) const { return false; } +}; + +struct AttributeQualifierUintComparator final { + const uint64_t lhs; + + bool operator()(const Kind&) const { return true; } + + bool operator()(int64_t) const { return false; } + + bool operator()(uint64_t rhs) const { return lhs < rhs; } + + bool operator()(const std::string&) const { return true; } + + bool operator()(bool) const { return false; } +}; + +struct AttributeQualifierStringComparator final { + const std::string& lhs; + + bool operator()(const Kind&) const { return true; } + + bool operator()(int64_t) const { return false; } + + bool operator()(uint64_t) const { return false; } + + bool operator()(const std::string& rhs) const { return lhs < rhs; } + + bool operator()(bool) const { return false; } +}; + +struct AttributeQualifierBoolComparator final { + const bool lhs; + + bool operator()(const Kind&) const { return true; } + + bool operator()(int64_t) const { return true; } + + bool operator()(uint64_t) const { return true; } + + bool operator()(const std::string&) const { return true; } + + bool operator()(bool rhs) const { return lhs < rhs; } +}; + +} // namespace + +struct AttributeQualifier::ComparatorVisitor final { + const AttributeQualifier::Variant& rhs; + + bool operator()(const Kind& lhs) const { + return absl::visit(AttributeQualifierTypeComparator{lhs}, rhs); + } + + bool operator()(int64_t lhs) const { + return absl::visit(AttributeQualifierIntComparator{lhs}, rhs); + } + + bool operator()(uint64_t lhs) const { + return absl::visit(AttributeQualifierUintComparator{lhs}, rhs); + } + + bool operator()(const std::string& lhs) const { + return absl::visit(AttributeQualifierStringComparator{lhs}, rhs); + } + + bool operator()(bool lhs) const { + return absl::visit(AttributeQualifierBoolComparator{lhs}, rhs); + } +}; + +Kind AttributeQualifier::kind() const { + return absl::visit(AttributeQualifierTypeVisitor{}, value_); +} + +bool AttributeQualifier::operator<(const AttributeQualifier& other) const { + // The order is not publicly documented because it is subject to change. + // Currently we sort in the following order, with each type being sorted + // against itself: bool, int, uint, string, type. + return absl::visit(ComparatorVisitor{other.value_}, value_); +} + +bool Attribute::operator==(const Attribute& other) const { + // We cannot check pointer equality as a short circuit because we have to + // treat all invalid AttributeQualifier as not equal to each other. + // TODO(issues/41) we only support Ident-rooted attributes at the moment. + if (variable_name() != other.variable_name()) { + return false; + } + + if (qualifier_path().size() != other.qualifier_path().size()) { + return false; + } + + for (size_t i = 0; i < qualifier_path().size(); i++) { + if (!(qualifier_path()[i] == other.qualifier_path()[i])) { + return false; + } + } + + return true; +} + +bool Attribute::operator<(const Attribute& other) const { + if (impl_.get() == other.impl_.get()) { + return false; + } + auto lhs_begin = qualifier_path().begin(); + auto lhs_end = qualifier_path().end(); + auto rhs_begin = other.qualifier_path().begin(); + auto rhs_end = other.qualifier_path().end(); + while (lhs_begin != lhs_end && rhs_begin != rhs_end) { + if (*lhs_begin < *rhs_begin) { + return true; + } + if (!(*lhs_begin == *rhs_begin)) { + return false; + } + lhs_begin++; + rhs_begin++; + } + if (lhs_begin == lhs_end && rhs_begin == rhs_end) { + // Neither has any elements left, they are equal. Compare variable names. + return variable_name() < other.variable_name(); + } + if (lhs_begin == lhs_end) { + // Left has no more elements. Right is greater. + return true; + } + // Right has no more elements. Left is greater. + ABSL_ASSERT(rhs_begin == rhs_end); + return false; +} + +const absl::StatusOr Attribute::AsString() const { + if (variable_name().empty()) { + return absl::InvalidArgumentError( + "Only ident rooted attributes are supported."); + } + + std::string result = std::string(variable_name()); + + for (const auto& qualifier : qualifier_path()) { + CEL_RETURN_IF_ERROR(absl::visit( + AttributeStringPrinter(&result, qualifier.kind()), qualifier.value_)); + } + + return result; +} + +bool AttributeQualifier::IsMatch(const AttributeQualifier& other) const { + if (absl::holds_alternative(value_) || + absl::holds_alternative(other.value_)) { + return false; + } + return value_ == other.value_; +} + +absl::StatusOr AttributeQualifier::AsString() const { + std::string result; + CEL_RETURN_IF_ERROR( + absl::visit(AttributeQualifierStringPrinter(&result, kind()), value_)); + return result; +} + +} // namespace cel diff --git a/base/attribute.h b/base/attribute.h new file mode 100644 index 000000000..69dcaf161 --- /dev/null +++ b/base/attribute.h @@ -0,0 +1,278 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "base/kind.h" + +namespace cel { + +// AttributeQualifier represents a segment in +// attribute resolutuion path. A segment can be qualified by values of +// following types: string/int64_t/uint64_t/bool. +class AttributeQualifier final { + private: + struct ComparatorVisitor; + + using Variant = absl::variant; + + public: + static AttributeQualifier OfInt(int64_t value) { + return AttributeQualifier(absl::in_place_type, std::move(value)); + } + + static AttributeQualifier OfUint(uint64_t value) { + return AttributeQualifier(absl::in_place_type, std::move(value)); + } + + static AttributeQualifier OfString(std::string value) { + return AttributeQualifier(absl::in_place_type, + std::move(value)); + } + + static AttributeQualifier OfBool(bool value) { + return AttributeQualifier(absl::in_place_type, std::move(value)); + } + + AttributeQualifier() = default; + + AttributeQualifier(const AttributeQualifier&) = default; + AttributeQualifier(AttributeQualifier&&) = default; + + AttributeQualifier& operator=(const AttributeQualifier&) = default; + AttributeQualifier& operator=(AttributeQualifier&&) = default; + + Kind kind() const; + + // Family of Get... methods. Return values if requested type matches the + // stored one. + absl::optional GetInt64Key() const { + return absl::holds_alternative(value_) + ? absl::optional(absl::get<1>(value_)) + : absl::nullopt; + } + + absl::optional GetUint64Key() const { + return absl::holds_alternative(value_) + ? absl::optional(absl::get<2>(value_)) + : absl::nullopt; + } + + absl::optional GetStringKey() const { + return absl::holds_alternative(value_) + ? absl::optional(absl::get<3>(value_)) + : absl::nullopt; + } + + absl::optional GetBoolKey() const { + return absl::holds_alternative(value_) + ? absl::optional(absl::get<4>(value_)) + : absl::nullopt; + } + + bool operator==(const AttributeQualifier& other) const { + return IsMatch(other); + } + + bool operator<(const AttributeQualifier& other) const; + + bool IsMatch(absl::string_view other_key) const { + absl::optional key = GetStringKey(); + return (key.has_value() && key.value() == other_key); + } + + absl::StatusOr AsString() const; + + private: + friend class Attribute; + friend struct ComparatorVisitor; + + template + AttributeQualifier(absl::in_place_type_t in_place_type, T&& value) + : value_(in_place_type, std::forward(value)) {} + + bool IsMatch(const AttributeQualifier& other) const; + + // The previous implementation of Attribute preserved all value + // instances, regardless of whether they are supported in this context or not. + // We represented unsupported types by using the first alternative and thus + // preserve backwards compatibility with the result of `type()` above. + Variant value_; +}; + +// AttributeQualifierPattern matches a segment in +// attribute resolutuion path. AttributeQualifierPattern is capable of +// matching path elements of types string/int64/uint64/bool. +class AttributeQualifierPattern final { + private: + // Qualifier value. If not set, treated as wildcard. + std::optional value_; + + explicit AttributeQualifierPattern(std::optional value) + : value_(std::move(value)) {} + + public: + static AttributeQualifierPattern OfInt(int64_t value) { + return AttributeQualifierPattern(AttributeQualifier::OfInt(value)); + } + + static AttributeQualifierPattern OfUint(uint64_t value) { + return AttributeQualifierPattern(AttributeQualifier::OfUint(value)); + } + + static AttributeQualifierPattern OfString(std::string value) { + return AttributeQualifierPattern( + AttributeQualifier::OfString(std::move(value))); + } + + static AttributeQualifierPattern OfBool(bool value) { + return AttributeQualifierPattern(AttributeQualifier::OfBool(value)); + } + + static AttributeQualifierPattern CreateWildcard() { + return AttributeQualifierPattern(std::nullopt); + } + + explicit AttributeQualifierPattern(AttributeQualifier qualifier) + : AttributeQualifierPattern( + std::optional(std::move(qualifier))) {} + + bool IsWildcard() const { return !value_.has_value(); } + + bool IsMatch(const AttributeQualifier& qualifier) const { + if (IsWildcard()) return true; + return value_.value() == qualifier; + } + + bool IsMatch(absl::string_view other_key) const { + if (!value_.has_value()) return true; + return value_->IsMatch(other_key); + } +}; + +// Attribute represents resolved attribute path. +class Attribute final { + public: + explicit Attribute(std::string variable_name) + : Attribute(std::move(variable_name), {}) {} + + Attribute(std::string variable_name, + std::vector qualifier_path) + : impl_(std::make_shared(std::move(variable_name), + std::move(qualifier_path))) {} + + absl::string_view variable_name() const { return impl_->variable_name; } + + bool has_variable_name() const { return !impl_->variable_name.empty(); } + + absl::Span qualifier_path() const { + return impl_->qualifier_path; + } + + bool operator==(const Attribute& other) const; + + bool operator<(const Attribute& other) const; + + const absl::StatusOr AsString() const; + + private: + struct Impl final { + Impl(std::string variable_name, + std::vector qualifier_path) + : variable_name(std::move(variable_name)), + qualifier_path(std::move(qualifier_path)) {} + + std::string variable_name; + std::vector qualifier_path; + }; + + std::shared_ptr impl_; +}; + +// AttributePattern is a fully-qualified absolute attribute path pattern. +// Supported segments steps in the path are: +// - field selection; +// - map lookup by key; +// - list access by index. +class AttributePattern final { + public: + // MatchType enum specifies how closely pattern is matching the attribute: + enum class MatchType { + NONE, // Pattern does not match attribute itself nor its children + PARTIAL, // Pattern matches an entity nested within attribute; + FULL // Pattern matches an attribute itself. + }; + + AttributePattern(std::string variable, + std::vector qualifier_path) + : variable_(std::move(variable)), + qualifier_path_(std::move(qualifier_path)) {} + + absl::string_view variable() const { return variable_; } + + absl::Span qualifier_path() const { + return qualifier_path_; + } + + // Matches the pattern to an attribute. + // Distinguishes between no-match, partial match and full match cases. + MatchType IsMatch(const Attribute& attribute) const { + MatchType result = MatchType::NONE; + if (attribute.variable_name() != variable_) { + return result; + } + + auto max_index = qualifier_path().size(); + result = MatchType::FULL; + if (qualifier_path().size() > attribute.qualifier_path().size()) { + max_index = attribute.qualifier_path().size(); + result = MatchType::PARTIAL; + } + + for (size_t i = 0; i < max_index; i++) { + if (!(qualifier_path()[i].IsMatch(attribute.qualifier_path()[i]))) { + return MatchType::NONE; + } + } + return result; + } + + private: + std::string variable_; + std::vector qualifier_path_; +}; + +struct FieldSpecifier { + int64_t number; + std::string name; +}; + +using SelectQualifier = absl::variant; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_H_ diff --git a/base/attribute_set.h b/base/attribute_set.h new file mode 100644 index 000000000..078f37881 --- /dev/null +++ b/base/attribute_set.h @@ -0,0 +1,108 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_SET_H_ +#define THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_SET_H_ + +#include "absl/container/btree_set.h" +#include "absl/types/span.h" +#include "base/attribute.h" + +namespace google::api::expr::runtime { +class AttributeUtility; +} // namespace google::api::expr::runtime + +namespace cel { + +class UnknownValue; +namespace base_internal { +class UnknownSet; +} + +// AttributeSet is a container for CEL attributes that are identified as +// unknown during expression evaluation. +class AttributeSet final { + private: + using Container = absl::btree_set; + + public: + using value_type = typename Container::value_type; + using size_type = typename Container::size_type; + using iterator = typename Container::const_iterator; + using const_iterator = typename Container::const_iterator; + + AttributeSet() = default; + AttributeSet(const AttributeSet&) = default; + AttributeSet(AttributeSet&&) = default; + AttributeSet& operator=(const AttributeSet&) = default; + AttributeSet& operator=(AttributeSet&&) = default; + + explicit AttributeSet(absl::Span attributes) { + for (const auto& attr : attributes) { + Add(attr); + } + } + + AttributeSet(const AttributeSet& set1, const AttributeSet& set2) + : attributes_(set1.attributes_) { + for (const auto& attr : set2.attributes_) { + Add(attr); + } + } + + iterator begin() const { return attributes_.begin(); } + + const_iterator cbegin() const { return attributes_.cbegin(); } + + iterator end() const { return attributes_.end(); } + + const_iterator cend() const { return attributes_.cend(); } + + size_type size() const { return attributes_.size(); } + + bool empty() const { return attributes_.empty(); } + + bool operator==(const AttributeSet& other) const { + return this == &other || attributes_ == other.attributes_; + } + + bool operator!=(const AttributeSet& other) const { + return !operator==(other); + } + + static AttributeSet Merge(const AttributeSet& set1, + const AttributeSet& set2) { + return AttributeSet(set1, set2); + } + + private: + friend class google::api::expr::runtime::AttributeUtility; + friend class UnknownValue; + friend class base_internal::UnknownSet; + + void Add(const Attribute& attribute) { attributes_.insert(attribute); } + + void Add(const AttributeSet& other) { + for (const auto& attribute : other) { + Add(attribute); + } + } + + // Attribute container. + Container attributes_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_SET_H_ diff --git a/base/builtins.h b/base/builtins.h new file mode 100644 index 000000000..871c2e608 --- /dev/null +++ b/base/builtins.h @@ -0,0 +1,106 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ + +namespace cel { + +// Constants specifying names for CEL builtins. +// +// Prefer to use the constants in `common/standard_definitions.h`. +namespace builtin { + +// Comparison +constexpr char kEqual[] = "_==_"; +constexpr char kInequal[] = "_!=_"; +constexpr char kLess[] = "_<_"; +constexpr char kLessOrEqual[] = "_<=_"; +constexpr char kGreater[] = "_>_"; +constexpr char kGreaterOrEqual[] = "_>=_"; + +// Logical +constexpr char kAnd[] = "_&&_"; +constexpr char kOr[] = "_||_"; +constexpr char kNot[] = "!_"; + +// Strictness +constexpr char kNotStrictlyFalse[] = "@not_strictly_false"; +// Deprecated '__not_strictly_false__' function. Preserved for backwards +// compatibility with stored expressions. +constexpr char kNotStrictlyFalseDeprecated[] = "__not_strictly_false__"; + +// Arithmetical +constexpr char kAdd[] = "_+_"; +constexpr char kSubtract[] = "_-_"; +constexpr char kNeg[] = "-_"; +constexpr char kMultiply[] = "_*_"; +constexpr char kDivide[] = "_/_"; +constexpr char kModulo[] = "_%_"; + +// String operations +constexpr char kRegexMatch[] = "matches"; +constexpr char kStringContains[] = "contains"; +constexpr char kStringEndsWith[] = "endsWith"; +constexpr char kStringStartsWith[] = "startsWith"; + +// Container operations +constexpr char kIn[] = "@in"; +// Deprecated '_in_' operator. Preserved for backwards compatibility with stored +// expressions. +constexpr char kInDeprecated[] = "_in_"; +// Deprecated 'in()' function. Preserved for backwards compatibility with stored +// expressions. +constexpr char kInFunction[] = "in"; +constexpr char kIndex[] = "_[_]"; +constexpr char kSize[] = "size"; + +constexpr char kTernary[] = "_?_:_"; + +// Timestamp and Duration +constexpr char kDuration[] = "duration"; +constexpr char kTimestamp[] = "timestamp"; +constexpr char kFullYear[] = "getFullYear"; +constexpr char kMonth[] = "getMonth"; +constexpr char kDayOfYear[] = "getDayOfYear"; +constexpr char kDayOfMonth[] = "getDayOfMonth"; +constexpr char kDate[] = "getDate"; +constexpr char kDayOfWeek[] = "getDayOfWeek"; +constexpr char kHours[] = "getHours"; +constexpr char kMinutes[] = "getMinutes"; +constexpr char kSeconds[] = "getSeconds"; +constexpr char kMilliseconds[] = "getMilliseconds"; + +// Type conversions +constexpr char kBool[] = "bool"; +constexpr char kBytes[] = "bytes"; +constexpr char kDouble[] = "double"; +constexpr char kDyn[] = "dyn"; +constexpr char kInt[] = "int"; +constexpr char kString[] = "string"; +constexpr char kType[] = "type"; +constexpr char kUint[] = "uint"; + +// Runtime-only functions. +// The convention for runtime-only functions where only the runtime needs to +// differentiate behavior is to prefix the function with `#`. +// Note, this is a different convention from CEL internal functions where the +// whole stack needs to be aware of the function id. +constexpr char kRuntimeListAppend[] = "#list_append"; + +} // namespace builtin + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ diff --git a/base/function.h b/base/function.h new file mode 100644 index 000000000..c209feb25 --- /dev/null +++ b/base/function.h @@ -0,0 +1,20 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ +#define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ + +#include "runtime/function.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ diff --git a/base/function_adapter.h b/base/function_adapter.h new file mode 100644 index 000000000..d4c4f38e2 --- /dev/null +++ b/base/function_adapter.h @@ -0,0 +1,19 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ + +#include "runtime/function_adapter.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ diff --git a/base/function_descriptor.h b/base/function_descriptor.h new file mode 100644 index 000000000..3b2a88672 --- /dev/null +++ b/base/function_descriptor.h @@ -0,0 +1,20 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ +#define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ + +#include "common/function_descriptor.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ diff --git a/base/function_result.h b/base/function_result.h new file mode 100644 index 000000000..977ceeb90 --- /dev/null +++ b/base/function_result.h @@ -0,0 +1,70 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_H_ + +#include +#include + +#include "base/function_descriptor.h" + +namespace cel { + +// Represents a function result that is unknown at the time of execution. This +// allows for lazy evaluation of expensive functions. +class FunctionResult final { + public: + FunctionResult() = delete; + FunctionResult(const FunctionResult&) = default; + FunctionResult(FunctionResult&&) = default; + FunctionResult& operator=(const FunctionResult&) = default; + FunctionResult& operator=(FunctionResult&&) = default; + + FunctionResult(FunctionDescriptor descriptor, int64_t expr_id) + : descriptor_(std::move(descriptor)), expr_id_(expr_id) {} + + // The descriptor of the called function that return Unknown. + const FunctionDescriptor& descriptor() const { return descriptor_; } + + // The id of the |Expr| that triggered the function call step. Provided + // informationally -- if two different |Expr|s generate the same unknown call, + // they will be treated as the same unknown function result. + int64_t call_expr_id() const { return expr_id_; } + + // Equality operator provided for testing. Compatible with set less-than + // comparator. + // Compares descriptor then arguments elementwise. + bool IsEqualTo(const FunctionResult& other) const { + return descriptor() == other.descriptor(); + } + + // TODO(uncreated-issue/5): re-implement argument capture + + private: + FunctionDescriptor descriptor_; + int64_t expr_id_; +}; + +inline bool operator==(const FunctionResult& lhs, const FunctionResult& rhs) { + return lhs.IsEqualTo(rhs); +} + +inline bool operator<(const FunctionResult& lhs, const FunctionResult& rhs) { + return lhs.descriptor() < rhs.descriptor(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_H_ diff --git a/base/function_result_set.cc b/base/function_result_set.cc new file mode 100644 index 000000000..a03a0c5db --- /dev/null +++ b/base/function_result_set.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/function_result_set.h" + +namespace cel { + +// Implementation for merge constructor. +FunctionResultSet::FunctionResultSet(const FunctionResultSet& lhs, + const FunctionResultSet& rhs) + : function_results_(lhs.function_results_) { + for (const auto& function_result : rhs) { + function_results_.insert(function_result); + } +} + +} // namespace cel diff --git a/base/function_result_set.h b/base/function_result_set.h new file mode 100644 index 000000000..ac81f14d2 --- /dev/null +++ b/base/function_result_set.h @@ -0,0 +1,105 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_SET_H_ +#define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_SET_H_ + +#include +#include + +#include "absl/container/btree_set.h" +#include "base/function_result.h" + +namespace google::api::expr::runtime { +class AttributeUtility; +} // namespace google::api::expr::runtime + +namespace cel { + +class UnknownValue; +namespace base_internal { +class UnknownSet; +} + +// Represents a collection of unknown function results at a particular point in +// execution. Execution should advance further if this set of unknowns are +// provided. It may not advance if only a subset are provided. +// Set semantics use |IsEqualTo()| defined on |FunctionResult|. +class FunctionResultSet final { + private: + using Container = absl::btree_set; + + public: + using value_type = typename Container::value_type; + using size_type = typename Container::size_type; + using iterator = typename Container::const_iterator; + using const_iterator = typename Container::const_iterator; + + FunctionResultSet() = default; + FunctionResultSet(const FunctionResultSet&) = default; + FunctionResultSet(FunctionResultSet&&) = default; + FunctionResultSet& operator=(const FunctionResultSet&) = default; + FunctionResultSet& operator=(FunctionResultSet&&) = default; + + // Merge constructor -- effectively union(lhs, rhs). + FunctionResultSet(const FunctionResultSet& lhs, const FunctionResultSet& rhs); + + // Initialize with a single FunctionResult. + explicit FunctionResultSet(FunctionResult initial) + : function_results_{std::move(initial)} {} + + FunctionResultSet(std::initializer_list il) + : function_results_(il) {} + + iterator begin() const { return function_results_.begin(); } + + const_iterator cbegin() const { return function_results_.cbegin(); } + + iterator end() const { return function_results_.end(); } + + const_iterator cend() const { return function_results_.cend(); } + + size_type size() const { return function_results_.size(); } + + bool empty() const { return function_results_.empty(); } + + bool operator==(const FunctionResultSet& other) const { + return this == &other || function_results_ == other.function_results_; + } + + bool operator!=(const FunctionResultSet& other) const { + return !operator==(other); + } + + private: + friend class google::api::expr::runtime::AttributeUtility; + friend class UnknownValue; + friend class base_internal::UnknownSet; + + void Add(const FunctionResult& function_result) { + function_results_.insert(function_result); + } + + void Add(const FunctionResultSet& other) { + for (const auto& function_result : other) { + Add(function_result); + } + } + + Container function_results_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_SET_H_ diff --git a/base/internal/BUILD b/base/internal/BUILD new file mode 100644 index 000000000..187b008c0 --- /dev/null +++ b/base/internal/BUILD @@ -0,0 +1,54 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "memory_manager_testing", + testonly = True, + srcs = ["memory_manager_testing.cc"], + hdrs = ["memory_manager_testing.h"], + deps = [ + "//internal:testing", + ], +) + +cc_library( + name = "message_wrapper", + hdrs = ["message_wrapper.h"], +) + +cc_library( + name = "operators", + hdrs = ["operators.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "unknown_set", + srcs = ["unknown_set.cc"], + hdrs = ["unknown_set.h"], + deps = [ + "//base:attributes", + "//base:function_result_set", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + ], +) diff --git a/base/internal/memory_manager_testing.cc b/base/internal/memory_manager_testing.cc new file mode 100644 index 000000000..5b403e3c1 --- /dev/null +++ b/base/internal/memory_manager_testing.cc @@ -0,0 +1,30 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/memory_manager_testing.h" + +#include + +namespace cel::base_internal { + +std::string MemoryManagerTestModeToString(MemoryManagerTestMode mode) { + switch (mode) { + case MemoryManagerTestMode::kGlobal: + return "Global"; + case MemoryManagerTestMode::kArena: + return "Arena"; + } +} + +} // namespace cel::base_internal diff --git a/base/internal/memory_manager_testing.h b/base/internal/memory_manager_testing.h new file mode 100644 index 000000000..946660fec --- /dev/null +++ b/base/internal/memory_manager_testing.h @@ -0,0 +1,54 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ + +#include +#include + +#include "internal/testing.h" + +namespace cel::base_internal { + +enum class MemoryManagerTestMode { + kGlobal = 0, + kArena, +}; + +std::string MemoryManagerTestModeToString(MemoryManagerTestMode mode); + +template +void AbslStringify(S& sink, MemoryManagerTestMode mode) { + sink.Append(MemoryManagerTestModeToString(mode)); +} + +inline auto MemoryManagerTestModeAll() { + return testing::Values(MemoryManagerTestMode::kGlobal, + MemoryManagerTestMode::kArena); +} + +inline std::string MemoryManagerTestModeName( + const testing::TestParamInfo& info) { + return MemoryManagerTestModeToString(info.param); +} + +inline std::string MemoryManagerTestModeTupleName( + const testing::TestParamInfo>& info) { + return MemoryManagerTestModeToString(std::get<0>(info.param)); +} + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ diff --git a/base/internal/message_wrapper.h b/base/internal/message_wrapper.h new file mode 100644 index 000000000..616ae0df6 --- /dev/null +++ b/base/internal/message_wrapper.h @@ -0,0 +1,30 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MESSAGE_WRAPPER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MESSAGE_WRAPPER_H_ + +#include + +namespace cel::base_internal { + +inline constexpr uintptr_t kMessageWrapperTagMask = 0b1; +inline constexpr uintptr_t kMessageWrapperPtrMask = ~kMessageWrapperTagMask; +inline constexpr int kMessageWrapperTagSize = 1; +inline constexpr uintptr_t kMessageWrapperTagTypeInfoValue = 0b0; +inline constexpr uintptr_t kMessageWrapperTagMessageValue = 0b1; + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MESSAGE_WRAPPER_H_ diff --git a/base/internal/operators.h b/base/internal/operators.h new file mode 100644 index 000000000..04ffe2d79 --- /dev/null +++ b/base/internal/operators.h @@ -0,0 +1,91 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ + +#include "absl/strings/string_view.h" + +namespace cel { + +enum class OperatorId; + +namespace base_internal { + +struct OperatorData final { + OperatorData() = delete; + OperatorData(const OperatorData&) = delete; + OperatorData(OperatorData&&) = delete; + OperatorData& operator=(const OperatorData&) = delete; + OperatorData& operator=(OperatorData&&) = delete; + + constexpr OperatorData(cel::OperatorId id, absl::string_view name, + absl::string_view display_name, int precedence, + int arity) + : id(id), + name(name), + display_name(display_name), + precedence(precedence), + arity(arity) {} + + const cel::OperatorId id; + const absl::string_view name; + const absl::string_view display_name; + const int precedence; + const int arity; +}; + +#define CEL_INTERNAL_UNARY_OPERATORS_ENUM(XX) \ + XX(LogicalNot, "!", "!_", 2, 1) \ + XX(Negate, "-", "-_", 2, 1) \ + XX(NotStrictlyFalse, "", "@not_strictly_false", 0, 1) \ + XX(OldNotStrictlyFalse, "", "__not_strictly_false__", 0, 1) + +#define CEL_INTERNAL_BINARY_OPERATORS_ENUM(XX) \ + XX(Equals, "==", "_==_", 5, 2) \ + XX(NotEquals, "!=", "_!=_", 5, 2) \ + XX(Less, "<", "_<_", 5, 2) \ + XX(LessEquals, "<=", "_<=_", 5, 2) \ + XX(Greater, ">", "_>_", 5, 2) \ + XX(GreaterEquals, ">=", "_>=_", 5, 2) \ + XX(In, "in", "@in", 5, 2) \ + XX(OldIn, "in", "_in_", 5, 2) \ + XX(Index, "", "_[_]", 1, 2) \ + XX(LogicalOr, "||", "_||_", 7, 2) \ + XX(LogicalAnd, "&&", "_&&_", 6, 2) \ + XX(Add, "+", "_+_", 4, 2) \ + XX(Subtract, "-", "_-_", 4, 2) \ + XX(Multiply, "*", "_*_", 3, 2) \ + XX(Divide, "/", "_/_", 3, 2) \ + XX(Modulo, "%", "_%_", 3, 2) + +#define CEL_INTERNAL_TERNARY_OPERATORS_ENUM(XX) \ + XX(Conditional, "", "_?_:_", 8, 3) + +// Macro definining all the operators and their properties. +// (1) - The identifier. +// (2) - The display name if applicable, otherwise an empty string. +// (3) - The name. +// (4) - The precedence if applicable, otherwise 0. +// (5) - The arity. +#define CEL_INTERNAL_OPERATORS_ENUM(XX) \ + CEL_INTERNAL_TERNARY_OPERATORS_ENUM(XX) \ + CEL_INTERNAL_BINARY_OPERATORS_ENUM(XX) \ + CEL_INTERNAL_UNARY_OPERATORS_ENUM(XX) + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ diff --git a/base/internal/unknown_set.cc b/base/internal/unknown_set.cc new file mode 100644 index 000000000..32c891857 --- /dev/null +++ b/base/internal/unknown_set.cc @@ -0,0 +1,31 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/unknown_set.h" + +#include "absl/base/no_destructor.h" + +namespace cel::base_internal { + +const AttributeSet& EmptyAttributeSet() { + static const absl::NoDestructor empty_attribute_set; + return *empty_attribute_set; +} + +const FunctionResultSet& EmptyFunctionResultSet() { + static const absl::NoDestructor empty_function_result_set; + return *empty_function_result_set; +} + +} // namespace cel::base_internal diff --git a/base/internal/unknown_set.h b/base/internal/unknown_set.h new file mode 100644 index 000000000..2ef9020d7 --- /dev/null +++ b/base/internal/unknown_set.h @@ -0,0 +1,131 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_UNKNOWN_SET_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_UNKNOWN_SET_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "base/attribute_set.h" +#include "base/function_result_set.h" + +namespace cel::base_internal { + +// For compatibility with the old API and to avoid unnecessary copying when +// converting between the old and new representations, we store the historical +// members of google::api::expr::runtime::UnknownSet in this struct for use with +// std::shared_ptr. +struct UnknownSetRep final { + UnknownSetRep() = default; + + UnknownSetRep(AttributeSet attributes, FunctionResultSet function_results) + : attributes(std::move(attributes)), + function_results(std::move(function_results)) {} + + explicit UnknownSetRep(AttributeSet attributes) + : attributes(std::move(attributes)) {} + + explicit UnknownSetRep(FunctionResultSet function_results) + : function_results(std::move(function_results)) {} + + AttributeSet attributes; + FunctionResultSet function_results; +}; + +const AttributeSet& EmptyAttributeSet(); + +const FunctionResultSet& EmptyFunctionResultSet(); + +struct UnknownSetAccess; + +class UnknownSet final { + private: + using Rep = UnknownSetRep; + + public: + // Construct the empty set. + // Uses singletons instead of allocating new containers. + UnknownSet() = default; + + UnknownSet(const UnknownSet&) = default; + UnknownSet(UnknownSet&&) = default; + UnknownSet& operator=(const UnknownSet&) = default; + UnknownSet& operator=(UnknownSet&&) = default; + + // Initialization specifying subcontainers + explicit UnknownSet(AttributeSet attributes) + : rep_(std::make_shared(std::move(attributes))) {} + + explicit UnknownSet(FunctionResultSet function_results) + : rep_(std::make_shared(std::move(function_results))) {} + + UnknownSet(AttributeSet attributes, FunctionResultSet function_results) + : rep_(std::make_shared(std::move(attributes), + std::move(function_results))) {} + + // Merge constructor + UnknownSet(const UnknownSet& set1, const UnknownSet& set2) + : UnknownSet( + AttributeSet(set1.unknown_attributes(), set2.unknown_attributes()), + FunctionResultSet(set1.unknown_function_results(), + set2.unknown_function_results())) {} + + const AttributeSet& unknown_attributes() const { + return rep_ != nullptr ? rep_->attributes : EmptyAttributeSet(); + } + const FunctionResultSet& unknown_function_results() const { + return rep_ != nullptr ? rep_->function_results : EmptyFunctionResultSet(); + } + + bool operator==(const UnknownSet& other) const { + return this == &other || + (unknown_attributes() == other.unknown_attributes() && + unknown_function_results() == other.unknown_function_results()); + } + + bool operator!=(const UnknownSet& other) const { return !operator==(other); } + + private: + friend struct UnknownSetAccess; + + explicit UnknownSet(std::shared_ptr impl) : rep_(std::move(impl)) {} + + void Add(const UnknownSet& other) { + if (rep_ == nullptr) { + rep_ = std::make_shared(); + } + rep_->attributes.Add(other.unknown_attributes()); + rep_->function_results.Add(other.unknown_function_results()); + } + + std::shared_ptr rep_; +}; + +struct UnknownSetAccess final { + static UnknownSet Construct(std::shared_ptr rep) { + return UnknownSet(std::move(rep)); + } + + static void Add(UnknownSet& dest, const UnknownSet& src) { dest.Add(src); } + + static const std::shared_ptr& Rep(const UnknownSet& value) { + return value.rep_; + } +}; + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_UNKNOWN_SET_H_ diff --git a/base/kind.h b/base/kind.h new file mode 100644 index 000000000..3ec0133b0 --- /dev/null +++ b/base/kind.h @@ -0,0 +1,25 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_KIND_H_ +#define THIRD_PARTY_CEL_CPP_BASE_KIND_H_ + +// This header exists for compatibility and should be removed once all includes +// have been updated. + +#include "common/kind.h" // IWYU pragma: export +#include "common/type_kind.h" // IWYU pragma: export +#include "common/value_kind.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_BASE_KIND_H_ diff --git a/base/operators.cc b/base/operators.cc new file mode 100644 index 000000000..b7df40b27 --- /dev/null +++ b/base/operators.cc @@ -0,0 +1,300 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/operators.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/internal/operators.h" + +namespace cel { + +namespace { + +using base_internal::OperatorData; + +struct OperatorDataNameComparer { + using is_transparent = void; + + bool operator()(const OperatorData* lhs, const OperatorData* rhs) const { + return lhs->name < rhs->name; + } + + bool operator()(const OperatorData* lhs, absl::string_view rhs) const { + return lhs->name < rhs; + } + + bool operator()(absl::string_view lhs, const OperatorData* rhs) const { + return lhs < rhs->name; + } +}; + +struct OperatorDataDisplayNameComparer { + using is_transparent = void; + + bool operator()(const OperatorData* lhs, const OperatorData* rhs) const { + return lhs->display_name < rhs->display_name; + } + + bool operator()(const OperatorData* lhs, absl::string_view rhs) const { + return lhs->display_name < rhs; + } + + bool operator()(absl::string_view lhs, const OperatorData* rhs) const { + return lhs < rhs->display_name; + } +}; + +#define CEL_OPERATORS_DATA(id, symbol, name, precedence, arity) \ + ABSL_CONST_INIT const OperatorData id##_storage = { \ + OperatorId::k##id, name, symbol, precedence, arity}; +CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DATA) +#undef CEL_OPERATORS_DATA + +#define CEL_OPERATORS_COUNT(id, symbol, name, precedence, arity) +1 + +using OperatorsArray = + std::array; + +using UnaryOperatorsArray = + std::array; + +using BinaryOperatorsArray = + std::array; + +using TernaryOperatorsArray = + std::array; + +#undef CEL_OPERATORS_COUNT + +ABSL_CONST_INIT absl::once_flag operators_once_flag; + +#define CEL_OPERATORS_DO(id, symbol, name, precedence, arity) &id##_storage, + +OperatorsArray operators_by_name = { + CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +OperatorsArray operators_by_display_name = { + CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +UnaryOperatorsArray unary_operators_by_name = { + CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +UnaryOperatorsArray unary_operators_by_display_name = { + CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +BinaryOperatorsArray binary_operators_by_name = { + CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +BinaryOperatorsArray binary_operators_by_display_name = { + CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +TernaryOperatorsArray ternary_operators_by_name = { + CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +TernaryOperatorsArray ternary_operators_by_display_name = { + CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +#undef CEL_OPERATORS_DO + +void InitializeOperators() { + std::stable_sort(operators_by_name.begin(), operators_by_name.end(), + OperatorDataNameComparer{}); + std::stable_sort(operators_by_display_name.begin(), + operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); + std::stable_sort(unary_operators_by_name.begin(), + unary_operators_by_name.end(), OperatorDataNameComparer{}); + std::stable_sort(unary_operators_by_display_name.begin(), + unary_operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); + std::stable_sort(binary_operators_by_name.begin(), + binary_operators_by_name.end(), OperatorDataNameComparer{}); + std::stable_sort(binary_operators_by_display_name.begin(), + binary_operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); + std::stable_sort(ternary_operators_by_name.begin(), + ternary_operators_by_name.end(), OperatorDataNameComparer{}); + std::stable_sort(ternary_operators_by_display_name.begin(), + ternary_operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); +} + +} // namespace + +UnaryOperator::UnaryOperator(Operator op) : data_(op.data_) { + ABSL_CHECK(op.arity() == Arity::kUnary); // Crask OK +} + +BinaryOperator::BinaryOperator(Operator op) : data_(op.data_) { + ABSL_CHECK(op.arity() == Arity::kBinary); // Crask OK +} + +TernaryOperator::TernaryOperator(Operator op) : data_(op.data_) { + ABSL_CHECK(op.arity() == Arity::kTernary); // Crask OK +} + +#define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ + UnaryOperator Operator::id() { return UnaryOperator(&id##_storage); } + +CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR) + +#undef CEL_UNARY_OPERATOR + +#define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ + BinaryOperator Operator::id() { return BinaryOperator(&id##_storage); } + +CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR) + +#undef CEL_BINARY_OPERATOR + +#define CEL_TERNARY_OPERATOR(id, symbol, name, precedence, arity) \ + TernaryOperator Operator::id() { return TernaryOperator(&id##_storage); } + +CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) + +#undef CEL_TERNARY_OPERATOR + +absl::optional Operator::FindByName(absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = + std::lower_bound(operators_by_name.cbegin(), operators_by_name.cend(), + input, OperatorDataNameComparer{}); + if (it == operators_by_name.cend() || (*it)->name != input) { + return std::nullopt; + } + return Operator(*it); +} + +absl::optional Operator::FindByDisplayName(absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = std::lower_bound(operators_by_display_name.cbegin(), + operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == operators_by_name.cend() || (*it)->display_name != input) { + return std::nullopt; + } + return Operator(*it); +} + +absl::optional UnaryOperator::FindByName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = std::lower_bound(unary_operators_by_name.cbegin(), + unary_operators_by_name.cend(), input, + OperatorDataNameComparer{}); + if (it == unary_operators_by_name.cend() || (*it)->name != input) { + return std::nullopt; + } + return UnaryOperator(*it); +} + +absl::optional UnaryOperator::FindByDisplayName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = std::lower_bound(unary_operators_by_display_name.cbegin(), + unary_operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == unary_operators_by_display_name.cend() || + (*it)->display_name != input) { + return std::nullopt; + } + return UnaryOperator(*it); +} + +absl::optional BinaryOperator::FindByName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = std::lower_bound(binary_operators_by_name.cbegin(), + binary_operators_by_name.cend(), input, + OperatorDataNameComparer{}); + if (it == binary_operators_by_name.cend() || (*it)->name != input) { + return std::nullopt; + } + return BinaryOperator(*it); +} + +absl::optional BinaryOperator::FindByDisplayName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = std::lower_bound(binary_operators_by_display_name.cbegin(), + binary_operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == binary_operators_by_display_name.cend() || + (*it)->display_name != input) { + return std::nullopt; + } + return BinaryOperator(*it); +} + +absl::optional TernaryOperator::FindByName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = std::lower_bound(ternary_operators_by_name.cbegin(), + ternary_operators_by_name.cend(), input, + OperatorDataNameComparer{}); + if (it == ternary_operators_by_name.cend() || (*it)->name != input) { + return std::nullopt; + } + return TernaryOperator(*it); +} + +absl::optional TernaryOperator::FindByDisplayName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return std::nullopt; + } + auto it = std::lower_bound(ternary_operators_by_display_name.cbegin(), + ternary_operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == ternary_operators_by_display_name.cend() || + (*it)->display_name != input) { + return std::nullopt; + } + return TernaryOperator(*it); +} + +} // namespace cel diff --git a/base/operators.h b/base/operators.h new file mode 100644 index 000000000..778262c4b --- /dev/null +++ b/base/operators.h @@ -0,0 +1,507 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/internal/operators.h" + +namespace cel { + +enum class Arity { + kUnary = 1, + kBinary = 2, + kTernary = 3, +}; + +enum class OperatorId { + kConditional = 1, + kLogicalAnd, + kLogicalOr, + kLogicalNot, + kEquals, + kNotEquals, + kLess, + kLessEquals, + kGreater, + kGreaterEquals, + kAdd, + kSubtract, + kMultiply, + kDivide, + kModulo, + kNegate, + kIndex, + kIn, + kNotStrictlyFalse, + kOldIn, + kOldNotStrictlyFalse, +}; + +enum class UnaryOperatorId { + kLogicalNot = static_cast(OperatorId::kLogicalNot), + kNegate = static_cast(OperatorId::kNegate), + kNotStrictlyFalse = static_cast(OperatorId::kNotStrictlyFalse), + kOldNotStrictlyFalse = static_cast(OperatorId::kOldNotStrictlyFalse), +}; + +enum class BinaryOperatorId { + kLogicalAnd = static_cast(OperatorId::kLogicalAnd), + kLogicalOr = static_cast(OperatorId::kLogicalOr), + kEquals = static_cast(OperatorId::kEquals), + kNotEquals = static_cast(OperatorId::kNotEquals), + kLess = static_cast(OperatorId::kLess), + kLessEquals = static_cast(OperatorId::kLessEquals), + kGreater = static_cast(OperatorId::kGreater), + kGreaterEquals = static_cast(OperatorId::kGreaterEquals), + kAdd = static_cast(OperatorId::kAdd), + kSubtract = static_cast(OperatorId::kSubtract), + kMultiply = static_cast(OperatorId::kMultiply), + kDivide = static_cast(OperatorId::kDivide), + kModulo = static_cast(OperatorId::kModulo), + kIndex = static_cast(OperatorId::kIndex), + kIn = static_cast(OperatorId::kIn), + kOldIn = static_cast(OperatorId::kOldIn), +}; + +enum class TernaryOperatorId { + kConditional = static_cast(OperatorId::kConditional), +}; + +class UnaryOperator; +class BinaryOperator; +class TernaryOperator; + +class Operator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static TernaryOperator Conditional(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalAnd(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalOr(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator LogicalNot(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Equals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator NotEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Less(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LessEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Greater(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator GreaterEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Add(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Subtract(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Multiply(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Divide(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Modulo(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator Negate(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Index(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator In(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator NotStrictlyFalse(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator OldIn(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator OldNotStrictlyFalse(); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( + absl::string_view input); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + + Operator() = delete; + Operator(const Operator&) = default; + Operator(Operator&&) = default; + Operator& operator=(const Operator&) = default; + Operator& operator=(Operator&&) = default; + + constexpr OperatorId id() const { return data_->id; } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr Arity arity() const { return static_cast(data_->arity); } + + private: + friend class UnaryOperator; + friend class BinaryOperator; + friend class TernaryOperator; + + constexpr explicit Operator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const Operator& lhs, const Operator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(OperatorId lhs, const Operator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const Operator& lhs, OperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const Operator& lhs, const Operator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(OperatorId lhs, const Operator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const Operator& lhs, OperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const Operator& op) { + return H::combine(std::move(state), static_cast(op.id())); +} + +class UnaryOperator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator LogicalNot() { + return Operator::LogicalNot(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator Negate() { + return Operator::Negate(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator NotStrictlyFalse() { + return Operator::NotStrictlyFalse(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator OldNotStrictlyFalse() { + return Operator::OldNotStrictlyFalse(); + } + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( + absl::string_view input); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + + UnaryOperator() = delete; + UnaryOperator(const UnaryOperator&) = default; + UnaryOperator(UnaryOperator&&) = default; + UnaryOperator& operator=(const UnaryOperator&) = default; + UnaryOperator& operator=(UnaryOperator&&) = default; + + // Support for explicit casting of Operator to UnaryOperator. + // `Operator::arity()` must return `Arity::kUnary`, or this will crash. + explicit UnaryOperator(Operator op); + + constexpr UnaryOperatorId id() const { + return static_cast(data_->id); + } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr Arity arity() const { + ABSL_ASSERT(data_->arity == 1); + return Arity::kUnary; + } + + constexpr operator Operator() const { // NOLINT(google-explicit-constructor) + return Operator(data_); + } + + private: + friend class Operator; + + constexpr explicit UnaryOperator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const UnaryOperator& lhs, const UnaryOperator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(UnaryOperatorId lhs, const UnaryOperator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const UnaryOperator& lhs, UnaryOperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const UnaryOperator& lhs, const UnaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(UnaryOperatorId lhs, const UnaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const UnaryOperator& lhs, UnaryOperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const UnaryOperator& op) { + return H::combine(std::move(state), static_cast(op.id())); +} + +class BinaryOperator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalAnd() { + return Operator::LogicalAnd(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalOr() { + return Operator::LogicalOr(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Equals() { + return Operator::Equals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator NotEquals() { + return Operator::NotEquals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Less() { + return Operator::Less(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LessEquals() { + return Operator::LessEquals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Greater() { + return Operator::Greater(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator GreaterEquals() { + return Operator::GreaterEquals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Add() { + return Operator::Add(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Subtract() { + return Operator::Subtract(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Multiply() { + return Operator::Multiply(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Divide() { + return Operator::Divide(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Modulo() { + return Operator::Modulo(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Index() { + return Operator::Index(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator In() { + return Operator::In(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator OldIn() { + return Operator::OldIn(); + } + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( + absl::string_view input); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + + BinaryOperator() = delete; + BinaryOperator(const BinaryOperator&) = default; + BinaryOperator(BinaryOperator&&) = default; + BinaryOperator& operator=(const BinaryOperator&) = default; + BinaryOperator& operator=(BinaryOperator&&) = default; + + // Support for explicit casting of Operator to BinaryOperator. + // `Operator::arity()` must return `Arity::kBinary`, or this will crash. + explicit BinaryOperator(Operator op); + + constexpr BinaryOperatorId id() const { + return static_cast(data_->id); + } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr Arity arity() const { + ABSL_ASSERT(data_->arity == 2); + return Arity::kBinary; + } + + constexpr operator Operator() const { // NOLINT(google-explicit-constructor) + return Operator(data_); + } + + private: + friend class Operator; + + constexpr explicit BinaryOperator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const BinaryOperator& lhs, + const BinaryOperator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(BinaryOperatorId lhs, const BinaryOperator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const BinaryOperator& lhs, BinaryOperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const BinaryOperator& lhs, + const BinaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(BinaryOperatorId lhs, const BinaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const BinaryOperator& lhs, BinaryOperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const BinaryOperator& op) { + return H::combine(std::move(state), static_cast(op.id())); +} + +class TernaryOperator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static TernaryOperator Conditional() { + return Operator::Conditional(); + } + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByName(absl::string_view input); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + + TernaryOperator() = delete; + TernaryOperator(const TernaryOperator&) = default; + TernaryOperator(TernaryOperator&&) = default; + TernaryOperator& operator=(const TernaryOperator&) = default; + TernaryOperator& operator=(TernaryOperator&&) = default; + + // Support for explicit casting of Operator to TernaryOperator. + // `Operator::arity()` must return `Arity::kTernary`, or this will crash. + explicit TernaryOperator(Operator op); + + constexpr TernaryOperatorId id() const { + return static_cast(data_->id); + } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr Arity arity() const { + ABSL_ASSERT(data_->arity == 3); + return Arity::kTernary; + } + + constexpr operator Operator() const { // NOLINT(google-explicit-constructor) + return Operator(data_); + } + + private: + friend class Operator; + + constexpr explicit TernaryOperator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const TernaryOperator& lhs, + const TernaryOperator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(TernaryOperatorId lhs, const TernaryOperator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const TernaryOperator& lhs, TernaryOperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const TernaryOperator& lhs, + const TernaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(TernaryOperatorId lhs, const TernaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const TernaryOperator& lhs, TernaryOperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const TernaryOperator& op) { + return H::combine(std::move(state), static_cast(op.id())); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ diff --git a/base/operators_test.cc b/base/operators_test.cc new file mode 100644 index 000000000..6049f76c8 --- /dev/null +++ b/base/operators_test.cc @@ -0,0 +1,209 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/operators.h" + +#include + +#include "absl/hash/hash_testing.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/internal/operators.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::Eq; +using ::testing::Optional; + +template +void TestOperator(Op op, OpId id, absl::string_view name, + absl::string_view display_name, int precedence, Arity arity) { + EXPECT_EQ(op.id(), id); + EXPECT_EQ(Operator(op).id(), static_cast(id)); + EXPECT_EQ(op.name(), name); + EXPECT_EQ(op.display_name(), display_name); + EXPECT_EQ(op.precedence(), precedence); + EXPECT_EQ(op.arity(), arity); + EXPECT_EQ(Operator(op).arity(), arity); + EXPECT_EQ(Op(Operator(op)), op); +} + +void TestUnaryOperator(UnaryOperator op, UnaryOperatorId id, + absl::string_view name, absl::string_view display_name, + int precedence) { + TestOperator(op, id, name, display_name, precedence, Arity::kUnary); +} + +void TestBinaryOperator(BinaryOperator op, BinaryOperatorId id, + absl::string_view name, absl::string_view display_name, + int precedence) { + TestOperator(op, id, name, display_name, precedence, Arity::kBinary); +} + +void TestTernaryOperator(TernaryOperator op, TernaryOperatorId id, + absl::string_view name, absl::string_view display_name, + int precedence) { + TestOperator(op, id, name, display_name, precedence, Arity::kTernary); +} + +TEST(Operator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_FALSE((std::is_convertible_v)); + EXPECT_FALSE((std::is_convertible_v)); + EXPECT_FALSE((std::is_convertible_v)); +} + +TEST(UnaryOperator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE((std::is_convertible_v)); +} + +TEST(BinaryOperator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE((std::is_convertible_v)); +} + +TEST(TernaryOperator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE((std::is_convertible_v)); +} + +#define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ + TEST(UnaryOperator, id) { \ + TestUnaryOperator(UnaryOperator::id(), UnaryOperatorId::k##id, name, \ + symbol, precedence); \ + } + +CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR) + +#undef CEL_UNARY_OPERATOR + +#define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ + TEST(BinaryOperator, id) { \ + TestBinaryOperator(BinaryOperator::id(), BinaryOperatorId::k##id, name, \ + symbol, precedence); \ + } + +CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR) + +#undef CEL_BINARY_OPERATOR + +#define CEL_TERNARY_OPERATOR(id, symbol, name, precedence, arity) \ + TEST(TernaryOperator, id) { \ + TestTernaryOperator(TernaryOperator::id(), TernaryOperatorId::k##id, name, \ + symbol, precedence); \ + } + +CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) + +#undef CEL_TERNARY_OPERATOR + +TEST(Operator, FindByName) { + EXPECT_THAT(Operator::FindByName("@in"), Optional(Eq(Operator::In()))); + EXPECT_THAT(Operator::FindByName("_in_"), Optional(Eq(Operator::OldIn()))); + EXPECT_THAT(Operator::FindByName("in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByName(""), Eq(std::nullopt)); +} + +TEST(Operator, FindByDisplayName) { + EXPECT_THAT(Operator::FindByDisplayName("-"), + Optional(Eq(Operator::Subtract()))); + EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName(""), Eq(std::nullopt)); +} + +TEST(UnaryOperator, FindByName) { + EXPECT_THAT(UnaryOperator::FindByName("-_"), + Optional(Eq(Operator::Negate()))); + EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName(""), Eq(std::nullopt)); +} + +TEST(UnaryOperator, FindByDisplayName) { + EXPECT_THAT(UnaryOperator::FindByDisplayName("-"), + Optional(Eq(Operator::Negate()))); + EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(std::nullopt)); +} + +TEST(BinaryOperator, FindByName) { + EXPECT_THAT(BinaryOperator::FindByName("_-_"), + Optional(Eq(Operator::Subtract()))); + EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName(""), Eq(std::nullopt)); +} + +TEST(BinaryOperator, FindByDisplayName) { + EXPECT_THAT(BinaryOperator::FindByDisplayName("-"), + Optional(Eq(Operator::Subtract()))); + EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(std::nullopt)); +} + +TEST(TernaryOperator, FindByName) { + EXPECT_THAT(TernaryOperator::FindByName("_?_:_"), + Optional(Eq(TernaryOperator::Conditional()))); + EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName(""), Eq(std::nullopt)); +} + +TEST(TernaryOperator, FindByDisplayName) { + EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); +} + +TEST(Operator, SupportsAbslHash) { +#define CEL_OPERATOR(id, symbol, name, precedence, arity) \ + Operator(Operator::id()), + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATOR)})); +#undef CEL_OPERATOR +} + +TEST(UnaryOperator, SupportsAbslHash) { +#define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ + UnaryOperator::id(), + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR)})); +#undef CEL_UNARY_OPERATOR +} + +TEST(BinaryOperator, SupportsAbslHash) { +#define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ + BinaryOperator::id(), + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR)})); +#undef CEL_BINARY_OPERATOR +} + +} // namespace +} // namespace cel diff --git a/base/status_macros.h b/base/status_macros.h deleted file mode 100644 index 4e7da02e1..000000000 --- a/base/status_macros.h +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ -#define THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ - -#include // for use with down_cast<> - -#include - -// Early-returns the status if it is in error; otherwise, proceeds. -// -// The argument expression is guaranteed to be evaluated exactly once. -#if !defined(RETURN_IF_ERROR) -#define RETURN_IF_ERROR(__status) \ - do { \ - auto _status = __status; \ - if (!_status.ok()) { \ - return _status; \ - } \ - } while (false) -#endif - -#if !defined(ASSIGN_OR_RETURN) -#define CEL_CONCAT_(x, y) x##y -#define CEL_CONCAT(x, y) CEL_CONCAT_(x, y) -#define ASSIGN_OR_RETURN(lhs, rexpr) \ - auto CEL_CONCAT(_statusor, __LINE__) = \ - static_cast(rexpr); \ - RETURN_IF_ERROR(CEL_CONCAT(_statusor, __LINE__).status()); \ - lhs = std::move(CEL_CONCAT(_statusor, __LINE__).value()); -#endif - -template // use like this: down_cast(foo); -inline To down_cast(From* f) { // so we only accept pointers - static_assert( - (std::is_base_of::type>::value), - "target type not derived from source type"); - - // We skip the assert and hence the dynamic_cast if RTTI is disabled. -#if !defined(__GNUC__) || defined(__GXX_RTTI) - // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. - assert(f == nullptr || dynamic_cast(f) != nullptr); -#endif // !defined(__GNUC__) || defined(__GXX_RTTI) - - return static_cast(f); -} - -#if !defined(ASSERT_OK) -#define ASSERT_OK(expression) ASSERT_TRUE(expression.ok()) -#endif - -#if !defined(EXPECT_OK) -#define EXPECT_OK(expression) EXPECT_TRUE(expression.ok()) -#endif - -#endif // THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ diff --git a/base/type_provider.h b/base/type_provider.h new file mode 100644 index 000000000..9ed8524e1 --- /dev/null +++ b/base/type_provider.h @@ -0,0 +1,26 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ + +#include "common/type_reflector.h" // IWYU pragma: export + +namespace cel { + +using TypeProvider = TypeReflector; + +} + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ diff --git a/bazel/BUILD b/bazel/BUILD index ffd0fb0cd..5b3cb2d2c 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1 +1,42 @@ +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") +load("@rules_java//java:defs.bzl", "java_binary") + +java_binary( + name = "antlr4_tool", + main_class = "org.antlr.v4.Tool", + runtime_deps = ["@antlr4_jar//jar"], +) + package(default_visibility = ["//visibility:public"]) + +exports_files( + srcs = [ + "antlr.patch", + ], + visibility = ["//:__subpackages__"], +) + +cc_binary( + name = "cel_cc_embed", + srcs = ["cel_cc_embed.cc"], + visibility = ["//:__subpackages__"], + deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_binary( + name = "cat_param_file", + srcs = ["cat_param_file.cc"], + visibility = ["//:__subpackages__"], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/log:initialize", + ], +) diff --git a/bazel/antlr.bzl b/bazel/antlr.bzl index bce7f9577..a4d28cdf8 100644 --- a/bazel/antlr.bzl +++ b/bazel/antlr.bzl @@ -1,46 +1,140 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Generate C++ parser and lexer from a grammar file. """ -load("@rules_antlr//antlr:antlr4.bzl", "antlr", "headers", "sources") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc/common:cc_common.bzl", "cc_common") +load("@rules_cc//cc/common:cc_info.bzl", "CcInfo") -def antlr_cc_library(name, src, listener = False, visitor = True): +def antlr_cc_library(name, src, package): """Creates a C++ lexer and parser from a source grammar. - Args: name: Base name for the lexer and the parser rules. src: source ANTLR grammar file - listener: generate ANTLR listener (default: False) - visitor: generate ANTLR visitor (default: True) + package: The namespace for the generated code """ generated = name + "_grammar" - antlr( + antlr_library( name = generated, - srcs = [src], - language = "Cpp", - listener = listener, - visitor = visitor, - package = generated, + src = src, + package = package, + shell = select( + { + "@platforms//os:windows": "PowerShell.exe", + "//conditions:default": "bash", + }, + ), + genfiles_prefixed = select( + { + "@platforms//os:windows": False, + "//conditions:default": True, + }, + ), ) - - headers( - name = "headers", - rule = ":" + generated, + cc_library( + name = name + "_cc_parser", + srcs = [generated], + defines = [ + "ANTLR4CPP_STATIC", + ], + deps = [ + generated, + "@antlr4-cpp-runtime//:antlr4-cpp-runtime", + ], + copts = ["-fexceptions"], + linkstatic = 1, ) - sources( - name = "sources", - rule = ":" + generated, - ) +def _antlr_library(ctx): + output = ctx.actions.declare_directory(ctx.attr.name) - native.cc_library( - name = name + "_cc_parser", - hdrs = [":headers"], - srcs = [":sources"], - includes = ["$(INCLUDES)"], - deps = ["@antlr4_runtimes//:cpp"], - toolchains = [":" + generated], - # ANTLR runtime does not build with dynamic linking - linkstatic = True, - alwayslink = 1, + antlr_args = ctx.actions.args() + antlr_args.add("-Dlanguage=Cpp") + antlr_args.add("-no-listener") + antlr_args.add("-visitor") + antlr_args.add("-o", output.path) + antlr_args.add("-package", ctx.attr.package) + antlr_args.add(ctx.file.src) + + # Strip ".g4" extension. + basename = ctx.file.src.basename[:-3] + + suffixes = ["Lexer", "Parser", "BaseVisitor", "Visitor"] + + ctx.actions.run( + mnemonic = "GenAntlr", + arguments = [antlr_args], + inputs = [ctx.file.src], + outputs = [output], + executable = ctx.executable._tool, + progress_message = "Processing ANTLR grammar. -o " + output.path, ) + + files = [] + for suffix in suffixes: + header = ctx.actions.declare_file(basename + suffix + ".h") + source = ctx.actions.declare_file(basename + suffix + ".cpp") + prefix = ctx.file.src.path[:-3] if ctx.attr.genfiles_prefixed else basename + generated = output.path + "/" + prefix + suffix + + executable = ctx.attr.shell + + ctx.actions.run( + mnemonic = "CopyHeader" + suffix, + inputs = [output], + outputs = [header], + executable = executable, + arguments = [ + "-c", + 'cp "{generated}" "{out}"'.format(generated = generated + ".h", out = header.path), + ], + ) + ctx.actions.run( + mnemonic = "CopySource" + suffix, + inputs = [output], + outputs = [source], + executable = executable, + arguments = [ + "-c", + 'cp "{generated}" "{out}"'.format(generated = generated + ".cpp", out = source.path), + ], + ) + + files.append(header) + files.append(source) + + compilation_context = cc_common.create_compilation_context(headers = depset(files)) + return [DefaultInfo(files = depset(files)), CcInfo(compilation_context = compilation_context)] + +antlr_library = rule( + implementation = _antlr_library, + attrs = { + "src": attr.label(allow_single_file = [".g4"], mandatory = True), + "package": attr.string(), + "_tool": attr.label( + executable = True, + cfg = "exec", # buildifier: disable=attr-cfg + default = Label("//bazel:antlr4_tool"), + ), + "shell": attr.string( + mandatory = True, + ), + "genfiles_prefixed": attr.bool( + mandatory = True, + ), + }, +) diff --git a/bazel/cat_param_file.cc b/bazel/cat_param_file.cc new file mode 100644 index 000000000..0bc497597 --- /dev/null +++ b/bazel/cat_param_file.cc @@ -0,0 +1,63 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/log/initialize.h" + +// Read a bazel param file and concatenate the inputs. +// The param file is line delimited with each line a file to concat. +int main(int argc, char** argv) { + absl::InitializeLog(); + if (argc != 3) { + std::cerr << "usage: cat_param_file " << std::endl; + std::cerr << "args " << argc << std::endl; + return 2; + } + + const char* param_file = argv[1]; + const char* out_file = argv[2]; + std::ifstream ifs(param_file, std::ios::binary); + std::ofstream ofs(out_file, std::ios::binary); + + ABSL_QCHECK(ifs.good()) << "failed to open param file " << param_file; + ABSL_QCHECK(ofs.good()) << "failed to open out file " << out_file; + + for (std::string line; std::getline(ifs, line);) { + std::ifstream in(line, std::ios::binary); + if (!in.good()) { + ABSL_LOG(ERROR) << "failed to open input file " << line; + continue; + } + constexpr size_t kBufSize = 256; + char buf[kBufSize]; + while (true) { + in.read(buf, kBufSize); + size_t read = in.gcount(); + if (read == 0) { + break; + } + ofs.write(buf, read); + } + } + + ofs.flush(); + + return 0; +} diff --git a/bazel/cel_cc_embed.bzl b/bazel/cel_cc_embed.bzl new file mode 100644 index 000000000..8f0144b22 --- /dev/null +++ b/bazel/cel_cc_embed.bzl @@ -0,0 +1,49 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Provides the `cel_cc_embed` build rule. +""" + +def _cel_cc_embed(ctx): + output = ctx.actions.declare_file(ctx.attr.name + ".inc") + args = ctx.actions.args() + src = ctx.file.src + args.add("--in", src) + args.add("--out", output.path) + ctx.actions.run( + mnemonic = "GenerateEmbedTextualHeader", + outputs = [output], + inputs = [src], + progress_message = "generating embed textual header", + executable = ctx.executable.gen_tool, + arguments = [args], + ) + + return DefaultInfo( + files = depset([output]), + ) + +cel_cc_embed = rule( + implementation = _cel_cc_embed, + attrs = { + "src": attr.label(allow_single_file = True, mandatory = True), + "gen_tool": attr.label( + executable = True, + cfg = "exec", + allow_files = True, + default = Label("//bazel:cel_cc_embed"), + ), + }, +) diff --git a/bazel/cel_cc_embed.cc b/bazel/cel_cc_embed.cc new file mode 100644 index 000000000..805154571 --- /dev/null +++ b/bazel/cel_cc_embed.cc @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/log/absl_check.h" +#include "absl/log/initialize.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" + +ABSL_FLAG(std::string, in, "", ""); +ABSL_FLAG(std::string, out, "", ""); + +namespace { + +std::vector ReadFile(const std::string& path) { + ABSL_CHECK(!path.empty()) << "--in is required"; + std::ifstream file(path, std::ifstream::binary); + ABSL_CHECK(file.is_open()) << path; + file.seekg(0, file.end); + ABSL_CHECK(file.good()); + size_t size = static_cast(file.tellg()); + file.seekg(0, file.beg); + ABSL_CHECK(file.good()); + std::vector buffer; + buffer.resize(size); + file.read(reinterpret_cast(buffer.data()), size); + ABSL_CHECK(file.good()); + return buffer; +} + +void WriteFile(const std::string& path, absl::Span data) { + ABSL_CHECK(!path.empty()) << "--out is required"; + std::ofstream file(path); + ABSL_CHECK(file.is_open()) << path; + file.write(data.data(), data.size()); + ABSL_CHECK(file.good()); + file.flush(); + ABSL_CHECK(file.good()); +} + +} // namespace + +int main(int argc, char** argv) { + { + auto args = absl::ParseCommandLine(argc, argv); + ABSL_CHECK(args.empty() || args.size() == 1) + << "unexpected positional args: " << absl::StrJoin(args, ", "); + } + absl::InitializeLog(); + + auto in_buffer = ReadFile(absl::GetFlag(FLAGS_in)); + std::string out_buffer; + out_buffer.reserve(in_buffer.size() * 6); + for (const auto& in_byte : in_buffer) { + absl::StrAppend(&out_buffer, "0x", + absl::Hex(in_byte, absl::PadSpec::kZeroPad2), ", "); + } + if (!in_buffer.empty()) { + // Replace last space with newline. + out_buffer.back() = '\n'; + } + WriteFile(absl::GetFlag(FLAGS_out), out_buffer); + + return EXIT_SUCCESS; +} diff --git a/bazel/cel_proto_transitive_descriptor_set.bzl b/bazel/cel_proto_transitive_descriptor_set.bzl new file mode 100644 index 000000000..1b735fe59 --- /dev/null +++ b/bazel/cel_proto_transitive_descriptor_set.bzl @@ -0,0 +1,54 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Provides the `cel_proto_transitive_descriptor_set` build rule. +""" + +load("@com_google_protobuf//bazel/common:proto_info.bzl", "ProtoInfo") + +def _cel_proto_transitive_descriptor_set(ctx): + output = ctx.actions.declare_file(ctx.attr.name + ".binarypb") + transitive_descriptor_sets = depset(transitive = [dep[ProtoInfo].transitive_descriptor_sets for dep in ctx.attr.deps]) + args = ctx.actions.args() + args.use_param_file(param_file_arg = "%s", use_always = True) + args.add_all(transitive_descriptor_sets) + ctx.actions.run( + mnemonic = "CelProtoTransitiveDescriptorSet", + outputs = [output], + inputs = transitive_descriptor_sets, + progress_message = "Joining descriptors.", + executable = ctx.executable.cat_tool, + arguments = [args] + [output.path], + ) + return DefaultInfo( + files = depset([output]), + runfiles = ctx.runfiles(files = [output]), + ) + +cel_proto_transitive_descriptor_set = rule( + attrs = { + "deps": attr.label_list(providers = [[ProtoInfo]]), + "cat_tool": attr.label( + executable = True, + cfg = "exec", + allow_files = True, + default = Label("//bazel:cat_param_file"), + ), + }, + outputs = { + "out": "%{name}.binarypb", + }, + implementation = _cel_proto_transitive_descriptor_set, +) diff --git a/bazel/deps.bzl b/bazel/deps.bzl new file mode 100644 index 000000000..477eb2c6d --- /dev/null +++ b/bazel/deps.bzl @@ -0,0 +1,246 @@ +""" +Legacy workspace dependencies of cel-cpp. + +Dependencies are now managed by MODULE.bazel. The values here are not updated, but this file is +retained for clients that referenced it directly. +""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_jar") + +def base_deps(): + """Base evaluator and test dependencies.""" + + # Abseil LTS 20240722.0 + ABSL_SHA1 = "4447c7562e3bc702ade25105912dce503f0c4010" + ABSL_SHA256 = "d8342ad77aa9e16103c486b615460c24a695a1f04cdb760eb02fef780df99759" + http_archive( + name = "com_google_absl", + urls = ["https://github.com/abseil/abseil-cpp/archive/" + ABSL_SHA1 + ".zip"], + strip_prefix = "abseil-cpp-" + ABSL_SHA1, + sha256 = ABSL_SHA256, + ) + + # v1.15.2 + GOOGLETEST_SHA1 = "b514bdc898e2951020cbdca1304b75f5950d1f59" + GOOGLETEST_SHA256 = "8c0ceafa3ea24bf78e3519b7846d99e76c45899aa4dac4d64e7dd62e495de9fd" + http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/" + GOOGLETEST_SHA1 + ".zip"], + strip_prefix = "googletest-" + GOOGLETEST_SHA1, + sha256 = GOOGLETEST_SHA256, + ) + + # v1.6.0 + BENCHMARK_SHA1 = "f91b6b42b1b9854772a90ae9501464a161707d1e" + BENCHMARK_SHA256 = "00bd0837db9266c758a087cdf0831a0d3e337c6bb9e3fad75d2be4f9bf480d95" + http_archive( + name = "com_github_google_benchmark", + urls = ["https://github.com/google/benchmark/archive/" + BENCHMARK_SHA1 + ".zip"], + strip_prefix = "benchmark-" + BENCHMARK_SHA1, + sha256 = BENCHMARK_SHA256, + ) + + # 2024-02-01 + RE2_SHA1 = "9665465b69ab699279ef9fb9454559d90fed1d76" + RE2_SHA256 = "dcd82922c7a1d3b7c2a147c045585a9f76066f9c0269a06b857eccbbf6f96dba" + http_archive( + name = "com_googlesource_code_re2", + urls = ["https://github.com/google/re2/archive/" + RE2_SHA1 + ".zip"], + strip_prefix = "re2-" + RE2_SHA1, + sha256 = RE2_SHA256, + ) + + # v28.0 + PROTOBUF_SHA1 = "439c42c735ae1efed57ab7771986f2a3c0b99319" + PROTOBUF_SHA256 = "495b76871df8d102e5c539f9d43f990f5ca53ac183702f5ed90070ba8c8759d1" + http_archive( + name = "com_google_protobuf", + sha256 = PROTOBUF_SHA256, + strip_prefix = "protobuf-" + PROTOBUF_SHA1, + urls = ["https://github.com/protocolbuffers/protobuf/archive/" + PROTOBUF_SHA1 + ".zip"], + ) + + GOOGLEAPIS_GIT_SHA = "6eb56cdf5f54f70d0dbfce051add28a35c1203ce" # June 26, 2024 + GOOGLEAPIS_SHA = "6321a7eac9e5280e7abca07ddf2cab9179cbd49a6828c26f4c7c73d5a45f39ad" + http_archive( + name = "com_google_googleapis", + sha256 = GOOGLEAPIS_SHA, + strip_prefix = "googleapis-" + GOOGLEAPIS_GIT_SHA, + urls = ["https://github.com/googleapis/googleapis/archive/" + GOOGLEAPIS_GIT_SHA + ".tar.gz"], + ) + + http_archive( + name = "rules_cc", + urls = ["https://github.com/bazelbuild/rules_cc/releases/download/0.0.10-rc1/rules_cc-0.0.10-rc1.tar.gz"], + sha256 = "d75a040c32954da0d308d3f2ea2ba735490f49b3a7aa3e4b40259ca4b814f825", + ) + + http_archive( + name = "rules_proto", + sha256 = "6fb6767d1bef535310547e03247f7518b03487740c11b6c6adb7952033fe1295", + strip_prefix = "rules_proto-6.0.2", + url = "https://github.com/bazelbuild/rules_proto/releases/download/6.0.2/rules_proto-6.0.2.tar.gz", + ) + +def parser_deps(): + """ANTLR dependency for the parser.""" + + # Sept 4, 2023 + ANTLR4_VERSION = "4.13.1" + + http_archive( + name = "antlr4_runtimes", + build_file_content = """ +package(default_visibility = ["//visibility:public"]) +cc_library( + name = "cpp", + srcs = glob(["runtime/Cpp/runtime/src/**/*.cpp"]), + hdrs = glob(["runtime/Cpp/runtime/src/**/*.h"]), + defines = ["ANTLR4CPP_USING_ABSEIL"], + includes = ["runtime/Cpp/runtime/src"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + ], +) + """, + sha256 = "365ff6aec0b1612fb964a763ca73748d80e0b3379cbdd9f82d86333eb8ae4638", + strip_prefix = "antlr4-" + ANTLR4_VERSION, + urls = ["https://github.com/antlr/antlr4/archive/refs/tags/" + ANTLR4_VERSION + ".zip"], + ) + http_jar( + name = "antlr4_jar", + urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], + sha256 = "bc13a9c57a8dd7d5196888211e5ede657cb64a3ce968608697e4f668251a8487", + ) + +def flatbuffers_deps(): + """FlatBuffers support.""" + FLAT_BUFFERS_SHA = "a83caf5910644ba1c421c002ef68e42f21c15f9f" + http_archive( + name = "com_github_google_flatbuffers", + sha256 = "b8efbc25721e76780752bad775a97c3f77a0250271e2db37fc747b20e8b0f24a", + strip_prefix = "flatbuffers-" + FLAT_BUFFERS_SHA, + url = "https://github.com/google/flatbuffers/archive/" + FLAT_BUFFERS_SHA + ".tar.gz", + ) + +def cel_spec_deps(): + """CEL Spec conformance testing.""" + http_archive( + name = "io_bazel_rules_go", + sha256 = "b2038e2de2cace18f032249cb4bb0048abf583a36369fa98f687af1b3f880b26", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.48.1/rules_go-v0.48.1.zip", + "https://github.com/bazelbuild/rules_go/releases/download/v0.48.1/rules_go-v0.48.1.zip", + ], + ) + + http_archive( + name = "rules_python", + sha256 = "e3f1cc7a04d9b09635afb3130731ed82b5f58eadc8233d4efb59944d92ffc06f", + strip_prefix = "rules_python-0.33.2", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.33.2/rules_python-0.33.2.tar.gz", + ) + + CEL_SPEC_GIT_SHA = "afa18f9bd5a83f5960ca06c1f9faea406ab34ccc" # Dec 2, 2024 + http_archive( + name = "com_google_cel_spec", + sha256 = "19b4084ba33cc8da7a640d999e46731efbec585ad2995951dc61a7af24f059cb", + strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, + urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], + ) + +_ICU4C_VERSION_MAJOR = "76" +_ICU4C_VERSION_MINOR = "1" +_ICU4C_BUILD = """ +load("@rules_foreign_cc//foreign_cc:configure.bzl", "configure_make") + +filegroup( + name = "all", + srcs = glob(["**"]), + visibility = ["//visibility:private"], +) + +config_setting( + name = "dbg", + values = {{ + "compilation_mode": "dbg", + }}, + visibility = ["//visibility:private"], +) + +configure_make( + name = "icu4c", + configure_command = "source/configure", + configure_in_place = True, + configure_options = [ + "--enable-shared", + "--enable-static", + "--disable-extras", + "--disable-icuio", + "--disable-layoutex", + "--disable-icu-config", + ] + select({{ + ":dbg": ["--enable-debug"], + "//conditions:default": [], + }}), + lib_source = ":all", + out_shared_libs = [ + "libicudata.so", + "libicudata.so.{version_major}", + "libicudata.so.{version_major}.{version_minor}", + "libicui18n.so", + "libicui18n.so.{version_major}", + "libicui18n.so.{version_major}.{version_minor}", + "libicutu.so", + "libicutu.so.{version_major}", + "libicutu.so.{version_major}.{version_minor}", + "libicuuc.so", + "libicuuc.so.{version_major}", + "libicuuc.so.{version_major}.{version_minor}", + ], + out_static_libs = [ + "libicudata.a", + "libicui18n.a", + "libicutu.a", + "libicuuc.a", + ], + args = ["-j 8"], + visibility = ["//visibility:public"], +) +""".format(version_major = _ICU4C_VERSION_MAJOR, version_minor = _ICU4C_VERSION_MINOR) + +def cel_cpp_extensions_deps(): + http_archive( + name = "rules_foreign_cc", + sha256 = "8e5605dc2d16a4229cb8fbe398514b10528553ed4f5f7737b663fdd92f48e1c2", + strip_prefix = "rules_foreign_cc-0.13.0", + url = "https://github.com/bazel-contrib/rules_foreign_cc/releases/download/0.13.0/rules_foreign_cc-0.13.0.tar.gz", + ) + http_archive( + name = "icu4c", + sha256 = "dfacb46bfe4747410472ce3e1144bf28a102feeaa4e3875bac9b4c6cf30f4f3e", + url = "https://github.com/unicode-org/icu/releases/download/release-{version_major}-{version_minor}/icu4c-{version_major}_{version_minor}-src.tgz".format(version_major = _ICU4C_VERSION_MAJOR, version_minor = _ICU4C_VERSION_MINOR), + strip_prefix = "icu", + patch_cmds = [ + "rm -f source/common/BUILD.bazel", + "rm -f source/i18n/BUILD.bazel", + "rm -f source/stubdata/BUILD.bazel", + "rm -f source/tools/gennorm2/BUILD.bazel", + "rm -f source/tools/toolutil/BUILD.bazel", + "rm -f source/tools/unicode/c/genprops/BUILD.bazel", + "rm -f source/tools/unicode/c/genuca/BUILD.bazel", + "rm -f source/vendor/double-conversion/upstream/WORKSPACE", + ], + build_file_content = _ICU4C_BUILD, + ) + +def cel_cpp_deps(): + """All core dependencies of cel-cpp.""" + base_deps() + parser_deps() + flatbuffers_deps() + cel_spec_deps() diff --git a/checker/BUILD b/checker/BUILD new file mode 100644 index 000000000..7f3ccfef7 --- /dev/null +++ b/checker/BUILD @@ -0,0 +1,256 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "checker_options", + hdrs = ["checker_options.h"], +) + +cc_library( + name = "type_check_issue", + srcs = ["type_check_issue.cc"], + hdrs = ["type_check_issue.h"], + deps = [ + "//common:source", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "type_check_issue_test", + srcs = ["type_check_issue_test.cc"], + deps = [ + ":type_check_issue", + "//common:source", + "//internal:testing", + ], +) + +cc_library( + name = "validation_result", + srcs = ["validation_result.cc"], + hdrs = ["validation_result.h"], + deps = [ + ":type_check_issue", + "//common:ast", + "//common:decl", + "//common:source", + "//common:type", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "validation_result_test", + srcs = ["validation_result_test.cc"], + deps = [ + ":type_check_issue", + ":validation_result", + "//common:ast", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_library( + name = "type_checker", + srcs = ["type_checker.cc"], + hdrs = ["type_checker.h"], + deps = [ + ":validation_result", + "//common:ast", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_checker_builder", + hdrs = ["type_checker_builder.h"], + deps = [ + ":checker_options", + ":type_checker", + "//common:container", + "//common:decl", + "//common:type", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_checker_builder_factory", + srcs = ["type_checker_builder_factory.cc"], + hdrs = ["type_checker_builder_factory.h"], + deps = [ + ":checker_options", + ":type_checker_builder", + "//checker/internal:type_checker_impl", + "//internal:noop_delete", + "//internal:status_macros", + "//internal:well_known_types", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_checker_builder_factory_test", + srcs = ["type_checker_builder_factory_test.cc"], + deps = [ + ":checker_options", + ":optional", + ":standard_library", + ":type_checker", + ":type_checker_builder", + ":type_checker_builder_factory", + ":validation_result", + "//checker/internal:test_ast_helpers", + "//common:ast", + "//common:decl", + "//common:type", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "standard_library", + srcs = ["standard_library.cc"], + hdrs = ["standard_library.h"], + deps = [ + ":type_checker_builder", + "//checker/internal:builtins_arena", + "//common:constant", + "//common:decl", + "//common:standard_definitions", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "standard_library_test", + srcs = ["standard_library_test.cc"], + deps = [ + ":checker_options", + ":standard_library", + ":type_checker", + ":type_checker_builder", + ":type_checker_builder_factory", + ":validation_result", + "//checker/internal:test_ast_helpers", + "//common:ast", + "//common:constant", + "//common:decl", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "optional", + srcs = ["optional.cc"], + hdrs = ["optional.h"], + deps = [ + ":type_checker_builder", + "//base:builtins", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "optional_test", + srcs = ["optional_test.cc"], + deps = [ + ":checker_options", + ":optional", + ":standard_library", + ":type_check_issue", + ":type_checker", + ":type_checker_builder", + ":type_checker_builder_factory", + "//checker/internal:test_ast_helpers", + "//common:ast", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "type_checker_subset_factory", + srcs = ["type_checker_subset_factory.cc"], + hdrs = ["type_checker_subset_factory.h"], + deps = [ + ":type_checker_builder", + "//common:decl", + "//common:signature", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "type_checker_subset_factory_test", + srcs = ["type_checker_subset_factory_test.cc"], + deps = [ + ":type_checker_subset_factory", + ":validation_result", + "//common:standard_definitions", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/checker/checker_options.h b/checker/checker_options.h new file mode 100644 index 000000000..cb85337fa --- /dev/null +++ b/checker/checker_options.h @@ -0,0 +1,110 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_CHECKER_OPTIONS_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_CHECKER_OPTIONS_H_ + +namespace cel { + +// Options for enabling core type checker features. +struct CheckerOptions { + // Enable overloads for numeric comparisons across types. + // For example, 1.0 < 2 will resolve to lt_double_int. + // + // By default, this is disabled and expressions must explicitly cast to dyn or + // the same type to compare. + bool enable_cross_numeric_comparisons = false; + + // Enable legacy behavior for null assignment. + // + // Historically, CEL has allowed null to be assigned to structs, abstract + // types, durations, timestamps, and any types. This is inconsistent with + // CEL's usual interpretation of null as a literal JSON null. + // + // TODO(uncreated-issue/75): Need a concrete plan for updating existing CEL + // expressions that depend on the old behavior. + bool enable_legacy_null_assignment = true; + + // Enable updating parsed struct type names to the fully qualified type name + // when resolved. + // + // Enabled by default, but can be disabled to preserve the original type name + // as parsed. + bool update_struct_type_names = true; + + // Temporary flag to enable type parameter name validation. + // + // When enabled, the TypeCheckerBuilder will validate that type parameter + // names are simple identifiers when declared. + bool enable_type_parameter_name_validation = true; + + // Well-known types defined by protobuf are treated specially in CEL, and + // generally don't behave like other messages as runtime values. When used as + // context declarations, this introduces some ambiguity about the intended + // types of the field declarations, so it is disallowed by default. + // + // When enabled, the well-known types are treated like a normal message type + // for the purposes for declaring context bindings (i.e no unpacking or + // adapting), and use the Descriptor that is assumed by CEL. + // + // E.g. for google.protobuf.Any, the type checker will add a context binding + // with `type_url: string` and `value: bytes` as top level variables. + bool allow_well_known_type_context_declarations = false; + + // Maximum number (inclusive) of expression nodes to check for an input + // expression. + // + // If exceeded, the checker should return a status with code InvalidArgument. + int max_expression_node_count = 100000; + + // Maximum number (inclusive) of error-level issues to tolerate for an input + // ast. + // + // If exceeded, the checker will stop processing the ast and return + // the current set of issues. + int max_error_issues = 20; + + // Maximum amount of nesting allowed for type declarations in function + // signatures and variable declarations. + // + // If exceeded, the TypeCheckerBuilder will report an error when adding the + // declaration. + // + // For untrusted declarations, the caller should set a lower limit to mitigate + // expressions that compound nesting e.g. + // type5(T)->type(type(type(type(type(T)))))); type5(type5(T)) -> type10(T) + int max_type_decl_nesting = 13; + + // If true, the checker will include the resolved function name in the + // reference map for the function call expr. + // + // If false, the function name will be empty and implied by the overload id + // set. This matches the behavior in cel-go and cel-java. + // + // Temporary flag to allow rolling out the change. No functional changes to + // evaluation behavior in either mode. + bool enable_function_name_in_reference = true; + + // If true, the checker will use the proto json field names for protobuf + // messages. Unlike protojson parsers, it will not accept the standard proto + // field names as valid json field names. + // + // Note: The checked AST will contain the json field names and an extension + // tag, but will require runtime support for resolving the json field names. + bool use_json_field_names = false; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_CHECKER_OPTIONS_H_ diff --git a/checker/internal/BUILD b/checker/internal/BUILD new file mode 100644 index 000000000..20c476db2 --- /dev/null +++ b/checker/internal/BUILD @@ -0,0 +1,401 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package( + # Implementation details for the checker library. + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "test_ast_helpers", + testonly = 1, + srcs = ["test_ast_helpers.cc"], + hdrs = ["test_ast_helpers.h"], + deps = [ + "//common:ast", + "//internal:status_macros", + "//parser", + "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "test_ast_helpers_test", + srcs = ["test_ast_helpers_test.cc"], + deps = [ + ":test_ast_helpers", + "//common:ast", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_library( + name = "builtins_arena", + srcs = ["builtins_arena.cc"], + hdrs = ["builtins_arena.h"], + deps = [ + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_check_env", + srcs = ["type_check_env.cc"], + hdrs = ["type_check_env.h"], + deps = [ + ":descriptor_pool_type_introspector", + ":proto_type_mask", + ":proto_type_mask_registry", + "//common:constant", + "//common:container", + "//common:decl", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "namespace_generator", + srcs = ["namespace_generator.cc"], + hdrs = ["namespace_generator.h"], + deps = [ + "//common:container", + "//internal:lexis", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "namespace_generator_test", + srcs = ["namespace_generator_test.cc"], + deps = [ + ":namespace_generator", + "//common:container", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "type_checker_impl", + srcs = [ + "type_checker_builder_impl.cc", + "type_checker_impl.cc", + ], + hdrs = [ + "type_checker_builder_impl.h", + "type_checker_impl.h", + ], + deps = [ + ":namespace_generator", + ":proto_type_mask", + ":type_check_env", + ":type_inference_context", + "//checker:checker_options", + "//checker:type_check_issue", + "//checker:type_checker", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:ast_rewrite", + "//common:ast_traverse", + "//common:ast_visitor", + "//common:ast_visitor_base", + "//common:constant", + "//common:container", + "//common:decl", + "//common:expr", + "//common:format_type_name", + "//common:standard_definitions", + "//common:type", + "//common:type_kind", + "//internal:lexis", + "//internal:status_macros", + "//parser:macro", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_checker_impl_test", + srcs = ["type_checker_impl_test.cc"], + deps = [ + ":test_ast_helpers", + ":type_check_env", + ":type_checker_impl", + "//checker:checker_options", + "//checker:type_check_issue", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:ast_proto", + "//common:container", + "//common:decl", + "//common:expr", + "//common:source", + "//common:type", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", + "//testutil:baseline_tests", + "//testutil:test_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_checker_builder_impl_test", + srcs = ["type_checker_builder_impl_test.cc"], + deps = [ + ":test_ast_helpers", + ":type_checker_impl", + "//checker:checker_options", + "//checker:type_checker", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:decl", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_inference_context", + srcs = ["type_inference_context.cc"], + hdrs = ["type_inference_context.h"], + deps = [ + "//common:decl", + "//common:format_type_name", + "//common:standard_definitions", + "//common:type", + "//common:type_kind", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_inference_context_test", + srcs = ["type_inference_context_test.cc"], + deps = [ + ":type_inference_context", + "//common:decl", + "//common:type", + "//common:type_kind", + "//internal:testing", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "descriptor_pool_type_introspector", + srcs = ["descriptor_pool_type_introspector.cc"], + hdrs = ["descriptor_pool_type_introspector.h"], + deps = [ + "//common:type", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "descriptor_pool_type_introspector_test", + srcs = ["descriptor_pool_type_introspector_test.cc"], + deps = [ + ":descriptor_pool_type_introspector", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "field_path", + srcs = ["field_path.cc"], + hdrs = ["field_path.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "field_path_test", + srcs = ["field_path_test.cc"], + deps = [ + ":field_path", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask", + srcs = ["proto_type_mask.cc"], + hdrs = ["proto_type_mask.h"], + deps = [ + ":field_path", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_test", + srcs = ["proto_type_mask_test.cc"], + deps = [ + ":field_path", + ":proto_type_mask", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask_registry", + srcs = ["proto_type_mask_registry.cc"], + hdrs = ["proto_type_mask_registry.h"], + deps = [ + ":field_path", + ":proto_type_mask", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_registry_test", + srcs = ["proto_type_mask_registry_test.cc"], + deps = [ + ":proto_type_mask", + ":proto_type_mask_registry", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/checker/internal/builtins_arena.cc b/checker/internal/builtins_arena.cc new file mode 100644 index 000000000..7a9d1ba6d --- /dev/null +++ b/checker/internal/builtins_arena.cc @@ -0,0 +1,28 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/builtins_arena.h" + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { + +google::protobuf::Arena* absl_nonnull BuiltinsArena() { + static absl::NoDestructor kArena; + return &(*kArena); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/builtins_arena.h b/checker/internal/builtins_arena.h new file mode 100644 index 000000000..333e09d68 --- /dev/null +++ b/checker/internal/builtins_arena.h @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_BUILTINS_ARENA_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_BUILTINS_ARENA_H_ + +#include "absl/base/nullability.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { + +// Shared arena for builtin types that are shared across all type checker +// instances. +google::protobuf::Arena* absl_nonnull BuiltinsArena(); + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_BUILTINS_ARENA_H_ diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc new file mode 100644 index 000000000..733e4a3cb --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -0,0 +1,245 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +// Standard implementation for field lookups. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr> +FindStructTypeFieldByNameDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type, absl::string_view name) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return std::nullopt; + } + const google::protobuf::FieldDescriptor* absl_nullable field = + descriptor->FindFieldByName(name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + + field = descriptor_pool->FindExtensionByPrintableName(descriptor, name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + return std::nullopt; +} + +// Standard implementation for listing fields. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr< + std::optional>> +ListStructTypeFieldsDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return std::nullopt; + } + + std::vector extensions; + descriptor_pool->FindAllExtensions(descriptor, &extensions); + + std::vector fields; + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back({field->name(), StructTypeField(MessageTypeField(field))}); + } + + return fields; +} + +} // namespace + +using Field = DescriptorPoolTypeIntrospector::Field; + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(name); + if (descriptor != nullptr) { + return Type::Message(descriptor); + } + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + return Type::Enum(enum_descriptor); + } + return std::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const { + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(type); + if (enum_descriptor != nullptr) { + const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = + enum_descriptor->FindValueByName(value); + if (enum_value_descriptor == nullptr) { + return std::nullopt; + } + return EnumConstant{ + .type = Type::Enum(enum_descriptor), + .type_full_name = enum_descriptor->full_name(), + .value_name = enum_value_descriptor->name(), + .number = enum_value_descriptor->number(), + }; + } + return std::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const { + if (!use_json_name_) { + return FindStructTypeFieldByNameDirectly(descriptor_pool_, type, name); + } + + const FieldTable* field_table = GetFieldTable(type); + + if (field_table == nullptr) { + return std::nullopt; + } + + if (auto it = field_table->json_name_map.find(name); + it != field_table->json_name_map.end()) { + return field_table->fields[it->second].field; + } + + if (auto it = field_table->extension_name_map.find(name); + it != field_table->extension_name_map.end()) { + return field_table->fields[it->second].field; + } + + return std::nullopt; +} + +absl::StatusOr< + std::optional>> +DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( + absl::string_view type) const { + if (!use_json_name_) { + return ListStructTypeFieldsDirectly(descriptor_pool_, type); + } + + const FieldTable* field_table = GetFieldTable(type); + if (field_table == nullptr) { + return std::nullopt; + } + std::vector fields; + fields.reserve(field_table->non_extensions.size()); + for (const auto& field : field_table->non_extensions) { + fields.push_back({field.json_name, field.field}); + } + return fields; +} + +const DescriptorPoolTypeIntrospector::FieldTable* +DescriptorPoolTypeIntrospector::GetFieldTable( + absl::string_view type_name) const { + absl::MutexLock lock(mu_); + if (auto it = field_tables_.find(type_name); it != field_tables_.end()) { + return it->second.get(); + } + if (cel::IsWellKnownMessageType(type_name)) { + return nullptr; + } + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return nullptr; + } + absl::string_view stable_type_name = descriptor->full_name(); + ABSL_DCHECK(stable_type_name == type_name); + std::unique_ptr field_table = CreateFieldTable(descriptor); + const FieldTable* field_table_ptr = field_table.get(); + field_tables_[stable_type_name] = std::move(field_table); + return field_table_ptr; +} + +std::unique_ptr +DescriptorPoolTypeIntrospector::CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const { + ABSL_DCHECK(!IsWellKnownMessageType(descriptor)); + std::vector fields; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + + std::vector extensions; + descriptor_pool_->FindAllExtensions(descriptor, &extensions); + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); i++) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(field)), + .json_name = field->json_name(), + .is_extension = false, + }); + field_name_map[field->name()] = fields.size() - 1; + if (use_json_name_ && !field->json_name().empty()) { + json_name_map[field->json_name()] = fields.size() - 1; + } + } + int non_extension_count = fields.size(); + + for (const google::protobuf::FieldDescriptor* extension : extensions) { + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(extension)), + .json_name = "", + .is_extension = true, + }); + extension_name_map[extension->full_name()] = fields.size() - 1; + } + int extension_count = fields.size() - non_extension_count; + auto result = std::make_unique(); + result->descriptor = descriptor; + result->fields = std::move(fields); + result->non_extensions = + absl::MakeConstSpan(result->fields).subspan(0, non_extension_count); + result->extensions = absl::MakeConstSpan(result->fields) + .subspan(non_extension_count, extension_count); + result->json_name_map = std::move(json_name_map); + result->field_name_map = std::move(field_name_map); + result->extension_name_map = std::move(extension_name_map); + return result; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/descriptor_pool_type_introspector.h b/checker/internal/descriptor_pool_type_introspector.h new file mode 100644 index 000000000..8a970ea00 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Implementation of `TypeIntrospector` that uses a `google::protobuf::DescriptorPool`. +// +// This is used by the type checker to resolve protobuf types and their fields +// and apply any options like using JSON names. +// +// Neither copyable nor movable. Should be managed by a TypeCheckEnv. +class DescriptorPoolTypeIntrospector : public TypeIntrospector { + public: + struct Field { + StructTypeField field; + absl::string_view json_name; + bool is_extension = false; + }; + + DescriptorPoolTypeIntrospector() = delete; + explicit DescriptorPoolTypeIntrospector( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) + : descriptor_pool_(descriptor_pool) {} + + DescriptorPoolTypeIntrospector(const DescriptorPoolTypeIntrospector&) = + delete; + DescriptorPoolTypeIntrospector& operator=( + const DescriptorPoolTypeIntrospector&) = delete; + DescriptorPoolTypeIntrospector(DescriptorPoolTypeIntrospector&&) = delete; + DescriptorPoolTypeIntrospector& operator=(DescriptorPoolTypeIntrospector&&) = + delete; + + void set_use_json_name(bool use_json_name) { use_json_name_ = use_json_name; } + + bool use_json_name() const { return use_json_name_; } + + private: + struct FieldTable { + const google::protobuf::Descriptor* absl_nonnull descriptor; + std::vector fields; + absl::Span non_extensions; + absl::Span extensions; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + }; + + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final; + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final; + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final; + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final; + + std::unique_ptr CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const; + + const FieldTable* GetFieldTable(absl::string_view type_name) const; + + // Cached map of type to field table. + mutable absl::flat_hash_map> + field_tables_ ABSL_GUARDED_BY(mu_); + + mutable absl::Mutex mu_; + bool use_json_name_ = false; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc new file mode 100644 index 000000000..db766b347 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::testing::AllOf; +using ::testing::Contains; +using ::testing::Eq; +using ::testing::Not; +using ::testing::Optional; +using ::testing::Property; +using ::testing::SizeIs; +using ::testing::Truly; + +TEST(DescriptorPoolTypeIntrospectorTest, FindType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + EXPECT_THAT(introspector.FindType("cel.expr.conformance.proto3.TestAllTypes"), + IsOkAndHolds(Optional(Property(&Type::IsMessage, true)))); + EXPECT_THAT(introspector.FindType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"), + IsOkAndHolds(Optional(Property(&Type::IsEnum, true)))); + EXPECT_THAT(introspector.FindType("non.existent.Type"), + IsOkAndHolds(Eq(std::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindEnumConstant) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto result = introspector.FindEnumConstant( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", "FOO"); + ASSERT_THAT(result, IsOkAndHolds(Optional(AllOf( + Truly([](const TypeIntrospector::EnumConstant& v) { + return v.value_name == "FOO" && v.number == 0; + }))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByName) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + introspector.set_use_json_name(false); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameJsonNameIgnored) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(false); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + EXPECT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindExtension) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto2.TestAllTypes", + "cel.expr.conformance.proto2.int32_ext"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + + ASSERT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + absl::StatusOr> field = + introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +MATCHER_P(FieldListingIs, field_name, "") { return arg.name == field_name; } + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + absl::StatusOr< + std::optional>> + fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(*fields, Optional(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeExtensions) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto2.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(259)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("single_int64"))); + EXPECT_THAT( + **fields, + Not(Contains(FieldListingIs("cel.expr.conformance.proto2.int32_ext")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + ListFieldsForStructTypeWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("singleInt64"))); + EXPECT_THAT(**fields, Not(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeNotFound) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.SomeOtherType"); + EXPECT_THAT(fields, IsOkAndHolds(Eq(std::nullopt))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/field_path.cc b/checker/internal/field_path.cc new file mode 100644 index 000000000..5ecc4219b --- /dev/null +++ b/checker/internal/field_path.cc @@ -0,0 +1,30 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace cel::checker_internal { + +std::string FieldPath::DebugString() const { + return absl::Substitute( + "FieldPath { field path: '$0', field selection: {'$1'} }", path_, + absl::StrJoin(field_selection_, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/field_path.h b/checker/internal/field_path.h new file mode 100644 index 000000000..d67d9b935 --- /dev/null +++ b/checker/internal/field_path.h @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ + +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace cel::checker_internal { + +// Represents a single path within a FieldMask. +class FieldPath { + public: + explicit FieldPath(std::string path) + : path_(std::move(path)), + field_selection_(absl::StrSplit(path_, kPathDelimiter)) {} + + // Returns the input path. + // For example: "f.b.d". + absl::string_view GetPath() const { return path_; } + + // Returns the list of nested field names in the path. + // For example: {"f", "b", "d"}. + absl::Span GetFieldSelection() const { + return field_selection_; + } + + // Returns the first field name in the path. + // For example: "f". + std::string GetFieldName() const { return field_selection_.front(); } + + template + friend void AbslStringify(Sink& sink, const FieldPath& field_path) { + sink.Append(field_path.DebugString()); + } + + private: + static constexpr char kPathDelimiter = '.'; + + std::string DebugString() const; + + // The input path. For example: "f.b.d". + std::string path_; + // The list of nested field names in the path. For example: {"f", "b", "d"}. + std::vector field_selection_; +}; + +inline bool operator==(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() == rhs.GetFieldSelection(); +} + +// Compares the field selections in the field paths. +// This is only intended as an arbitrary ordering for a set. +inline bool operator<(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() < rhs.GetFieldSelection(); +} + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ diff --git a/checker/internal/field_path_test.cc b/checker/internal/field_path_test.cc new file mode 100644 index 000000000..9a1434954 --- /dev/null +++ b/checker/internal/field_path_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include "absl/strings/str_cat.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::ElementsAre; + +TEST(FieldPathTest, EmptyPathReturnsEmptyString) { + FieldPath field_path(""); + EXPECT_EQ(field_path.GetPath(), ""); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, DelimiterPathReturnsEmptyStrings) { + FieldPath field_path("."); + EXPECT_EQ(field_path.GetPath(), "."); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("", "")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, FieldPathReturnsFields) { + FieldPath field_path("resource.name.other_field"); + EXPECT_EQ(field_path.GetPath(), "resource.name.other_field"); + EXPECT_THAT(field_path.GetFieldSelection(), + ElementsAre("resource", "name", "other_field")); + EXPECT_EQ(field_path.GetFieldName(), "resource"); +} + +TEST(FieldPathTest, AbslStringifyPrintsFieldSelection) { + FieldPath field_path("resource.name"); + EXPECT_EQ(absl::StrCat(field_path), + "FieldPath { field path: 'resource.name', field selection: " + "{'resource', 'name'} }"); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_TRUE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_FALSE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_TRUE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesIdenticalFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.type"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/namespace_generator.cc b/checker/internal/namespace_generator.cc new file mode 100644 index 000000000..7ab7628e4 --- /dev/null +++ b/checker/internal/namespace_generator.cc @@ -0,0 +1,186 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/namespace_generator.h" + +#include +#include +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/container.h" +#include "internal/lexis.h" + +namespace cel::checker_internal { +namespace { + +bool FieldSelectInterpretationCandidatesImpl( + absl::string_view prefix, + absl::Span partly_qualified_name, bool prefix_is_alias, + absl::FunctionRef callback) { + for (int i = 0; i < partly_qualified_name.size(); ++i) { + std::string buf; + int count = partly_qualified_name.size() - i; + auto end_idx = count - (prefix_is_alias ? 0 : 1); + auto ident = absl::StrJoin(partly_qualified_name.subspan(0, count), "."); + absl::string_view candidate = ident; + if (absl::StartsWith(candidate, ".")) { + candidate = candidate.substr(1); + } + if (!prefix.empty()) { + buf = absl::StrCat(prefix, ".", candidate); + candidate = buf; + } + if (!callback(candidate, end_idx)) { + return false; + } + } + if (prefix_is_alias) { + return callback(prefix, 0); + } + return true; +} + +bool FieldSelectInterpretationCandidates( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/false, callback); +} + +bool FieldSelectInterpretationCandidatesWithAlias( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/true, callback); +} + +} // namespace + +absl::StatusOr NamespaceGenerator::Create( + const ExpressionContainer& expression_container) { + std::vector candidates; + + absl::string_view container = expression_container.container(); + if (container.empty()) { + return NamespaceGenerator(&expression_container, std::move(candidates)); + } + + std::string prefix; + for (auto segment : absl::StrSplit(container, '.')) { + // Assumes the the ExpressionContainer has already validated the container + // and aliases. + ABSL_DCHECK(internal::LexisIsIdentifier(segment)); + if (prefix.empty()) { + prefix = segment; + } else { + absl::StrAppend(&prefix, ".", segment); + } + candidates.push_back(prefix); + } + std::reverse(candidates.begin(), candidates.end()); + return NamespaceGenerator(&expression_container, std::move(candidates)); +} + +void NamespaceGenerator::GenerateCandidates( + absl::string_view simple_name, + absl::FunctionRef callback) const { + // Special case for root-relative names. Aliases still apply first. + bool is_root_relative = absl::StartsWith(simple_name, "."); + if (is_root_relative) { + simple_name = simple_name.substr(1); + } + + // The name is unqualified, but may include a namespace (struct creation). + // This is just a quirk of the parser. + if (auto dot_pos = simple_name.find('.'); + dot_pos != absl::string_view::npos) { + absl::string_view first_segment = simple_name.substr(0, dot_pos); + absl::string_view rest = simple_name.substr(dot_pos + 1); + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + callback(absl::StrCat(resolved_alias, ".", rest)); + return; + } + } else { + if (auto resolved_alias = expression_container_->FindAlias(simple_name); + !resolved_alias.empty()) { + callback(resolved_alias); + return; + } + } + + if (is_root_relative) { + callback(simple_name); + return; + } + + for (const auto& prefix : candidates_) { + std::string candidate = absl::StrCat(prefix, ".", simple_name); + if (!callback(candidate)) { + return; + } + } + callback(simple_name); +} + +void NamespaceGenerator::GenerateCandidates( + absl::Span partly_qualified_name, + absl::FunctionRef callback) const { + if (partly_qualified_name.empty()) { + return; + } + + // Special case for root-relative names. Aliases still apply first. + absl::string_view first_segment = partly_qualified_name[0]; + bool is_root_relative = absl::StartsWith(first_segment, "."); + if (is_root_relative) { + first_segment = first_segment.substr(1); + } + + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + FieldSelectInterpretationCandidatesWithAlias( + resolved_alias, partly_qualified_name.subspan(1), callback); + // If the alias matches, we don't check the container even if name + // resolution fails. + return; + } + + if (is_root_relative) { + FieldSelectInterpretationCandidates("", partly_qualified_name, callback); + return; + } + + for (const auto& prefix : candidates_) { + if (!FieldSelectInterpretationCandidates(prefix, partly_qualified_name, + callback)) { + return; + } + } + FieldSelectInterpretationCandidates("", partly_qualified_name, callback); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/namespace_generator.h b/checker/internal/namespace_generator.h new file mode 100644 index 000000000..61cb1956b --- /dev/null +++ b/checker/internal/namespace_generator.h @@ -0,0 +1,120 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_NAMESPACE_GENERATOR_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_NAMESPACE_GENERATOR_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/function_ref.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/container.h" + +namespace cel::checker_internal { + +// Utility class for generating namespace qualified candidates for reference +// resolution. +// +// This class is expected to be scoped to a single type checking operation and +// borrows the ExpressionContainer from the TypeCheckEnv. +class NamespaceGenerator { + public: + static absl::StatusOr Create( + const ExpressionContainer& expression_container + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Copyable and movable. + NamespaceGenerator(const NamespaceGenerator&) = default; + NamespaceGenerator& operator=(const NamespaceGenerator&) = default; + NamespaceGenerator(NamespaceGenerator&&) = default; + NamespaceGenerator& operator=(NamespaceGenerator&&) = default; + + // For the simple case of an unqualified name, generate all qualified + // candidates and pass them to the provided callback. The callback may return + // false to terminate early. + // + // The supplied string_view is only valid for the duration of the callback + // invocation: the callback must handle copying the underlying string if the + // value needs to be persisted. + // + // Example: + // For container (com.google) + // and unqualified name foo + // + // com.google.foo, com.foo, foo + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (foo = com.example) + // unqualified name foo + // + // com.example + void GenerateCandidates( + absl::string_view simple_name, + absl::FunctionRef callback) const; + + // For a partially qualified name, generate all the qualified candidates in + // order of resolution precedence and pass them to the provided callback. The + // callback may return false to terminate early. + // + // The supplied string_view is only valid for the duration of the callback + // invocation: the callback must handle copying the underlying string if the + // value needs to be persisted. + // + // Example: + // For container (com.google) + // and partially qualified name Foo.bar + // + // (com.google.Foo.bar), + // (com.google.Foo).bar, + // (com.Foo.bar), + // (com.Foo).bar, + // (Foo.bar), + // (Foo).bar, + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (Foo = com.example.Foo) + // partially qualified name Foo.bar + // + // (com.example.Foo.bar), + // (com.example.Foo).bar, + void GenerateCandidates( + absl::Span partly_qualified_name, + absl::FunctionRef callback) const; + + private: + explicit NamespaceGenerator( + const ExpressionContainer* absl_nonnull expression_container, + std::vector candidates) + : candidates_(std::move(candidates)), + expression_container_(expression_container) {} + + // list of prefixes ordered from most qualified to least. + std::vector candidates_; + const ExpressionContainer* absl_nonnull expression_container_; +}; +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_NAMESPACE_GENERATOR_H_ diff --git a/checker/internal/namespace_generator_test.cc b/checker/internal/namespace_generator_test.cc new file mode 100644 index 000000000..ba9bb88a4 --- /dev/null +++ b/checker/internal/namespace_generator_test.cc @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/namespace_generator.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/container.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::IsOk; +using ::testing::ElementsAre; +using ::testing::Pair; + +TEST(NamespaceGeneratorTest, EmptyContainer) { + ExpressionContainer container; + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector candidates; + generator.GenerateCandidates("foo", [&](absl::string_view candidate) { + candidates.push_back(std::string(candidate)); + return true; + }); + EXPECT_THAT(candidates, ElementsAre("foo")); +} + +TEST(NamespaceGeneratorTest, MultipleSegments) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector candidates; + generator.GenerateCandidates("foo", [&](absl::string_view candidate) { + candidates.push_back(std::string(candidate)); + return true; + }); + EXPECT_THAT(candidates, ElementsAre("com.example.foo", "com.foo", "foo")); +} + +TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector candidates; + generator.GenerateCandidates(".foo", [&](absl::string_view candidate) { + candidates.push_back(std::string(candidate)); + return true; + }); + EXPECT_THAT(candidates, ElementsAre("foo")); +} + +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT( + candidates, + ElementsAre(Pair("com.example.foo.Bar", 1), Pair("com.example.foo", 0), + Pair("com.foo.Bar", 1), Pair("com.foo", 0), + Pair("foo.Bar", 1), Pair("foo", 0))); +} + +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT(candidates, + ElementsAre(Pair("bar.baz.Bar", 1), Pair("bar.baz", 0))); +} + +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasNoMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAbbreviation("foo.Bar"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + // No match on the alias (Bar) since it's not the first segment. + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT( + candidates, + ElementsAre(Pair("com.example.foo.Bar", 1), Pair("com.example.foo", 0), + Pair("com.foo.Bar", 1), Pair("com.foo", 0), + Pair("foo.Bar", 1), Pair("foo", 0))); +} + +TEST(NamespaceGeneratorTest, + MultipleSegmentsSelectInterpretationRootNamespace) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {".foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT(candidates, ElementsAre(Pair("foo.Bar", 1), Pair("foo", 0))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.cc b/checker/internal/proto_type_mask.cc new file mode 100644 index 000000000..85e39cb69 --- /dev/null +++ b/checker/internal/proto_type_mask.cc @@ -0,0 +1,87 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "checker/internal/field_path.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; + +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name) { + const Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("type '$0' not found", type_name)); + } + return descriptor; +} + +absl::StatusOr FindField(const Descriptor* descriptor, + absl::string_view field_name) { + const FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("could not select field '$0' from type '$1'", + field_name, descriptor->full_name())); + } + return field_descriptor; +} + +absl::StatusOr> ProtoTypeMask::GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const { + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, this->GetTypeName())); + absl::btree_set field_names; + for (const FieldPath& field_path : this->GetFieldPaths()) { + std::string field_name = field_path.GetFieldName(); + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(descriptor, field_name)); + field_names.insert(field_descriptor->name()); + } + return field_names; +} + +std::string ProtoTypeMask::DebugString() const { + // Represent each FieldPath by its path because it is easiest to read. + std::vector paths; + paths.reserve(field_paths_.size()); + for (const FieldPath& field_path : field_paths_) { + paths.emplace_back(field_path.GetPath()); + } + return absl::Substitute( + "ProtoTypeMask { type name: '$0', field paths: { '$1' } }", type_name_, + absl::StrJoin(paths, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.h b/checker/internal/proto_type_mask.h new file mode 100644 index 000000000..f7d522cba --- /dev/null +++ b/checker/internal/proto_type_mask.h @@ -0,0 +1,111 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Returns a descriptor for the input type name. +// Returns an error if the type name is not found. +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name); + +// Returns a field descriptor for the input field name. +// Returns an error if the field name is not found. +absl::StatusOr FindField( + const google::protobuf::Descriptor* descriptor, absl::string_view field_name); + +// Represents the fraction of a protobuf type's object graph that should be +// visible within CEL expressions. +class ProtoTypeMask { + public: + explicit ProtoTypeMask(std::string type_name, + const std::vector& field_paths) + : type_name_(std::move(type_name)) { + for (const std::string& field_path : field_paths) { + field_paths_.insert(FieldPath(field_path)); + } + } + + // Returns a set of field names. The set includes the first field name from + // each field path. We are able to return a set of absl::string_view because + // the result is backed by the descriptor pool. + absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const; + + // Returns the type's full name. + // For example: "google.rpc.context.AttributeContext". + absl::string_view GetTypeName() const { return type_name_; } + + // Returns a representation of the FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + const absl::btree_set& GetFieldPaths() const { + return field_paths_; + } + + template + friend void AbslStringify(Sink& sink, const ProtoTypeMask& proto_type_mask) { + sink.Append(proto_type_mask.DebugString()); + } + + private: + std::string DebugString() const; + + // A type's full name. For example: "google.rpc.context.AttributeContext". + std::string type_name_; + // A representation of a FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + // A FieldMask contains one or more paths which contain identifier characters + // that have been dot delimited, e.g. resource.name, request.auth.claims. + // For each path, all descendent fields after the last element in the path are + // visible. An empty set means all fields are hidden. + absl::btree_set field_paths_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ diff --git a/checker/internal/proto_type_mask_registry.cc b/checker/internal/proto_type_mask_registry.cc new file mode 100644 index 000000000..9c50c9784 --- /dev/null +++ b/checker/internal/proto_type_mask_registry.cc @@ -0,0 +1,180 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/field_path.h" +#include "checker/internal/proto_type_mask.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using TypeMap = + absl::flat_hash_map>; + +// Returns a message type descriptor for the input field descriptor. +// Returns an error if the field is not a message type. +absl::StatusOr GetMessage( + const FieldDescriptor* field_descriptor) { + cel::MessageTypeField field(field_descriptor); + cel::Type type = field.GetType(); + absl::optional message_type = type.AsMessage(); + if (!message_type.has_value()) { + return absl::InvalidArgumentError(absl::Substitute( + "field '$0' is not a message type", field_descriptor->name())); + } + return &(*message_type.value()); +} + +// Inserts the type name with an empty set into types_and_visible_fields. +// Returns an error if the type name is already present with a non-empty set. +absl::Status AddAllHiddenFields(TypeMap& types_and_visible_fields, + absl::string_view type_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (!result->second.empty()) { + return absl::InvalidArgumentError( + absl::Substitute("cannot insert a proto type mask with all hidden " + "fields when type '$0' has already been inserted " + "with a proto type mask with a visible field", + type_name)); + } + return absl::OkStatus(); + } + types_and_visible_fields.insert({std::string(type_name), {}}); + return absl::OkStatus(); +} + +// Inserts the type name and field name into types_and_visible_fields. +// Returns an error if the type name is already present with an empty set. +absl::Status AddVisibleField(TypeMap& types_and_visible_fields, + absl::string_view type_name, + absl::string_view field_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (result->second.empty()) { + return absl::InvalidArgumentError(absl::Substitute( + "cannot insert a proto type mask with visible " + "field '$0' when type '$1' has already been inserted " + "with a proto type mask with all hidden fields", + field_name, type_name)); + } + result->second.insert(std::string(field_name)); + return absl::OkStatus(); + } + types_and_visible_fields.insert( + {std::string(type_name), {std::string(field_name)}}); + return absl::OkStatus(); +} + +// Processes the input proto type masks to create and return the +// types_and_visible_fields map. +// Returns an error if one of the proto type masks is not valid. For example, +// if a type is not found in the descriptor pool, if a field name is not +// found, or if a field is not a message type when we are expecting it to be. +// Returns an error if there is a conflict in field visibility when +// updating the map. +absl::StatusOr ComputeVisibleFieldsMap( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + TypeMap types_and_visible_fields; + for (const ProtoTypeMask& proto_type_mask : proto_type_masks) { + absl::string_view type_name = proto_type_mask.GetTypeName(); + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, type_name)); + const absl::btree_set& field_paths = + proto_type_mask.GetFieldPaths(); + if (field_paths.empty()) { + CEL_RETURN_IF_ERROR( + AddAllHiddenFields(types_and_visible_fields, type_name)); + } + for (const FieldPath& field_path : field_paths) { + const Descriptor* target_descriptor = descriptor; + absl::Span field_selection = + field_path.GetFieldSelection(); + for (auto iterator = field_selection.begin(); + iterator != field_selection.end(); ++iterator) { + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(target_descriptor, *iterator)); + CEL_RETURN_IF_ERROR(AddVisibleField(types_and_visible_fields, + target_descriptor->full_name(), + *iterator)); + if (std::next(iterator) != field_selection.end()) { + CEL_ASSIGN_OR_RETURN(target_descriptor, GetMessage(field_descriptor)); + } + } + } + } + return types_and_visible_fields; +} + +} // namespace + +absl::StatusOr> +ProtoTypeMaskRegistry::Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN( + auto types_and_visible_fields, + ComputeVisibleFieldsMap(descriptor_pool, proto_type_masks)); + std::shared_ptr proto_type_mask_registry = + absl::WrapUnique(new ProtoTypeMaskRegistry(types_and_visible_fields)); + return proto_type_mask_registry; +} + +bool ProtoTypeMaskRegistry::FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const { + auto iterator = types_and_visible_fields_.find(type_name); + if (iterator != types_and_visible_fields_.end() && + !iterator->second.contains(field_name)) { + return false; + } + return true; +} + +std::string ProtoTypeMaskRegistry::DebugString() const { + std::string output = "ProtoTypeMaskRegistry { "; + for (auto& element : types_and_visible_fields_) { + absl::StrAppend(&output, "{type: '", element.first, "', visible_fields: '", + absl::StrJoin(element.second, "', '"), "'} "); + } + absl::StrAppend(&output, "}"); + return output; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_registry.h b/checker/internal/proto_type_mask_registry.h new file mode 100644 index 000000000..338353e7d --- /dev/null +++ b/checker/internal/proto_type_mask_registry.h @@ -0,0 +1,83 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Stores information related to ProtoTypeMasks. Visibility is defined per type, +// meaning that all messages of a type have the same visible fields. +class ProtoTypeMaskRegistry { + public: + // Processes the input proto type masks to create a ProtoTypeMaskRegistry. + // Returns an error if one of the proto type masks is not valid. For example, + // if a type is not found in the descriptor pool, if a field name is not + // found, or if a field is not a message type when we are expecting it to be. + // Returns an error if there is a conflict in field visibility when + // updating the map. + static absl::StatusOr> Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks); + + const absl::flat_hash_map>& + GetTypesAndVisibleFields() const { + return types_and_visible_fields_; + } + + // Returns true when the field name is visible. A field is visible if: + // 1. The type name is not a key in the map. + // 2. The type name is a key in the map and the field name is in the set of + // field names that are visible for the type. + bool FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const; + + template + friend void AbslStringify( + Sink& sink, + const std::shared_ptr& proto_type_mask_registry) { + sink.Append(proto_type_mask_registry->DebugString()); + } + + private: + explicit ProtoTypeMaskRegistry( + absl::flat_hash_map> + types_and_visible_fields) + : types_and_visible_fields_(std::move(types_and_visible_fields)) {} + + std::string DebugString() const; + + // Map of types that have a field mask where the keys are + // fully qualified type names and the values are the set of field names that + // are visible for the type. + absl::flat_hash_map> + types_and_visible_fields_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ diff --git a/checker/internal/proto_type_mask_registry_test.cc b/checker/internal/proto_type_mask_registry_test.cc new file mode 100644 index 000000000..3a73c8823 --- /dev/null +++ b/checker/internal/proto_type_mask_registry_test.cc @@ -0,0 +1,402 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TypeMap = + absl::flat_hash_map>; + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptyInputSucceedsAndAllFieldsAreVisible) { + std::vector proto_type_masks = {}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), IsEmpty()); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyTypeReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask("", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownTypeReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("com.example.UnknownType", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDuplicateEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {""})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDelimiterFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {"."})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownFieldReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneNonMessageFieldsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_any", "single_timestamp"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("single_int32", "single_any", + "single_timestamp")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_any")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_timestamp")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoNonMessageFieldReturnsError) { + std::vector proto_type_masks; + proto_type_masks.push_back( + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32.any_field_name"})); + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'single_int32' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre(Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageUnknownFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthThreeMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneRepeatedMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"repeated_nested_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("repeated_nested_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "repeated_nested_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoRepeatedMessageFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"repeated_nested_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("field 'repeated_nested_message' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithListOfFieldPathsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message", "single_int32")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddVisibleFieldThenAllHiddenFieldsReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with all hidden fields when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with a visible " + "field"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddAllHiddenThenVisibleFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with visible field 'bb' when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with all " + "hidden fields"))); +} + +TEST(ProtoTypeMaskRegistryTest, AbslStringifyPrintsTypesAndVisibleFieldsMap) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + absl::StrCat(proto_type_mask_registry), + AllOf(HasSubstr("ProtoTypeMaskRegistry {"), + HasSubstr("{type: 'cel.expr.conformance.proto3.TestAllTypes', " + "visible_fields: 'standalone_message'}"), + HasSubstr("{type: " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'" + ", visible_fields: 'bb'}"))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_test.cc b/checker/internal/proto_type_mask_test.cc new file mode 100644 index 000000000..0c534f8cf --- /dev/null +++ b/checker/internal/proto_type_mask_test.cc @@ -0,0 +1,143 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +TEST(ProtoTypeMaskTest, EmptyTypeNameAndEmptyFieldPathsSucceeds) { + std::string type_name = ""; + std::vector field_paths; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), ""); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), IsEmpty()); +} + +TEST(ProtoTypeMaskTest, NotEmptyTypeNameAndNotEmptyFieldPathsSucceeds) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), "google.type.Expr"); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), + UnorderedElementsAre(FieldPath("resource.name"), + FieldPath("resource.type"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyTypeReturnsError) { + ProtoTypeMask proto_type_mask("", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownTypeReturnsError) { + ProtoTypeMask proto_type_mask("com.example.UnknownType", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithEmptySetFieldPathSucceedsAndReturnsEmptySet) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", {}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, IsEmpty()); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {""}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithDelimiterFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "."}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownFieldReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"unknown_field"}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_string"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, + UnorderedElementsAre("single_int32", "single_string")); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldPathsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32", + "child.any_field_name"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, UnorderedElementsAre("payload", "child")); +} + +TEST(ProtoTypeMaskTest, AbslStringifyPrintsTypeNameAndFieldPaths) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_THAT(absl::StrCat(proto_type_mask), + HasSubstr("ProtoTypeMask { type name: 'google.type.Expr', field " + "paths: { 'resource.name', 'resource.type' } }")); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/test_ast_helpers.cc b/checker/internal/test_ast_helpers.cc new file mode 100644 index 000000000..543f70a89 --- /dev/null +++ b/checker/internal/test_ast_helpers.cc @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "checker/internal/test_ast_helpers.h" + +#include + +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "internal/status_macros.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" + +namespace cel::checker_internal { + +absl::StatusOr> MakeTestParsedAst( + absl::string_view expression) { + static const cel::Parser* parser = []() { + cel::ParserOptions options = {.enable_optional_syntax = true}; + auto parser = NewParserBuilder(options)->Build(); + ABSL_CHECK_OK(parser); + return parser->release(); + }(); + + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + return parser->Parse(*source); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/test_ast_helpers.h b/checker/internal/test_ast_helpers.h new file mode 100644 index 000000000..44a1e0a0f --- /dev/null +++ b/checker/internal/test_ast_helpers.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TESTING_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" + +namespace cel::checker_internal { + +absl::StatusOr> MakeTestParsedAst( + absl::string_view expression); + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TESTING_H_ diff --git a/checker/internal/test_ast_helpers_test.cc b/checker/internal/test_ast_helpers_test.cc new file mode 100644 index 000000000..51fb8461a --- /dev/null +++ b/checker/internal/test_ast_helpers_test.cc @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/test_ast_helpers.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/ast.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; + +TEST(MakeTestParsedAstTest, Works) { + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, MakeTestParsedAst("123")); + EXPECT_TRUE(ast->root_expr().has_const_expr()); +} + +TEST(MakeTestParsedAstTest, ForwardsParseError) { + EXPECT_THAT(MakeTestParsedAst("%123"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc new file mode 100644 index 000000000..8dc83518d --- /dev/null +++ b/checker/internal/type_check_env.cc @@ -0,0 +1,131 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_check_env.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { + +const VariableDecl* absl_nullable TypeCheckEnv::LookupVariable( + absl::string_view name) const { + if (auto it = variables_.find(name); it != variables_.end()) { + return &it->second; + } + return nullptr; +} + +const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( + absl::string_view name) const { + if (auto it = functions_.find(name); it != functions_.end()) { + return &it->second; + } + + return nullptr; +} + +absl::StatusOr> TypeCheckEnv::LookupTypeName( + absl::string_view name) const { + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); + ++iter) { + CEL_ASSIGN_OR_RETURN(auto type, (*iter)->FindType(name)); + if (type.has_value()) { + return type; + } + } + return std::nullopt; +} + +absl::StatusOr> TypeCheckEnv::LookupEnumConstant( + absl::string_view type, absl::string_view value) const { + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); + ++iter) { + CEL_ASSIGN_OR_RETURN(auto enum_constant, + (*iter)->FindEnumConstant(type, value)); + if (enum_constant.has_value()) { + auto decl = MakeVariableDecl(absl::StrCat(enum_constant->type_full_name, + ".", enum_constant->value_name), + enum_constant->type); + decl.set_value(Constant(static_cast(enum_constant->number))); + return decl; + } + } + return std::nullopt; +} + +absl::StatusOr> TypeCheckEnv::LookupTypeConstant( + google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const { + CEL_ASSIGN_OR_RETURN(std::optional type, LookupTypeName(name)); + if (type.has_value()) { + return MakeVariableDecl(type->name(), TypeType(arena, *type)); + } + + if (name.find('.') != name.npos) { + size_t last_dot = name.rfind('.'); + absl::string_view enum_name_candidate = name.substr(0, last_dot); + absl::string_view value_name_candidate = name.substr(last_dot + 1); + return LookupEnumConstant(enum_name_candidate, value_name_candidate); + } + + return std::nullopt; +} + +absl::StatusOr> TypeCheckEnv::LookupStructField( + absl::string_view type_name, absl::string_view field_name) const { + if (proto_type_mask_registry_ != nullptr && + !proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { + return std::nullopt; + } + // Check the type providers in registration order. + // Note: this doesn't allow for shadowing a type with a subset type of the + // same name -- the later type provider will still be considered when + // checking field accesses. + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); + ++iter) { + CEL_ASSIGN_OR_RETURN( + auto field, (*iter)->FindStructTypeFieldByName(type_name, field_name)); + if (field.has_value()) { + return field; + } + } + return std::nullopt; +} + +const VariableDecl* absl_nullable VariableScope::LookupLocalVariable( + absl::string_view name) const { + const VariableScope* scope = this; + while (scope != nullptr) { + if (auto it = scope->variables_.find(name); it != scope->variables_.end()) { + return &it->second; + } + scope = scope->parent_; + } + return nullptr; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h new file mode 100644 index 000000000..00fea0ba3 --- /dev/null +++ b/checker/internal/type_check_env.h @@ -0,0 +1,249 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECK_ENV_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECK_ENV_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/descriptor_pool_type_introspector.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/proto_type_mask_registry.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +class TypeCheckEnv; + +// Helper class for managing nested scopes and the local variables they +// implicitly declare. +// +// Nested scopes have a lifetime dependency on any parent scopes and should +// generally be managed by unique_ptrs. +class VariableScope { + public: + explicit VariableScope() : parent_(nullptr) {} + + VariableScope(const VariableScope&) = delete; + VariableScope& operator=(const VariableScope&) = delete; + VariableScope(VariableScope&&) = default; + VariableScope& operator=(VariableScope&&) = default; + + bool InsertVariableIfAbsent(VariableDecl decl) { + return variables_.insert({decl.name(), std::move(decl)}).second; + } + + std::unique_ptr MakeNestedScope() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return absl::WrapUnique(new VariableScope(this)); + } + + const VariableDecl* absl_nullable LookupLocalVariable( + absl::string_view name) const; + + private: + explicit VariableScope( + const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND) + : parent_(parent) {} + + const VariableScope* absl_nullable parent_; + absl::flat_hash_map variables_; +}; + +// Class managing the state of the type check environment. +// +// Maintains lookup maps for variables and functions and the set of type +// providers. +// +// This class is thread-compatible. +class TypeCheckEnv { + private: + using VariableDeclPtr = const VariableDecl* absl_nonnull; + using FunctionDeclPtr = const FunctionDecl* absl_nonnull; + + public: + explicit TypeCheckEnv( + absl_nonnull std::shared_ptr + descriptor_pool) + : descriptor_pool_(std::move(descriptor_pool)), + proto_type_introspector_( + std::make_shared( + descriptor_pool_.get())) { + type_providers_.push_back( + std::make_shared()); + type_providers_.push_back(proto_type_introspector_); + } + + TypeCheckEnv(const TypeCheckEnv&) = default; + TypeCheckEnv& operator=(const TypeCheckEnv&) = default; + TypeCheckEnv(TypeCheckEnv&&) = default; + TypeCheckEnv& operator=(TypeCheckEnv&&) = default; + + const ExpressionContainer& container() const { return container_; } + + void set_container(ExpressionContainer container) { + container_ = std::move(container); + } + + const DescriptorPoolTypeIntrospector& proto_type_introspector() const { + return *proto_type_introspector_; + } + DescriptorPoolTypeIntrospector& proto_type_introspector() { + return *proto_type_introspector_; + } + + void set_expected_type(const Type& type) { expected_type_ = std::move(type); } + + const absl::optional& expected_type() const { return expected_type_; } + + absl::Span> type_providers() + const { + return type_providers_; + } + + void AddTypeProvider(std::unique_ptr provider) { + type_providers_.push_back(std::move(provider)); + } + + void AddTypeProvider(std::shared_ptr provider) { + type_providers_.push_back(std::move(provider)); + } + + const absl::flat_hash_map& variables() const { + return variables_; + } + + // Inserts a variable declaration into the environment of the current scope if + // is is not already present. Parent scopes are not searched. + // + // Returns true if the variable was inserted, false otherwise. + bool InsertVariableIfAbsent(VariableDecl decl) { + return variables_.insert({decl.name(), std::move(decl)}).second; + } + + // Inserts a variable declaration into the environment of the current scope. + // Parent scopes are not searched. + void InsertOrReplaceVariable(VariableDecl decl) { + variables_[decl.name()] = std::move(decl); + } + + absl::Status CreateProtoTypeMaskRegistry( + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN(proto_type_mask_registry_, + ProtoTypeMaskRegistry::Create(descriptor_pool_.get(), + proto_type_masks)); + return absl::OkStatus(); + } + + const absl::flat_hash_map& functions() const { + return functions_; + } + + // Inserts a function declaration into the environment of the current scope if + // is is not already present. Parent scopes are not searched (allowing for + // shadowing). + // + // Returns true if the decl was inserted, false otherwise. + bool InsertFunctionIfAbsent(FunctionDecl decl) { + return functions_.insert({decl.name(), std::move(decl)}).second; + } + + void InsertOrReplaceFunction(FunctionDecl decl) { + functions_[decl.name()] = std::move(decl); + } + + // Returns the declaration for the given name if it is found in the current + // or any parent scope. + // Note: the returned declaration ptr is only valid as long as no changes are + // made to the environment. + const VariableDecl* absl_nullable LookupVariable( + absl::string_view name) const; + const FunctionDecl* absl_nullable LookupFunction( + absl::string_view name) const; + + absl::StatusOr> LookupTypeName( + absl::string_view name) const; + + absl::StatusOr> LookupStructField( + absl::string_view type_name, absl::string_view field_name) const; + + absl::StatusOr> LookupTypeConstant( + google::protobuf::Arena* absl_nonnull arena, absl::string_view type_name) const; + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { + return descriptor_pool_.get(); + } + + // Used to keep an arena alive if one was needed to allocate types. + // + // Expected to be called exactly once if at all. + void set_arena(std::shared_ptr arena) { + ABSL_DCHECK(arena_ == nullptr || arena == arena_); + arena_ = std::move(arena); + } + + // Returns the arena if one was set, nullptr otherwise. + std::shared_ptr arena() const { return arena_; } + + private: + absl::StatusOr> LookupEnumConstant( + absl::string_view type, absl::string_view value) const; + + absl_nonnull std::shared_ptr descriptor_pool_; + + // If set, an arena was needed to allocate types in the environment. + // + // The TypeCheckEnv does not otherwise use the arena, though it may be used by + // derived TypeCheckerBuilders. + absl_nullable std::shared_ptr arena_; + ExpressionContainer container_; + + // Used to resolve fields on message types. + std::shared_ptr proto_type_introspector_; + + // Maps fully qualified names to declarations. + absl::flat_hash_map variables_; + absl::flat_hash_map functions_; + + std::shared_ptr proto_type_mask_registry_; + + // Type providers for custom types. + std::vector> type_providers_; + + absl::optional expected_type_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECK_ENV_H_ diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc new file mode 100644 index 000000000..4289fb528 --- /dev/null +++ b/checker/internal/type_checker_builder_impl.cc @@ -0,0 +1,548 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_checker_builder_impl.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/type_check_env.h" +#include "checker/internal/type_checker_impl.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "common/type_kind.h" +#include "internal/lexis.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +const absl::flat_hash_map>& GetStdMacros() { + static const absl::NoDestructor< + absl::flat_hash_map>> + kStdMacros({ + {"has", {HasMacro()}}, + {"all", {AllMacro()}}, + {"exists", {ExistsMacro()}}, + {"exists_one", {ExistsOneMacro()}}, + {"filter", {FilterMacro()}}, + {"map", {Map2Macro(), Map3Macro()}}, + {"optMap", {OptMapMacro()}}, + {"optFlatMap", {OptFlatMapMacro()}}, + }); + return *kStdMacros; +} + +absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { + const auto& std_macros = GetStdMacros(); + auto it = std_macros.find(decl.name()); + if (it == std_macros.end()) { + return absl::OkStatus(); + } + const auto& macros = it->second; + for (const auto& macro : macros) { + bool macro_member = macro.is_receiver_style(); + size_t macro_arg_count = macro.argument_count() + (macro_member ? 1 : 0); + for (const auto& ovl : decl.overloads()) { + if (ovl.member() == macro_member && + ovl.args().size() == macro_arg_count) { + return absl::InvalidArgumentError(absl::StrCat( + "overload for name '", macro.function(), "' with ", macro_arg_count, + " argument(s) overlaps with predefined macro")); + } + } + } + return absl::OkStatus(); +} + +absl::Status AddWellKnownContextDeclarationVariables( + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env, bool use_json_name) { + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(field->name())) { + continue; + } + Type type = MessageTypeField(field).GetType(); + if (type.IsEnum()) { + type = IntType(); + } + absl::string_view name = field->name(); + if (use_json_name) { + name = field->json_name(); + } + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { + return absl::AlreadyExistsError( + absl::StrCat("variable '", name, + "' declared multiple times (from context declaration: '", + descriptor->full_name(), "')")); + } + } + return absl::OkStatus(); +} + +absl::Status AddContextDeclarationVariables( + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env) { + const bool use_json_name = env.proto_type_introspector().use_json_name(); + if (IsWellKnownMessageType(descriptor)) { + return AddWellKnownContextDeclarationVariables( + descriptor, context_type_fields, env, use_json_name); + } + CEL_ASSIGN_OR_RETURN(auto fields, + env.proto_type_introspector().ListFieldsForStructType( + descriptor->full_name())); + if (!fields.has_value()) { + return absl::InternalError(absl::StrCat("context declaration '", + descriptor->full_name(), + "' not found, but was expected")); + } + for (const auto& field_entry : *fields) { + Type type = field_entry.field.GetType(); + if (type.IsEnum()) { + type = IntType(); + } + + absl::string_view name = field_entry.name; + + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(name)) { + continue; + } + + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { + return absl::AlreadyExistsError( + absl::StrCat("variable '", name, + "' declared multiple times (from context declaration: '", + descriptor->full_name(), "')")); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr MergeFunctionDecls( + const FunctionDecl& existing_decl, const FunctionDecl& new_decl) { + if (existing_decl.name() != new_decl.name()) { + return absl::InternalError( + "Attempted to merge function decls with different names"); + } + + FunctionDecl merged_decl = existing_decl; + for (const auto& ovl : new_decl.overloads()) { + // We do not tolerate signature collisions, even if they are exact matches. + CEL_RETURN_IF_ERROR(merged_decl.AddOverload(ovl)); + } + + return merged_decl; +} + +std::optional FilterDecl(FunctionDecl decl, + const TypeCheckerSubset& subset) { + FunctionDecl filtered; + std::string name = decl.release_name(); + std::vector overloads = decl.release_overloads(); + for (auto& ovl : overloads) { + if (subset.should_include_overload(name, ovl)) { + absl::Status s = filtered.AddOverload(std::move(ovl)); + if (!s.ok()) { + // Should not be possible to construct the original decl in a way that + // would cause this to fail. + ABSL_LOG(DFATAL) << "failed to add overload to filtered decl: " << s; + } + } + } + if (filtered.overloads().empty()) { + return std::nullopt; + } + filtered.set_name(std::move(name)); + return filtered; +} + +absl::Status ValidateType(const Type& t, bool check_type_param_name, + int depth_limit, int remaining_depth) { + if (remaining_depth-- <= 0) { + return absl::InvalidArgumentError( + absl::StrCat("type nesting limit of ", depth_limit, " exceeded")); + } + switch (t.kind()) { + case TypeKind::kTypeParam: { + if (!check_type_param_name) { + return absl::OkStatus(); + } + const TypeParamType& type_param = t.GetTypeParam(); + if (!internal::LexisIsIdentifier(type_param.name())) { + return absl::InvalidArgumentError( + absl::StrCat("type parameter name '", type_param.name(), + "' is not a valid identifier")); + } + return absl::OkStatus(); + } + case TypeKind::kList: { + Type element_type = t.AsList()->GetElement(); + return ValidateType(element_type, check_type_param_name, depth_limit, + remaining_depth); + } + case TypeKind::kMap: { + Type key_type = t.AsMap()->GetKey(); + Type value_type = t.AsMap()->GetValue(); + CEL_RETURN_IF_ERROR(ValidateType(key_type, check_type_param_name, + depth_limit, remaining_depth)); + return ValidateType(value_type, check_type_param_name, depth_limit, + remaining_depth); + } + case TypeKind::kStruct: { + auto message_type = t.AsMessage(); + if (message_type.has_value() && !static_cast(*message_type)) { + return absl::InvalidArgumentError( + "an empty message type cannot be used in a type declaration"); + } + return absl::OkStatus(); + } + case TypeKind::kOpaque: { + for (Type type_param : t.AsOpaque()->GetParameters()) { + CEL_RETURN_IF_ERROR(ValidateType(type_param, check_type_param_name, + depth_limit, remaining_depth)); + } + return absl::OkStatus(); + } + case TypeKind::kType: { + for (Type type_param : t.AsType()->GetParameters()) { + CEL_RETURN_IF_ERROR(ValidateType(type_param, check_type_param_name, + depth_limit, remaining_depth)); + } + return absl::OkStatus(); + } + default: + break; + } + return absl::OkStatus(); +} + +absl::Status ValidateFunctionDecl(const FunctionDecl& decl, + bool check_type_param_name, int depth_limit) { + CEL_RETURN_IF_ERROR(CheckStdMacroOverlap(decl)); + for (const auto& ovl : decl.overloads()) { + CEL_RETURN_IF_ERROR(ValidateType(ovl.result(), check_type_param_name, + depth_limit, depth_limit)); + for (const auto& arg : ovl.args()) { + CEL_RETURN_IF_ERROR( + ValidateType(arg, check_type_param_name, depth_limit, depth_limit)); + } + } + return absl::OkStatus(); +} + +absl::Status ValidateVariableDecl(const VariableDecl& decl, + bool check_type_param_name, int depth_limit) { + return ValidateType(decl.type(), check_type_param_name, depth_limit, + depth_limit); +} + +} // namespace + +absl::Status TypeCheckerBuilderImpl::BuildLibraryConfig( + const CheckerLibrary& library, + TypeCheckerBuilderImpl::ConfigRecord* config) { + target_config_ = config; + absl::Cleanup reset([this] { target_config_ = &default_config_; }); + + return library.configure(*this); +} + +absl::Status TypeCheckerBuilderImpl::ApplyConfig( + TypeCheckerBuilderImpl::ConfigRecord config, + const TypeCheckerSubset* subset, TypeCheckEnv& env) { + using FunctionDeclRecord = TypeCheckerBuilderImpl::FunctionDeclRecord; + + for (auto& type_provider : config.type_providers) { + env.AddTypeProvider(std::move(type_provider)); + } + + for (FunctionDeclRecord& fn : config.functions) { + FunctionDecl decl = std::move(fn.decl); + if (subset != nullptr) { + std::optional filtered = + FilterDecl(std::move(decl), *subset); + if (!filtered.has_value()) { + continue; + } + decl = std::move(*filtered); + } + + switch (fn.add_semantic) { + case AddSemantic::kInsertIfAbsent: { + std::string name = decl.name(); + if (!env.InsertFunctionIfAbsent(std::move(decl))) { + return absl::AlreadyExistsError( + absl::StrCat("function '", name, "' declared multiple times")); + } + break; + } + case AddSemantic::kTryMerge: { + const FunctionDecl* existing_decl = env.LookupFunction(decl.name()); + FunctionDecl to_add = std::move(decl); + if (existing_decl != nullptr) { + CEL_ASSIGN_OR_RETURN( + to_add, MergeFunctionDecls(*existing_decl, std::move(to_add))); + } + env.InsertOrReplaceFunction(std::move(to_add)); + break; + } + default: + return absl::InternalError(absl::StrCat( + "unsupported function add semantic: ", fn.add_semantic)); + } + } + + for (const google::protobuf::Descriptor* context_type : config.context_types) { + CEL_RETURN_IF_ERROR(AddContextDeclarationVariables( + context_type, config.context_type_fields, env)); + } + + for (VariableDeclRecord& var : config.variables) { + switch (var.add_semantic) { + case AddSemantic::kInsertIfAbsent: { + if (!env.InsertVariableIfAbsent(var.decl)) { + return absl::AlreadyExistsError(absl::StrCat( + "variable '", var.decl.name(), "' declared multiple times")); + } + break; + } + case AddSemantic::kInsertOrReplace: { + env.InsertOrReplaceVariable(var.decl); + break; + } + default: + return absl::InternalError(absl::StrCat( + "unsupported variable add semantic: ", var.add_semantic)); + } + } + + CEL_RETURN_IF_ERROR(env.CreateProtoTypeMaskRegistry(config.proto_type_masks)); + + return absl::OkStatus(); +} + +absl::StatusOr> TypeCheckerBuilderImpl::Build() { + TypeCheckEnv env(template_env_); + CEL_RETURN_IF_ERROR(ConfigureTypeCheckEnv(env)); + return std::make_unique(std::move(env), + options_); +} + +absl::Status TypeCheckerBuilderImpl::ConfigureTypeCheckEnv(TypeCheckEnv& env) { + if (expression_container_.has_value()) { + env.set_container(*expression_container_); + } + if (expected_type_.has_value()) { + env.set_expected_type(*expected_type_); + } + + ConfigRecord anonymous_config; + std::vector configs; + for (const auto& library : libraries_) { + ConfigRecord* config = &anonymous_config; + if (!library.id.empty()) { + configs.emplace_back(); + config = &configs.back(); + config->id = library.id; + } + CEL_RETURN_IF_ERROR(BuildLibraryConfig(library, config)); + } + + env.proto_type_introspector().set_use_json_name( + options_.use_json_field_names); + + for (const ConfigRecord& config : configs) { + TypeCheckerSubset* subset = nullptr; + if (!config.id.empty()) { + auto it = subsets_.find(config.id); + if (it != subsets_.end()) { + subset = &it->second; + } + } + CEL_RETURN_IF_ERROR(ApplyConfig(std::move(config), subset, env)); + } + CEL_RETURN_IF_ERROR(ApplyConfig(std::move(anonymous_config), + /*subset=*/nullptr, env)); + + CEL_RETURN_IF_ERROR(ApplyConfig(default_config_, /*subset=*/nullptr, env)); + if (type_arena_ != nullptr) { + env.set_arena(type_arena_); + } + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::AddLibrary(CheckerLibrary library) { + if (!library.id.empty() && !library_ids_.insert(library.id).second) { + return absl::AlreadyExistsError( + absl::StrCat("library '", library.id, "' already exists")); + } + if (!library.configure) { + return absl::OkStatus(); + } + + libraries_.push_back(std::move(library)); + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::AddLibrarySubset( + TypeCheckerSubset subset) { + if (subset.library_id.empty()) { + return absl::InvalidArgumentError( + "library_id must not be empty for subset"); + } + std::string id = subset.library_id; + if (!subsets_.insert({id, std::move(subset)}).second) { + return absl::AlreadyExistsError( + absl::StrCat("library subset for '", id, "' already exists")); + } + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) { + CEL_RETURN_IF_ERROR( + ValidateVariableDecl(decl, options_.enable_type_parameter_name_validation, + options_.max_type_decl_nesting)); + target_config_->variables.push_back({decl, AddSemantic::kInsertIfAbsent}); + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::AddOrReplaceVariable( + const VariableDecl& decl) { + CEL_RETURN_IF_ERROR( + ValidateVariableDecl(decl, options_.enable_type_parameter_name_validation, + options_.max_type_decl_nesting)); + target_config_->variables.push_back({decl, AddSemantic::kInsertOrReplace}); + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( + absl::string_view type) { + const google::protobuf::Descriptor* desc = + template_env_.descriptor_pool()->FindMessageTypeByName(type); + if (desc == nullptr) { + return absl::NotFoundError( + absl::StrCat("context declaration '", type, "' not found")); + } + + if (IsWellKnownMessageType(desc) && + !options_.allow_well_known_type_context_declarations) { + return absl::InvalidArgumentError( + absl::StrCat("context declaration '", type, "' is not a struct")); + } + + for (const auto* context_type : target_config_->context_types) { + if (context_type->full_name() == desc->full_name()) { + return absl::AlreadyExistsError( + absl::StrCat("context declaration '", type, "' already exists")); + } + } + + target_config_->context_types.push_back(desc); + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) { + if (field_paths.empty()) { + return absl::InvalidArgumentError("field paths cannot be the empty set"); + } + + ProtoTypeMask proto_type_mask(std::string(type), field_paths); + target_config_->proto_type_masks.push_back(proto_type_mask); + + CEL_RETURN_IF_ERROR(AddContextDeclaration(type)); + CEL_ASSIGN_OR_RETURN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(template_env_.descriptor_pool())); + target_config_->context_type_fields.insert({type, std::move(field_names)}); + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { + CEL_RETURN_IF_ERROR( + ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, + options_.max_type_decl_nesting)); + target_config_->functions.push_back( + {std::move(decl), AddSemantic::kInsertIfAbsent}); + return absl::OkStatus(); +} + +absl::Status TypeCheckerBuilderImpl::MergeFunction(const FunctionDecl& decl) { + CEL_RETURN_IF_ERROR( + ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, + options_.max_type_decl_nesting)); + target_config_->functions.push_back( + {std::move(decl), AddSemantic::kTryMerge}); + return absl::OkStatus(); +} + +void TypeCheckerBuilderImpl::AddTypeProvider( + std::unique_ptr provider) { + target_config_->type_providers.push_back(std::move(provider)); +} + +void TypeCheckerBuilderImpl::set_container(absl::string_view container) { + if (!expression_container_.has_value()) { + expression_container_.emplace(); + } + expression_container_->SetContainer(container).IgnoreError(); +} + +void TypeCheckerBuilderImpl::SetExpressionContainer( + ExpressionContainer container) { + expression_container_ = std::move(container); +} + +void TypeCheckerBuilderImpl::SetExpectedType(const Type& type) { + expected_type_ = type; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h new file mode 100644 index 000000000..9895a8aee --- /dev/null +++ b/checker/internal/type_checker_builder_impl.h @@ -0,0 +1,170 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_BUILDER_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_BUILDER_IMPL_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "checker/checker_options.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/type_check_env.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Builder for TypeChecker instances. +class TypeCheckerBuilderImpl : public TypeCheckerBuilder { + public: + TypeCheckerBuilderImpl( + absl_nonnull std::shared_ptr + descriptor_pool, + const CheckerOptions& options) + : options_(options), + target_config_(&default_config_), + template_env_(std::move(descriptor_pool)) {} + + // Constructor for building an extended TypeChecker. + explicit TypeCheckerBuilderImpl(const CheckerOptions& options, + const TypeCheckEnv& template_env) + : options_(options), + target_config_(&default_config_), + template_env_(template_env) { + if (auto arena = template_env_.arena(); arena != nullptr) { + type_arena_ = std::move(arena); + } + } + + // Move only. + TypeCheckerBuilderImpl(const TypeCheckerBuilderImpl&) = delete; + TypeCheckerBuilderImpl(TypeCheckerBuilderImpl&&) = default; + TypeCheckerBuilderImpl& operator=(const TypeCheckerBuilderImpl&) = delete; + TypeCheckerBuilderImpl& operator=(TypeCheckerBuilderImpl&&) = default; + + absl::StatusOr> Build() override; + + absl::Status AddLibrary(CheckerLibrary library) override; + absl::Status AddLibrarySubset(TypeCheckerSubset subset) override; + + absl::Status AddVariable(const VariableDecl& decl) override; + absl::Status AddOrReplaceVariable(const VariableDecl& decl) override; + absl::Status AddContextDeclaration(absl::string_view type) override; + absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) override; + + absl::Status AddFunction(const FunctionDecl& decl) override; + absl::Status MergeFunction(const FunctionDecl& decl) override; + + void SetExpectedType(const Type& type) override; + + void AddTypeProvider(std::unique_ptr provider) override; + + void set_container(absl::string_view container) override; + + void SetExpressionContainer( + ExpressionContainer expression_container) override; + + const CheckerOptions& options() const override { return options_; } + + google::protobuf::Arena* absl_nonnull arena() override { + if (type_arena_ == nullptr) { + type_arena_ = std::make_shared(); + } + return type_arena_.get(); + } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const override { + return template_env_.descriptor_pool(); + } + + private: + // Sematic for adding a possibly duplicated declaration. + enum class AddSemantic { + kInsertIfAbsent, + kInsertOrReplace, + // Attempts to merge with any existing overloads for the same function. + // Will fail if any of the IDs or signatures collide. + kTryMerge, + }; + + struct VariableDeclRecord { + VariableDecl decl; + AddSemantic add_semantic; + }; + + struct FunctionDeclRecord { + FunctionDecl decl; + AddSemantic add_semantic; + }; + + // A record of configuration calls. + // Used to replay the configuration in calls to Build(). + struct ConfigRecord { + std::string id = ""; + std::vector variables; + std::vector functions; + std::vector> type_providers; + std::vector context_types; + // Maps context type names to fields names to add as variables. + // Only includes context types that are defined with proto type masks. + absl::flat_hash_map> + context_type_fields; + std::vector proto_type_masks; + }; + + absl::Status BuildLibraryConfig(const CheckerLibrary& library, + ConfigRecord* absl_nonnull config); + + absl::Status ApplyConfig(ConfigRecord config, const TypeCheckerSubset* subset, + TypeCheckEnv& env); + + absl::Status ConfigureTypeCheckEnv(TypeCheckEnv& env); + + CheckerOptions options_; + // Default target for configuration changes. Used for direct calls to + // AddVariable, AddFunction, etc. + ConfigRecord default_config_; + // Active target for configuration changes. + // This is used to track which library the change is made on behalf of. + ConfigRecord* absl_nonnull target_config_; + TypeCheckEnv template_env_; + std::shared_ptr type_arena_; + std::vector libraries_; + absl::flat_hash_map subsets_; + absl::flat_hash_set library_ids_; + absl::optional expression_container_; + absl::optional expected_type_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc new file mode 100644 index 000000000..fa7f80960 --- /dev/null +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -0,0 +1,659 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_checker_builder_impl.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "checker/checker_options.h" +#include "checker/internal/test_ast_helpers.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; + +struct ContextDeclsTestCase { + std::string expr; + TypeSpec expected_type; +}; + +class ContextDeclsFieldsDefinedTest + : public testing::TestWithParam {}; + +TEST_P(ContextDeclsFieldsDefinedTest, ContextDeclsFieldsDefined) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(GetParam().expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type); +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypes, ContextDeclsFieldsDefinedTest, + testing::Values( + ContextDeclsTestCase{"single_int64", TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"single_uint32", TypeSpec(PrimitiveType::kUint64)}, + ContextDeclsTestCase{"single_double", TypeSpec(PrimitiveType::kDouble)}, + ContextDeclsTestCase{"single_string", TypeSpec(PrimitiveType::kString)}, + ContextDeclsTestCase{"single_any", TypeSpec(WellKnownTypeSpec::kAny)}, + ContextDeclsTestCase{"single_duration", + TypeSpec(WellKnownTypeSpec::kDuration)}, + ContextDeclsTestCase{ + "single_bool_wrapper", + TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + ContextDeclsTestCase{ + "list_value", + TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec())))}, + ContextDeclsTestCase{ + "standalone_message", + TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, + ContextDeclsTestCase{"standalone_enum", + TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"repeated_bytes", + TypeSpec(ListTypeSpec(std::make_unique( + PrimitiveType::kBytes)))}, + ContextDeclsTestCase{ + "repeated_nested_message", + TypeSpec(ListTypeSpec(std::make_unique(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, + ContextDeclsTestCase{ + "map_int32_timestamp", + TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), + std::make_unique(WellKnownTypeSpec::kTimestamp)))}, + ContextDeclsTestCase{ + "single_struct", + TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))})); + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnEmptyFieldPaths) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {}), + StatusIs(absl::StatusCode::kInvalidArgument, + "field paths cannot be the empty set")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnUnknownFieldPath) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"}), + StatusIs(absl::StatusCode::kInvalidArgument, + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'")); +} + +class ContextDeclsWithProtoTypeMaskFieldsDefinedTest + : public testing::TestWithParam {}; + +std::string LogFieldName(absl::string_view field_name, absl::string_view expr) { + return absl::StrCat("field_name: ", field_name, ", expr: ", expr); +} + +TEST_P(ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + ContextDeclsWithProtoTypeMaskFieldsDefined) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {GetParam().expr}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + std::vector field_names = { + "single_int64", "single_uint32", "single_double", + "single_string", "single_any", "single_duration", + "single_bool_wrapper", "list_value", "standalone_message", + "standalone_enum", "repeated_bytes", "repeated_nested_message", + "map_int32_timestamp", "single_struct"}; + for (auto& field_name : field_names) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(field_name)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + if (field_name == GetParam().expr) { + // The field name that is part of the proto type mask is visible. + ASSERT_TRUE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type) + << LogFieldName(field_name, GetParam().expr); + } else { + // The field names that are not part of the proto type mask are not + // visible. + EXPECT_FALSE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypes, ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + testing::Values( + ContextDeclsTestCase{"single_int64", TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"single_uint32", TypeSpec(PrimitiveType::kUint64)}, + ContextDeclsTestCase{"single_double", TypeSpec(PrimitiveType::kDouble)}, + ContextDeclsTestCase{"single_string", TypeSpec(PrimitiveType::kString)}, + ContextDeclsTestCase{"single_any", TypeSpec(WellKnownTypeSpec::kAny)}, + ContextDeclsTestCase{"single_duration", + TypeSpec(WellKnownTypeSpec::kDuration)}, + ContextDeclsTestCase{ + "single_bool_wrapper", + TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + ContextDeclsTestCase{ + "list_value", + TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec())))}, + ContextDeclsTestCase{ + "standalone_message", + TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, + ContextDeclsTestCase{"standalone_enum", + TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"repeated_bytes", + TypeSpec(ListTypeSpec(std::make_unique( + PrimitiveType::kBytes)))}, + ContextDeclsTestCase{ + "repeated_nested_message", + TypeSpec(ListTypeSpec(std::make_unique(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, + ContextDeclsTestCase{ + "map_int32_timestamp", + TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), + std::make_unique(WellKnownTypeSpec::kTimestamp)))}, + ContextDeclsTestCase{ + "single_struct", + TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))})); + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAccess) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("payload.standalone_message.bb")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int32")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAssignment) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes.NestedMessage{bb: 12345})")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int32: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int64: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + +TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + EXPECT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + StatusIs(absl::StatusCode::kAlreadyExists, + "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " + "already exists")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnDuplicateContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"}), + IsOk()); + EXPECT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + StatusIs(absl::StatusCode::kAlreadyExists, + "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " + "already exists")); +} + +TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclaration("com.example.UnknownType"), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.UnknownType' not found")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnContextDeclarationNotFound) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask("com.example.UnknownType", + {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.UnknownType' not found")); +} + +TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclaration("google.protobuf.Timestamp"), + StatusIs( + absl::StatusCode::kInvalidArgument, + "context declaration 'google.protobuf.Timestamp' is not a struct")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnNonStructMessageType) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Timestamp", {"any_field_name"}), + StatusIs( + absl::StatusCode::kInvalidArgument, + "context declaration 'google.protobuf.Timestamp' is not a struct")); +} + +TEST(ContextDeclsTest, CustomStructNotSupported) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + class MyTypeProvider : public cel::TypeIntrospector { + public: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override { + if (name == "com.example.MyStruct") { + return common_internal::MakeBasicStructType("com.example.MyStruct"); + } + return std::nullopt; + } + }; + + builder.AddTypeProvider(std::make_unique()); + + EXPECT_THAT(builder.AddContextDeclaration("com.example.MyStruct"), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.MyStruct' not found")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, CustomStructNotSupported) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + class MyTypeProvider : public cel::TypeIntrospector { + public: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override { + if (name == "com.example.MyStruct") { + return common_internal::MakeBasicStructType("com.example.MyStruct"); + } + return std::nullopt; + } + }; + + builder.AddTypeProvider(std::make_unique()); + + EXPECT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "com.example.MyStruct", {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.MyStruct' not found")); +} + +TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto2.TestAllTypes"), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + ErrorOnOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + NonOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.NestedTestAllTypes", + {"payload.single_int64"}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int32")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); +} + +TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int64' declared multiple times")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int64' declared multiple times")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, NonOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), IsOk()); +} + +TEST(TypeCheckerBuilderImplTest, + InvalidTypeParamNameVariableValidationDisabled) { + CheckerOptions options; + options.enable_type_parameter_name_validation = false; + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + options); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("x", TypeParamType(""))), + IsOk()); + ASSERT_THAT(builder.AddOrReplaceVariable( + MakeVariableDecl("x", TypeParamType("T% foo"))), + IsOk()); +} + +TEST(TypeCheckerBuilderImplTest, ErrorOnUnspecifiedMessageType) { + CheckerOptions options; + options.enable_type_parameter_name_validation = true; + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + options); + ASSERT_THAT( + builder.AddVariable(MakeVariableDecl("x", MessageType())), + StatusIs(absl::StatusCode::kInvalidArgument, + "an empty message type cannot be used in a type declaration")); +} + +TEST(TypeCheckerBuilderImplTest, ErrorOnInvalidTypeParamNameVariable) { + CheckerOptions options; + options.enable_type_parameter_name_validation = true; + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + options); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("x", TypeParamType(""))), + StatusIs(absl::StatusCode::kInvalidArgument, + "type parameter name '' is not a valid identifier")); + ASSERT_THAT( + builder.AddOrReplaceVariable( + MakeVariableDecl("x", TypeParamType("T% foo"))), + StatusIs(absl::StatusCode::kInvalidArgument, + "type parameter name 'T% foo' is not a valid identifier")); +} + +TEST(TypeCheckerBuilderImplTest, ErrorOnTooDeepTypeNestingVariable) { + CheckerOptions options; + options.max_type_decl_nesting = 2; + google::protobuf::Arena arena; + + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + options); + ASSERT_THAT(builder.AddVariable( + MakeVariableDecl("x", TypeType(&arena, TypeParamType("T")))), + IsOk()); + ASSERT_THAT( + builder.AddOrReplaceVariable(MakeVariableDecl( + "x", TypeType(&arena, TypeType(&arena, TypeParamType("T% foo"))))), + StatusIs(absl::StatusCode::kInvalidArgument, + "type nesting limit of 2 exceeded")); +} + +TEST(TypeCheckerBuilderImplTest, ErrorOnInvalidTypeParamNameFunction) { + CheckerOptions options; + options.enable_type_parameter_name_validation = true; + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + options); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl( + "type2", + MakeOverloadDecl("type2", TypeType(&arena, TypeParamType("")), + TypeParamType("")))); + ASSERT_THAT(builder.AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "type parameter name '' is not a valid identifier")); +} + +TEST(TypeCheckerBuilderImplTest, ErrorOnTooDeepTypeNestingFunction) { + CheckerOptions options; + options.max_type_decl_nesting = 2; + google::protobuf::Arena arena; + + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + options); + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl( + "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); + ASSERT_THAT(builder.AddFunction(fn_decl), IsOk()); + + Type list_type = ListType(&arena, ListType(&arena, IntType())); + + ASSERT_OK_AND_ASSIGN( + fn_decl, + MakeFunctionDecl("add", MakeOverloadDecl("add_list_list_int", list_type, + list_type, list_type))); + + ASSERT_THAT(builder.MergeFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "type nesting limit of 2 exceeded")); +} + +TEST(TypeCheckerBuilderImplTest, ReplaceVariable) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + ASSERT_THAT(builder.AddOrReplaceVariable( + MakeVariableDecl("single_int64", StringType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int64")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + const auto& checked_ast = *result.GetAst(); + + EXPECT_EQ(checked_ast.GetReturnType(), TypeSpec(PrimitiveType::kString)); +} + +TEST(TypeCheckerBuilderImplTest, LazyArenaInitialization) { + auto builder = std::make_unique( + internal::GetSharedTestingDescriptorPool(), CheckerOptions{}); + + ASSERT_THAT(builder->AddLibrary(CheckerLibrary{ + .id = "test_lib", + .configure = [](TypeCheckerBuilder& builder) -> absl::Status { + auto l = ListType(builder.arena(), IntType()); + return builder.AddVariable(MakeVariableDecl("foo", l)); + }, + }), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + builder.reset(); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + const auto& checked_ast = *result.GetAst(); + + EXPECT_EQ(checked_ast.GetReturnType(), + TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kInt64)))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc new file mode 100644 index 000000000..bca187417 --- /dev/null +++ b/checker/internal/type_checker_impl.cc @@ -0,0 +1,1443 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_checker_impl.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/checker_options.h" +#include "checker/internal/namespace_generator.h" +#include "checker/internal/type_check_env.h" +#include "checker/internal/type_checker_builder_impl.h" +#include "checker/internal/type_inference_context.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor.h" +#include "common/ast_visitor_base.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/format_type_name.h" +#include "common/standard_definitions.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { +namespace { + +bool MatchesBlock(const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call = expr.call_expr(); + return call.function() == "cel.@block" && call.args().size() == 2 && + call.args()[0].has_list_expr(); +} + +using AstType = cel::TypeSpec; +using Severity = TypeCheckIssue::Severity; + +constexpr const char kOptionalSelect[] = "_?._"; + +std::string FormatCandidate(absl::Span qualifiers) { + return absl::StrJoin(qualifiers, "."); +} + +// Flatten the type to the AST type representation to remove any lifecycle +// dependency between the type check environment and the AST. +// +// TODO(uncreated-issue/72): It may be better to do this at the point of serialization +// in the future, but requires corresponding change for the runtime to correctly +// rehydrate the serialized Ast. +absl::StatusOr FlattenType(const Type& type); + +absl::StatusOr FlattenAbstractType(const OpaqueType& type) { + std::vector parameter_types; + parameter_types.reserve(type.GetParameters().size()); + for (const auto& param : type.GetParameters()) { + CEL_ASSIGN_OR_RETURN(auto param_type, FlattenType(param)); + parameter_types.push_back(std::move(param_type)); + } + + return AstType( + AbstractType(std::string(type.name()), std::move(parameter_types))); +} + +absl::StatusOr FlattenMapType(const MapType& type) { + CEL_ASSIGN_OR_RETURN(auto key, FlattenType(type.key())); + CEL_ASSIGN_OR_RETURN(auto value, FlattenType(type.value())); + + return AstType(MapTypeSpec(std::make_unique(std::move(key)), + std::make_unique(std::move(value)))); +} + +absl::StatusOr FlattenListType(const ListType& type) { + CEL_ASSIGN_OR_RETURN(auto elem, FlattenType(type.element())); + + return AstType(ListTypeSpec(std::make_unique(std::move(elem)))); +} + +absl::StatusOr FlattenMessageType(const StructType& type) { + return AstType(MessageTypeSpec(std::string(type.name()))); +} + +absl::StatusOr FlattenTypeType(const TypeType& type) { + if (type.GetParameters().size() > 1) { + return absl::InternalError( + absl::StrCat("Unsupported type: ", type.DebugString())); + } + if (type.GetParameters().empty()) { + return AstType(std::make_unique()); + } + CEL_ASSIGN_OR_RETURN(auto param, FlattenType(type.GetParameters()[0])); + return AstType(std::make_unique(std::move(param))); +} + +absl::StatusOr FlattenType(const Type& type) { + switch (type.kind()) { + case TypeKind::kDyn: + return AstType(DynTypeSpec()); + case TypeKind::kError: + return AstType(ErrorTypeSpec()); + case TypeKind::kNull: + return AstType(NullTypeSpec()); + case TypeKind::kBool: + return AstType(PrimitiveType::kBool); + case TypeKind::kInt: + return AstType(PrimitiveType::kInt64); + case TypeKind::kEnum: + return AstType(PrimitiveType::kInt64); + case TypeKind::kUint: + return AstType(PrimitiveType::kUint64); + case TypeKind::kDouble: + return AstType(PrimitiveType::kDouble); + case TypeKind::kString: + return AstType(PrimitiveType::kString); + case TypeKind::kBytes: + return AstType(PrimitiveType::kBytes); + case TypeKind::kDuration: + return AstType(WellKnownTypeSpec::kDuration); + case TypeKind::kTimestamp: + return AstType(WellKnownTypeSpec::kTimestamp); + case TypeKind::kStruct: + return FlattenMessageType(type.GetStruct()); + case TypeKind::kList: + return FlattenListType(type.GetList()); + case TypeKind::kMap: + return FlattenMapType(type.GetMap()); + case TypeKind::kOpaque: + return FlattenAbstractType(type.GetOpaque()); + case TypeKind::kBoolWrapper: + return AstType(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return AstType(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return AstType(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kTypeParam: + // Convert any remaining free type params to dyn. + return AstType(DynTypeSpec()); + case TypeKind::kType: + return FlattenTypeType(type.GetType()); + case TypeKind::kAny: + return AstType(WellKnownTypeSpec::kAny); + default: + return absl::InternalError( + absl::StrCat("unsupported type encountered making AST serializable: ", + type.DebugString())); + } +} + +class ResolveVisitor : public AstVisitorBase { + public: + struct FunctionResolution { + const FunctionDecl* decl; + bool namespace_rewrite; + }; + + struct AttributeResolution { + const VariableDecl* decl; + bool requires_disambiguation; + bool local; + }; + + ResolveVisitor(NamespaceGenerator namespace_generator, + const TypeCheckEnv& env, const Ast& ast, + TypeInferenceContext& inference_context, + std::vector& issues, + google::protobuf::Arena* absl_nonnull arena) + : namespace_generator_(std::move(namespace_generator)), + env_(&env), + inference_context_(&inference_context), + issues_(&issues), + ast_(&ast), + root_scope_(), + arena_(arena), + current_scope_(&root_scope_) {} + + void PreVisitExpr(const Expr& expr) override { + expr_stack_.push_back(&expr); + if (expr_stack_.size() == 1 && MatchesBlock(expr)) { + ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2); + ABSL_DCHECK(block_init_list_ == nullptr); + block_init_list_ = &expr.call_expr().args()[0]; + } + } + + void PostVisitExpr(const Expr& expr) override { + if (expr_stack_.empty()) { + return; + } + expr_stack_.pop_back(); + if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) { + HandleBlockIndex(&expr); + } + } + + void PostVisitConst(const Expr& expr, const Constant& constant) override; + + void PreVisitComprehension(const Expr& expr, + const ComprehensionExpr& comprehension) override; + + void PostVisitComprehension(const Expr& expr, + const ComprehensionExpr& comprehension) override; + + void PostVisitMap(const Expr& expr, const MapExpr& map) override; + + void PostVisitList(const Expr& expr, const ListExpr& list) override; + + void PreVisitComprehensionSubexpression( + const Expr& expr, const ComprehensionExpr& comprehension, + ComprehensionArg comprehension_arg) override; + + void PostVisitComprehensionSubexpression( + const Expr& expr, const ComprehensionExpr& comprehension, + ComprehensionArg comprehension_arg) override; + + void PostVisitIdent(const Expr& expr, const IdentExpr& ident) override; + + void PostVisitSelect(const Expr& expr, const SelectExpr& select) override; + + void PostVisitCall(const Expr& expr, const CallExpr& call) override; + + void PostVisitStruct(const Expr& expr, + const StructExpr& create_struct) override; + + // Accessors for resolved values. + const absl::flat_hash_map& functions() + const { + return functions_; + } + + const absl::flat_hash_map& attributes() + const { + return attributes_; + } + + const absl::flat_hash_map& struct_types() const { + return struct_types_; + } + + const absl::flat_hash_map& types() const { return types_; } + + const absl::Status& status() const { return status_; } + + int error_count() const { return error_count_; } + + void AssertExpectedType(const Expr& expr, const Type& expected_type) { + Type observed = GetDeducedType(&expr); + if (!inference_context_->IsAssignable(observed, expected_type)) { + ReportTypeMismatch(expr.id(), expected_type, observed); + } + } + + private: + struct ComprehensionScope { + const Expr* comprehension_expr; + const VariableScope* parent; + VariableScope* accu_scope; + VariableScope* iter_scope; + }; + + struct FunctionOverloadMatch { + // Overall result type. + // If resolution is incomplete, this will be DynType. + Type result_type; + // A new declaration with the narrowed overload candidates. + // Owned by the Check call scoped arena. + const FunctionDecl* decl; + }; + + void ResolveSimpleIdentifier(const Expr& expr, absl::string_view name); + + void ResolveQualifiedIdentifier(const Expr& expr, + absl::Span qualifiers); + + // Resolves the function call shape (i.e. the number of arguments and call + // style) for the given function call. + const FunctionDecl* ResolveFunctionCallShape(const Expr& expr, + absl::string_view function_name, + int arg_count, bool is_receiver); + + // Resolves a global identifier (i.e. declared in the CEL environment). + const VariableDecl* absl_nullable LookupGlobalIdentifier( + absl::string_view name); + + // Resolves a local identifier (i.e. a bind or comprehension var). + const VariableDecl* absl_nullable LookupLocalIdentifier( + absl::string_view name); + + // Resolves the applicable function overloads for the given function call. + // + // If found, assigns a new function decl with the resolved overloads. + void ResolveFunctionOverloads(const Expr& expr, const FunctionDecl& decl, + int arg_count, bool is_receiver, + bool is_namespaced); + + void ResolveSelectOperation(const Expr& expr, absl::string_view field, + const Expr& operand); + + void ReportIssue(TypeCheckIssue issue) { + if (issue.severity() == Severity::kError) { + error_count_++; + } + issues_->push_back(std::move(issue)); + } + + void ReportMissingReference(const Expr& expr, absl::string_view name) { + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(expr.id()), + absl::StrCat("undeclared reference to '", name, "' (in container '", + env_->container().container(), "')"))); + } + + void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, + absl::string_view struct_name) { + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(expr_id), + absl::StrCat("undefined field '", field_name, "' not found in struct '", + struct_name, "'"))); + } + + void ReportTypeMismatch(int64_t expr_id, const Type& expected, + const Type& actual) { + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(expr_id), + absl::StrCat("expected type '", + FormatTypeName(inference_context_->FinalizeType(expected)), + "' but found '", + FormatTypeName(inference_context_->FinalizeType(actual)), + "'"))); + } + + absl::Status CheckFieldAssignments(const Expr& expr, + const StructExpr& create_struct, + Type struct_type, + absl::string_view resolved_name) { + for (const auto& field : create_struct.fields()) { + const Expr* value = &field.value(); + Type value_type = GetDeducedType(value); + + // Lookup message type by name to support WellKnownType creation. + CEL_ASSIGN_OR_RETURN( + std::optional field_info, + env_->LookupStructField(resolved_name, field.name())); + if (!field_info.has_value()) { + ReportUndefinedField(field.id(), field.name(), resolved_name); + continue; + } + Type field_type = field_info->GetType(); + if (field.optional()) { + field_type = OptionalType(arena_, field_type); + } + if (!inference_context_->IsAssignable(value_type, field_type)) { + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(field.id()), + absl::StrCat( + "expected type of field '", field_info->name(), "' is '", + FormatTypeName(inference_context_->FinalizeType(field_type)), + "' but provided type is '", + FormatTypeName(inference_context_->FinalizeType(value_type)), + "'"))); + continue; + } + } + + return absl::OkStatus(); + } + + std::optional CheckFieldType(int64_t expr_id, const Type& operand_type, + absl::string_view field_name); + + void HandleOptSelect(const Expr& expr); + void HandleBlockIndex(const Expr* expr); + + // Get the assigned type of the given subexpression. Should only be called if + // the given subexpression is expected to have already been checked. + // + // If unknown, returns DynType as a placeholder and reports an error. + // Whether or not the subexpression is valid for the checker configuration, + // the type checker should have assigned a type (possibly ErrorType). If there + // is no assigned type, the type checker failed to handle the subexpression + // and should not attempt to continue type checking. + Type GetDeducedType(const Expr* expr) { + auto iter = types_.find(expr); + if (iter != types_.end()) { + return iter->second; + } + status_.Update(absl::InvalidArgumentError( + absl::StrCat("Could not deduce type for expression id: ", expr->id()))); + return DynType(); + } + + NamespaceGenerator namespace_generator_; + const TypeCheckEnv* absl_nonnull env_; + TypeInferenceContext* absl_nonnull inference_context_; + std::vector* absl_nonnull issues_; + const Ast* absl_nonnull ast_; + VariableScope root_scope_; + google::protobuf::Arena* absl_nonnull arena_; + + // state tracking for the traversal. + const VariableScope* current_scope_; + std::vector expr_stack_; + absl::flat_hash_map> + maybe_namespaced_functions_; + const Expr* block_init_list_ = nullptr; + // Select operations that need to be resolved outside of the traversal. + // These are handled separately to disambiguate between namespaces and field + // accesses + absl::flat_hash_set deferred_select_operations_; + std::vector> comprehension_vars_; + std::vector comprehension_scopes_; + absl::Status status_; + int error_count_ = 0; + + // References that were resolved and may require AST rewrites. + absl::flat_hash_map functions_; + absl::flat_hash_map attributes_; + absl::flat_hash_map struct_types_; + + absl::flat_hash_map types_; +}; + +void ResolveVisitor::PostVisitIdent(const Expr& expr, const IdentExpr& ident) { + if (expr_stack_.size() == 1) { + ResolveSimpleIdentifier(expr, ident.name()); + return; + } + + // Walk up the stack to find the qualifiers. + // + // If the identifier is the target of a receiver call, then note + // the function so we can disambiguate namespaced functions later. + int stack_pos = expr_stack_.size() - 1; + std::vector qualifiers; + qualifiers.push_back(ident.name()); + const Expr* receiver_call = nullptr; + const Expr* root_candidate = expr_stack_[stack_pos]; + + // Try to identify the root of the select chain, possibly as the receiver of + // a function call. + while (stack_pos > 0) { + --stack_pos; + const Expr* parent = expr_stack_[stack_pos]; + + if (parent->has_call_expr() && + (&parent->call_expr().target() == root_candidate)) { + receiver_call = parent; + break; + } else if (!parent->has_select_expr()) { + break; + } + + qualifiers.push_back(parent->select_expr().field()); + deferred_select_operations_.insert(parent); + root_candidate = parent; + if (parent->select_expr().test_only()) { + break; + } + } + + if (receiver_call == nullptr) { + ResolveQualifiedIdentifier(*root_candidate, qualifiers); + } else { + maybe_namespaced_functions_[receiver_call] = std::move(qualifiers); + } +} + +void ResolveVisitor::PostVisitConst(const Expr& expr, + const Constant& constant) { + switch (constant.kind().index()) { + case ConstantKindIndexOf(): + types_[&expr] = NullType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = BoolType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = IntType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = UintType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = DoubleType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = BytesType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = StringType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = DurationType(); + break; + case ConstantKindIndexOf(): + types_[&expr] = TimestampType(); + break; + default: + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(expr.id()), + absl::StrCat("unsupported constant type: ", + constant.kind().index()))); + types_[&expr] = ErrorType(); + break; + } +} + +bool IsSupportedKeyType(const Type& type) { + switch (type.kind()) { + case TypeKind::kBool: + case TypeKind::kInt: + case TypeKind::kUint: + case TypeKind::kString: + case TypeKind::kDyn: + return true; + default: + return false; + } +} + +void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { + // Roughly follows map type inferencing behavior in Go. + // + // We try to infer the type of the map if all of the keys or values are + // homogeneously typed, otherwise assume the type parameter is dyn (defer to + // runtime for enforcing type compatibility). + // + // TODO(uncreated-issue/72): Widening behavior is not well documented for map / list + // construction in the spec and is a bit inconsistent between implementations. + // + // In the future, we should probably default enforce homogeneously + // typed maps unless tagged as JSON (and the values are assignable to + // the JSON value union type). + + Type overall_key_type = + inference_context_->InstantiateTypeParams(TypeParamType("K")); + Type overall_value_type = + inference_context_->InstantiateTypeParams(TypeParamType("V")); + + auto assignability_context = inference_context_->CreateAssignabilityContext(); + for (const auto& entry : map.entries()) { + const Expr* key = &entry.key(); + Type key_type = GetDeducedType(key); + if (!IsSupportedKeyType(key_type)) { + // The Go type checker implementation can allow any type as a map key, but + // per the spec this should be limited to the types listed in + // IsSupportedKeyType. + // + // To match the Go implementation, we just warn here, but in the future + // we should consider making this an error. + ReportIssue(TypeCheckIssue( + Severity::kWarning, ast_->ComputeSourceLocation(key->id()), + absl::StrCat( + "unsupported map key type: ", + FormatTypeName(inference_context_->FinalizeType(key_type))))); + } + + if (!assignability_context.IsAssignable(key_type, overall_key_type)) { + overall_key_type = DynType(); + } + } + + if (!overall_key_type.IsDyn()) { + assignability_context.UpdateInferredTypeAssignments(); + } + + assignability_context.Reset(); + for (const auto& entry : map.entries()) { + const Expr* value = &entry.value(); + Type value_type = GetDeducedType(value); + if (entry.optional()) { + if (value_type.IsOptional()) { + value_type = value_type.GetOptional().GetParameter(); + } else { + ReportTypeMismatch(entry.value().id(), OptionalType(arena_, value_type), + value_type); + continue; + } + } + if (!inference_context_->IsAssignable(value_type, overall_value_type)) { + overall_value_type = DynType(); + } + } + + if (!overall_value_type.IsDyn()) { + assignability_context.UpdateInferredTypeAssignments(); + } + + types_[&expr] = inference_context_->FullySubstitute( + MapType(arena_, overall_key_type, overall_value_type)); +} + +void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { + if (&expr == block_init_list_) { + // Don't try to coalesce list type here because it can influence the + // resolved type of the list elements. cel.@block is always list and + // the elements are treated independently at runtime. + types_[&expr] = ListType(); + return; + } + + // Follows list type inferencing behavior in Go (see map comments above). + Type overall_elem_type = + inference_context_->InstantiateTypeParams(TypeParamType("E")); + auto assignability_context = inference_context_->CreateAssignabilityContext(); + for (const auto& element : list.elements()) { + const Expr* value = &element.expr(); + Type value_type = GetDeducedType(value); + if (element.optional()) { + if (value_type.IsOptional()) { + value_type = value_type.GetOptional().GetParameter(); + } else { + ReportTypeMismatch(element.expr().id(), + OptionalType(arena_, value_type), value_type); + continue; + } + } + + if (!assignability_context.IsAssignable(value_type, overall_elem_type)) { + overall_elem_type = DynType(); + } + } + + if (!overall_elem_type.IsDyn()) { + assignability_context.UpdateInferredTypeAssignments(); + } + + types_[&expr] = + inference_context_->FullySubstitute(ListType(arena_, overall_elem_type)); +} + +void ResolveVisitor::PostVisitStruct(const Expr& expr, + const StructExpr& create_struct) { + absl::Status status; + std::string resolved_name; + Type resolved_type; + namespace_generator_.GenerateCandidates( + create_struct.name(), [&](const absl::string_view name) { + auto type = env_->LookupTypeName(name); + if (!type.ok()) { + status.Update(type.status()); + return false; + } else if (type->has_value()) { + resolved_name = name; + resolved_type = **type; + return false; + } + return true; + }); + + if (!status.ok()) { + status_.Update(status); + return; + } + + if (resolved_name.empty()) { + ReportMissingReference(expr, create_struct.name()); + types_[&expr] = ErrorType(); + return; + } + + if (resolved_type.kind() != TypeKind::kStruct && + !IsWellKnownMessageType(resolved_name)) { + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(expr.id()), + absl::StrCat("type '", resolved_name, + "' does not support message creation"))); + types_[&expr] = ErrorType(); + return; + } + + types_[&expr] = resolved_type; + struct_types_[&expr] = resolved_name; + + status_.Update( + CheckFieldAssignments(expr, create_struct, resolved_type, resolved_name)); +} + +void ResolveVisitor::PostVisitCall(const Expr& expr, const CallExpr& call) { + if (call.function() == kOptionalSelect) { + HandleOptSelect(expr); + return; + } + // Handle disambiguation of namespaced functions. + if (auto iter = maybe_namespaced_functions_.find(&expr); + iter != maybe_namespaced_functions_.end()) { + std::string namespaced_name = + absl::StrCat(FormatCandidate(iter->second), ".", call.function()); + const FunctionDecl* decl = + ResolveFunctionCallShape(expr, namespaced_name, call.args().size(), + /* is_receiver= */ false); + if (decl != nullptr) { + ResolveFunctionOverloads(expr, *decl, call.args().size(), + /* is_receiver= */ false, + /* is_namespaced= */ true); + return; + } + // Else, resolve the target as an attribute (deferred earlier), then + // resolve the function call normally. + ResolveQualifiedIdentifier(call.target(), iter->second); + } + + int arg_count = call.args().size(); + if (call.has_target()) { + ++arg_count; + } + + const FunctionDecl* decl = ResolveFunctionCallShape( + expr, call.function(), arg_count, call.has_target()); + + if (decl == nullptr) { + ReportMissingReference(expr, call.function()); + types_[&expr] = ErrorType(); + return; + } + + ResolveFunctionOverloads(expr, *decl, arg_count, call.has_target(), + /* is_namespaced= */ false); +} + +void ResolveVisitor::PreVisitComprehension( + const Expr& expr, const ComprehensionExpr& comprehension) { + std::unique_ptr accu_scope = current_scope_->MakeNestedScope(); + auto* accu_scope_ptr = accu_scope.get(); + + std::unique_ptr iter_scope = accu_scope->MakeNestedScope(); + auto* iter_scope_ptr = iter_scope.get(); + + // Keep the temporary decls alive as long as the visitor. + comprehension_vars_.push_back(std::move(accu_scope)); + comprehension_vars_.push_back(std::move(iter_scope)); + + comprehension_scopes_.push_back( + {&expr, current_scope_, accu_scope_ptr, iter_scope_ptr}); +} + +void ResolveVisitor::PostVisitComprehension( + const Expr& expr, const ComprehensionExpr& comprehension) { + comprehension_scopes_.pop_back(); + types_[&expr] = inference_context_->FullySubstitute( + GetDeducedType(&comprehension.result())); +} + +void ResolveVisitor::PreVisitComprehensionSubexpression( + const Expr& expr, const ComprehensionExpr& comprehension, + ComprehensionArg comprehension_arg) { + if (comprehension_scopes_.empty()) { + status_.Update(absl::InternalError( + "Comprehension scope stack is empty in comprehension")); + return; + } + auto& scope = comprehension_scopes_.back(); + if (scope.comprehension_expr != &expr) { + status_.Update(absl::InternalError("Comprehension scope stack broken")); + return; + } + + switch (comprehension_arg) { + case ComprehensionArg::LOOP_CONDITION: + current_scope_ = scope.accu_scope; + break; + case ComprehensionArg::LOOP_STEP: + current_scope_ = scope.iter_scope; + break; + case ComprehensionArg::RESULT: + current_scope_ = scope.accu_scope; + break; + default: + current_scope_ = scope.parent; + break; + } +} + +void ResolveVisitor::PostVisitComprehensionSubexpression( + const Expr& expr, const ComprehensionExpr& comprehension, + ComprehensionArg comprehension_arg) { + if (comprehension_scopes_.empty()) { + status_.Update(absl::InternalError( + "Comprehension scope stack is empty in comprehension")); + return; + } + auto& scope = comprehension_scopes_.back(); + if (scope.comprehension_expr != &expr) { + status_.Update(absl::InternalError("Comprehension scope stack broken")); + return; + } + current_scope_ = scope.parent; + + // Setting the type depends on the order the visitor is called -- the visitor + // guarantees iter range and accu init are visited before subexpressions where + // the corresponding variables can be referenced. + switch (comprehension_arg) { + case ComprehensionArg::ACCU_INIT: + scope.accu_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.accu_var(), + GetDeducedType(&comprehension.accu_init()))); + break; + case ComprehensionArg::ITER_RANGE: { + Type range_type = GetDeducedType(&comprehension.iter_range()); + Type iter_type = DynType(); // iter_var for non comprehensions v2. + Type iter_type1 = DynType(); // iter_var for comprehensions v2. + Type iter_type2 = DynType(); // iter_var2 for comprehensions v2. + switch (range_type.kind()) { + case TypeKind::kList: + iter_type1 = IntType(); + iter_type = iter_type2 = range_type.GetList().element(); + break; + case TypeKind::kMap: + iter_type = iter_type1 = range_type.GetMap().key(); + iter_type2 = range_type.GetMap().value(); + break; + case TypeKind::kDyn: + break; + default: + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(comprehension.iter_range().id()), + absl::StrCat( + "expression of type '", + FormatTypeName(inference_context_->FinalizeType(range_type)), + "' cannot be the range of a comprehension (must be " + "list, map, or dynamic)"))); + break; + } + if (comprehension.iter_var2().empty()) { + scope.iter_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.iter_var(), iter_type)); + } else { + scope.iter_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.iter_var(), iter_type1)); + scope.iter_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.iter_var2(), iter_type2)); + } + break; + } + default: + break; + } +} + +void ResolveVisitor::PostVisitSelect(const Expr& expr, + const SelectExpr& select) { + if (!deferred_select_operations_.contains(&expr)) { + ResolveSelectOperation(expr, select.field(), select.operand()); + } +} + +const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape( + const Expr& expr, absl::string_view function_name, int arg_count, + bool is_receiver) { + const FunctionDecl* decl = nullptr; + namespace_generator_.GenerateCandidates( + function_name, [&, this](absl::string_view candidate) -> bool { + decl = env_->LookupFunction(candidate); + if (decl == nullptr) { + return true; + } + bool is_logical_op = (candidate == cel::StandardFunctions::kAnd || + candidate == cel::StandardFunctions::kOr) && + arg_count >= 2; + for (const auto& ovl : decl->overloads()) { + if (ovl.member() == is_receiver && + (ovl.args().size() == arg_count || is_logical_op)) { + return false; + } + } + // Name match, but no matching overloads. + decl = nullptr; + return true; + }); + return decl; +} + +void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, + const FunctionDecl& decl, + int arg_count, bool is_receiver, + bool is_namespaced) { + std::vector arg_types; + arg_types.reserve(arg_count); + if (is_receiver) { + arg_types.push_back(GetDeducedType(&expr.call_expr().target())); + } + for (int i = 0; i < expr.call_expr().args().size(); ++i) { + arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i])); + } + + std::optional resolution = + inference_context_->ResolveOverload(decl, arg_types, is_receiver); + + if (!resolution.has_value()) { + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(expr.id()), + absl::StrCat("found no matching overload for '", decl.name(), + "' applied to '(", + absl::StrJoin(arg_types, ", ", + [](std::string* out, const Type& type) { + out->append(FormatTypeName(type)); + }), + ")'"))); + types_[&expr] = ErrorType(); + return; + } + + auto* result_decl = google::protobuf::Arena::Create(arena_); + result_decl->set_name(decl.name()); + for (const auto& ovl : resolution->overloads) { + absl::Status s = result_decl->AddOverload(ovl); + if (!s.ok()) { + // Overloads should be filtered list from the original declaration, + // so a status means an invariant was broken. + status_.Update(absl::InternalError(absl::StrCat( + "failed to add overload to resolved function declaration: ", s))); + } + } + + functions_[&expr] = {result_decl, is_namespaced}; + types_[&expr] = resolution->result_type; +} + +const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier( + absl::string_view name) { + if (absl::StartsWith(name, ".")) { + // Should not happen for normally parsed CEL, but prevent lookup in case + // of hand-crafted ASTs. + return nullptr; + } + return current_scope_->LookupLocalVariable(name); +} + +const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier( + absl::string_view name) { + if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) { + return decl; + } + absl::StatusOr> constant = + env_->LookupTypeConstant(arena_, name); + + if (!constant.ok()) { + status_.Update(constant.status()); + return nullptr; + } + + if (constant->has_value()) { + if (constant->value().type().kind() == TypeKind::kEnum) { + // Treat enum constant as just an int after resolving the reference. + // This preserves existing behavior in the other type checkers. + constant->value().set_type(IntType()); + } + return google::protobuf::Arena::Create( + arena_, std::move(constant).value().value()); + } + + return nullptr; +} + +void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, + absl::string_view name) { + // Local variables (comprehension, bind) are simple identifiers so we can + // skip generating the different namespace-qualified candidates. + if (!absl::StartsWith(name, ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(name); + if (local_decl != nullptr) { + attributes_[&expr] = {local_decl, /*requires_disambiguation=*/false, + /*local=*/true}; + types_[&expr] = + inference_context_->InstantiateTypeParams(local_decl->type()); + return; + } + } + + const VariableDecl* decl = nullptr; + namespace_generator_.GenerateCandidates( + name, [&decl, this](absl::string_view candidate) { + decl = LookupGlobalIdentifier(candidate); + // continue searching. + return decl == nullptr; + }); + + bool requires_disambiguation = false; + if (absl::StartsWith(name, ".")) { + requires_disambiguation = LookupLocalIdentifier(name.substr(1)) != nullptr; + } + + if (decl != nullptr) { + attributes_[&expr] = {decl, requires_disambiguation, /*local=*/false}; + types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); + return; + } + + ReportMissingReference(expr, name); + types_[&expr] = ErrorType(); +} + +void ResolveVisitor::ResolveQualifiedIdentifier( + const Expr& expr, absl::Span qualifiers) { + if (qualifiers.size() == 1) { + ResolveSimpleIdentifier(expr, qualifiers[0]); + return; + } + + int matched_segment_index = -1; + const VariableDecl* decl = nullptr; + bool requires_disambiguation = false; + bool is_local = false; + // Local variables (comprehension, bind) are simple identifiers so we can + // skip generating the different namespace-qualified candidates. + if (!absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); + if (local_decl != nullptr) { + decl = local_decl; + matched_segment_index = 0; + is_local = true; + goto resolve_select_trail; + } + } + + namespace_generator_.GenerateCandidates( + qualifiers, [&decl, &matched_segment_index, this]( + absl::string_view candidate, int segment_index) { + decl = LookupGlobalIdentifier(candidate); + if (decl != nullptr) { + matched_segment_index = segment_index; + return false; + } + return true; + }); + + if (decl == nullptr) { + ReportMissingReference(expr, FormatCandidate(qualifiers)); + types_[&expr] = ErrorType(); + return; + } + + if (absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = + LookupLocalIdentifier(qualifiers[0].substr(1)); + if (local_decl != nullptr) { + requires_disambiguation = true; + } + } + +resolve_select_trail: + + const int num_select_opts = qualifiers.size() - matched_segment_index - 1; + + const Expr* root = &expr; + std::vector select_opts; + select_opts.reserve(num_select_opts); + for (int i = 0; i < num_select_opts; ++i) { + select_opts.push_back(root); + root = &root->select_expr().operand(); + } + + attributes_[root] = {decl, requires_disambiguation, is_local}; + types_[root] = inference_context_->InstantiateTypeParams(decl->type()); + + // fix-up select operations that were deferred. + for (auto iter = select_opts.rbegin(); iter != select_opts.rend(); ++iter) { + ResolveSelectOperation(**iter, (*iter)->select_expr().field(), + (*iter)->select_expr().operand()); + } +} + +std::optional ResolveVisitor::CheckFieldType(int64_t id, + const Type& operand_type, + absl::string_view field) { + if (operand_type.kind() == TypeKind::kDyn || + operand_type.kind() == TypeKind::kAny) { + return DynType(); + } + + switch (operand_type.kind()) { + case TypeKind::kStruct: { + StructType struct_type = operand_type.GetStruct(); + auto field_info = env_->LookupStructField(struct_type.name(), field); + if (!field_info.ok()) { + status_.Update(field_info.status()); + return std::nullopt; + } + if (!field_info->has_value()) { + ReportUndefinedField(id, field, struct_type.name()); + return std::nullopt; + } + auto type = field_info->value().GetType(); + if (type.kind() == TypeKind::kEnum) { + // Treat enum as just an int. + return IntType(); + } + return type; + } + + case TypeKind::kMap: { + MapType map_type = operand_type.GetMap(); + return map_type.GetValue(); + } + case TypeKind::kTypeParam: { + // If the operand is a free type variable, bind it to dyn to prevent + // an alternative type from being inferred. + if (inference_context_->IsAssignable(DynType(), operand_type)) { + return DynType(); + } + break; + } + default: + break; + } + + ReportIssue(TypeCheckIssue::CreateError( + ast_->ComputeSourceLocation(id), + absl::StrCat( + "expression of type '", + FormatTypeName(inference_context_->FinalizeType(operand_type)), + "' cannot be the operand of a select operation"))); + return std::nullopt; +} + +void ResolveVisitor::ResolveSelectOperation(const Expr& expr, + absl::string_view field, + const Expr& operand) { + const Type& operand_type = GetDeducedType(&operand); + + std::optional result_type; + int64_t id = expr.id(); + // Support short-hand optional chaining. + if (operand_type.IsOptional()) { + auto optional_type = operand_type.GetOptional(); + Type held_type = optional_type.GetParameter(); + result_type = CheckFieldType(id, held_type, field); + if (result_type.has_value()) { + result_type = OptionalType(arena_, *result_type); + } + } else { + result_type = CheckFieldType(id, operand_type, field); + } + + if (!result_type.has_value()) { + types_[&expr] = ErrorType(); + return; + } + + if (expr.select_expr().test_only()) { + types_[&expr] = BoolType(); + } else { + types_[&expr] = *result_type; + } +} + +void ResolveVisitor::HandleOptSelect(const Expr& expr) { + if (expr.call_expr().function() != kOptionalSelect || + expr.call_expr().args().size() != 2) { + status_.Update( + absl::InvalidArgumentError("Malformed optional select expression.")); + return; + } + + const Expr* operand = &expr.call_expr().args().at(0); + const Expr* field = &expr.call_expr().args().at(1); + if (!field->has_const_expr() || !field->const_expr().has_string_value()) { + status_.Update( + absl::InvalidArgumentError("Malformed optional select expression.")); + return; + } + + Type operand_type = GetDeducedType(operand); + if (operand_type.IsOptional()) { + operand_type = operand_type.GetOptional().GetParameter(); + } + + std::optional field_type = CheckFieldType( + expr.id(), operand_type, field->const_expr().string_value()); + if (!field_type.has_value()) { + types_[&expr] = ErrorType(); + return; + } + const FunctionDecl* select_decl = env_->LookupFunction(kOptionalSelect); + types_[&expr] = OptionalType(arena_, field_type.value()); + // Remove the type annotation for the field now that we've validated it as + // a valid field access instead of a string literal. + types_.erase(field); + if (select_decl != nullptr) { + functions_[&expr] = FunctionResolution{select_decl, + /*.namespace_rewrite=*/false}; + } +} + +void ResolveVisitor::HandleBlockIndex(const Expr* expr) { + ABSL_DCHECK(block_init_list_ != nullptr); + ABSL_DCHECK(block_init_list_->has_list_expr()); + const auto& elements = block_init_list_->list_expr().elements(); + int index = -1; + for (size_t i = 0; i < elements.size(); ++i) { + if (&elements[i].expr() == expr) { + index = i; + break; + } + } + if (index < 0) { + status_.Update(absl::InternalError( + "could not resolve expression as a cel.@block subexpression")); + return; + } + std::string var_name = absl::StrCat("@index", index); + + // Block is typically manually assembled from logically separate + // expressions so fix the type instead of inferring any remaining free type + // params as for normal subexpressions. + auto type = inference_context_->FinalizeType(GetDeducedType(expr)); + + VariableDecl decl = MakeVariableDecl(var_name, std::move(type)); + + // The C++ runtime requires that the indexes are topologically ordered. + // They just come into scope in order as we walk the AST so we don't need + // to do any additional work to check references to other initializers in + // an init expr. + // + // TODO(uncreated-issue/90): This is slightly inconsistent with the java + // runtime implementation which just requires the references to be acyclic. + auto* scope = + comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get(); + scope->InsertVariableIfAbsent(std::move(decl)); + current_scope_ = scope; +} + +class ResolveRewriter : public AstRewriterBase { + public: + explicit ResolveRewriter(const ResolveVisitor& visitor, + const TypeInferenceContext& inference_context, + const CheckerOptions& options, + Ast::ReferenceMap& references, Ast::TypeMap& types, + ValidationResult::TypeMap& resolved_types) + : visitor_(visitor), + inference_context_(inference_context), + reference_map_(references), + type_map_(types), + resolved_types_(resolved_types), + options_(options) {} + bool PostVisitRewrite(Expr& expr) override { + bool rewritten = false; + if (auto iter = visitor_.attributes().find(&expr); + iter != visitor_.attributes().end()) { + const VariableDecl* decl = iter->second.decl; + auto& ast_ref = reference_map_[expr.id()]; + std::string name = decl->name(); + if (iter->second.requires_disambiguation && + !absl::StartsWith(name, ".")) { + name = absl::StrCat(".", name); + } + ast_ref.set_name(name); + if (decl->has_value()) { + ast_ref.set_value(decl->value()); + } + expr.mutable_ident_expr().set_name(std::move(name)); + rewritten = true; + } else if (auto iter = visitor_.functions().find(&expr); + iter != visitor_.functions().end()) { + const FunctionDecl* decl = iter->second.decl; + const bool needs_rewrite = iter->second.namespace_rewrite; + auto& ast_ref = reference_map_[expr.id()]; + if (options_.enable_function_name_in_reference) { + ast_ref.set_name(decl->name()); + } + for (const auto& overload : decl->overloads()) { + ast_ref.mutable_overload_id().push_back(overload.id()); + } + expr.mutable_call_expr().set_function(decl->name()); + if (needs_rewrite && expr.call_expr().has_target()) { + expr.mutable_call_expr().set_target(nullptr); + } + rewritten = true; + } else if (auto iter = visitor_.struct_types().find(&expr); + iter != visitor_.struct_types().end()) { + auto& ast_ref = reference_map_[expr.id()]; + ast_ref.set_name(iter->second); + if (expr.has_struct_expr() && options_.update_struct_type_names) { + expr.mutable_struct_expr().set_name(iter->second); + } + rewritten = true; + } + + if (auto iter = visitor_.types().find(&expr); + iter != visitor_.types().end()) { + cel::Type finalized_type = inference_context_.FinalizeType(iter->second); + auto flattened_type = FlattenType(finalized_type); + + if (!flattened_type.ok()) { + status_.Update(flattened_type.status()); + return rewritten; + } + type_map_[expr.id()] = *std::move(flattened_type); + resolved_types_[expr.id()] = finalized_type; + rewritten = true; + } + + return rewritten; + } + + const absl::Status& status() const { return status_; } + + private: + absl::Status status_; + const ResolveVisitor& visitor_; + const TypeInferenceContext& inference_context_; + Ast::ReferenceMap& reference_map_; + Ast::TypeMap& type_map_; + ValidationResult::TypeMap& resolved_types_; + const CheckerOptions& options_; +}; + +} // namespace + +absl::StatusOr TypeCheckerImpl::CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + std::optional type_arena; + if (arena == nullptr) { + type_arena.emplace(); + arena = &(*type_arena); + } + + std::vector issues; + CEL_ASSIGN_OR_RETURN(auto generator, + NamespaceGenerator::Create(env_.container())); + + TypeInferenceContext type_inference_context( + arena, options_.enable_legacy_null_assignment); + ResolveVisitor visitor(std::move(generator), env_, *ast, + type_inference_context, issues, arena); + + TraversalOptions opts; + opts.use_comprehension_callbacks = true; + bool error_limit_reached = false; + auto traversal = AstTraversal::Create(ast->root_expr(), opts); + + for (int step = 0; step < options_.max_expression_node_count * 2; ++step) { + bool has_next = traversal.Step(visitor); + if (!visitor.status().ok()) { + return visitor.status(); + } + if (visitor.error_count() > options_.max_error_issues) { + error_limit_reached = true; + break; + } + if (!has_next) { + break; + } + } + + if (!traversal.IsDone() && !error_limit_reached) { + return absl::InvalidArgumentError( + absl::StrCat("Maximum expression node count exceeded: ", + options_.max_expression_node_count)); + } + + if (error_limit_reached) { + issues.push_back(TypeCheckIssue::CreateError( + {}, absl::StrCat("maximum number of ERROR issues exceeded: ", + options_.max_error_issues))); + } else if (env_.expected_type().has_value()) { + visitor.AssertExpectedType(ast->root_expr(), *env_.expected_type()); + } + + // If any issues are errors, return without an AST. + for (const auto& issue : issues) { + if (issue.severity() == Severity::kError) { + return ValidationResult(std::move(issues)); + } + } + + // Apply updates as needed. + // Happens in a second pass to simplify validating that pointers haven't + // been invalidated by other updates. + ValidationResult::TypeMap resolved_types; + ResolveRewriter rewriter(visitor, type_inference_context, options_, + ast->mutable_reference_map(), + ast->mutable_type_map(), resolved_types); + AstRewrite(ast->mutable_root_expr(), rewriter); + + CEL_RETURN_IF_ERROR(rewriter.status()); + + ast->set_is_checked(true); + if (options_.use_json_field_names) { + ast->mutable_source_info().mutable_extensions().push_back( + cel::ExtensionSpec("json_name", + std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime})); + } + + auto result = ValidationResult(std::move(ast), std::move(issues)); + if (!type_arena.has_value()) { + // cel::Type values will expire after this function returns when the local + // arena is destructed. Only set the resolved type map if we're using the + // caller's arena. + result.SetResolvedTypeMap(std::move(resolved_types)); + } + + return result; +} + +std::unique_ptr TypeCheckerImpl::ToBuilder() const { + return std::make_unique(options_, env_); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h new file mode 100644 index 000000000..9ee9a50d0 --- /dev/null +++ b/checker/internal/type_checker_impl.h @@ -0,0 +1,58 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_IMPL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "checker/checker_options.h" +#include "checker/internal/type_check_env.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { + +// Implementation of the TypeChecker interface. +// +// See cel::TypeCheckerBuilder for constructing instances. +class TypeCheckerImpl : public TypeChecker { + public: + explicit TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {}) + : env_(std::move(env)), options_(options) {} + + TypeCheckerImpl(const TypeCheckerImpl&) = delete; + TypeCheckerImpl& operator=(const TypeCheckerImpl&) = delete; + TypeCheckerImpl(TypeCheckerImpl&&) = delete; + TypeCheckerImpl& operator=(TypeCheckerImpl&&) = delete; + + absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const override; + + std::unique_ptr ToBuilder() const override; + + private: + TypeCheckEnv env_; + google::protobuf::Arena type_arena_; + CheckerOptions options_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_IMPL_H_ diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc new file mode 100644 index 000000000..61ef7d55b --- /dev/null +++ b/checker/internal/type_checker_impl_test.cc @@ -0,0 +1,2821 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_checker_impl.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "checker/checker_options.h" +#include "checker/internal/test_ast_helpers.h" +#include "checker/internal/type_check_env.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/source.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" +#include "testutil/baseline_tests.h" +#include "testutil/test_macros.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace checker_internal { + +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::Reference; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::_; +using ::testing::Contains; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::Pair; +using ::testing::Property; +using ::testing::SizeIs; + +using AstType = cel::TypeSpec; +using Severity = TypeCheckIssue::Severity; + +namespace testpb3 = ::cel::expr::conformance::proto3; +namespace testpb2 = ::cel::expr::conformance::proto2; + +std::string SevString(Severity severity) { + switch (severity) { + case Severity::kDeprecated: + return "Deprecated"; + case Severity::kError: + return "Error"; + case Severity::kWarning: + return "Warning"; + case Severity::kInformation: + return "Information"; + } +} + +} // namespace +} // namespace checker_internal + +template +void AbslStringify(Sink& sink, const TypeCheckIssue& issue) { + absl::Format(&sink, "TypeCheckIssue(%s): %s", + checker_internal::SevString(issue.severity()), issue.message()); +} + +namespace checker_internal { +namespace { + +google::protobuf::Arena* absl_nonnull TestTypeArena() { + static absl::NoDestructor kArena; + return &(*kArena); +} + +absl::StatusOr> MakeTestParsedAstWithMacros( + absl::string_view expression, const cel::MacroRegistry& registry) { + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse( + *source, registry, + {.enable_optional_syntax = true})); + return cel::CreateAstFromParsedExpr(parsed_expr); +} + +FunctionDecl MakeIdentFunction() { + auto decl = MakeFunctionDecl( + "identity", + MakeOverloadDecl("identity", TypeParamType("A"), TypeParamType("A"))); + ABSL_CHECK_OK(decl.status()); + return decl.value(); +} + +MATCHER_P2(IsIssueWithSubstring, severity, substring, "") { + const TypeCheckIssue& issue = arg; + if (issue.severity() == severity && + absl::StrContains(issue.message(), substring)) { + return true; + } + + *result_listener << "expected: " << SevString(severity) << " " << substring + << "\nactual: " << SevString(issue.severity()) << " " + << issue.message(); + + return false; +} + +MATCHER_P(IsVariableReference, var_name, "") { + const Reference& reference = arg; + if (reference.name() == var_name) { + return true; + } + *result_listener << "expected: " << var_name + << "\nactual: " << reference.name(); + + return false; +} + +MATCHER_P2(IsFunctionReference, fn_name, overloads, "") { + const Reference& reference = arg; + + absl::flat_hash_set got_overload_set( + reference.overload_id().begin(), reference.overload_id().end()); + absl::flat_hash_set want_overload_set(overloads.begin(), + overloads.end()); + + if (got_overload_set != want_overload_set) { + *result_listener << "reference to " << fn_name << "\n" + << "expected overload_ids: " + << absl::StrJoin(want_overload_set, ",") + << "\nactual: " << absl::StrJoin(got_overload_set, ","); + } + + return got_overload_set == want_overload_set; +} + +absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena, + TypeCheckEnv& env) { + Type list_of_a = ListType(arena, TypeParamType("A")); + + FunctionDecl add_op; + + add_op.set_name("_+_"); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl("add_int_int", IntType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl("add_uint_uint", UintType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( + "add_double_double", DoubleType(), DoubleType(), DoubleType()))); + + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); + + FunctionDecl not_op; + not_op.set_name("!_"); + CEL_RETURN_IF_ERROR(not_op.AddOverload( + MakeOverloadDecl("logical_not", + /*return_type=*/BoolType{}, BoolType{}))); + FunctionDecl not_strictly_false; + not_strictly_false.set_name("@not_strictly_false"); + CEL_RETURN_IF_ERROR(not_strictly_false.AddOverload( + MakeOverloadDecl("not_strictly_false", + /*return_type=*/BoolType{}, DynType{}))); + FunctionDecl mult_op; + mult_op.set_name("_*_"); + CEL_RETURN_IF_ERROR(mult_op.AddOverload( + MakeOverloadDecl("mult_int_int", + /*return_type=*/IntType(), IntType(), IntType()))); + FunctionDecl or_op; + or_op.set_name("_||_"); + CEL_RETURN_IF_ERROR(or_op.AddOverload( + MakeOverloadDecl("logical_or", + /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); + + FunctionDecl and_op; + and_op.set_name("_&&_"); + CEL_RETURN_IF_ERROR(and_op.AddOverload( + MakeOverloadDecl("logical_and", + /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); + + FunctionDecl lt_op; + lt_op.set_name("_<_"); + CEL_RETURN_IF_ERROR(lt_op.AddOverload( + MakeOverloadDecl("lt_int_int", + /*return_type=*/BoolType{}, IntType(), IntType()))); + + FunctionDecl gt_op; + gt_op.set_name("_>_"); + CEL_RETURN_IF_ERROR(gt_op.AddOverload( + MakeOverloadDecl("gt_int_int", + /*return_type=*/BoolType{}, IntType(), IntType()))); + + FunctionDecl eq_op; + eq_op.set_name("_==_"); + CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl( + "equals", + /*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A")))); + + FunctionDecl ne_op; + ne_op.set_name("_!=_"); + CEL_RETURN_IF_ERROR(ne_op.AddOverload(MakeOverloadDecl( + "not_equals", + /*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A")))); + + FunctionDecl ternary_op; + ternary_op.set_name("_?_:_"); + CEL_RETURN_IF_ERROR(ternary_op.AddOverload(MakeOverloadDecl( + "conditional", + /*return_type=*/ + TypeParamType("A"), BoolType{}, TypeParamType("A"), TypeParamType("A")))); + + FunctionDecl index_op; + index_op.set_name("_[_]"); + CEL_RETURN_IF_ERROR(index_op.AddOverload(MakeOverloadDecl( + "index", + /*return_type=*/ + TypeParamType("A"), ListType(arena, TypeParamType("A")), IntType()))); + + FunctionDecl to_int; + to_int.set_name("int"); + CEL_RETURN_IF_ERROR(to_int.AddOverload( + MakeOverloadDecl("to_int", + /*return_type=*/IntType(), DynType()))); + + FunctionDecl to_duration; + to_duration.set_name("duration"); + CEL_RETURN_IF_ERROR(to_duration.AddOverload( + MakeOverloadDecl("to_duration", + /*return_type=*/DurationType(), StringType()))); + + FunctionDecl to_timestamp; + to_timestamp.set_name("timestamp"); + CEL_RETURN_IF_ERROR(to_timestamp.AddOverload( + MakeOverloadDecl("to_timestamp", + /*return_type=*/TimestampType(), IntType()))); + + FunctionDecl to_dyn; + to_dyn.set_name("dyn"); + CEL_RETURN_IF_ERROR(to_dyn.AddOverload( + MakeOverloadDecl("to_dyn", + /*return_type=*/DynType(), TypeParamType("A")))); + + FunctionDecl to_type; + to_type.set_name("type"); + CEL_RETURN_IF_ERROR(to_type.AddOverload( + MakeOverloadDecl("to_type", + /*return_type=*/TypeType(arena, TypeParamType("A")), + TypeParamType("A")))); + + Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto block_decl, + MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam, + ListType(), kParam))); + + env.InsertFunctionIfAbsent(std::move(not_op)); + env.InsertFunctionIfAbsent(std::move(not_strictly_false)); + env.InsertFunctionIfAbsent(std::move(add_op)); + env.InsertFunctionIfAbsent(std::move(mult_op)); + env.InsertFunctionIfAbsent(std::move(or_op)); + env.InsertFunctionIfAbsent(std::move(and_op)); + env.InsertFunctionIfAbsent(std::move(lt_op)); + env.InsertFunctionIfAbsent(std::move(gt_op)); + env.InsertFunctionIfAbsent(std::move(to_int)); + env.InsertFunctionIfAbsent(std::move(eq_op)); + env.InsertFunctionIfAbsent(std::move(ne_op)); + env.InsertFunctionIfAbsent(std::move(ternary_op)); + env.InsertFunctionIfAbsent(std::move(index_op)); + env.InsertFunctionIfAbsent(std::move(to_dyn)); + env.InsertFunctionIfAbsent(std::move(to_type)); + env.InsertFunctionIfAbsent(std::move(to_duration)); + env.InsertFunctionIfAbsent(std::move(to_timestamp)); + env.InsertFunctionIfAbsent(std::move(block_decl)); + + return absl::OkStatus(); +} + +TEST(TypeCheckerImplTest, SmokeTest) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("1 + 2")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, BlockMacroSupport) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAstWithMacros( + "cel.block([1, 2], cel.index(0) + cel.index(1))", registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Overall type should be int. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kInt64); +} + +TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // cel.index(1) refers to 'a' which is string. + // So overall type should be string. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kString); +} + +TEST(TypeCheckerImplTest, BadIndex) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + HasSubstr("undeclared reference to '@index2' (in container")); +} + +TEST(TypeCheckerImplTest, SimpleIdentsResolved) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, ReportMissingIdentDecl) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), + ElementsAre(IsIssueWithSubstring(Severity::kError, + "undeclared reference to 'y'"))); +} + +TEST(TypeCheckerImplTest, ErrorLimitInclusive) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + CheckerOptions options; + options.max_error_issues = 1; + + TypeCheckerImpl impl(std::move(env), options); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("1 + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + ElementsAre(IsIssueWithSubstring(Severity::kError, + "undeclared reference to 'y'"))); + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("x + y + z")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + ElementsAre( + IsIssueWithSubstring(Severity::kError, "undeclared reference to 'x'"), + IsIssueWithSubstring(Severity::kError, "undeclared reference to 'y'"), + IsIssueWithSubstring(Severity::kError, + "maximum number of ERROR issues exceeded: 1"))); +} + +MATCHER_P3(IsIssueWithLocation, line, column, message, "") { + const TypeCheckIssue& issue = arg; + if (issue.location().line == line && issue.location().column == column && + absl::StrContains(issue.message(), message)) { + return true; + } + return false; +} + +TEST(TypeCheckerImplTest, LocationCalculation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto source, NewSource("a ||\n" + "b ||\n" + " c ||\n" + " d")); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst(source->content().ToString())); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT( + result.GetIssues(), + ElementsAre(IsIssueWithLocation(1, 0, "undeclared reference to 'a'"), + IsIssueWithLocation(2, 0, "undeclared reference to 'b'"), + IsIssueWithLocation(3, 1, "undeclared reference to 'c'"), + IsIssueWithLocation(4, 1, "undeclared reference to 'd'"))) + << absl::StrJoin(result.GetIssues(), "\n", + [&](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.ToDisplayString(*source)); + }); +} + +TEST(TypeCheckerImplTest, QualifiedIdentsResolved) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); + env.InsertVariableIfAbsent(MakeVariableDecl("x.z", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.y + x.z")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, ReportMissingQualifiedIdentDecl) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("y.x")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), + ElementsAre(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'y.x'"))); +} + +TEST(TypeCheckerImplTest, ResolveMostQualfiedIdent) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + env.InsertVariableIfAbsent(MakeVariableDecl("x.y", MapType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.y.z")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(_, IsVariableReference("x.y")))); +} + +TEST(TypeCheckerImplTest, MemberFunctionCallResolved) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); + FunctionDecl foo; + foo.set_name("foo"); + ASSERT_THAT(foo.AddOverload(MakeMemberOverloadDecl("int_foo_int", + /*return_type=*/IntType(), + IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.foo(y)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, MemberFunctionCallNotDeclared) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.foo(y)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), + ElementsAre(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'foo'"))); +} + +TEST(TypeCheckerImplTest, FunctionShapeMismatch) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + // foo(int, int) -> int + ASSERT_OK_AND_ASSIGN( + auto foo, + MakeFunctionDecl("foo", MakeOverloadDecl("foo_int_int", IntType(), + IntType(), IntType()))); + env.InsertFunctionIfAbsent(foo); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo(1, 2, 3)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), + ElementsAre(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'foo'"))); +} + +TEST(TypeCheckerImplTest, NamespaceFunctionCallResolved) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + // Variables + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); + + // add x.foo as a namespaced function. + FunctionDecl foo; + foo.set_name("x.foo"); + ASSERT_THAT( + foo.AddOverload(MakeOverloadDecl("x_foo_int", + /*return_type=*/IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.foo(y)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); +} + +TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + // Variables + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + // add x.foo as a namespaced function. + FunctionDecl foo; + foo.set_name("x.y.foo"); + ASSERT_THAT( + foo.AddOverload(MakeOverloadDecl("x_y_foo_int", + /*return_type=*/IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.y.foo(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); +} + +TEST(TypeCheckerImplTest, NamespacedFunctionWithAbbreviation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + // Variables + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + FunctionDecl foo; + foo.set_name("x.y.foo"); + ASSERT_THAT( + foo.AddOverload(MakeOverloadDecl("x_y_foo_int", + /*return_type=*/IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + env.set_container(*MakeExpressionContainer("", "x.y.foo")); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); +} + +TEST(TypeCheckerImplTest, MixedListTypeToDyn) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[1, 'a']")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + EXPECT_TRUE( + result.GetAst()->type_map().at(1).list_type().elem_type().has_dyn()); +} + +TEST(TypeCheckerImplTest, FreeListTypeToDyn) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[]")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + EXPECT_TRUE( + result.GetAst()->type_map().at(1).list_type().elem_type().has_dyn()); +} + +TEST(TypeCheckerImplTest, FreeMapValueTypeToDyn) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}.field")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + auto root_id = result.GetAst()->root_expr().id(); + EXPECT_TRUE(result.GetAst()->type_map().at(root_id).has_dyn()); +} + +TEST(TypeCheckerImplTest, FreeMapTypeToDyn) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().key_type().has_dyn()); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().value_type().has_dyn()); +} + +TEST(TypeCheckerImplTest, MapTypeWithMixedKeys) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'a': 1, 2: 3}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + const auto* checked_ast = result.GetAst(); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().key_type().has_dyn()); + EXPECT_EQ(checked_ast->type_map().at(1).map_type().value_type().primitive(), + PrimitiveType::kInt64); +} + +TEST(TypeCheckerImplTest, MapTypeUnsupportedKeyWarns) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{{}: 'a'}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), + ElementsAre(IsIssueWithSubstring(Severity::kWarning, + "unsupported map key type:"))); +} + +TEST(TypeCheckerImplTest, MapTypeWithMixedValues) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'a': 1, 'b': '2'}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->type_map().at(1).map_type().key_type().primitive(), + PrimitiveType::kString); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().value_type().has_dyn()); +} + +TEST(TypeCheckerImplTest, ComprehensionVariablesResolved) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("[1, 2, 3].exists(x, x * x > 10)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, MapComprehensionVariablesResolved) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("{1: 3, 2: 4}.exists(x, x == 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, NestedComprehensions) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst("[1, 2].all(x, ['1', '2'].exists(y, int(y) == x))")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsShadowNamespacePriorityRules) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer("com")); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + // Namespace compre var shadows com.x + env.InsertVariableIfAbsent(MakeVariableDecl("com.x", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("['1', '2'].exists(x, x == '2')")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->reference_map(), + Not(Contains(Pair(_, IsVariableReference("com.x"))))); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdent) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("[{'y': '2'}].all(x, x.y == '2')")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->reference_map(), + Not(Contains(Pair(_, IsVariableReference("x.y"))))); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdentTypeError) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[0].all(x, x.y == 0)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT( + result.FormatError(), + HasSubstr("type 'int' cannot be the operand of a select operation")); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesQualifiedIdent) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("[{'y': 0}].all(x, .x.y == 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(_, IsVariableReference(".x.y")))); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesQualifiedIdentMixed) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x.y", StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("[{'y': 0}].all(x, .x.y != x.y)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.FormatError(), + HasSubstr("no matching overload for '_!=_' applied to '(string, int)'")); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesIdent) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("['foo'].all(x, .x == 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(_, IsVariableReference(".x")))); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsCyclicParamAssignability) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + // This is valid because the list construction in the transform will resolve + // to list(dyn) since candidates E1 -> E2 and list(E1) -> E2 don't agree. + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[].map(c, [ c, [c] ])")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Remainder are conceptually the same, but confirm generality. + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, [[c]] ])")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [c], [[c]] ])")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, c ])")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [c], c ])")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [[c]], c ])")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, type(c) ])")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); +} + +struct PrimitiveLiteralsTestCase { + std::string expr; + PrimitiveType expected_type; +}; + +class PrimitiveLiteralsTest + : public testing::TestWithParam {}; + +TEST_P(PrimitiveLiteralsTest, LiteralsTypeInferred) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + const PrimitiveLiteralsTestCase& test_case = GetParam(); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->mutable_type_map()[1].primitive(), + test_case.expected_type); +} + +INSTANTIATE_TEST_SUITE_P(PrimitiveLiteralsTests, PrimitiveLiteralsTest, + ::testing::Values( + PrimitiveLiteralsTestCase{ + .expr = "1", + .expected_type = PrimitiveType::kInt64, + }, + PrimitiveLiteralsTestCase{ + .expr = "1.0", + .expected_type = PrimitiveType::kDouble, + }, + PrimitiveLiteralsTestCase{ + .expr = "1u", + .expected_type = PrimitiveType::kUint64, + }, + PrimitiveLiteralsTestCase{ + .expr = "'string'", + .expected_type = PrimitiveType::kString, + }, + PrimitiveLiteralsTestCase{ + .expr = "b'bytes'", + .expected_type = PrimitiveType::kBytes, + }, + PrimitiveLiteralsTestCase{ + .expr = "false", + .expected_type = PrimitiveType::kBool, + })); +struct AstTypeConversionTestCase { + Type decl_type; + TypeSpec expected_type; +}; + +class AstTypeConversionTest + : public testing::TestWithParam {}; + +TEST_P(AstTypeConversionTest, TypeConversion) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_TRUE( + env.InsertVariableIfAbsent(MakeVariableDecl("x", GetParam().decl_type))); + const AstTypeConversionTestCase& test_case = GetParam(); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->mutable_type_map()[1], test_case.expected_type) + << GetParam().decl_type.DebugString(); +} + +INSTANTIATE_TEST_SUITE_P( + Primitives, AstTypeConversionTest, + ::testing::Values( + AstTypeConversionTestCase{ + .decl_type = NullType(), + .expected_type = AstType(NullTypeSpec()), + }, + AstTypeConversionTestCase{ + .decl_type = DynType(), + .expected_type = AstType(DynTypeSpec()), + }, + AstTypeConversionTestCase{ + .decl_type = BoolType(), + .expected_type = AstType(PrimitiveType::kBool), + }, + AstTypeConversionTestCase{ + .decl_type = IntType(), + .expected_type = AstType(PrimitiveType::kInt64), + }, + AstTypeConversionTestCase{ + .decl_type = UintType(), + .expected_type = AstType(PrimitiveType::kUint64), + }, + AstTypeConversionTestCase{ + .decl_type = DoubleType(), + .expected_type = AstType(PrimitiveType::kDouble), + }, + AstTypeConversionTestCase{ + .decl_type = StringType(), + .expected_type = AstType(PrimitiveType::kString), + }, + AstTypeConversionTestCase{ + .decl_type = BytesType(), + .expected_type = AstType(PrimitiveType::kBytes), + }, + AstTypeConversionTestCase{ + .decl_type = TimestampType(), + .expected_type = AstType(WellKnownTypeSpec::kTimestamp), + }, + AstTypeConversionTestCase{ + .decl_type = DurationType(), + .expected_type = AstType(WellKnownTypeSpec::kDuration), + })); + +INSTANTIATE_TEST_SUITE_P( + Wrappers, AstTypeConversionTest, + ::testing::Values( + AstTypeConversionTestCase{ + .decl_type = IntWrapperType(), + .expected_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + }, + AstTypeConversionTestCase{ + .decl_type = UintWrapperType(), + .expected_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + }, + AstTypeConversionTestCase{ + .decl_type = DoubleWrapperType(), + .expected_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + }, + AstTypeConversionTestCase{ + .decl_type = BoolWrapperType(), + .expected_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kBool)), + }, + AstTypeConversionTestCase{ + .decl_type = StringWrapperType(), + .expected_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kString)), + }, + AstTypeConversionTestCase{ + .decl_type = BytesWrapperType(), + .expected_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + })); + +INSTANTIATE_TEST_SUITE_P( + ComplexTypes, AstTypeConversionTest, + ::testing::Values( + AstTypeConversionTestCase{ + .decl_type = ListType(TestTypeArena(), IntType()), + .expected_type = AstType( + ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), + }, + AstTypeConversionTestCase{ + .decl_type = MapType(TestTypeArena(), IntType(), IntType()), + .expected_type = AstType( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(PrimitiveType::kInt64))), + }, + AstTypeConversionTestCase{ + .decl_type = TypeType(TestTypeArena(), IntType()), + .expected_type = + AstType(std::make_unique(PrimitiveType::kInt64)), + }, + AstTypeConversionTestCase{ + .decl_type = OpaqueType(TestTypeArena(), "tuple", + {IntType(), IntType()}), + .expected_type = AstType( + AbstractType("tuple", {AstType(PrimitiveType::kInt64), + AstType(PrimitiveType::kInt64)})), + }, + AstTypeConversionTestCase{ + .decl_type = StructType(MessageType(TestAllTypes::descriptor())), + .expected_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))})); + +TEST(TypeCheckerImplTest, NullLiteral) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("null")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->mutable_type_map()[1].has_null()); +} + +TEST(TypeCheckerImplTest, ExpressionLimitInclusive) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + CheckerOptions options; + options.max_expression_node_count = 2; + TypeCheckerImpl impl(std::move(env), options); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}.foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("{}.foo.bar")); + EXPECT_THAT(impl.Check(std::move(ast)), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expression node count exceeded: 2"))); +} + +TEST(TypeCheckerImplTest, ComprehensionUnsupportedRange) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("'abc'.all(x, y == 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring( + Severity::kError, + "expression of type 'string' cannot be " + "the range of a comprehension"))); +} + +TEST(TypeCheckerImplTest, ComprehensionDynRange) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("range", DynType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("range.all(x, x == 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(TypeCheckerImplTest, BasicOvlResolution) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Assumes parser numbering: + should always be id 2. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->mutable_reference_map()[2], + IsFunctionReference( + "_+_", std::vector{"add_double_double"})); +} + +TEST(TypeCheckerImplTest, OvlResolutionMultipleOverloads) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("dyn(x) + dyn(y)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Assumes parser numbering: + should always be id 3. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->mutable_reference_map()[3], + IsFunctionReference("_+_", std::vector{ + "add_double_double", "add_int_int", + "add_list", "add_uint_uint"})); +} + +TEST(TypeCheckerImplTest, BasicFunctionResultTypeResolution) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); + env.InsertVariableIfAbsent(MakeVariableDecl("z", DoubleType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y + z")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Assumes parser numbering: + should always be id 2 and 4. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->mutable_reference_map()[2], + IsFunctionReference( + "_+_", std::vector{"add_double_double"})); + EXPECT_THAT(checked_ast->mutable_reference_map()[4], + IsFunctionReference( + "_+_", std::vector{"add_double_double"})); + int64_t root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->mutable_type_map()[root_id].primitive(), + PrimitiveType::kDouble); +} + +TEST(TypeCheckerImplTest, BasicOvlResolutionNoMatch) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + "no matching overload for '_+_'" + " applied to '(int, string)'"))); +} + +TEST(TypeCheckerImplTest, ParmeterizedOvlResolutionMatch) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + env.InsertVariableIfAbsent(MakeVariableDecl("y", StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("([x] + []) == [x]")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerImplTest, AliasedTypeVarSameType) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("[].exists(x, x == 10 || x == '10')")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + ElementsAre(IsIssueWithSubstring( + Severity::kError, "no matching overload for '_==_' applied to"))); +} + +TEST(TypeCheckerImplTest, TypeVarRange) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + env.InsertFunctionIfAbsent(MakeIdentFunction()); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("identity([]).exists(x, x == 10 )")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()) << absl::StrJoin(result.GetIssues(), "\n"); +} + +TEST(TypeCheckerImplTest, WellKnownTypeCreation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.AddTypeProvider(std::make_unique()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("google.protobuf.Int32Value{value: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)))))); + EXPECT_THAT( + checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), + Property(&Reference::name, "google.protobuf.Int32Value")))); +} + +TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.AddTypeProvider(std::make_unique()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("google.protobuf.Struct{fields: {}}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + int64_t map_expr_id = + checked_ast->root_expr().struct_expr().fields().at(0).value().id(); + ASSERT_NE(map_expr_id, 0); + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(map_expr_id, + Eq(AstType(MapTypeSpec( + std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))))))); +} + +struct VariadicLogicalCheckerTestCase { + std::string expr; +}; + +class VariadicLogicalCheckerTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalCheckerTest, Check) { + const auto& test_case = GetParam(); + + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, parser->Parse(*source)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + auto checker_builder = impl.ToBuilder(); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("a", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("b", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("c", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("d", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("e", BoolType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, checker_builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(parsed_ast))); + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveType::kBool))))); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalChecker, VariadicLogicalCheckerTest, + testing::Values(VariadicLogicalCheckerTestCase{"true && false && true"}, + VariadicLogicalCheckerTestCase{"a && b && c && d"}, + VariadicLogicalCheckerTestCase{"a || b || c || d"}, + VariadicLogicalCheckerTestCase{"a && b && (c || d || e)"}, + VariadicLogicalCheckerTestCase{"a && b && c"}, + VariadicLogicalCheckerTestCase{"a || b || c"}, + VariadicLogicalCheckerTestCase{"[a, b, c].exists(x, x)"}, + VariadicLogicalCheckerTestCase{"[a, b, c].all(x, x)"})); + +TEST(TypeCheckerImplTest, VariadicLogicalOperatorsError) { + cel::expr::ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + } + } + )pb", + &parsed_expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, + cel::CreateAstFromParsedExpr(parsed_expr)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + impl.Check(std::move(parsed_ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, "undeclared reference"))); +} + +TEST(TypeCheckerImplTest, ExpectedTypeMatches) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.set_expected_type(MapType(&arena, StringType(), StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(MapTypeSpec( + std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kString))))))); +} + +TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.set_expected_type(MapType(&arena, StringType(), StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'abc': 123}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, + "expected type 'map(string, string)' but found 'map(string, int)'"))); +} + +TEST(TypeCheckerImplTest, ToBuilder) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + TypeCheckerImpl impl(std::move(env)); + auto builder = impl.ToBuilder(); + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN(auto new_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + new_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerImplTest, ToBuilderPropagatesArena) { + auto arena = std::make_shared(); + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_arena(arena); + + Type list_type = ListType(arena.get(), IntType()); + ASSERT_TRUE( + env.InsertVariableIfAbsent(MakeVariableDecl("my_list", list_type))); + + auto base_checker = std::make_unique(std::move(env)); + + std::unique_ptr builder = base_checker->ToBuilder(); + + base_checker.reset(); + arena.reset(); + + ASSERT_OK_AND_ASSIGN(auto derived_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("my_list")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + derived_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerImplTest, BadSourcePosition) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); + ast->mutable_source_info().mutable_positions()[1] = -42; + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_OK_AND_ASSIGN(auto source, NewSource("foo")); + + EXPECT_FALSE(result.IsValid()); + ASSERT_THAT(result.GetIssues(), SizeIs(1)); + + EXPECT_EQ( + result.GetIssues()[0].ToDisplayString(*source), + "ERROR: :-1:-1: undeclared reference to 'foo' (in container '')"); +} + +// Check that the TypeChecker will fail if no type is deduced for a +// subexpression. This is meant to be a guard against failing to account for new +// types of expressions in the type checker logic. +TEST(TypeCheckerImplTest, FailsIfNoTypeDeduced) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + env.InsertVariableIfAbsent(MakeVariableDecl("a", BoolType())); + env.InsertVariableIfAbsent(MakeVariableDecl("b", BoolType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("a || b")); + + // Assume that an unspecified expr kind is not deducible. + Expr unspecified_expr; + unspecified_expr.set_id(3); + ast->mutable_root_expr().mutable_call_expr().mutable_args()[1] = + std::move(unspecified_expr); + + ASSERT_THAT(impl.Check(std::move(ast)), + StatusIs(absl::StatusCode::kInvalidArgument, + "Could not deduce type for expression id: 3")); +} + +TEST(TypeCheckerImplTest, BadLineOffsets) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto source, NewSource("\nfoo")); + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); + ast->mutable_source_info().mutable_line_offsets()[1] = 1; + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + ASSERT_THAT(result.GetIssues(), SizeIs(1)); + + EXPECT_EQ(result.GetIssues()[0].ToDisplayString(*source), + "ERROR: :-1:-1: undeclared reference to 'foo' (in " + "container '')"); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); + ast->mutable_source_info().mutable_line_offsets().clear(); + ast->mutable_source_info().mutable_line_offsets().push_back(-1); + ast->mutable_source_info().mutable_line_offsets().push_back(2); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + ASSERT_THAT(result.GetIssues(), SizeIs(1)); + + EXPECT_EQ(result.GetIssues()[0].ToDisplayString(*source), + "ERROR: :-1:-1: undeclared reference to 'foo' (in " + "container '')"); + } +} + +TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer("google.protobuf")); + env.AddTypeProvider(std::make_unique()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("Int32Value{value: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)))))); + EXPECT_THAT( + checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), + Property(&Reference::name, "google.protobuf.Int32Value")))); +} + +TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer("google.protobuf")); + env.AddTypeProvider(std::make_unique()); + + CheckerOptions options; + options.update_struct_type_names = false; + TypeCheckerImpl impl(std::move(env), options); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("Int32Value{value: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)))))); + EXPECT_THAT( + checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), + Property(&Reference::name, "google.protobuf.Int32Value")))); + EXPECT_THAT(checked_ast->root_expr().struct_expr(), + Property(&StructExpr::name, "Int32Value")); +} + +TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("TestAllTypes.NestedEnum.BAZ")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + auto ref_iter = + checked_ast->reference_map().find(checked_ast->root_expr().id()); + ASSERT_NE(ref_iter, checked_ast->reference_map().end()); + EXPECT_EQ(ref_iter->second.name(), + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAZ"); + EXPECT_EQ(ref_iter->second.value().int_value(), 2); +} + +struct CheckedExprTestCase { + std::string expr; + TypeSpec expected_result_type; + std::string error_substring; +}; + +class WktCreationTest : public testing::TestWithParam {}; + +TEST_P(WktCreationTest, MessageCreation) { + google::protobuf::Arena arena; + const CheckedExprTestCase& test_case = GetParam(); + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.AddTypeProvider(std::make_unique()); + env.set_container(*MakeExpressionContainer("google.protobuf")); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(test_case.expected_result_type)))); +} + +INSTANTIATE_TEST_SUITE_P( + WellKnownTypes, WktCreationTest, + ::testing::Values( + CheckedExprTestCase{ + .expr = "google.protobuf.Int32Value{value: 10}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = ".google.protobuf.Int32Value{value: 10}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "Int32Value{value: 10}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "google.protobuf.Int32Value{value: '10'}", + .expected_result_type = AstType(), + .error_substring = "expected type of field 'value' is 'int' but " + "provided type is 'string'"}, + CheckedExprTestCase{ + .expr = "google.protobuf.Int32Value{not_a_field: '10'}", + .expected_result_type = AstType(), + .error_substring = "undefined field 'not_a_field' not found in " + "struct 'google.protobuf.Int32Value'"}, + CheckedExprTestCase{ + .expr = "NotAType{not_a_field: '10'}", + .expected_result_type = AstType(), + .error_substring = + "undeclared reference to 'NotAType' (in container " + "'google.protobuf')"}, + CheckedExprTestCase{ + .expr = ".protobuf.Int32Value{value: 10}", + .expected_result_type = AstType(), + .error_substring = + "undeclared reference to '.protobuf.Int32Value' (in container " + "'google.protobuf')"}, + CheckedExprTestCase{ + .expr = "Int32Value{value: 10}.value", + .expected_result_type = AstType(), + .error_substring = + "expression of type 'wrapper(int)' cannot be the " + "operand of a select operation"}, + CheckedExprTestCase{ + .expr = "Int64Value{value: 10}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "BoolValue{value: true}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kBool)), + }, + CheckedExprTestCase{ + .expr = "UInt64Value{value: 10u}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + }, + CheckedExprTestCase{ + .expr = "UInt32Value{value: 10u}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + }, + CheckedExprTestCase{ + .expr = "FloatValue{value: 1.25}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + }, + CheckedExprTestCase{ + .expr = "DoubleValue{value: 1.25}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + }, + CheckedExprTestCase{ + .expr = "StringValue{value: 'test'}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kString)), + }, + CheckedExprTestCase{ + .expr = "BytesValue{value: b'test'}", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + }, + CheckedExprTestCase{ + .expr = "Duration{seconds: 10, nanos: 11}", + .expected_result_type = AstType(WellKnownTypeSpec::kDuration), + }, + CheckedExprTestCase{ + .expr = "Timestamp{seconds: 10, nanos: 11}", + .expected_result_type = AstType(WellKnownTypeSpec::kTimestamp), + }, + CheckedExprTestCase{ + .expr = "Struct{fields: {'key': 'value'}}", + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))), + }, + CheckedExprTestCase{ + .expr = "ListValue{values: [1, 2, 3]}", + .expected_result_type = + AstType(ListTypeSpec(std::make_unique(DynTypeSpec()))), + }, + CheckedExprTestCase{ + .expr = R"cel( + Any{ + type_url:'type.googleapis.com/google.protobuf.Int32Value', + value: b'' + })cel", + .expected_result_type = AstType(WellKnownTypeSpec::kAny), + }, + CheckedExprTestCase{ + .expr = "Int64Value{value: 10} + 1", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "BoolValue{value: false} || true", + .expected_result_type = AstType(PrimitiveType::kBool), + })); + +TEST(AliasTest, ImportVariable) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr.conformance", + "com.example.TestVariable1", + "com.example.TestVariable2")); + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable1", + MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable2", + MessageType(testpb2::TestAllTypes::descriptor())))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + "TestVariable1.single_int64 == TestVariable2.single_int64")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + ASSERT_TRUE(checked_ast->root_expr().has_call_expr()); + ASSERT_EQ(checked_ast->root_expr().call_expr().function(), "_==_"); + ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[0] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable1"); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[1] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable2"); +} + +TEST(AliasTest, AliasToContainerResolvesMessage) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes")))))); + + EXPECT_THAT( + checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), + Property(&Reference::name, + "cel.expr.conformance.proto3.TestAllTypes")))); + + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(AliasTest, AliasSimpleName) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("foo", "bar"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + env.InsertOrReplaceVariable(MakeVariableDecl("bar", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), "bar"); +} + +TEST(AliasTest, AliasPreventsContainerResolution) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr")); + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("cel.expr.pb3.FooVariable", IntType()))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("expr.pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), + "cel.expr.pb3.FooVariable"); + } +} + +TEST(AliasTest, AliasPreventsDisambiguation) { + // Copying behavior from cel-go and cel-java. + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + env.InsertOrReplaceVariable(MakeVariableDecl("pb3.Foo", IntType())); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst(".pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.Foo'"))); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(".pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to '.pb3.Foo'"))); + } +} + +class GenericMessagesTest : public testing::TestWithParam { +}; + +TEST_P(GenericMessagesTest, TypeChecksProto3Imports) { + const CheckedExprTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer( + "", "cel.expr.conformance.proto3.TestAllTypes", + "cel.expr.conformance.proto3.NestedTestAllTypes")); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( + "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(test_case.expected_result_type)))) + << cel::test::FormatBaselineAst(*checked_ast); +} + +TEST_P(GenericMessagesTest, TypeChecksProto3Container) { + const CheckedExprTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( + "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(test_case.expected_result_type)))) + << cel::test::FormatBaselineAst(*checked_ast); +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypesCreation, GenericMessagesTest, + ::testing::Values( + CheckedExprTestCase{ + .expr = "TestAllTypes{not_a_field: 10}", + .expected_result_type = AstType(), + .error_substring = + "undefined field 'not_a_field' not found in " + "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64: 10}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64: 'string'}", + .expected_result_type = AstType(), + .error_substring = + "expected type of field 'single_int64' is 'int' but " + "provided type is 'string'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int32: 10}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_uint64: 10u}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_uint32: 10u}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sint64: 10}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sint32: 10}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_fixed64: 10u}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_fixed32: 10u}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sfixed64: 10}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sfixed32: 10}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_double: 1.25}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_float: 1.25}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_string: 'string'}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_bool: true}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_bytes: b'string'}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + // Well-known + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: TestAllTypes{single_int64: 10}}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: 1}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: 'string'}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: ['string']}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_nested_message: " + "[TestAllTypes.NestedMessage{bb: 42}]}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_duration: duration('1s')}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_timestamp: timestamp(0)}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_struct: {}}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_struct: {'key': 'value'}}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_struct: {1: 2}}", + .expected_result_type = AstType(), + .error_substring = "expected type of field 'single_struct' is " + "'map(string, dyn)' but " + "provided type is 'map(int, int)'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{list_value: [1, 2, 3]}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{list_value: []}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{list_value: 1}", + .expected_result_type = AstType(), + .error_substring = + "expected type of field 'list_value' is 'list(dyn)' but " + "provided type is 'int'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64_wrapper: 1}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64_wrapper: null}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: null}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: 1.0}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: 'string'}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: {'string': 'string'}}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: ['string']}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_int64: [1, 2, 3]}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_int64: ['string']}", + .expected_result_type = AstType(), + .error_substring = + "expected type of field 'repeated_int64' is 'list(int)'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{map_string_int64: ['string']}", + .expected_result_type = AstType(), + .error_substring = "expected type of field 'map_string_int64' is " + "'map(string, int)'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{map_string_int64: {'string': 1}}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_nested_enum: 1}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = + "TestAllTypes{single_nested_enum: TestAllTypes.NestedEnum.BAR}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes.NestedEnum.BAR", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes", + .expected_result_type = AstType(std::make_unique( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes == type(TestAllTypes{})", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + // Special case for the NullValue enum. + CheckedExprTestCase{ + .expr = "TestAllTypes{null_value: 0}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + // Legacy nullability behaviors. + CheckedExprTestCase{ + .expr = "TestAllTypes{single_duration: null}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_timestamp: null}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_nested_message: null}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{}.single_duration == null", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{}.single_timestamp == null", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{}.single_nested_message == null", + .expected_result_type = AstType(PrimitiveType::kBool), + })); + +INSTANTIATE_TEST_SUITE_P( + TestAllTypesFieldSelection, GenericMessagesTest, + ::testing::Values( + CheckedExprTestCase{ + .expr = "test_msg.not_a_field", + .expected_result_type = AstType(), + .error_substring = + "undefined field 'not_a_field' not found in " + "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, + CheckedExprTestCase{ + .expr = "test_msg.single_int64", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_nested_enum", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_nested_enum == 1", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = + "test_msg.single_nested_enum == TestAllTypes.NestedEnum.BAR", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "has(test_msg.not_a_field)", + .expected_result_type = AstType(), + .error_substring = + "undefined field 'not_a_field' not found in " + "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, + CheckedExprTestCase{ + .expr = "has(test_msg.single_int64)", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_int32", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_uint64", + .expected_result_type = AstType(PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_uint32", + .expected_result_type = AstType(PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sint64", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sint32", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_fixed64", + .expected_result_type = AstType(PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_fixed32", + .expected_result_type = AstType(PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sfixed64", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sfixed32", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_float", + .expected_result_type = AstType(PrimitiveType::kDouble), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_double", + .expected_result_type = AstType(PrimitiveType::kDouble), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_string", + .expected_result_type = AstType(PrimitiveType::kString), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_bool", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_bytes", + .expected_result_type = AstType(PrimitiveType::kBytes), + }, + // Basic tests for containers. This is covered in more detail in + // conformance tests and the type provider implementation. + CheckedExprTestCase{ + .expr = "test_msg.repeated_int32", + .expected_result_type = AstType( + ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), + }, + CheckedExprTestCase{ + .expr = "test_msg.repeated_string", + .expected_result_type = AstType(ListTypeSpec( + std::make_unique(PrimitiveType::kString))), + }, + CheckedExprTestCase{ + .expr = "test_msg.map_bool_bool", + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(PrimitiveType::kBool), + std::make_unique(PrimitiveType::kBool))), + }, + // Note: The Go type checker permits this so C++ does as well. Some + // test cases expect that field selection on a map is always allowed, + // even if a specific, non-string key type is known. + CheckedExprTestCase{ + .expr = "test_msg.map_bool_bool.field_like_key", + .expected_result_type = AstType(PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "test_msg.map_string_int64", + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kInt64))), + }, + CheckedExprTestCase{ + .expr = "test_msg.map_string_int64.field_like_key", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + // Well-known + CheckedExprTestCase{ + .expr = "test_msg.single_duration", + .expected_result_type = AstType(WellKnownTypeSpec::kDuration), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_timestamp", + .expected_result_type = AstType(WellKnownTypeSpec::kTimestamp), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_any", + .expected_result_type = AstType(WellKnownTypeSpec::kAny), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_int64_wrapper", + .expected_result_type = + AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_struct", + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))), + }, + CheckedExprTestCase{ + .expr = "test_msg.list_value", + .expected_result_type = + AstType(ListTypeSpec(std::make_unique(DynTypeSpec()))), + }, + CheckedExprTestCase{ + .expr = "test_msg.list_value", + .expected_result_type = + AstType(ListTypeSpec(std::make_unique(DynTypeSpec()))), + }, + // Basic tests for nested messages. + CheckedExprTestCase{ + .expr = "NestedTestAllTypes{}.child.child.payload.single_int64", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_struct.field.nested_field", + .expected_result_type = AstType(DynTypeSpec()), + }, + CheckedExprTestCase{ + .expr = "{}.field.nested_field", + .expected_result_type = AstType(DynTypeSpec()), + })); + +INSTANTIATE_TEST_SUITE_P( + TypeInferences, GenericMessagesTest, + ::testing::Values( + CheckedExprTestCase{.expr = "[1, test_msg.single_int64_wrapper]", + .expected_result_type = AstType(ListTypeSpec( + std::make_unique(PrimitiveTypeWrapper( + PrimitiveType::kInt64))))}, + CheckedExprTestCase{.expr = "[1, 2, test_msg.single_int64_wrapper]", + .expected_result_type = AstType(ListTypeSpec( + std::make_unique(PrimitiveTypeWrapper( + PrimitiveType::kInt64))))}, + CheckedExprTestCase{.expr = "[test_msg.single_int64_wrapper, 1]", + .expected_result_type = AstType(ListTypeSpec( + std::make_unique(PrimitiveTypeWrapper( + PrimitiveType::kInt64))))}, + CheckedExprTestCase{ + .expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]", + .expected_result_type = AstType( + ListTypeSpec(std::make_unique(DynTypeSpec())))}, + CheckedExprTestCase{.expr = "[null, test_msg][0]", + .expected_result_type = AstType(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes"))}, + CheckedExprTestCase{ + .expr = "[{'k': dyn(1)}, {dyn('k'): 1}][0]", + // Ambiguous type resolution, but we prefer the first option. + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {dyn('k'): 1}][0]", + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(DynTypeSpec()), + std::make_unique(PrimitiveType::kInt64)))}, + CheckedExprTestCase{ + .expr = "[{dyn('k'): 1}, {'k': 1}][0]", + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(DynTypeSpec()), + std::make_unique(PrimitiveType::kInt64)))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {'k': dyn(1)}][0]", + .expected_result_type = AstType( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))}, + CheckedExprTestCase{.expr = "[{'k': 1}, {dyn('k'): dyn(1)}][0]", + .expected_result_type = AstType(MapTypeSpec( + std::make_unique(DynTypeSpec()), + std::make_unique(DynTypeSpec())))}, + CheckedExprTestCase{ + .expr = + "[{'k': 1.0}, {dyn('k'): test_msg.single_int64_wrapper}][0]", + .expected_result_type = AstType(DynTypeSpec())}, + CheckedExprTestCase{ + .expr = "test_msg.single_int64", + .expected_result_type = AstType(PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "[[1], {1: 2u}][0]", + .expected_result_type = AstType(DynTypeSpec()), + }, + CheckedExprTestCase{ + .expr = "[{1: 2u}, [1]][0]", + .expected_result_type = AstType(DynTypeSpec()), + }, + CheckedExprTestCase{ + .expr = "[test_msg.single_int64_wrapper," + " test_msg.single_string_wrapper][0]", + .expected_result_type = AstType(DynTypeSpec()), + })); + +class StrictNullAssignmentTest + : public testing::TestWithParam {}; + +TEST_P(StrictNullAssignmentTest, TypeChecksProto3) { + const CheckedExprTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( + "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + CheckerOptions options; + options.enable_legacy_null_assignment = false; + TypeCheckerImpl impl(std::move(env), options); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(test_case.expected_result_type)))); +} + +INSTANTIATE_TEST_SUITE_P( + TestStrictNullAssignment, StrictNullAssignmentTest, + ::testing::Values( + // Legacy nullability behaviors rejected. + CheckedExprTestCase{ + .expr = "TestAllTypes{single_duration: null}", + .expected_result_type = AstType(), + .error_substring = + "'single_duration' is 'google.protobuf.Duration' but provided " + "type is 'null_type'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_timestamp: null}", + .expected_result_type = AstType(), + .error_substring = + "'single_timestamp' is 'google.protobuf.Timestamp' but " + "provided type is 'null_type'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_nested_message: null}", + .expected_result_type = AstType(), + // Debug string includes descriptor address. + .error_substring = "but provided type is 'null_type'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{}.single_duration == null", + .expected_result_type = AstType(), + .error_substring = "no matching overload for '_==_'", + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{}.single_timestamp == null", + .expected_result_type = AstType(), + .error_substring = "no matching overload for '_==_'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{}.single_nested_message == null", + .expected_result_type = AstType(), + .error_substring = "no matching overload for '_==_'", + })); + +} // namespace +} // namespace checker_internal +} // namespace cel diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc new file mode 100644 index 000000000..4f738b804 --- /dev/null +++ b/checker/internal/type_inference_context.cc @@ -0,0 +1,680 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_inference_context.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/decl.h" +#include "common/format_type_name.h" +#include "common/standard_definitions.h" +#include "common/type.h" +#include "common/type_kind.h" + +namespace cel::checker_internal { +namespace { + +bool IsWildCardType(Type type) { + switch (type.kind()) { + case TypeKind::kAny: + case TypeKind::kDyn: + case TypeKind::kError: + return true; + default: + return false; + } +} + +// Returns true if the given type is a legacy nullable type. +// +// Historically, structs and abstract types were considered nullable. This is +// inconsistent with CEL's usual interpretation of null as a literal JSON null. +// +// TODO(uncreated-issue/74): Need a concrete plan for updating existing CEL expressions +// that depend on the old behavior. +bool IsLegacyNullable(Type type) { + switch (type.kind()) { + case TypeKind::kStruct: + case TypeKind::kDuration: + case TypeKind::kTimestamp: + case TypeKind::kAny: + case TypeKind::kOpaque: + return true; + default: + return false; + } +} + +bool IsTypeVar(absl::string_view name) { return absl::StartsWith(name, "T%"); } + +bool IsUnionType(Type t) { + switch (t.kind()) { + case TypeKind::kAny: + case TypeKind::kBoolWrapper: + case TypeKind::kBytesWrapper: + case TypeKind::kDyn: + case TypeKind::kDoubleWrapper: + case TypeKind::kIntWrapper: + case TypeKind::kStringWrapper: + case TypeKind::kUintWrapper: + return true; + default: + return false; + } +} + +// Returns true if `a` is a subset of `b`. +// (b is more general than a and admits a). +bool IsSubsetOf(Type a, Type b) { + switch (b.kind()) { + case TypeKind::kAny: + return true; + case TypeKind::kBoolWrapper: + return a.IsBool() || a.IsNull(); + case TypeKind::kBytesWrapper: + return a.IsBytes() || a.IsNull(); + case TypeKind::kDoubleWrapper: + return a.IsDouble() || a.IsNull(); + case TypeKind::kDyn: + return true; + case TypeKind::kIntWrapper: + return a.IsInt() || a.IsNull(); + case TypeKind::kStringWrapper: + return a.IsString() || a.IsNull(); + case TypeKind::kUintWrapper: + return a.IsUint() || a.IsNull(); + default: + return false; + } +} + +struct FunctionOverloadInstance { + Type result_type; + std::vector param_types; +}; + +FunctionOverloadInstance InstantiateFunctionOverload( + TypeInferenceContext& inference_context, const OverloadDecl& ovl) { + FunctionOverloadInstance result; + result.param_types.reserve(ovl.args().size()); + + TypeInferenceContext::InstanceMap substitutions; + result.result_type = + inference_context.InstantiateTypeParams(ovl.result(), substitutions); + + for (int i = 0; i < ovl.args().size(); ++i) { + result.param_types.push_back( + inference_context.InstantiateTypeParams(ovl.args()[i], substitutions)); + } + return result; +} + +// Converts a wrapper type to its corresponding primitive type. +// Returns nullopt if the type is not a wrapper type. +std::optional WrapperToPrimitive(const Type& t) { + switch (t.kind()) { + case TypeKind::kBoolWrapper: + return BoolType(); + case TypeKind::kBytesWrapper: + return BytesType(); + case TypeKind::kDoubleWrapper: + return DoubleType(); + case TypeKind::kStringWrapper: + return StringType(); + case TypeKind::kIntWrapper: + return IntType(); + case TypeKind::kUintWrapper: + return UintType(); + default: + return std::nullopt; + } +} + +} // namespace + +Type TypeInferenceContext::InstantiateTypeParams(const Type& type) { + InstanceMap substitutions; + return InstantiateTypeParams(type, substitutions); +} + +Type TypeInferenceContext::InstantiateTypeParams( + const Type& type, + absl::flat_hash_map& substitutions) { + switch (type.kind()) { + // Unparameterized types -- just forward. + case TypeKind::kAny: + case TypeKind::kBool: + case TypeKind::kBoolWrapper: + case TypeKind::kBytes: + case TypeKind::kBytesWrapper: + case TypeKind::kDouble: + case TypeKind::kDoubleWrapper: + case TypeKind::kDuration: + case TypeKind::kDyn: + case TypeKind::kError: + case TypeKind::kInt: + case TypeKind::kNull: + case TypeKind::kString: + case TypeKind::kStringWrapper: + case TypeKind::kStruct: + case TypeKind::kTimestamp: + case TypeKind::kUint: + case TypeKind::kIntWrapper: + case TypeKind::kUintWrapper: + return type; + case TypeKind::kTypeParam: { + absl::string_view name = type.AsTypeParam()->name(); + if (IsTypeVar(name)) { + // Already instantiated (e.g. list comprehension variable). + return type; + } + if (auto it = substitutions.find(name); it != substitutions.end()) { + return TypeParamType(it->second); + } + absl::string_view substitution = NewTypeVar(name); + substitutions[type.AsTypeParam()->name()] = substitution; + return TypeParamType(substitution); + } + case TypeKind::kType: { + auto type_type = type.AsType(); + auto parameters = type_type->GetParameters(); + if (parameters.size() == 1) { + Type param = InstantiateTypeParams(parameters[0], substitutions); + return TypeType(arena_, param); + } else if (parameters.size() > 1) { + return ErrorType(); + } else { // generic type + return type; + } + } + case TypeKind::kList: { + Type elem = + InstantiateTypeParams(type.AsList()->element(), substitutions); + return ListType(arena_, elem); + } + case TypeKind::kMap: { + Type key = InstantiateTypeParams(type.AsMap()->key(), substitutions); + Type value = InstantiateTypeParams(type.AsMap()->value(), substitutions); + return MapType(arena_, key, value); + } + case TypeKind::kOpaque: { + auto opaque_type = type.AsOpaque(); + auto parameters = opaque_type->GetParameters(); + std::vector param_instances; + param_instances.reserve(parameters.size()); + + for (int i = 0; i < parameters.size(); ++i) { + param_instances.push_back( + InstantiateTypeParams(parameters[i], substitutions)); + } + return OpaqueType(arena_, type.AsOpaque()->name(), param_instances); + } + default: + return ErrorType(); + } +} + +bool TypeInferenceContext::IsAssignable(const Type& from, const Type& to) { + SubstitutionMap prospective_substitutions; + bool result = IsAssignableInternal(from, to, prospective_substitutions); + if (result) { + UpdateTypeParameterBindings(prospective_substitutions); + } + return result; +} + +bool TypeInferenceContext::IsAssignableInternal( + const Type& from, const Type& to, + SubstitutionMap& prospective_substitutions) { + Type to_subs = Substitute(to, prospective_substitutions); + Type from_subs = Substitute(from, prospective_substitutions); + + // Types always assignable to themselves. + // Remainder is checking for assignability across different types. + if (to_subs == from_subs) { + return true; + } + + // Resolve free type parameters. + if (to_subs.kind() == TypeKind::kTypeParam || + from_subs.kind() == TypeKind::kTypeParam) { + return IsAssignableWithConstraints(from_subs, to_subs, + prospective_substitutions); + } + + // Maybe widen a prospective type binding if another potential binding is + // more general and admits the previous binding. + if ( + // Checking assignability to a specific type var + // that has a prospective type assignment. + to.kind() == TypeKind::kTypeParam && + prospective_substitutions.contains(to.GetTypeParam().name())) { + SubstitutionMap prospective_subs_cpy = prospective_substitutions; + if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) == + RelativeGenerality::kMoreGeneral) { + if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) && + !OccursWithin(to.GetTypeParam().name(), from_subs, + prospective_subs_cpy)) { + prospective_subs_cpy[to.GetTypeParam().name()] = from_subs; + prospective_substitutions = std::move(prospective_subs_cpy); + return true; + // otherwise, continue with normal assignability check. + } + } + } + + // Type is as concrete as it can be under current substitutions. + if (std::optional wrapped_type = WrapperToPrimitive(to_subs); + wrapped_type.has_value()) { + return from_subs.IsNull() || + IsAssignableInternal(*wrapped_type, from_subs, + prospective_substitutions); + } + + // Wrapper types are assignable to their corresponding primitive type ( + // somewhat similar to auto unboxing). This is a bit odd with CEL's null_type, + // but there isn't a dedicated syntax for narrowing from the nullable. + if (auto from_wrapper = WrapperToPrimitive(from_subs); + from_wrapper.has_value()) { + return IsAssignableInternal(*from_wrapper, to_subs, + prospective_substitutions); + } + + if (enable_legacy_null_assignment_) { + if (from_subs.IsNull() && IsLegacyNullable(to_subs)) { + return true; + } + + if (to_subs.IsNull() && IsLegacyNullable(from_subs)) { + return true; + } + } + + if (from_subs.kind() == TypeKind::kType && + to_subs.kind() == TypeKind::kType) { + // Types are always assignable to themselves (even if differently + // parameterized). + return true; + } + + if (to_subs.kind() == TypeKind::kEnum && from_subs.kind() == TypeKind::kInt) { + return true; + } + + if (from_subs.kind() == TypeKind::kEnum && to_subs.kind() == TypeKind::kInt) { + return true; + } + + if (IsWildCardType(from_subs) || IsWildCardType(to_subs)) { + return true; + } + + if (to_subs.kind() != from_subs.kind() || + to_subs.name() != from_subs.name()) { + return false; + } + + // Recurse for the type parameters. + auto to_params = to_subs.GetParameters(); + auto from_params = from_subs.GetParameters(); + const auto params_size = to_params.size(); + + if (params_size != from_params.size()) { + return false; + } + for (size_t i = 0; i < params_size; ++i) { + if (!IsAssignableInternal(from_params[i], to_params[i], + prospective_substitutions)) { + return false; + } + } + return true; +} + +Type TypeInferenceContext::Substitute( + const Type& type, const SubstitutionMap& substitutions) const { + Type subs = type; + while (subs.kind() == TypeKind::kTypeParam) { + TypeParamType t = subs.GetTypeParam(); + if (auto it = substitutions.find(t.name()); it != substitutions.end()) { + subs = it->second; + continue; + } + if (auto it = type_parameter_bindings_.find(t.name()); + it != type_parameter_bindings_.end()) { + if (it->second.type.has_value()) { + subs = *it->second.type; + continue; + } + } + break; + } + return subs; +} + +TypeInferenceContext::RelativeGenerality +TypeInferenceContext::CompareGenerality( + const Type& from, const Type& to, + const SubstitutionMap& prospective_substitutions) const { + Type from_subs = Substitute(from, prospective_substitutions); + Type to_subs = Substitute(to, prospective_substitutions); + + if (from_subs == to_subs) { + return RelativeGenerality::kEquivalent; + } + + if (IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) { + return RelativeGenerality::kMoreGeneral; + } + + if (IsUnionType(to_subs)) { + return RelativeGenerality::kLessGeneral; + } + + if (enable_legacy_null_assignment_ && IsLegacyNullable(from_subs) && + to_subs.IsNull()) { + return RelativeGenerality::kMoreGeneral; + } + + // Not a polytype. Check if it is a parameterized type and all parameters are + // equivalent and at least one is more general. + if (from_subs.IsList() && to_subs.IsList()) { + return CompareGenerality(from_subs.AsList()->GetElement(), + to_subs.AsList()->GetElement(), + prospective_substitutions); + } + + if (from_subs.IsMap() && to_subs.IsMap()) { + RelativeGenerality key_generality = + CompareGenerality(from_subs.AsMap()->GetKey(), + to_subs.AsMap()->GetKey(), prospective_substitutions); + RelativeGenerality value_generality = CompareGenerality( + from_subs.AsMap()->GetValue(), to_subs.AsMap()->GetValue(), + prospective_substitutions); + if (key_generality == RelativeGenerality::kLessGeneral || + value_generality == RelativeGenerality::kLessGeneral) { + return RelativeGenerality::kLessGeneral; + } + if (key_generality == RelativeGenerality::kMoreGeneral || + value_generality == RelativeGenerality::kMoreGeneral) { + return RelativeGenerality::kMoreGeneral; + } + return RelativeGenerality::kEquivalent; + } + + if (from_subs.IsOpaque() && to_subs.IsOpaque() && + from_subs.AsOpaque()->name() == to_subs.AsOpaque()->name() && + from_subs.AsOpaque()->GetParameters().size() == + to_subs.AsOpaque()->GetParameters().size()) { + RelativeGenerality max_generality = RelativeGenerality::kEquivalent; + for (int i = 0; i < from_subs.AsOpaque()->GetParameters().size(); ++i) { + RelativeGenerality generality = CompareGenerality( + from_subs.AsOpaque()->GetParameters()[i], + to_subs.AsOpaque()->GetParameters()[i], prospective_substitutions); + if (generality == RelativeGenerality::kLessGeneral) { + return RelativeGenerality::kLessGeneral; + } + if (generality == RelativeGenerality::kMoreGeneral) { + max_generality = RelativeGenerality::kMoreGeneral; + } + } + return max_generality; + } + + // Default not comparable. Since we ruled out polytypes, they should be + // equivalent for the purposes of deciding the most general eligible + // substitution. + return RelativeGenerality::kEquivalent; +} + +bool TypeInferenceContext::OccursWithin( + absl::string_view var_name, const Type& type, + const SubstitutionMap& substitutions) const { + // This is difficult to trigger in normal CEL expressions, but may + // happen with comprehensions where we can potentially reference a variable + // with a free type var in different ways. + // + // This check guarantees that we don't introduce a recursive type definition + // (a cycle in the substitution map). + // + // We can't reuse Substitute here because it does the pointer chasing and + // might hide a cycle. + // + // E.g. + // T2 in T3 when + // T3 -> T2 -> null_type; + Type substitution = type; + while (substitution.kind() == TypeKind::kTypeParam) { + absl::string_view param_name = substitution.AsTypeParam()->name(); + if (param_name == var_name) { + return true; + } + + if (auto it = substitutions.find(param_name); it != substitutions.end()) { + substitution = it->second; + continue; + } + if (auto it = type_parameter_bindings_.find(param_name); + it != type_parameter_bindings_.end() && it->second.type.has_value()) { + substitution = it->second.type.value(); + continue; + } + + // Type parameter is free. + return false; + } + + for (const auto& param : substitution.GetParameters()) { + if (OccursWithin(var_name, param, substitutions)) { + return true; + } + } + return false; +} + +bool TypeInferenceContext::IsAssignableWithConstraints( + const Type& from, const Type& to, + SubstitutionMap& prospective_substitutions) { + if (to.kind() == TypeKind::kTypeParam && + from.kind() == TypeKind::kTypeParam) { + if (to.AsTypeParam()->name() != from.AsTypeParam()->name()) { + // Simple case, bind from to 'to' if both are free. + prospective_substitutions[from.AsTypeParam()->name()] = to; + } + return true; + } + + if (to.kind() == TypeKind::kTypeParam) { + absl::string_view name = to.AsTypeParam()->name(); + if (!OccursWithin(name, from, prospective_substitutions)) { + prospective_substitutions[name] = from; + return true; + } + } + + if (from.kind() == TypeKind::kTypeParam) { + absl::string_view name = from.AsTypeParam()->name(); + if (!OccursWithin(name, to, prospective_substitutions)) { + prospective_substitutions[name] = to; + return true; + } + } + + // If either types are wild cards but we weren't able to specialize, + // assume assignable and continue. + if (IsWildCardType(from) || IsWildCardType(to)) { + return true; + } + + return false; +} + +std::optional +TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, + absl::Span argument_types, + bool is_receiver) { + std::optional result_type; + + bool is_logical_op = (decl.name() == cel::StandardFunctions::kAnd || + decl.name() == cel::StandardFunctions::kOr) && + argument_types.size() >= 2; + + std::vector matching_overloads; + for (const auto& ovl : decl.overloads()) { + if (ovl.member() != is_receiver || + (!is_logical_op && argument_types.size() != ovl.args().size())) { + continue; + } + + auto call_type_instance = InstantiateFunctionOverload(*this, ovl); + if (!is_logical_op) { + ABSL_DCHECK_EQ(argument_types.size(), + call_type_instance.param_types.size()); + } + bool is_match = true; + AssignabilityContext assignability_context = CreateAssignabilityContext(); + for (int i = 0; i < argument_types.size(); ++i) { + int param_index = is_logical_op ? 0 : i; + if (!assignability_context.IsAssignable( + argument_types[i], call_type_instance.param_types[param_index])) { + is_match = false; + break; + } + } + + if (is_match) { + matching_overloads.push_back(ovl); + assignability_context.UpdateInferredTypeAssignments(); + if (!result_type.has_value()) { + result_type = call_type_instance.result_type; + } else { + if (!TypeEquivalent(*result_type, call_type_instance.result_type)) { + result_type = DynType(); + } + } + } + } + + if (!result_type.has_value() || matching_overloads.empty()) { + return std::nullopt; + } + return OverloadResolution{ + .result_type = FullySubstitute(*result_type, /*free_to_dyn=*/false), + .overloads = std::move(matching_overloads), + }; +} + +void TypeInferenceContext::UpdateTypeParameterBindings( + const SubstitutionMap& prospective_substitutions) { + if (prospective_substitutions.empty()) { + return; + } + for (auto iter = prospective_substitutions.begin(); + iter != prospective_substitutions.end(); ++iter) { + if (auto binding_iter = type_parameter_bindings_.find(iter->first); + binding_iter != type_parameter_bindings_.end()) { + binding_iter->second.type = iter->second; + } else { + ABSL_LOG(WARNING) << "Uninstantiated type parameter: " << iter->first; + } + } +} + +bool TypeInferenceContext::TypeEquivalent(const Type& a, const Type& b) { + return a == b; +} + +Type TypeInferenceContext::FullySubstitute(const Type& type, + bool free_to_dyn) const { + switch (type.kind()) { + case TypeKind::kTypeParam: { + Type subs = Substitute(type, {}); + if (subs.kind() == TypeKind::kTypeParam) { + if (free_to_dyn) { + return DynType(); + } + return subs; + } + return FullySubstitute(subs, free_to_dyn); + } + case TypeKind::kType: { + if (type.AsType()->GetParameters().empty()) { + return type; + } + Type param = FullySubstitute(type.AsType()->GetType(), free_to_dyn); + return TypeType(arena_, param); + } + case TypeKind::kList: { + Type elem = FullySubstitute(type.AsList()->GetElement(), free_to_dyn); + return ListType(arena_, elem); + } + case TypeKind::kMap: { + Type key = FullySubstitute(type.AsMap()->GetKey(), free_to_dyn); + Type value = FullySubstitute(type.AsMap()->GetValue(), free_to_dyn); + return MapType(arena_, key, value); + } + case TypeKind::kOpaque: { + std::vector types; + for (const auto& param : type.AsOpaque()->GetParameters()) { + types.push_back(FullySubstitute(param, free_to_dyn)); + } + return OpaqueType(arena_, type.AsOpaque()->name(), types); + } + default: + return type; + } +} + +bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from, + const Type& to) { + return inference_context_.IsAssignableInternal(from, to, + prospective_substitutions_); +} + +std::string TypeInferenceContext::DebugString() const { + return absl::StrCat( + "type_parameter_bindings: ", + absl::StrJoin(type_parameter_bindings_, "\n ", + [](std::string* out, const auto& binding) { + absl::StrAppend( + out, binding.first, " (", binding.second.name, + ") -> ", + cel::FormatTypeName(binding.second.type.value_or( + Type(TypeParamType("none"))))); + })); +} + +void TypeInferenceContext::AssignabilityContext:: + UpdateInferredTypeAssignments() { + inference_context_.UpdateTypeParameterBindings(prospective_substitutions_); + prospective_substitutions_.clear(); +} + +void TypeInferenceContext::AssignabilityContext::Reset() { + prospective_substitutions_.clear(); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/type_inference_context.h b/checker/internal/type_inference_context.h new file mode 100644 index 000000000..1a1043047 --- /dev/null +++ b/checker/internal/type_inference_context.h @@ -0,0 +1,229 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_INFERENCE_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_INFERENCE_CONTEXT_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/decl.h" +#include "common/type.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { + +// Class manages context for type inferences in the type checker. +// TODO(uncreated-issue/72): for now, just checks assignability for concrete types. +// Support for finding substitutions of type parameters will be added in a +// follow-up CL. +class TypeInferenceContext { + public: + // Convenience alias for an instance map for type parameters mapped to type + // vars in a given context. + // + // This should be treated as opaque, the client should not manually modify. + using InstanceMap = absl::flat_hash_map; + + struct OverloadResolution { + Type result_type; + std::vector overloads; + }; + + private: + // Alias for a map from type var name to the type it is bound to. + // + // Used for prospective substitutions during type inference to make progress + // without affecting final assigned types. + using SubstitutionMap = absl::flat_hash_map; + + public: + // Helper class for managing several dependent type assignability checks. + // + // Note: while allowed, updating multiple AssignabilityContexts concurrently + // can lead to inconsistencies in the final type bindings. + class AssignabilityContext { + public: + // Checks if `from` is assignable to `to` with the current type + // substitutions and any additional prospective substitutions in the parent + // inference context. + bool IsAssignable(const Type& from, const Type& to); + + // Applies any prospective type assignments to the parent inference context. + // + // This should only be called after all assignability checks have completed. + // + // Leaves the AssignabilityContext in the starting state (i.e. no + // prospective substitutions). + void UpdateInferredTypeAssignments(); + + // Return the AssignabilityContext to the starting state (i.e. no + // prospective substitutions). + void Reset(); + + private: + explicit AssignabilityContext(TypeInferenceContext& inference_context) + : inference_context_(inference_context) {} + + AssignabilityContext(const AssignabilityContext&) = delete; + AssignabilityContext& operator=(const AssignabilityContext&) = delete; + AssignabilityContext(AssignabilityContext&&) = delete; + AssignabilityContext& operator=(AssignabilityContext&&) = delete; + + friend class TypeInferenceContext; + + TypeInferenceContext& inference_context_; + SubstitutionMap prospective_substitutions_; + }; + + explicit TypeInferenceContext(google::protobuf::Arena* arena, + bool enable_legacy_null_assignment = true) + : arena_(arena), + enable_legacy_null_assignment_(enable_legacy_null_assignment) {} + + // Creates a new AssignabilityContext for the current inference context. + // + // This is intended for managing several dependent type assignability checks + // that should only be added to the final type bindings if all checks succeed. + // + // Note: while allowed, updating multiple AssignabilityContexts concurrently + // can lead to inconsistencies in the final type bindings. + AssignabilityContext CreateAssignabilityContext() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AssignabilityContext(*this); + } + // Resolves any remaining type parameters in the given type to a concrete + // type or dyn. + Type FinalizeType(const Type& type) const { + return FullySubstitute(type, /*free_to_dyn=*/true); + } + + // Recursively apply any substitutions to the given type. + Type FullySubstitute(const Type& type, bool free_to_dyn = false) const; + + // Replace any generic type parameters in the given type with specific type + // variables. Internally, type variables are just a unique string parameter + // name. + Type InstantiateTypeParams(const Type& type); + + // Overload for function overload types that need coordination across + // multiple function parameters. + Type InstantiateTypeParams(const Type& type, InstanceMap& substitutions); + + // Resolves the applicable overloads for the given function call given the + // inferred argument types. + // + // If found, returns the result type and the list of applicable overloads. + absl::optional ResolveOverload( + const FunctionDecl& decl, absl::Span argument_types, + bool is_receiver); + + // Checks if `from` is assignable to `to`. + bool IsAssignable(const Type& from, const Type& to); + + std::string DebugString() const; + + private: + struct TypeVar { + absl::optional type; + absl::string_view name; + }; + + // Relative generality between two types. + enum class RelativeGenerality { + kMoreGeneral, + // Note: kLessGeneral does not imply it is definitely more specific, only + // that we cannot determine if equivalent or more general. + kLessGeneral, + kEquivalent, + }; + + absl::string_view NewTypeVar(absl::string_view name = "") { + next_type_parameter_id_++; + auto inserted = type_parameter_bindings_.insert( + {absl::StrCat("T%", next_type_parameter_id_), {absl::nullopt, name}}); + ABSL_DCHECK(inserted.second); + return inserted.first->first; + } + + // Returns true if the two types are equivalent with the current type + // substitutions. + bool TypeEquivalent(const Type& a, const Type& b); + + // Returns true if `from` is assignable to `to` with the current type + // substitutions and any additional prospective substitutions. + // + // `prospective_substitutions` is a map from type var name to the type it + // should be bound to in the current context, augmenting any existing + // substitutions. + // + // If the types are not assignable, returns false and leaves + // `prospective_substitutions` unmodified. + // + // If the types are assignable, returns true and updates + // `prospective_substitutions` with any new type parameter bindings. + bool IsAssignableInternal(const Type& from, const Type& to, + SubstitutionMap& prospective_substitutions); + + bool IsAssignableWithConstraints(const Type& from, const Type& to, + SubstitutionMap& prospective_substitutions); + + // Relative generality of `from` as compared to `to` with the current type + // substitutions and any additional prospective substitutions. + // + // Generality is only defined as a partial ordering. Some types are + // incomparable. However we only need to know if a type is definitely more + // general or not. + RelativeGenerality CompareGenerality( + const Type& from, const Type& to, + const SubstitutionMap& prospective_substitutions) const; + + Type Substitute(const Type& type, const SubstitutionMap& substitutions) const; + + bool OccursWithin(absl::string_view var_name, const Type& type, + const SubstitutionMap& substitutions) const; + + void UpdateTypeParameterBindings( + const SubstitutionMap& prospective_substitutions); + + // Map from type var parameter name to the type it is bound to. + // + // Type var parameters are formatted as "T%" to avoid collisions with + // provided type parameter names. + // + // node_hash_map is used to preserve pointer stability for use with + // TypeParamType. + // + // Type parameter instances should be resolved to a concrete type during type + // checking to remove the lifecycle dependency on the inference context + // instance. + // + // nullopt signifies a free type variable. + absl::node_hash_map type_parameter_bindings_; + int64_t next_type_parameter_id_ = 0; + google::protobuf::Arena* arena_; + bool enable_legacy_null_assignment_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_INFERENCE_CONTEXT_H_ diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc new file mode 100644 index 000000000..458d08ff1 --- /dev/null +++ b/checker/internal/type_inference_context_test.cc @@ -0,0 +1,850 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_inference_context.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/types/optional.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::SafeMatcherCast; +using ::testing::SizeIs; + +MATCHER_P(IsTypeParam, param, "") { + const Type& got = arg; + if (got.kind() != TypeKind::kTypeParam) { + return false; + } + TypeParamType type = got.GetTypeParam(); + + return type.name() == param; +} + +MATCHER_P(IsListType, elems_matcher, "") { + const Type& got = arg; + if (got.kind() != TypeKind::kList) { + return false; + } + ListType type = got.GetList(); + + Type elem = type.element(); + return SafeMatcherCast(elems_matcher) + .MatchAndExplain(elem, result_listener); +} + +MATCHER_P2(IsMapType, key_matcher, value_matcher, "") { + const Type& got = arg; + if (got.kind() != TypeKind::kMap) { + return false; + } + MapType type = got.GetMap(); + + Type key = type.key(); + Type value = type.value(); + return SafeMatcherCast(key_matcher) + .MatchAndExplain(key, result_listener) && + SafeMatcherCast(value_matcher) + .MatchAndExplain(value, result_listener); +} + +MATCHER_P(IsTypeKind, kind, "") { + const Type& got = arg; + TypeKind want_kind = kind; + if (got.kind() == want_kind) { + return true; + } + *result_listener << "got: " << TypeKindToString(got.kind()); + *result_listener << "\n"; + *result_listener << "wanted: " << TypeKindToString(want_kind); + return false; +} + +MATCHER_P(IsTypeType, matcher, "") { + const Type& got = arg; + + if (got.kind() != TypeKind::kType) { + return false; + } + + TypeType type_type = got.GetType(); + if (type_type.GetParameters().size() != 1) { + return false; + } + + return SafeMatcherCast(matcher).MatchAndExplain(got.GetParameters()[0], + result_listener); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParams) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type type = context.InstantiateTypeParams(TypeParamType("MyType")); + EXPECT_THAT(type, IsTypeParam("T%1")); + Type type2 = context.InstantiateTypeParams(TypeParamType("MyType")); + EXPECT_THAT(type2, IsTypeParam("T%2")); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsWithSubstitutions) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + TypeInferenceContext::InstanceMap instance_map; + Type type = + context.InstantiateTypeParams(TypeParamType("MyType"), instance_map); + EXPECT_THAT(type, IsTypeParam("T%1")); + Type type2 = + context.InstantiateTypeParams(TypeParamType("MyType"), instance_map); + EXPECT_THAT(type2, IsTypeParam("T%1")); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsUnparameterized) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type type = context.InstantiateTypeParams(IntType()); + EXPECT_TRUE(type.IsInt()); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsList) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type list_type = ListType(&arena, TypeParamType("MyType")); + + Type type = context.InstantiateTypeParams(list_type); + EXPECT_THAT(type, IsListType(IsTypeParam("T%1"))); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsListPrimitive) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type list_type = ListType(&arena, IntType()); + + Type type = context.InstantiateTypeParams(list_type); + EXPECT_THAT(type, IsListType(IsTypeKind(TypeKind::kInt))); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsMap) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type map_type = MapType(&arena, TypeParamType("K"), TypeParamType("V")); + + Type type = context.InstantiateTypeParams(map_type); + EXPECT_THAT(type, IsMapType(IsTypeParam("T%1"), IsTypeParam("T%2"))); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsMapSameParam) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type map_type = MapType(&arena, TypeParamType("E"), TypeParamType("E")); + + Type type = context.InstantiateTypeParams(map_type); + EXPECT_THAT(type, IsMapType(IsTypeParam("T%1"), IsTypeParam("T%1"))); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsMapPrimitive) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type map_type = MapType(&arena, StringType(), IntType()); + + Type type = context.InstantiateTypeParams(map_type); + EXPECT_THAT(type, IsMapType(IsTypeKind(TypeKind::kString), + IsTypeKind(TypeKind::kInt))); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsType) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type type_type = TypeType(&arena, TypeParamType("T")); + + Type type = context.InstantiateTypeParams(type_type); + EXPECT_THAT(type, IsTypeType(IsTypeParam("T%1"))); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsTypeEmpty) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + Type type_type = TypeType(); + + Type type = context.InstantiateTypeParams(type_type); + EXPECT_THAT(type, IsTypeKind(TypeKind::kType)); + EXPECT_THAT(type.AsType()->GetParameters(), IsEmpty()); +} + +TEST(TypeInferenceContextTest, InstantiateTypeParamsOpaque) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + std::vector parameters = {TypeParamType("T"), IntType(), + TypeParamType("U"), TypeParamType("T")}; + + Type type_type = OpaqueType(&arena, "MyTuple", parameters); + + Type type = context.InstantiateTypeParams(type_type); + ASSERT_THAT(type, IsTypeKind(TypeKind::kOpaque)); + EXPECT_EQ(type.AsOpaque()->name(), "MyTuple"); + EXPECT_THAT(type.AsOpaque()->GetParameters(), + ElementsAre(IsTypeParam("T%1"), IsTypeKind(TypeKind::kInt), + IsTypeParam("T%2"), IsTypeParam("T%1"))); +} + +// TODO(uncreated-issue/72): Does not consider any substitutions based on type +// inferences yet. +TEST(TypeInferenceContextTest, OpaqueTypeAssignable) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + std::vector parameters = {TypeParamType("T"), IntType()}; + + Type type_type = OpaqueType(&arena, "MyTuple", parameters); + + Type type = context.InstantiateTypeParams(type_type); + ASSERT_THAT(type, IsTypeKind(TypeKind::kOpaque)); + EXPECT_TRUE(context.IsAssignable(type, type)); +} + +TEST(TypeInferenceContextTest, WrapperTypeAssignable) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + EXPECT_TRUE(context.IsAssignable(StringType(), StringWrapperType())); + EXPECT_TRUE(context.IsAssignable(NullType(), StringWrapperType())); +} + +TEST(TypeInferenceContextTest, MismatchedTypeNotAssignable) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + EXPECT_FALSE(context.IsAssignable(IntType(), StringWrapperType())); +} + +TEST(TypeInferenceContextTest, OverloadResolution) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + auto decl, + MakeFunctionDecl( + "foo", + MakeOverloadDecl("foo_int_int", IntType(), IntType(), IntType()), + MakeOverloadDecl("foo_double_double", DoubleType(), DoubleType(), + DoubleType()))); + + auto resolution = context.ResolveOverload(decl, {IntType(), IntType()}, + /*is_receiver=*/false); + ASSERT_TRUE(resolution.has_value()); + EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); + EXPECT_THAT(resolution->overloads, SizeIs(1)); +} + +TEST(TypeInferenceContextTest, MultipleOverloadsResultTypeDyn) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + auto decl, + MakeFunctionDecl( + "foo", + MakeOverloadDecl("foo_int_int", IntType(), IntType(), IntType()), + MakeOverloadDecl("foo_double_double", DoubleType(), DoubleType(), + DoubleType()))); + + auto resolution = context.ResolveOverload(decl, {DynType(), DynType()}, + /*is_receiver=*/false); + ASSERT_TRUE(resolution.has_value()); + EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kDyn)); + EXPECT_THAT(resolution->overloads, SizeIs(2)); +} + +MATCHER_P(IsOverloadDecl, name, "") { + const OverloadDecl& got = arg; + return got.id() == name; +} + +TEST(TypeInferenceContextTest, ResolveOverloadBasic) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl( + "_+_", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()), + MakeOverloadDecl("add_double", DoubleType(), DoubleType(), + DoubleType()))); + + std::optional resolution = + context.ResolveOverload(decl, {IntType(), IntType()}, false); + ASSERT_TRUE(resolution.has_value()); + EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_int"))); +} + +TEST(TypeInferenceContextTest, ResolveOverloadFails) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl( + "_+_", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()), + MakeOverloadDecl("add_double", DoubleType(), DoubleType(), + DoubleType()))); + + std::optional resolution = + context.ResolveOverload(decl, {IntType(), DoubleType()}, false); + ASSERT_FALSE(resolution.has_value()); +} + +TEST(TypeInferenceContextTest, ResolveOverloadWithParamsNoMatch) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl( + "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), + TypeParamType("A")))); + + std::optional resolution = + context.ResolveOverload(decl, {IntType(), DoubleType()}, false); + ASSERT_FALSE(resolution.has_value()); +} + +TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl( + "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), + TypeParamType("A")))); + + std::optional resolution = + context.ResolveOverload(decl, {list_of_a, list_of_a}, false); + ASSERT_TRUE(resolution.has_value()) << context.DebugString(); +} + +TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch2) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + Type list_of_int = ListType(&arena, IntType()); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl( + "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), + TypeParamType("A")))); + + std::optional resolution = + context.ResolveOverload(decl, {list_of_a, list_of_int}, false); + ASSERT_TRUE(resolution.has_value()) << context.DebugString(); + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); +} + +TEST(TypeInferenceContextTest, ResolveOverloadWithParamsMatches) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl( + "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), + TypeParamType("A")))); + + std::optional resolution = + context.ResolveOverload(decl, {IntType(), IntType()}, false); + ASSERT_TRUE(resolution.has_value()); + EXPECT_TRUE(resolution->result_type.IsBool()); + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); +} + +TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, + list_of_a))); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + std::optional resolution = + context.ResolveOverload( + decl, {list_of_a_instance, ListType(&arena, IntType())}, false); + ASSERT_TRUE(resolution.has_value()); + EXPECT_TRUE(resolution->result_type.IsList()); + + EXPECT_THAT( + context.FinalizeType(resolution->result_type).AsList()->GetElement(), + IsTypeKind(TypeKind::kInt)) + << context.DebugString(); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_list"))); + + std::optional resolution2 = + context.ResolveOverload( + decl, {ListType(&arena, IntType()), list_of_a_instance}, false); + ASSERT_TRUE(resolution2.has_value()); + EXPECT_TRUE(resolution2->result_type.IsList()); + + EXPECT_THAT( + context.FinalizeType(resolution2->result_type).AsList()->GetElement(), + IsTypeKind(TypeKind::kInt)) + << context.DebugString(); + + EXPECT_THAT(resolution2->overloads, ElementsAre(IsOverloadDecl("add_list"))); +} + +TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsNoMatch) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, + list_of_a))); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + std::optional resolution = + context.ResolveOverload(decl, {list_of_a_instance, IntType()}, false); + EXPECT_FALSE(resolution.has_value()); +} + +TEST(TypeInferenceContextTest, InferencesAccumulate) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, + list_of_a))); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + std::optional resolution1 = + context.ResolveOverload(decl, {list_of_a_instance, list_of_a_instance}, + false); + ASSERT_TRUE(resolution1.has_value()); + EXPECT_TRUE(resolution1->result_type.IsList()); + + std::optional resolution2 = + context.ResolveOverload( + decl, {resolution1->result_type, ListType(&arena, IntType())}, false); + ASSERT_TRUE(resolution2.has_value()); + EXPECT_TRUE(resolution2->result_type.IsList()); + + EXPECT_THAT( + context.FinalizeType(resolution2->result_type).AsList()->GetElement(), + IsTypeKind(TypeKind::kInt)); + + EXPECT_THAT(resolution2->overloads, ElementsAre(IsOverloadDecl("add_list"))); +} + +TEST(TypeInferenceContextTest, DebugString) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + Type list_of_int = ListType(&arena, IntType()); + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, + list_of_a))); + + std::optional resolution = + context.ResolveOverload(decl, {list_of_int, list_of_int}, false); + ASSERT_TRUE(resolution.has_value()); + EXPECT_TRUE(resolution->result_type.IsList()); + + EXPECT_EQ(context.DebugString(), "type_parameter_bindings: T%1 (A) -> int"); +} + +struct TypeInferenceContextWrapperTypesTestCase { + Type wrapper_type; + Type wrapped_primitive_type; +}; + +class TypeInferenceContextWrapperTypesTest + : public ::testing::TestWithParam< + TypeInferenceContextWrapperTypesTestCase> { + public: + TypeInferenceContextWrapperTypesTest() : context_(&arena_) { + auto decl = MakeFunctionDecl( + "_?_:_", + MakeOverloadDecl("ternary", + /*result_type=*/TypeParamType("A"), BoolType(), + TypeParamType("A"), TypeParamType("A"))); + + ABSL_CHECK_OK(decl.status()); + ternary_decl_ = *std::move(decl); + } + + protected: + google::protobuf::Arena arena_; + TypeInferenceContext context_{&arena_}; + FunctionDecl ternary_decl_; +}; + +TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { + const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); + + std::optional resolution = + context_.ResolveOverload(ternary_decl_, + {BoolType(), test_case.wrapper_type, + test_case.wrapped_primitive_type}, + false); + ASSERT_TRUE(resolution.has_value()); + + EXPECT_THAT(context_.FinalizeType(resolution->result_type), + IsTypeKind(test_case.wrapper_type.kind())) + << context_.DebugString(); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); +} + +TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { + const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); + + std::optional resolution = + context_.ResolveOverload( + ternary_decl_, + {BoolType(), test_case.wrapper_type, test_case.wrapper_type}, false); + ASSERT_TRUE(resolution.has_value()); + + EXPECT_THAT(context_.FinalizeType(resolution->result_type), + IsTypeKind(test_case.wrapper_type.kind())) + << context_.DebugString(); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); +} + +TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { + const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); + + std::optional resolution = + context_.ResolveOverload(ternary_decl_, + {BoolType(), test_case.wrapper_type, NullType()}, + false); + ASSERT_TRUE(resolution.has_value()); + + EXPECT_THAT(context_.FinalizeType(resolution->result_type), + IsTypeKind(test_case.wrapper_type.kind())) + << context_.DebugString(); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); +} + +TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { + const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); + + std::optional resolution = + context_.ResolveOverload(ternary_decl_, + {BoolType(), NullType(), test_case.wrapper_type}, + false); + ASSERT_TRUE(resolution.has_value()); + + EXPECT_THAT(context_.FinalizeType(resolution->result_type), + IsTypeKind(test_case.wrapper_type.kind())) + << context_.DebugString(); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); +} + +TEST_P(TypeInferenceContextWrapperTypesTest, PrimitiveWidens) { + const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); + + std::optional resolution = + context_.ResolveOverload(ternary_decl_, + {BoolType(), test_case.wrapped_primitive_type, + test_case.wrapper_type}, + false); + ASSERT_TRUE(resolution.has_value()); + + EXPECT_THAT(context_.FinalizeType(resolution->result_type), + IsTypeKind(test_case.wrapper_type.kind())) + << context_.DebugString(); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); +} + +INSTANTIATE_TEST_SUITE_P( + Types, TypeInferenceContextWrapperTypesTest, + ::testing::Values( + TypeInferenceContextWrapperTypesTestCase{IntWrapperType(), IntType()}, + TypeInferenceContextWrapperTypesTestCase{UintWrapperType(), UintType()}, + TypeInferenceContextWrapperTypesTestCase{DoubleWrapperType(), + DoubleType()}, + TypeInferenceContextWrapperTypesTestCase{StringWrapperType(), + StringType()}, + TypeInferenceContextWrapperTypesTestCase{BytesWrapperType(), + BytesType()}, + TypeInferenceContextWrapperTypesTestCase{BoolWrapperType(), BoolType()}, + TypeInferenceContextWrapperTypesTestCase{DynType(), IntType()})); + +TEST(TypeInferenceContextTest, ResolveOverloadWithUnionTypePromotion) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl( + "_?_:_", + MakeOverloadDecl("ternary", + /*result_type=*/TypeParamType("A"), BoolType(), + TypeParamType("A"), TypeParamType("A")))); + + std::optional resolution = + context.ResolveOverload(decl, {BoolType(), NullType(), IntWrapperType()}, + false); + ASSERT_TRUE(resolution.has_value()); + + EXPECT_THAT(context.FinalizeType(resolution->result_type), + IsTypeKind(TypeKind::kIntWrapper)) + << context.DebugString(); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); +} + +// TypeType has special handling (differently-parameterized type-types are +// always assignable for the sake of comparisons). +TEST(TypeInferenceContextTest, ResolveOverloadWithTypeType) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl decl, + MakeFunctionDecl("type", + MakeOverloadDecl("to_type", + /*result_type=*/ + TypeType(&arena, TypeParamType("A")), + TypeParamType("A")))); + + std::optional resolution = + context.ResolveOverload(decl, {StringType()}, false); + ASSERT_TRUE(resolution.has_value()); + + auto result_type = context.FinalizeType(resolution->result_type); + ASSERT_THAT(result_type, IsTypeKind(TypeKind::kType)); + + EXPECT_THAT(result_type.AsType()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kString))); + + EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("to_type"))); +} + +TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl to_type_decl, + MakeFunctionDecl("type", + MakeOverloadDecl("to_type", + /*result_type=*/ + TypeType(&arena, TypeParamType("A")), + TypeParamType("A")))); + + ASSERT_OK_AND_ASSIGN( + FunctionDecl equals_decl, + MakeFunctionDecl("_==_", MakeOverloadDecl("equals", + /*result_type=*/ + BoolType(), TypeParamType("A"), + TypeParamType("A")))); + + std::optional resolution = + context.ResolveOverload(to_type_decl, {StringType()}, false); + ASSERT_TRUE(resolution.has_value()); + + auto lhs_result_type = resolution->result_type; + ASSERT_THAT(lhs_result_type, IsTypeKind(TypeKind::kType)); + + resolution = context.ResolveOverload(to_type_decl, {IntType()}, false); + ASSERT_TRUE(resolution.has_value()); + + auto rhs_result_type = resolution->result_type; + ASSERT_THAT(rhs_result_type, IsTypeKind(TypeKind::kType)); + + resolution = context.ResolveOverload( + equals_decl, {rhs_result_type, lhs_result_type}, false); + ASSERT_TRUE(resolution.has_value()); + auto result_type = context.FinalizeType(resolution->result_type); + ASSERT_THAT(result_type, IsTypeKind(TypeKind::kBool)); + + auto inferred_lhs = context.FinalizeType(lhs_result_type); + auto inferred_rhs = context.FinalizeType(rhs_result_type); + + ASSERT_THAT(inferred_rhs, IsTypeKind(TypeKind::kType)); + ASSERT_THAT(inferred_lhs, IsTypeKind(TypeKind::kType)); + + ASSERT_THAT(inferred_lhs.AsType()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kString))); + ASSERT_THAT(inferred_rhs.AsType()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kInt))); +} + +TEST(TypeInferenceContextTest, AssignabilityContext) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntWrapperType(), list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + EXPECT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kIntWrapper)); +} + +TEST(TypeInferenceContextTest, AssignabilityContextAbstractType) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntType()), + list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, DynType()), + list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + ASSERT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kOpaque)); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), + "optional_type"); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kDyn))); +} + +TEST(TypeInferenceContextTest, AssignabilityContextAbstractTypeWrapper) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntType()), + list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntWrapperType()), + list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + ASSERT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kOpaque)); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), + "optional_type"); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kIntWrapper))); +} + +TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntWrapperType(), list_of_a_instance.AsList()->GetElement())); + } + + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + EXPECT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kDyn)); +} + +TEST(TypeInferenceContextTest, AssignabilityContextReset) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + assignability_context.Reset(); + EXPECT_TRUE(assignability_context.IsAssignable( + DoubleType(), list_of_a_instance.AsList()->GetElement())); + assignability_context.UpdateInferredTypeAssignments(); + } + + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + EXPECT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kDouble)); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/optional.cc b/checker/optional.cc new file mode 100644 index 000000000..d41e68aa1 --- /dev/null +++ b/checker/optional.cc @@ -0,0 +1,245 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/optional.h" + +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "base/builtins.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "internal/status_macros.h" + +namespace cel { +namespace { + +Type OptionalOfV() { + static const absl::NoDestructor kInstance( + checker_internal::BuiltinsArena(), TypeParamType("V")); + + return *kInstance; +} + +Type TypeOfOptionalOfV() { + static const absl::NoDestructor kInstance( + checker_internal::BuiltinsArena(), OptionalOfV()); + + return *kInstance; +} + +Type ListOfV() { + static const absl::NoDestructor kInstance( + checker_internal::BuiltinsArena(), TypeParamType("V")); + + return *kInstance; +} + +Type OptionalListOfV() { + static const absl::NoDestructor kInstance( + checker_internal::BuiltinsArena(), ListOfV()); + + return *kInstance; +} + +Type MapOfKV() { + static const absl::NoDestructor kInstance( + checker_internal::BuiltinsArena(), TypeParamType("K"), + TypeParamType("V")); + + return *kInstance; +} + +Type OptionalMapOfKV() { + static const absl::NoDestructor kInstance( + checker_internal::BuiltinsArena(), MapOfKV()); + + return *kInstance; +} + +class OptionalNames { + public: + static constexpr char kOptionalType[] = "optional_type"; + static constexpr char kOptionalOf[] = "optional.of"; + static constexpr char kOptionalOfNonZeroValue[] = "optional.ofNonZeroValue"; + static constexpr char kOptionalNone[] = "optional.none"; + static constexpr char kOptionalValue[] = "value"; + static constexpr char kOptionalHasValue[] = "hasValue"; + static constexpr char kOptionalOr[] = "or"; + static constexpr char kOptionalOrValue[] = "orValue"; + static constexpr char kOptionalSelect[] = "_?._"; + static constexpr char kOptionalIndex[] = "_[?_]"; + static constexpr char kOptionalFirst[] = "first"; + static constexpr char kOptionalLast[] = "last"; +}; + +class OptionalOverloads { + public: + // Creation + static constexpr char kOptionalOf[] = "optional_of"; + static constexpr char kOptionalOfNonZeroValue[] = "optional_ofNonZeroValue"; + static constexpr char kOptionalNone[] = "optional_none"; + // Basic accessors + static constexpr char kOptionalValue[] = "optional_value"; + static constexpr char kOptionalHasValue[] = "optional_hasValue"; + // Chaining `or` overloads. + static constexpr char kOptionalOr[] = "optional_or_optional"; + static constexpr char kOptionalOrValue[] = "optional_orValue_value"; + // Selection + static constexpr char kOptionalSelect[] = "select_optional_field"; + // Indexing + static constexpr char kListOptionalIndexInt[] = "list_optindex_optional_int"; + static constexpr char kOptionalListOptionalIndexInt[] = + "optional_list_optindex_optional_int"; + static constexpr char kMapOptionalIndexValue[] = + "map_optindex_optional_value"; + static constexpr char kOptionalMapOptionalIndexValue[] = + "optional_map_optindex_optional_value"; + static constexpr char kListFirst[] = "list_first"; + static constexpr char kListLast[] = "list_last"; + // Syntactic sugar for chained indexing. + static constexpr char kOptionalListIndexInt[] = "optional_list_index_int"; + static constexpr char kOptionalMapIndexValue[] = "optional_map_index_value"; +}; + +absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) { + CEL_ASSIGN_OR_RETURN( + auto of, + MakeFunctionDecl(OptionalNames::kOptionalOf, + MakeOverloadDecl(OptionalOverloads::kOptionalOf, + OptionalOfV(), TypeParamType("V")))); + + CEL_ASSIGN_OR_RETURN( + auto of_non_zero, + MakeFunctionDecl( + OptionalNames::kOptionalOfNonZeroValue, + MakeOverloadDecl(OptionalOverloads::kOptionalOfNonZeroValue, + OptionalOfV(), TypeParamType("V")))); + + CEL_ASSIGN_OR_RETURN( + auto none, + MakeFunctionDecl( + OptionalNames::kOptionalNone, + MakeOverloadDecl(OptionalOverloads::kOptionalNone, OptionalOfV()))); + + CEL_ASSIGN_OR_RETURN( + auto value, MakeFunctionDecl(OptionalNames::kOptionalValue, + MakeMemberOverloadDecl( + OptionalOverloads::kOptionalValue, + TypeParamType("V"), OptionalOfV()))); + + CEL_ASSIGN_OR_RETURN( + auto has_value, MakeFunctionDecl(OptionalNames::kOptionalHasValue, + MakeMemberOverloadDecl( + OptionalOverloads::kOptionalHasValue, + BoolType(), OptionalOfV()))); + + CEL_ASSIGN_OR_RETURN( + auto or_, + MakeFunctionDecl( + OptionalNames::kOptionalOr, + MakeMemberOverloadDecl(OptionalOverloads::kOptionalOr, OptionalOfV(), + OptionalOfV(), OptionalOfV()))); + + CEL_ASSIGN_OR_RETURN(auto or_value, + MakeFunctionDecl(OptionalNames::kOptionalOrValue, + MakeMemberOverloadDecl( + OptionalOverloads::kOptionalOrValue, + TypeParamType("V"), OptionalOfV(), + TypeParamType("V")))); + + // This is special cased by the type checker -- just adding a Decl to prevent + // accidental user overloading. + CEL_ASSIGN_OR_RETURN( + auto select, + MakeFunctionDecl( + OptionalNames::kOptionalSelect, + MakeOverloadDecl(OptionalOverloads::kOptionalSelect, OptionalOfV(), + DynType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto opt_index, + MakeFunctionDecl( + OptionalNames::kOptionalIndex, + MakeOverloadDecl(OptionalOverloads::kOptionalListOptionalIndexInt, + OptionalOfV(), OptionalListOfV(), IntType()), + MakeOverloadDecl(OptionalOverloads::kListOptionalIndexInt, + OptionalOfV(), ListOfV(), IntType()), + MakeOverloadDecl(OptionalOverloads::kMapOptionalIndexValue, + OptionalOfV(), MapOfKV(), TypeParamType("K")), + MakeOverloadDecl(OptionalOverloads::kOptionalMapOptionalIndexValue, + OptionalOfV(), OptionalMapOfKV(), + TypeParamType("K")))); + + CEL_ASSIGN_OR_RETURN( + auto first, + MakeFunctionDecl(OptionalNames::kOptionalFirst, + MakeMemberOverloadDecl(OptionalOverloads::kListFirst, + OptionalOfV(), ListOfV()))); + + CEL_ASSIGN_OR_RETURN( + auto last, + MakeFunctionDecl(OptionalNames::kOptionalLast, + MakeMemberOverloadDecl(OptionalOverloads::kListLast, + OptionalOfV(), ListOfV()))); + + CEL_ASSIGN_OR_RETURN( + auto index, + MakeFunctionDecl( + cel::builtin::kIndex, + MakeOverloadDecl(OptionalOverloads::kOptionalListIndexInt, + OptionalOfV(), OptionalListOfV(), IntType()), + MakeOverloadDecl(OptionalOverloads::kOptionalMapIndexValue, + OptionalOfV(), OptionalMapOfKV(), + TypeParamType("K")))); + + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl(OptionalNames::kOptionalType, TypeOfOptionalOfV()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(of))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(of_non_zero))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(none))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(value))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(has_value))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(or_))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(or_value))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(opt_index))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(select))); + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(index))); + + if (version == 0 || version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(first))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(last))); + + return absl::OkStatus(); +} + +} // namespace + +CheckerLibrary OptionalCheckerLibrary(int version) { + return CheckerLibrary({ + "optional", + [version](TypeCheckerBuilder& builder) { + return RegisterOptionalDecls(builder, version); + }, + }); +} + +} // namespace cel diff --git a/checker/optional.h b/checker/optional.h new file mode 100644 index 000000000..c96737c31 --- /dev/null +++ b/checker/optional.h @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_OPTIONAL_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_OPTIONAL_H_ + +#include "checker/type_checker_builder.h" + +namespace cel { + +constexpr int kOptionalExtensionLatestVersion = 2; + +// Library for CEL optional definitions. +CheckerLibrary OptionalCheckerLibrary( + int version = kOptionalExtensionLatestVersion); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_OPTIONAL_H_ diff --git a/checker/optional_test.cc b/checker/optional_test.cc new file mode 100644 index 000000000..87c14f0cd --- /dev/null +++ b/checker/optional_test.cc @@ -0,0 +1,339 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/optional.h" + +#include +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/strings/str_join.h" +#include "checker/checker_options.h" +#include "checker/internal/test_ast_helpers.h" +#include "checker/standard_library.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" +#include "common/ast.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::checker_internal::MakeTestParsedAst; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::_; +using ::testing::Contains; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Key; +using ::testing::Not; +using ::testing::Property; +using ::testing::SizeIs; + +MATCHER_P(IsOptionalType, inner_type, "") { + const TypeSpec& type = arg; + if (!type.has_abstract_type()) { + return false; + } + const auto& abs_type = type.abstract_type(); + if (abs_type.name() != "optional_type") { + *result_listener << "expected optional_type, got: " << abs_type.name(); + return false; + } + if (abs_type.parameter_types().size() != 1) { + *result_listener << "unexpected number of parameters: " + << abs_type.parameter_types().size(); + return false; + } + + if (inner_type == abs_type.parameter_types()[0]) { + return true; + } + + *result_listener << "unexpected inner type: " + << abs_type.parameter_types()[0].type_kind().index(); + return false; +} + +TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); + builder->set_container("cel.expr.conformance.proto3"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, + std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("TestAllTypes{}.?single_int64")); + + ASSERT_OK_AND_ASSIGN(auto result, checker->Check(std::move(ast))); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + + ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); + int64_t field_id = checked_ast->root_expr().call_expr().args()[1].id(); + EXPECT_NE(field_id, 0); + + EXPECT_THAT(checked_ast->type_map(), Not(Contains(Key(field_id)))); + EXPECT_THAT(checked_ast->GetTypeOrDyn(checked_ast->root_expr().id()), + IsOptionalType(TypeSpec(PrimitiveType::kInt64))); +} + +struct TestCase { + std::string expr; + testing::Matcher result_type_matcher; + std::string error_substring; +}; + +class OptionalTest : public testing::TestWithParam {}; + +TEST_P(OptionalTest, Runner) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + const TestCase& test_case = GetParam(); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, + std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + + ASSERT_OK_AND_ASSIGN(auto result, checker->Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr(test_case.error_substring)))) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const auto& i) { + absl::StrAppend(out, i.message()); + }); + return; + } + + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << "for expression: " << test_case.expr; + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + + int64_t root_id = checked_ast->root_expr().id(); + + EXPECT_THAT(checked_ast->GetTypeOrDyn(root_id), test_case.result_type_matcher) + << "for expression: " << test_case.expr; +} + +INSTANTIATE_TEST_SUITE_P( + OptionalTests, OptionalTest, + ::testing::Values( + TestCase{ + "optional.of('abc')", + IsOptionalType(TypeSpec(PrimitiveType::kString)), + }, + TestCase{ + "optional.ofNonZeroValue('')", + IsOptionalType(TypeSpec(PrimitiveType::kString)), + }, + TestCase{ + "optional.none()", + IsOptionalType(TypeSpec(DynTypeSpec())), + }, + // Odd case -- the correct result might be a bespoke recursively-defined + // type but CEL doesn't support that. Null is used because it is + // implicitly assignable to optional types. This allows for a recursive + // type to be non-trivial and verify the checker is actually avoiding + // introducing a cyclic type. + TestCase{ + "[optional.none()].map(x, [?x, null, x])", + Eq(TypeSpec(ListTypeSpec(std::make_unique( + ListTypeSpec(std::make_unique(NullTypeSpec())))))), + }, + TestCase{ + "optional.of('abc').hasValue()", + Eq(TypeSpec(PrimitiveType::kBool)), + }, + TestCase{ + "optional.of('abc').value()", + Eq(TypeSpec(PrimitiveType::kString)), + }, + TestCase{ + "type(optional.of('abc')) == optional_type", + Eq(TypeSpec(PrimitiveType::kBool)), + }, + TestCase{ + "type(optional.of('abc')) == optional_type", + Eq(TypeSpec(PrimitiveType::kBool)), + }, + TestCase{ + "optional.of('abc').or(optional.of('def'))", + IsOptionalType(TypeSpec(PrimitiveType::kString)), + }, + TestCase{"optional.of('abc').or(optional.of(1))", _, + "no matching overload for 'or'"}, + TestCase{ + "optional.of('abc').orValue('def')", + Eq(TypeSpec(PrimitiveType::kString)), + }, + TestCase{"optional.of('abc').orValue(1)", _, + "no matching overload for 'orValue'"}, + TestCase{ + "{'k': 'v'}.?k", + IsOptionalType(TypeSpec(PrimitiveType::kString)), + }, + TestCase{"1.?k", _, + "expression of type 'int' cannot be the operand of a select " + "operation"}, + TestCase{ + "{'k': {'k': 'v'}}.?k.?k2", + IsOptionalType(TypeSpec(PrimitiveType::kString)), + }, + TestCase{ + "{'k': {'k': 'v'}}.?k.k2", + IsOptionalType(TypeSpec(PrimitiveType::kString)), + }, + TestCase{"{?'k': optional.of('v')}", + Eq(TypeSpec(MapTypeSpec(std::unique_ptr(new TypeSpec( + PrimitiveType::kString)), + std::unique_ptr(new TypeSpec( + PrimitiveType::kString)))))}, + TestCase{"{'k': 'v', ?'k2': optional.none()}", + Eq(TypeSpec(MapTypeSpec(std::unique_ptr(new TypeSpec( + PrimitiveType::kString)), + std::unique_ptr(new TypeSpec( + PrimitiveType::kString)))))}, + TestCase{"{'k': 'v', ?'k2': 'v'}", _, + "expected type 'optional_type(string)' but found 'string'"}, + TestCase{"[?optional.of('v')]", + Eq(TypeSpec(ListTypeSpec(std::unique_ptr( + new TypeSpec(PrimitiveType::kString)))))}, + TestCase{"['v', ?optional.none()]", + Eq(TypeSpec(ListTypeSpec(std::unique_ptr( + new TypeSpec(PrimitiveType::kString)))))}, + TestCase{"['v1', ?'v2']", _, + "expected type 'optional_type(string)' but found 'string'"}, + TestCase{"[optional.of(dyn('1')), optional.of('2')][0]", + IsOptionalType(TypeSpec(DynTypeSpec()))}, + TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]", + IsOptionalType(TypeSpec(DynTypeSpec()))}, + TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]", + IsOptionalType(TypeSpec(DynTypeSpec()))}, + TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]", + IsOptionalType(TypeSpec(DynTypeSpec()))}, + TestCase{"[optional.of('1'), optional.of(2)][0]", + Eq(TypeSpec(DynTypeSpec()))}, + TestCase{"['v1', ?'v2']", _, + "expected type 'optional_type(string)' but found 'string'"}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: " + "optional.of(1)}", + Eq(TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"[0][?1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"[[0]][?1][?1]", + IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"[[0]][?1][1]", + IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"{0: 1}[?1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"{0: {0: 1}}[?1][?1]", + IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"{0: {0: 1}}[?1][1]", + IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"{0: {0: 1}}[?1]['']", _, "no matching overload for '_[_]'"}, + TestCase{"{0: {0: 1}}[?1][?'']", _, "no matching overload for '_[?_]'"}, + TestCase{"[1, 2, 3].first()", + IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"[1, 2, 3].last()", + IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, + TestCase{"optional.of('abc').optMap(x, x + 'def')", + IsOptionalType(TypeSpec(PrimitiveType::kString))}, + TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", + IsOptionalType(TypeSpec(PrimitiveType::kString))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " + "optional.of(0)}", + Eq(TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes")))}, + // Legacy nullability behaviors. + TestCase{ + "cel.expr.conformance.proto3.TestAllTypes{?single_value: null}", + Eq(TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_value: " + "optional.of(null)}", + Eq(TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 " + "== null", + Eq(TypeSpec(PrimitiveType::kBool))})); + +class OptionalStrictNullAssignmentTest + : public testing::TestWithParam {}; + +TEST_P(OptionalStrictNullAssignmentTest, Runner) { + CheckerOptions options; + options.enable_legacy_null_assignment = false; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + const TestCase& test_case = GetParam(); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, + std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + + ASSERT_OK_AND_ASSIGN(auto result, checker->Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr(test_case.error_substring)))) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const auto& i) { + absl::StrAppend(out, i.message()); + }); + return; + } + + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << "for expression: " << test_case.expr; + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + + int64_t root_id = checked_ast->root_expr().id(); + + EXPECT_THAT(checked_ast->GetTypeOrDyn(root_id), test_case.result_type_matcher) + << "for expression: " << test_case.expr; +} + +INSTANTIATE_TEST_SUITE_P( + OptionalTests, OptionalStrictNullAssignmentTest, + ::testing::Values( + TestCase{ + "cel.expr.conformance.proto3.TestAllTypes{?single_int64: null}", _, + "expected type of field 'single_int64' is 'optional_type(int)' but " + "provided type is 'null_type'"}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 " + "== null", + _, "no matching overload for '_==_'"})); + +} // namespace +} // namespace cel diff --git a/checker/standard_library.cc b/checker/standard_library.cc new file mode 100644 index 000000000..744a171ef --- /dev/null +++ b/checker/standard_library.cc @@ -0,0 +1,864 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/standard_library.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/standard_definitions.h" +#include "common/type.h" +#include "internal/status_macros.h" + +namespace cel { +namespace { + +using ::cel::checker_internal::BuiltinsArena; + +// Arbitrary type parameter name A. +TypeParamType TypeParamA() { return TypeParamType("A"); } + +// Arbitrary type parameter name B. +TypeParamType TypeParamB() { return TypeParamType("B"); } + +Type ListOfA() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), TypeParamA())); + return *kInstance; +} + +Type MapOfAB() { + static absl::NoDestructor kInstance( + MapType(BuiltinsArena(), TypeParamA(), TypeParamB())); + return *kInstance; +} + +Type TypeOfType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), TypeType())); + return *kInstance; +} + +Type TypeOfA() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), TypeParamA())); + return *kInstance; +} + +Type TypeNullType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), NullType())); + return *kInstance; +} + +Type TypeBoolType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), BoolType())); + return *kInstance; +} + +Type TypeIntType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), IntType())); + return *kInstance; +} + +Type TypeUintType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), UintType())); + return *kInstance; +} + +Type TypeDoubleType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), DoubleType())); + return *kInstance; +} + +Type TypeStringType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), StringType())); + return *kInstance; +} + +Type TypeBytesType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), BytesType())); + return *kInstance; +} + +Type TypeDynType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), DynType())); + return *kInstance; +} + +Type TypeListType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), ListOfA())); + return *kInstance; +} + +Type TypeMapType() { + static absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), MapOfAB())); + return *kInstance; +} + +absl::Status AddArithmeticOps(TypeCheckerBuilder& builder) { + FunctionDecl add_op; + add_op.set_name(StandardFunctions::kAdd); + CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kAddInt, IntType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kAddDouble, DoubleType(), + DoubleType(), DoubleType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kAddUint, UintType(), UintType(), UintType()))); + // timestamp math + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kAddDurationDuration, + DurationType(), DurationType(), DurationType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kAddDurationTimestamp, + TimestampType(), DurationType(), TimestampType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kAddTimestampDuration, + TimestampType(), TimestampType(), DurationType()))); + // string concat + CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kAddBytes, BytesType(), BytesType(), BytesType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kAddString, StringType(), + StringType(), StringType()))); + // list concat + CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kAddList, ListOfA(), ListOfA(), ListOfA()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(add_op))); + + FunctionDecl subtract_op; + subtract_op.set_name(StandardFunctions::kSubtract); + CEL_RETURN_IF_ERROR(subtract_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kSubtractInt, IntType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(subtract_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kSubtractUint, UintType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(subtract_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kSubtractDouble, DoubleType(), + DoubleType(), DoubleType()))); + // Timestamp math + CEL_RETURN_IF_ERROR(subtract_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kSubtractDurationDuration, + DurationType(), DurationType(), DurationType()))); + CEL_RETURN_IF_ERROR(subtract_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kSubtractTimestampDuration, + TimestampType(), TimestampType(), DurationType()))); + CEL_RETURN_IF_ERROR(subtract_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kSubtractTimestampTimestamp, + DurationType(), TimestampType(), TimestampType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(subtract_op))); + + FunctionDecl multiply_op; + multiply_op.set_name(StandardFunctions::kMultiply); + CEL_RETURN_IF_ERROR(multiply_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kMultiplyInt, IntType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(multiply_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kMultiplyUint, UintType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(multiply_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kMultiplyDouble, DoubleType(), + DoubleType(), DoubleType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(multiply_op))); + + FunctionDecl division_op; + division_op.set_name(StandardFunctions::kDivide); + CEL_RETURN_IF_ERROR(division_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDivideInt, IntType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(division_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDivideUint, UintType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(division_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kDivideDouble, DoubleType(), + DoubleType(), DoubleType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(division_op))); + + FunctionDecl modulo_op; + modulo_op.set_name(StandardFunctions::kModulo); + CEL_RETURN_IF_ERROR(modulo_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kModuloInt, IntType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(modulo_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kModuloUint, UintType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(modulo_op))); + + FunctionDecl negate_op; + negate_op.set_name(StandardFunctions::kNeg); + CEL_RETURN_IF_ERROR(negate_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kNegateInt, IntType(), IntType()))); + CEL_RETURN_IF_ERROR(negate_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kNegateDouble, DoubleType(), DoubleType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(negate_op))); + + return absl::OkStatus(); +} + +absl::Status AddLogicalOps(TypeCheckerBuilder& builder) { + FunctionDecl not_op; + not_op.set_name(StandardFunctions::kNot); + CEL_RETURN_IF_ERROR(not_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kNot, BoolType(), BoolType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(not_op))); + + FunctionDecl and_op; + and_op.set_name(StandardFunctions::kAnd); + CEL_RETURN_IF_ERROR(and_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kAnd, BoolType(), BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(and_op))); + + FunctionDecl or_op; + or_op.set_name(StandardFunctions::kOr); + CEL_RETURN_IF_ERROR(or_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kOr, BoolType(), BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(or_op))); + + FunctionDecl conditional_op; + conditional_op.set_name(StandardFunctions::kTernary); + CEL_RETURN_IF_ERROR(conditional_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kConditional, TypeParamA(), + BoolType(), TypeParamA(), TypeParamA()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(conditional_op))); + + FunctionDecl not_strictly_false; + not_strictly_false.set_name(StandardFunctions::kNotStrictlyFalse); + CEL_RETURN_IF_ERROR(not_strictly_false.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kNotStrictlyFalse, BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(not_strictly_false))); + + FunctionDecl not_strictly_false_deprecated; + not_strictly_false_deprecated.set_name( + StandardFunctions::kNotStrictlyFalseDeprecated); + CEL_RETURN_IF_ERROR(not_strictly_false_deprecated.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kNotStrictlyFalseDeprecated, + BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR( + builder.AddFunction(std::move(not_strictly_false_deprecated))); + + return absl::OkStatus(); +} + +absl::Status AddTypeConversions(TypeCheckerBuilder& builder) { + FunctionDecl to_dyn; + to_dyn.set_name(StandardFunctions::kDyn); + CEL_RETURN_IF_ERROR(to_dyn.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kToDyn, DynType(), TypeParamA()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_dyn))); + + // Uint + FunctionDecl to_uint; + to_uint.set_name(StandardFunctions::kUint); + CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kUintToUint, UintType(), UintType()))); + CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kIntToUint, UintType(), IntType()))); + CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDoubleToUint, UintType(), DoubleType()))); + CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToUint, UintType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_uint))); + + // Int + FunctionDecl to_int; + to_int.set_name(StandardFunctions::kInt); + CEL_RETURN_IF_ERROR(to_int.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kIntToInt, IntType(), IntType()))); + CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kUintToInt, IntType(), UintType()))); + CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDoubleToInt, IntType(), DoubleType()))); + CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToInt, IntType(), StringType()))); + CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kTimestampToInt, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDurationToInt, IntType(), DurationType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_int))); + + FunctionDecl to_double; + to_double.set_name(StandardFunctions::kDouble); + CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDoubleToDouble, DoubleType(), DoubleType()))); + CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kIntToDouble, DoubleType(), IntType()))); + CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kUintToDouble, DoubleType(), UintType()))); + CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToDouble, DoubleType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_double))); + + FunctionDecl to_bool; + to_bool.set_name("bool"); + CEL_RETURN_IF_ERROR(to_bool.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kBoolToBool, BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(to_bool.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToBool, BoolType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_bool))); + + FunctionDecl to_string; + to_string.set_name(StandardFunctions::kString); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToString, StringType(), StringType()))); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kBytesToString, StringType(), BytesType()))); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kBoolToString, StringType(), BoolType()))); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDoubleToString, StringType(), DoubleType()))); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kIntToString, StringType(), IntType()))); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kUintToString, StringType(), UintType()))); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kTimestampToString, StringType(), TimestampType()))); + CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kDurationToString, StringType(), DurationType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_string))); + + FunctionDecl to_bytes; + to_bytes.set_name(StandardFunctions::kBytes); + CEL_RETURN_IF_ERROR(to_bytes.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kBytesToBytes, BytesType(), BytesType()))); + CEL_RETURN_IF_ERROR(to_bytes.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToBytes, BytesType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_bytes))); + + FunctionDecl to_timestamp; + to_timestamp.set_name(StandardFunctions::kTimestamp); + CEL_RETURN_IF_ERROR(to_timestamp.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kTimestampToTimestamp, + TimestampType(), TimestampType()))); + CEL_RETURN_IF_ERROR(to_timestamp.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToTimestamp, TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(to_timestamp.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kIntToTimestamp, TimestampType(), IntType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_timestamp))); + + FunctionDecl to_duration; + to_duration.set_name(StandardFunctions::kDuration); + CEL_RETURN_IF_ERROR(to_duration.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kDurationToDuration, DurationType(), + DurationType()))); + CEL_RETURN_IF_ERROR(to_duration.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kStringToDuration, DurationType(), StringType()))); + CEL_RETURN_IF_ERROR(to_duration.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kIntToDuration, DurationType(), IntType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_duration))); + + FunctionDecl to_type; + to_type.set_name(StandardFunctions::kType); + CEL_RETURN_IF_ERROR(to_type.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kToType, Type(TypeOfA()), TypeParamA()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_type))); + + return absl::OkStatus(); +} + +absl::Status AddEqualityOps(TypeCheckerBuilder& builder) { + FunctionDecl equals_op; + equals_op.set_name(StandardFunctions::kEqual); + CEL_RETURN_IF_ERROR(equals_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kEquals, BoolType(), TypeParamA(), TypeParamA()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(equals_op))); + + FunctionDecl not_equals_op; + not_equals_op.set_name(StandardFunctions::kInequal); + CEL_RETURN_IF_ERROR(not_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kNotEquals, BoolType(), + TypeParamA(), TypeParamA()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(not_equals_op))); + + return absl::OkStatus(); +} + +absl::Status AddContainerOps(TypeCheckerBuilder& builder) { + FunctionDecl index; + index.set_name(StandardFunctions::kIndex); + CEL_RETURN_IF_ERROR(index.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kIndexList, TypeParamA(), ListOfA(), IntType()))); + CEL_RETURN_IF_ERROR(index.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kIndexMap, TypeParamB(), MapOfAB(), TypeParamA()))); + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(index))); + + FunctionDecl in_op; + in_op.set_name(StandardFunctions::kIn); + CEL_RETURN_IF_ERROR(in_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kInList, BoolType(), TypeParamA(), ListOfA()))); + CEL_RETURN_IF_ERROR(in_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kInMap, BoolType(), TypeParamA(), MapOfAB()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(in_op))); + + FunctionDecl in_function_deprecated; + in_function_deprecated.set_name(StandardFunctions::kInFunction); + CEL_RETURN_IF_ERROR(in_function_deprecated.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kInList, BoolType(), TypeParamA(), ListOfA()))); + CEL_RETURN_IF_ERROR(in_function_deprecated.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kInMap, BoolType(), TypeParamA(), MapOfAB()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(in_function_deprecated))); + + FunctionDecl in_op_deprecated; + in_op_deprecated.set_name(StandardFunctions::kInDeprecated); + CEL_RETURN_IF_ERROR(in_op_deprecated.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kInList, BoolType(), TypeParamA(), ListOfA()))); + CEL_RETURN_IF_ERROR(in_op_deprecated.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kInMap, BoolType(), TypeParamA(), MapOfAB()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(in_op_deprecated))); + + FunctionDecl size; + size.set_name(StandardFunctions::kSize); + CEL_RETURN_IF_ERROR(size.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kSizeList, IntType(), ListOfA()))); + CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kSizeListMember, IntType(), ListOfA()))); + CEL_RETURN_IF_ERROR(size.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kSizeMap, IntType(), MapOfAB()))); + CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kSizeMapMember, IntType(), MapOfAB()))); + CEL_RETURN_IF_ERROR(size.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kSizeBytes, IntType(), BytesType()))); + CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kSizeBytesMember, IntType(), BytesType()))); + CEL_RETURN_IF_ERROR(size.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kSizeString, IntType(), StringType()))); + CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kSizeStringMember, IntType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(size))); + + return absl::OkStatus(); +} + +absl::Status AddRelationOps(TypeCheckerBuilder& builder) { + FunctionDecl less_op; + less_op.set_name(StandardFunctions::kLess); + // Numeric types + CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kLessInt, BoolType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kLessUint, BoolType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessDouble, BoolType(), + DoubleType(), DoubleType()))); + + // Non-numeric types + CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kLessBool, BoolType(), BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessString, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kLessBytes, BoolType(), BytesType(), BytesType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessDuration, BoolType(), + DurationType(), DurationType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessTimestamp, BoolType(), + TimestampType(), TimestampType()))); + + FunctionDecl greater_op; + greater_op.set_name(StandardFunctions::kGreater); + // Numeric types + CEL_RETURN_IF_ERROR(greater_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kGreaterInt, BoolType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kGreaterUint, BoolType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterDouble, BoolType(), + DoubleType(), DoubleType()))); + + // Non-numeric types + CEL_RETURN_IF_ERROR(greater_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kGreaterBool, BoolType(), BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterString, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterBytes, BoolType(), + BytesType(), BytesType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterDuration, BoolType(), + DurationType(), DurationType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterTimestamp, BoolType(), + TimestampType(), TimestampType()))); + + FunctionDecl less_equals_op; + less_equals_op.set_name(StandardFunctions::kLessOrEqual); + // Numeric types + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kLessEqualsInt, BoolType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsUint, BoolType(), + UintType(), UintType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsDouble, BoolType(), + DoubleType(), DoubleType()))); + + // Non-numeric types + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsBool, BoolType(), + BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsString, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsBytes, BoolType(), + BytesType(), BytesType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsDuration, BoolType(), + DurationType(), DurationType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsTimestamp, BoolType(), + TimestampType(), TimestampType()))); + + FunctionDecl greater_equals_op; + greater_equals_op.set_name(StandardFunctions::kGreaterOrEqual); + // Numeric types + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsInt, BoolType(), + IntType(), IntType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsUint, BoolType(), + UintType(), UintType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDouble, BoolType(), + DoubleType(), DoubleType()))); + // Non-numeric types + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsBool, BoolType(), + BoolType(), BoolType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsString, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsBytes, BoolType(), + BytesType(), BytesType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDuration, BoolType(), + DurationType(), DurationType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsTimestamp, BoolType(), + TimestampType(), TimestampType()))); + + if (builder.options().enable_cross_numeric_comparisons) { + // Less + CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kLessIntUint, BoolType(), IntType(), UintType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessIntDouble, BoolType(), + IntType(), DoubleType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kLessUintInt, BoolType(), UintType(), IntType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessUintDouble, BoolType(), + UintType(), DoubleType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessDoubleInt, BoolType(), + DoubleType(), IntType()))); + CEL_RETURN_IF_ERROR(less_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessDoubleUint, BoolType(), + DoubleType(), UintType()))); + // Greater + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterIntUint, BoolType(), + IntType(), UintType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterIntDouble, BoolType(), + IntType(), DoubleType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterUintInt, BoolType(), + UintType(), IntType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterUintDouble, BoolType(), + UintType(), DoubleType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterDoubleInt, BoolType(), + DoubleType(), IntType()))); + CEL_RETURN_IF_ERROR(greater_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterDoubleUint, BoolType(), + DoubleType(), UintType()))); + // LessEqual + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsIntUint, BoolType(), + IntType(), UintType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsIntDouble, BoolType(), + IntType(), DoubleType()))); + + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsUintInt, BoolType(), + UintType(), IntType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsUintDouble, BoolType(), + UintType(), DoubleType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsDoubleInt, BoolType(), + DoubleType(), IntType()))); + CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kLessEqualsDoubleUint, BoolType(), + DoubleType(), UintType()))); + // GreaterEqual + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsIntUint, BoolType(), + IntType(), UintType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsIntDouble, + BoolType(), IntType(), DoubleType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsUintInt, BoolType(), + UintType(), IntType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsUintDouble, + BoolType(), UintType(), DoubleType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDoubleInt, + BoolType(), DoubleType(), IntType()))); + CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDoubleUint, + BoolType(), DoubleType(), UintType()))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(less_op))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(greater_op))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(less_equals_op))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(greater_equals_op))); + + return absl::OkStatus(); +} + +absl::Status AddStringFunctions(TypeCheckerBuilder& builder) { + FunctionDecl contains; + contains.set_name(StandardFunctions::kStringContains); + CEL_RETURN_IF_ERROR(contains.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kContainsString, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(contains))); + + FunctionDecl starts_with; + starts_with.set_name(StandardFunctions::kStringStartsWith); + CEL_RETURN_IF_ERROR(starts_with.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kStartsWithString, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(starts_with))); + + FunctionDecl ends_with; + ends_with.set_name(StandardFunctions::kStringEndsWith); + CEL_RETURN_IF_ERROR(ends_with.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kEndsWithString, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(ends_with))); + + return absl::OkStatus(); +} + +absl::Status AddRegexFunctions(TypeCheckerBuilder& builder) { + FunctionDecl matches; + matches.set_name(StandardFunctions::kRegexMatch); + CEL_RETURN_IF_ERROR(matches.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kMatchesMember, BoolType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(matches.AddOverload(MakeOverloadDecl( + StandardOverloadIds::kMatches, BoolType(), StringType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(matches))); + return absl::OkStatus(); +} + +absl::Status AddTimeFunctions(TypeCheckerBuilder& builder) { + FunctionDecl get_full_year; + get_full_year.set_name(StandardFunctions::kFullYear); + CEL_RETURN_IF_ERROR(get_full_year.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToYear, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_full_year.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToYearWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_full_year))); + + FunctionDecl get_month; + get_month.set_name(StandardFunctions::kMonth); + CEL_RETURN_IF_ERROR(get_month.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToMonth, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_month.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToMonthWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_month))); + + FunctionDecl get_day_of_year; + get_day_of_year.set_name(StandardFunctions::kDayOfYear); + CEL_RETURN_IF_ERROR(get_day_of_year.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToDayOfYear, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_day_of_year.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfYearWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_day_of_year))); + + FunctionDecl get_day_of_month; + get_day_of_month.set_name(StandardFunctions::kDayOfMonth); + CEL_RETURN_IF_ERROR(get_day_of_month.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfMonth, + IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_day_of_month.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfMonthWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_day_of_month))); + + FunctionDecl get_date; + get_date.set_name(StandardFunctions::kDate); + CEL_RETURN_IF_ERROR(get_date.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToDate, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_date.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDateWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_date))); + + FunctionDecl get_day_of_week; + get_day_of_week.set_name(StandardFunctions::kDayOfWeek); + CEL_RETURN_IF_ERROR(get_day_of_week.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToDayOfWeek, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_day_of_week.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfWeekWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_day_of_week))); + + FunctionDecl get_hours; + get_hours.set_name(StandardFunctions::kHours); + CEL_RETURN_IF_ERROR(get_hours.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToHours, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_hours.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToHoursWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(get_hours.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kDurationToHours, IntType(), DurationType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_hours))); + + FunctionDecl get_minutes; + get_minutes.set_name(StandardFunctions::kMinutes); + CEL_RETURN_IF_ERROR(get_minutes.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToMinutes, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_minutes.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToMinutesWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(get_minutes.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kDurationToMinutes, IntType(), DurationType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_minutes))); + + FunctionDecl get_seconds; + get_seconds.set_name(StandardFunctions::kSeconds); + CEL_RETURN_IF_ERROR(get_seconds.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToSeconds, IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_seconds.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToSecondsWithTz, + IntType(), TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(get_seconds.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kDurationToSeconds, IntType(), DurationType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_seconds))); + + FunctionDecl get_milliseconds; + get_milliseconds.set_name(StandardFunctions::kMilliseconds); + CEL_RETURN_IF_ERROR(get_milliseconds.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToMilliseconds, + IntType(), TimestampType()))); + CEL_RETURN_IF_ERROR(get_milliseconds.AddOverload(MakeMemberOverloadDecl( + StandardOverloadIds::kTimestampToMillisecondsWithTz, IntType(), + TimestampType(), StringType()))); + CEL_RETURN_IF_ERROR(get_milliseconds.AddOverload( + MakeMemberOverloadDecl(StandardOverloadIds::kDurationToMilliseconds, + IntType(), DurationType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_milliseconds))); + + return absl::OkStatus(); +} + +absl::Status AddTypeConstantVariables(TypeCheckerBuilder& builder) { + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl(StandardFunctions::kDyn, TypeDynType()))); + + CEL_RETURN_IF_ERROR( + builder.AddVariable(MakeVariableDecl("bool", TypeBoolType()))); + + CEL_RETURN_IF_ERROR( + builder.AddVariable(MakeVariableDecl("null_type", TypeNullType()))); + + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl(StandardFunctions::kInt, TypeIntType()))); + + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl(StandardFunctions::kUint, TypeUintType()))); + + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl(StandardFunctions::kDouble, TypeDoubleType()))); + + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl(StandardFunctions::kString, TypeStringType()))); + + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl(StandardFunctions::kBytes, TypeBytesType()))); + + // Note: timestamp and duration are only referenced by the corresponding + // protobuf type names and handled by the type lookup logic. + + CEL_RETURN_IF_ERROR( + builder.AddVariable(MakeVariableDecl("list", TypeListType()))); + + CEL_RETURN_IF_ERROR( + builder.AddVariable(MakeVariableDecl("map", TypeMapType()))); + + CEL_RETURN_IF_ERROR( + builder.AddVariable(MakeVariableDecl("type", TypeOfType()))); + + return absl::OkStatus(); +} + +absl::Status AddEnumConstants(TypeCheckerBuilder& builder) { + VariableDecl pb_null; + pb_null.set_name("google.protobuf.NullValue.NULL_VALUE"); + pb_null.set_type(IntType()); + pb_null.set_value(Constant(int64_t{0})); + CEL_RETURN_IF_ERROR(builder.AddVariable(std::move(pb_null))); + return absl::OkStatus(); +} + +absl::Status AddStandardLibraryDecls(TypeCheckerBuilder& builder) { + CEL_RETURN_IF_ERROR(AddLogicalOps(builder)); + CEL_RETURN_IF_ERROR(AddArithmeticOps(builder)); + CEL_RETURN_IF_ERROR(AddTypeConversions(builder)); + CEL_RETURN_IF_ERROR(AddEqualityOps(builder)); + CEL_RETURN_IF_ERROR(AddContainerOps(builder)); + CEL_RETURN_IF_ERROR(AddRelationOps(builder)); + CEL_RETURN_IF_ERROR(AddStringFunctions(builder)); + CEL_RETURN_IF_ERROR(AddRegexFunctions(builder)); + CEL_RETURN_IF_ERROR(AddTimeFunctions(builder)); + CEL_RETURN_IF_ERROR(AddTypeConstantVariables(builder)); + CEL_RETURN_IF_ERROR(AddEnumConstants(builder)); + return absl::OkStatus(); +} + +} // namespace + +// Returns a CheckerLibrary containing all of the standard CEL declarations. +CheckerLibrary StandardCheckerLibrary() { + return {"stdlib", AddStandardLibraryDecls}; +} +} // namespace cel diff --git a/checker/standard_library.h b/checker/standard_library.h new file mode 100644 index 000000000..05f6d5bb7 --- /dev/null +++ b/checker/standard_library.h @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_STANDARD_LIBRARY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_STANDARD_LIBRARY_H_ + +#include "checker/type_checker_builder.h" + +namespace cel { + +// Returns a CheckerLibrary containing all of the standard CEL declarations. +CheckerLibrary StandardCheckerLibrary(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_STANDARD_LIBRARY_H_ diff --git a/checker/standard_library_test.cc b/checker/standard_library_test.cc new file mode 100644 index 000000000..f3330a76d --- /dev/null +++ b/checker/standard_library_test.cc @@ -0,0 +1,498 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/standard_library.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/checker_options.h" +#include "checker/internal/test_ast_helpers.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::Reference; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::IsEmpty; +using ::testing::Pointee; +using ::testing::Property; + +using AstType = cel::TypeSpec; + +TEST(StandardLibraryTest, StandardLibraryAddsDecls) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + EXPECT_THAT(builder->Build(), IsOk()); +} + +TEST(StandardLibraryTest, StandardLibraryErrorsIfAddedTwice) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(StandardLibraryTest, ComprehensionVarsIndirectCyclicParamAssignability) { + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + // Note: this is atypical -- parameterized variables aren't well supported + // outside of built-in syntax. + // e.g. `list : Type(List(A))` is instantiated per reference to bind A to + // the concrete type of a list in the same assignability context. + // + // Validate that parameterization is sanitized to be contextual + // List(V) -> List(T%1) + // Map(K, V) -> Map(T%2, T%3) + Type list_type = ListType(&arena, TypeParamType("V")); + Type map_type = MapType(&arena, TypeParamType("K"), TypeParamType("V")); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("list_var", list_type)), + IsOk()); + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("map_var", map_type)), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + + ASSERT_OK_AND_ASSIGN( + auto ast, checker_internal::MakeTestParsedAst( + "list_var.exists(v," + " map_var.filter(k, map_var[k] > 1.0).size() > int(v)" + ")")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); +} + +TEST(StandardLibraryTest, ComprehensionResultTypeIsSubstituted) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + // Test that type for the result list of .map is resolved to a concrete type + // when it is known. Checks for a bug where the result type is considered to + // still be flexible and may widen to dyn. + builder->set_container("cel.expr.conformance.proto2"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, checker_internal::MakeTestParsedAst( + "[TestAllTypes{}]" + ".map(x, x.repeated_nested_message[0])" + ".map(x, x.bb)[0]")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + + TypeSpec type = checked_ast->GetTypeOrDyn(checked_ast->root_expr().id()); + EXPECT_TRUE(type.has_primitive() && + type.primitive() == PrimitiveType::kInt64); +} + +class StandardLibraryDefinitionsTest : public ::testing::Test { + public: + void SetUp() override { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(stdlib_type_checker_, builder->Build()); + } + + protected: + std::unique_ptr stdlib_type_checker_; +}; + +class StdlibTypeVarDefinitionTest + : public StandardLibraryDefinitionsTest, + public testing::WithParamInterface {}; + +TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) { + auto ast = std::make_unique(); + ast->mutable_root_expr().mutable_ident_expr().set_name(GetParam()); + ast->mutable_root_expr().set_id(1); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, + stdlib_type_checker_->Check(std::move(ast))); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->GetReference(1), + Pointee(Property(&Reference::name, GetParam()))); + EXPECT_THAT(checked_ast->GetTypeOrDyn(1), Property(&AstType::has_type, true)); +} + +INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest, + ::testing::Values("bool", "bytes", "double", "dyn", + "int", "list", "map", "null_type", + "string", "type", "uint"), + [](const auto& info) -> std::string { + return info.param; + }); + +TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) { + auto ast = std::make_unique(); + + auto& enumerator = ast->mutable_root_expr(); + enumerator.set_id(4); + enumerator.mutable_select_expr().set_field("NULL_VALUE"); + auto& enumeration = enumerator.mutable_select_expr().mutable_operand(); + enumeration.set_id(3); + enumeration.mutable_select_expr().set_field("NullValue"); + auto& protobuf = enumeration.mutable_select_expr().mutable_operand(); + protobuf.set_id(2); + protobuf.mutable_select_expr().set_field("protobuf"); + auto& google = protobuf.mutable_select_expr().mutable_operand(); + google.set_id(1); + google.mutable_ident_expr().set_name("google"); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, + stdlib_type_checker_->Check(std::move(ast))); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->GetReference(4), + Pointee(Property(&Reference::name, + "google.protobuf.NullValue.NULL_VALUE"))); +} + +TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) { + auto ast = std::make_unique(); + + auto& ident = ast->mutable_root_expr(); + ident.set_id(1); + ident.mutable_ident_expr().set_name("type"); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, + stdlib_type_checker_->Check(std::move(ast))); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->GetReference(1), + Pointee(Property(&Reference::name, "type"))); + EXPECT_THAT(checked_ast->GetTypeOrDyn(1), Property(&AstType::has_type, true)); +} + +struct DefinitionsTestCase { + std::string expr; + bool type_check_success = true; + CheckerOptions options; +}; + +class StdLibDefinitionsTest + : public ::testing::TestWithParam { + public: +}; + +// Basic coverage that the standard library definitions are defined. +// This is not intended to be exhaustive since it is expected to be covered by +// spec conformance tests. +// +// TODO(uncreated-issue/72): Tests are fairly minimal right now -- it's not possible to +// test thoroughly without a more complete implementation of the type checker. +// Type-parameterized functions are not yet checkable. +TEST_P(StdLibDefinitionsTest, Runner) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), + GetParam().options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + checker_internal::MakeTestParsedAst(GetParam().expr)); + + ASSERT_OK_AND_ASSIGN(auto result, type_checker->Check(std::move(ast))); + EXPECT_EQ(result.IsValid(), GetParam().type_check_success); +} + +INSTANTIATE_TEST_SUITE_P( + Strings, StdLibDefinitionsTest, + ::testing::Values(DefinitionsTestCase{ + /* .expr = */ "'123'.size()", + }, + DefinitionsTestCase{ + /* .expr = */ "size('123')", + }, + DefinitionsTestCase{ + /* .expr = */ "'123' + '123'", + }, + DefinitionsTestCase{ + /* .expr = */ "'123' + '123'", + }, + DefinitionsTestCase{ + /* .expr = */ "'123' + '123'", + }, + DefinitionsTestCase{ + /* .expr = */ "'123'.endsWith('123')", + }, + DefinitionsTestCase{ + /* .expr = */ "'123'.startsWith('123')", + }, + DefinitionsTestCase{ + /* .expr = */ "'123'.contains('123')", + }, + DefinitionsTestCase{ + /* .expr = */ "'123'.matches(r'123')", + }, + DefinitionsTestCase{ + /* .expr = */ "matches('123', r'123')", + })); + +INSTANTIATE_TEST_SUITE_P(TypeCasts, StdLibDefinitionsTest, + ::testing::Values(DefinitionsTestCase{ + /* .expr = */ "int(1)", + }, + DefinitionsTestCase{ + /* .expr = */ "uint(1)", + }, + DefinitionsTestCase{ + /* .expr = */ "double(1)", + }, + DefinitionsTestCase{ + /* .expr = */ "string(1)", + }, + DefinitionsTestCase{ + /* .expr = */ "bool('true')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0)", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "type(1)", + })); + +INSTANTIATE_TEST_SUITE_P(Arithmetic, StdLibDefinitionsTest, + ::testing::Values(DefinitionsTestCase{ + /* .expr = */ "1 + 2", + }, + DefinitionsTestCase{ + /* .expr = */ "1 - 2", + }, + DefinitionsTestCase{ + /* .expr = */ "1 / 2", + }, + DefinitionsTestCase{ + /* .expr = */ "1 * 2", + }, + DefinitionsTestCase{ + /* .expr = */ "2 % 1", + }, + DefinitionsTestCase{ + /* .expr = */ "-1", + })); + +INSTANTIATE_TEST_SUITE_P( + TimeArithmetic, StdLibDefinitionsTest, + ::testing::Values(DefinitionsTestCase{ + /* .expr = */ "timestamp(0) + duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0) - duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0) - timestamp(0)", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s') + duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s') - duration('1s')", + })); + +INSTANTIATE_TEST_SUITE_P(NumericComparisons, StdLibDefinitionsTest, + ::testing::Values(DefinitionsTestCase{ + /* .expr = */ "1 > 2", + }, + DefinitionsTestCase{ + /* .expr = */ "1 < 2", + }, + DefinitionsTestCase{ + /* .expr = */ "1 >= 2", + }, + DefinitionsTestCase{ + /* .expr = */ "1 <= 2", + })); + +INSTANTIATE_TEST_SUITE_P( + CrossNumericComparisons, StdLibDefinitionsTest, + ::testing::Values( + DefinitionsTestCase{ + /* .expr = */ "1u < 2", + /* .type_check_success = */ true, + /* .options = */ {.enable_cross_numeric_comparisons = true}}, + DefinitionsTestCase{ + /* .expr = */ "1u > 2", + /* .type_check_success = */ true, + /* .options = */ {.enable_cross_numeric_comparisons = true}}, + DefinitionsTestCase{ + /* .expr = */ "1u <= 2", + /* .type_check_success = */ true, + /* .options = */ {.enable_cross_numeric_comparisons = true}}, + DefinitionsTestCase{ + /* .expr = */ "1u >= 2", + /* .type_check_success = */ true, + /* .options = */ {.enable_cross_numeric_comparisons = true}})); + +INSTANTIATE_TEST_SUITE_P( + TimeComparisons, StdLibDefinitionsTest, + ::testing::Values(DefinitionsTestCase{ + /* .expr = */ "duration('1s') < duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s') > duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s') <= duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s') >= duration('1s')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0) < timestamp(0)", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0) > timestamp(0)", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0) <= timestamp(0)", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0) >= timestamp(0)", + })); + +INSTANTIATE_TEST_SUITE_P( + TimeAccessors, StdLibDefinitionsTest, + ::testing::Values( + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getFullYear()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getFullYear('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getMonth()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getMonth('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getDayOfYear()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getDayOfYear('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getDate()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getDate('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getDayOfWeek()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getDayOfWeek('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getHours()", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s').getHours()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getHours('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getMinutes()", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s').getMinutes()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getMinutes('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getSeconds()", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s').getSeconds()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getSeconds('-08:00')", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getMilliseconds()", + }, + DefinitionsTestCase{ + /* .expr = */ "duration('1s').getMilliseconds()", + }, + DefinitionsTestCase{ + /* .expr = */ "timestamp(0).getMilliseconds('-08:00')", + })); + +INSTANTIATE_TEST_SUITE_P(Logic, StdLibDefinitionsTest, + ::testing::Values(DefinitionsTestCase{ + /* .expr = */ "true || false", + }, + DefinitionsTestCase{ + /* .expr = */ "true && false", + }, + DefinitionsTestCase{ + /* .expr = */ "!true", + }, + DefinitionsTestCase{ + /* .expr = */ "true ? 1 : 2", + })); + +} // namespace +} // namespace cel diff --git a/checker/type_check_issue.cc b/checker/type_check_issue.cc new file mode 100644 index 000000000..b1d3caa11 --- /dev/null +++ b/checker/type_check_issue.cc @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_check_issue.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +namespace { + +absl::string_view SeverityString(TypeCheckIssue::Severity severity) { + switch (severity) { + case TypeCheckIssue::Severity::kInformation: + return "INFORMATION"; + case TypeCheckIssue::Severity::kWarning: + return "WARNING"; + case TypeCheckIssue::Severity::kError: + return "ERROR"; + case TypeCheckIssue::Severity::kDeprecated: + return "DEPRECATED"; + default: + return "SEVERITY_UNSPECIFIED"; + } +} + +} // namespace + +std::string TypeCheckIssue::ToDisplayString(const Source* source) const { + int column = location_.column; + // convert to 1-based if it's in range. + int display_column = column >= 0 ? column + 1 : column; + if (source) { + return absl::StrFormat("%s: %s:%d:%d: %s%s", SeverityString(severity_), + source->description(), location_.line, + display_column, message_, + source->DisplayErrorLocation(location_)); + } + + return absl::StrFormat("%s: :%d:%d: %s", SeverityString(severity_), + location_.line, display_column, message_); +} + +} // namespace cel diff --git a/checker/type_check_issue.h b/checker/type_check_issue.h new file mode 100644 index 000000000..9f6f57a3d --- /dev/null +++ b/checker/type_check_issue.h @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECK_ISSUE_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECK_ISSUE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +// Represents a single issue identified in type checking. +class TypeCheckIssue { + public: + enum class Severity { kError, kWarning, kInformation, kDeprecated }; + + TypeCheckIssue(Severity severity, SourceLocation location, + std::string message) + : severity_(severity), + location_(location), + message_(std::move(message)) {} + + // Factory for error-severity issues. + static TypeCheckIssue CreateError(SourceLocation location, + std::string message) { + return TypeCheckIssue(Severity::kError, location, std::move(message)); + } + + // Factory for error-severity issues. + // line is 1-based, column is 0-based. + static TypeCheckIssue CreateError(int line, int column, std::string message) { + return TypeCheckIssue(Severity::kError, SourceLocation{line, column}, + std::move(message)); + } + + // Format the issue highlighting the source position. + std::string ToDisplayString(const Source* source) const; + + std::string ToDisplayString(const Source& source) const { + return ToDisplayString(&source); + } + + absl::string_view message() const { return message_; } + Severity severity() const { return severity_; } + SourceLocation location() const { return location_; } + + private: + Severity severity_; + SourceLocation location_; + std::string message_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECK_ISSUE_H_ diff --git a/checker/type_check_issue_test.cc b/checker/type_check_issue_test.cc new file mode 100644 index 000000000..9017fea99 --- /dev/null +++ b/checker/type_check_issue_test.cc @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_check_issue.h" + +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(TypeCheckIssueTest, DisplayString) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("test{\n\tfield1: 123\n}")); + TypeCheckIssue issue = TypeCheckIssue::CreateError(2, 2, "test error"); + // Note: The column is displayed as 1 based to match the Go checker. + EXPECT_EQ(issue.ToDisplayString(*source), + "ERROR: :2:3: test error\n" + " | field1: 123\n" + " | ..^"); +} + +TEST(TypeCheckIssueTest, DisplayStringNoPosition) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("test{\n\tfield1: 123\n}")); + TypeCheckIssue issue = TypeCheckIssue::CreateError(-1, -1, "test error"); + EXPECT_EQ(issue.ToDisplayString(*source), "ERROR: :-1:-1: test error"); +} + +TEST(TypeCheckIssueTest, DisplayStringDeprecated) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("test{\n\tfield1: 123\n}")); + TypeCheckIssue issue = TypeCheckIssue(TypeCheckIssue::Severity::kDeprecated, + {-1, -1}, "test error 2"); + EXPECT_EQ(issue.ToDisplayString(*source), + "DEPRECATED: :-1:-1: test error 2"); +} + +} // namespace +} // namespace cel diff --git a/checker/type_checker.cc b/checker/type_checker.cc new file mode 100644 index 000000000..6d59e144d --- /dev/null +++ b/checker/type_checker.cc @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker.h" + +namespace cel { +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast) const { + return CheckImpl(std::move(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::move(ast), arena); +} + +absl::StatusOr TypeChecker::Check(const Ast& ast) const { + return CheckImpl(std::make_unique(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + const Ast& ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::make_unique(ast), arena); +} +} // namespace cel diff --git a/checker/type_checker.h b/checker/type_checker.h new file mode 100644 index 000000000..edb6cc91f --- /dev/null +++ b/checker/type_checker.h @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class TypeCheckerBuilder; + +// TypeChecker interface. +// +// Checks references and type agreement for a parsed CEL expression. +// +// See Compiler for bundled parse and type check from a source expression +// string. +class TypeChecker { + public: + virtual ~TypeChecker() = default; + + // Checks the references and type agreement of the given parsed expression + // based on the configured CEL environment. + // + // Most type checking errors are returned as Issues in the validation result. + // A non-ok status is returned if type checking can't reasonably complete + // (e.g. if an internal precondition is violated or an extension returns an + // error). + absl::StatusOr Check(std::unique_ptr ast) const; + absl::StatusOr Check(std::unique_ptr ast, + google::protobuf::Arena* arena) const; + absl::StatusOr Check(const Ast& ast) const; + absl::StatusOr Check(const Ast& ast, + google::protobuf::Arena* arena) const; + + // Returns a builder initialized with the configuration of this type checker. + virtual std::unique_ptr ToBuilder() const = 0; + + private: + virtual absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* absl_nullable arena) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h new file mode 100644 index 000000000..c2d0cbf7b --- /dev/null +++ b/checker/type_checker_builder.h @@ -0,0 +1,186 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/checker_options.h" +#include "checker/type_checker.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class TypeCheckerBuilder; + +// Functional implementation to apply the library features to a +// TypeCheckerBuilder. +using TypeCheckerBuilderConfigurer = + absl::AnyInvocable; + +struct CheckerLibrary { + // Optional identifier to avoid collisions re-adding the same declarations. + // If id is empty, it is not considered. + std::string id; + TypeCheckerBuilderConfigurer configure; +}; + +// Represents a declaration to only use a subset of a library. +struct TypeCheckerSubset { + using FunctionPredicate = absl::AnyInvocable; + + // The id of the library to subset. Only one subset can be applied per + // library id. + // + // Must be non-empty. + std::string library_id; + // Predicate to apply to function overloads. If true, the overload will be + // included in the subset. If no overload for a function is included, the + // entire function is excluded. + FunctionPredicate should_include_overload; +}; + +// Interface for TypeCheckerBuilders. +class TypeCheckerBuilder { + public: + virtual ~TypeCheckerBuilder() = default; + + // Adds a library to the TypeChecker being built. + // + // Libraries are applied in the order they are added. They effectively + // apply before any direct calls to AddVariable, AddFunction, etc. + virtual absl::Status AddLibrary(CheckerLibrary library) = 0; + + // Adds a subset declaration for a library to the TypeChecker being built. + // + // At most one subset can be applied per library id. + virtual absl::Status AddLibrarySubset(TypeCheckerSubset subset) = 0; + + // Adds a variable declaration that may be referenced in expressions checked + // with the resulting type checker. + virtual absl::Status AddVariable(const VariableDecl& decl) = 0; + + // Adds a variable declaration that may be referenced in expressions checked + // with the resulting type checker. + // + // This version replaces any existing variable declaration with the same name. + virtual absl::Status AddOrReplaceVariable(const VariableDecl& decl) = 0; + + // Declares struct type by fully qualified name as a context declaration. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct is declared + // as an individual variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; + + // Declares struct type by fully qualified name as a context declaration. + // + // This version accepts a mask in terms of field selections from the + // context type. The mask specifies which fields are visible on the + // struct and its members. The visible fields for a type accumulate + // across calls. This is a lightweight way to adjust the type checking + // behavior for a group of related types. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct that is + // also the first field name in a field path is declared as an individual + // variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. It is an error if the input field paths is the empty + // set. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) = 0; + + // Adds a function declaration that may be referenced in expressions checked + // with the resulting TypeChecker. + virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; + + // Adds function declaration overloads to the TypeChecker being built. + // + // Attempts to merge with any existing overloads for a function decl with the + // same name. If the overloads are not compatible, an error is returned and + // no change is made. + virtual absl::Status MergeFunction(const FunctionDecl& decl) = 0; + + // Sets the expected type for checked expressions. + // + // Validation will fail with an ERROR level issue if the deduced type of the + // expression is not assignable to this type. + // + // Note: if set multiple times, the last value is used. + virtual void SetExpectedType(const Type& type) = 0; + + // Adds a type provider to the TypeChecker being built. + // + // Type providers are used to describe custom types with typed field + // traversal. This is not needed for built-in types or protobuf messages + // described by the associated descriptor pool. + virtual void AddTypeProvider(std::unique_ptr provider) = 0; + + // Set the container for the TypeChecker being built. + // + // This is used for resolving references in the expressions being built. + // + // Prefer setting the container via SetExpressionContainer(). + // + // Note: if set multiple times, the last value is used. This can lead to + // surprising behavior if used in a custom library. If container is not a + // valid container name, the operation is ignored. + virtual void set_container(absl::string_view container) = 0; + + virtual void SetExpressionContainer( + ExpressionContainer expression_container) = 0; + + // The current options for the TypeChecker being built. + virtual const CheckerOptions& options() const = 0; + + // Builds a new TypeChecker instance. + virtual absl::StatusOr> Build() = 0; + + // Returns a pointer to an arena that can be used to allocate memory for types + // that will be used by the TypeChecker being built. + // + // On Build(), the arena is transferred to the TypeChecker being built. + virtual google::protobuf::Arena* absl_nonnull arena() = 0; + + // The configured descriptor pool. + virtual const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() + const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ diff --git a/checker/type_checker_builder_factory.cc b/checker/type_checker_builder_factory.cc new file mode 100644 index 000000000..23c411996 --- /dev/null +++ b/checker/type_checker_builder_factory.cc @@ -0,0 +1,56 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker_builder_factory.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "checker/checker_options.h" +#include "checker/internal/type_checker_builder_impl.h" +#include "checker/type_checker_builder.h" +#include "internal/noop_delete.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr> CreateTypeCheckerBuilder( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const CheckerOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + return CreateTypeCheckerBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + options); +} + +absl::StatusOr> CreateTypeCheckerBuilder( + absl_nonnull std::shared_ptr descriptor_pool, + const CheckerOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + // Verify the standard descriptors, we do not need to keep + // `well_known_types::Reflection` at the moment here. + CEL_RETURN_IF_ERROR( + well_known_types::Reflection().Initialize(descriptor_pool.get())); + return std::make_unique( + std::move(descriptor_pool), options); +} + +} // namespace cel diff --git a/checker/type_checker_builder_factory.h b/checker/type_checker_builder_factory.h new file mode 100644 index 000000000..3f830c7c7 --- /dev/null +++ b/checker/type_checker_builder_factory.h @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "checker/checker_options.h" +#include "checker/type_checker_builder.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Creates a new `TypeCheckerBuilder`. +// +// The builder implementation is thread-hostile and should only be used from a +// single thread, but the resulting `TypeChecker` instance is thread-safe. +// +// When passing a raw pointer to a descriptor pool, the descriptor pool must +// outlive the type checker builder and the type checker builder it creates. +// +// The descriptor pool must include the minimally necessary +// descriptors required by CEL. Those are the following: +// - google.protobuf.NullValue +// - google.protobuf.BoolValue +// - google.protobuf.Int32Value +// - google.protobuf.Int64Value +// - google.protobuf.UInt32Value +// - google.protobuf.UInt64Value +// - google.protobuf.FloatValue +// - google.protobuf.DoubleValue +// - google.protobuf.BytesValue +// - google.protobuf.StringValue +// - google.protobuf.Any +// - google.protobuf.Duration +// - google.protobuf.Timestamp +absl::StatusOr> CreateTypeCheckerBuilder( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const CheckerOptions& options = {}); +absl::StatusOr> CreateTypeCheckerBuilder( + absl_nonnull std::shared_ptr descriptor_pool, + const CheckerOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc new file mode 100644 index 000000000..40406948d --- /dev/null +++ b/checker/type_checker_builder_factory_test.cc @@ -0,0 +1,852 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker_builder_factory.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "checker/checker_options.h" +#include "checker/internal/test_ast_helpers.h" +#include "checker/optional.h" +#include "checker/standard_library.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::checker_internal::MakeTestParsedAst; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Truly; + +TEST(TypeCheckerBuilderTest, AddVariable) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, AddComplexType) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + MapType map_type(builder->arena(), StringType(), IntType()); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("m", map_type)), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + builder.reset(); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("m.foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, TypeCheckersIndependent) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + MapType map_type(builder->arena(), StringType(), IntType()); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("m", map_type)), IsOk()); + ASSERT_OK_AND_ASSIGN( + FunctionDecl fn, + MakeFunctionDecl( + "foo", MakeOverloadDecl("foo", IntType(), IntType(), IntType()))); + ASSERT_THAT(builder->AddFunction(std::move(fn)), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker1, builder->Build()); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("ns.m2", map_type)), + IsOk()); + builder->set_container("ns"); + ASSERT_OK_AND_ASSIGN(auto checker2, builder->Build()); + // Test for lifetime issues between separate type checker instances from the + // same builder. + builder.reset(); + + { + ASSERT_OK_AND_ASSIGN(auto ast1, MakeTestParsedAst("foo(m.bar, m.bar)")); + ASSERT_OK_AND_ASSIGN(auto ast2, MakeTestParsedAst("foo(m.bar, m2.bar)")); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker1->Check(std::move(ast1))); + EXPECT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(ValidationResult result2, + checker1->Check(std::move(ast2))); + EXPECT_FALSE(result2.IsValid()); + } + checker1.reset(); + + { + ASSERT_OK_AND_ASSIGN(auto ast1, MakeTestParsedAst("foo(m.bar, m.bar)")); + ASSERT_OK_AND_ASSIGN(auto ast2, MakeTestParsedAst("foo(m.bar, m2.bar)")); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker2->Check(std::move(ast1))); + EXPECT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(ValidationResult result2, + checker2->Check(std::move(ast2))); + EXPECT_TRUE(result2.IsValid()); + } +} + +TEST(TypeCheckerBuilderTest, AddVariableRedeclaredError) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + // We resolve the variable declarations at the Build() call, so the error + // surfaces then. + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + + EXPECT_THAT(builder->Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'x' declared multiple times")); +} + +TEST(TypeCheckerBuilderTest, AddFunction) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl( + "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); + + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, AddFunctionRedeclaredError) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl( + "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); + + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + EXPECT_THAT(builder->Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "function 'add' declared multiple times")); +} + +TEST(TypeCheckerBuilderTest, AddLibrary) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl( + "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); + + ASSERT_THAT(builder->AddLibrary({"", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), + + IsOk()); + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +// Example test lib that adds: +// - add(int, int) -> int +// - add(double, double) -> double +// - sub(int, int) -> int +// - sub(double, double) -> double +absl::Status SubsetTestlibConfigurer(TypeCheckerBuilder& builder) { + absl::Status s; + CEL_ASSIGN_OR_RETURN( + FunctionDecl fn_decl, + MakeFunctionDecl( + "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()), + MakeOverloadDecl("add_double", DoubleType(), DoubleType(), + DoubleType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(fn_decl))); + + CEL_ASSIGN_OR_RETURN( + fn_decl, + MakeFunctionDecl( + "sub", MakeOverloadDecl("sub_int", IntType(), IntType(), IntType()), + MakeOverloadDecl("sub_double", DoubleType(), DoubleType(), + DoubleType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(fn_decl))); + + return absl::OkStatus(); +} + +CheckerLibrary SubsetTestlib() { return {"testlib", SubsetTestlibConfigurer}; } + +TEST(TypeCheckerBuilderTest, AddLibraryIncludeSubset) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); + ASSERT_THAT( + builder->AddLibrarySubset( + {"testlib", + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() == "add_int" || overload.id() == "sub_int"); + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + + std::vector results; + for (const auto& expr : + {"sub(1, 2)", "add(1, 2)", "sub(1.0, 2.0)", "add(1.0, 2.0)"}) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(ast))); + results.push_back(std::move(result)); + } + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_THAT(results, ElementsAre(Truly([](const ValidationResult& result) { + return result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return !result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return !result.IsValid(); + }))); +} + +TEST(TypeCheckerBuilderTest, AddLibraryExcludeSubset) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); + ASSERT_THAT( + builder->AddLibrarySubset( + {"testlib", + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() != "add_int" && overload.id() != "sub_int"); + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + + std::vector results; + for (const auto& expr : + {"sub(1, 2)", "add(1, 2)", "sub(1.0, 2.0)", "add(1.0, 2.0)"}) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(ast))); + results.push_back(std::move(result)); + } + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_THAT(results, ElementsAre(Truly([](const ValidationResult& result) { + return !result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return !result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return result.IsValid(); + }))); +} + +TEST(TypeCheckerBuilderTest, AddLibrarySubsetRemoveAllOvl) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); + ASSERT_THAT(builder->AddLibrarySubset({"testlib", + [](absl::string_view function, + const OverloadDecl& /*overload*/) { + return function != "add"; + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + + std::vector results; + for (const auto& expr : + {"sub(1, 2)", "add(1, 2)", "sub(1.0, 2.0)", "add(1.0, 2.0)"}) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(ast))); + results.push_back(std::move(result)); + } + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_THAT(results, ElementsAre(Truly([](const ValidationResult& result) { + return result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return !result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return result.IsValid(); + }), + Truly([](const ValidationResult& result) { + return !result.IsValid(); + }))); +} + +TEST(TypeCheckerBuilderTest, AddLibraryOneSubsetPerLibraryId) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); + ASSERT_THAT( + builder->AddLibrarySubset( + {"testlib", [](absl::string_view function, + const OverloadDecl& /*overload*/) { return true; }}), + IsOk()); + EXPECT_THAT( + builder->AddLibrarySubset( + {"testlib", [](absl::string_view function, + const OverloadDecl& /*overload*/) { return true; }}), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(TypeCheckerBuilderTest, AddLibrarySubsetLibraryIdRequireds) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); + EXPECT_THAT(builder->AddLibrarySubset({"", + [](absl::string_view function, + const OverloadDecl& /*overload*/) { + return function == "add"; + }}), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TypeCheckerBuilderTest, AddContextDeclaration) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), + IntType()))); + + ASSERT_THAT(builder->AddContextDeclaration( + "cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, AddContextDeclarationWithProtoTypeMask) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), + IntType()))); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, WellKnownTypeContextDeclarationError) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Any"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'google.protobuf.Any' is not a struct"))); +} + +TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclaration) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Any"), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + R"cel(value == b'' && type_url == 'type.googleapis.com/google.protobuf.Duration')cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, + AllowWellKnownTypeContextDeclarationWithProtoTypeMask) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Any", {"value"}), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + // Visible field: value + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("value")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not visible field: type_url + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("type_url")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationStruct) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Struct"), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst(R"cel(fields.foo.bar_list.exists(x, x == 1))cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationValue) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Value"), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst( + // Note: one of fields are all added with safe traversal, so + // we lose the union discriminator information. + R"cel( + null_value == 0 && + number_value == 0.0 && + string_value == '' && + list_value == [] && + struct_value == {} && + bool_value == false)cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationInt64Value) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Int64Value"), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(R"cel(value == 0)cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerBuilderTest, ContextDeclarationWithJsonName) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("cel.cpp.testutil.TestJsonNames"), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(int32_snake_case_json_name == 1 && + int64CamelCaseJsonName == 2 && + uint32DefaultJsonName == 3u && + // `uint64-custom-json-name` == 4u && + single_string == 'shadows' && + singleString == 'shadowed')cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionStructCreation) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(cel.cpp.testutil.TestJsonNames{ + int32_snake_case_json_name: 1, + int64CamelCaseJsonName: 2, + uint32DefaultJsonName: 3u, + `uint64-custom-json-name`: 4u, + single_string: 'shadows', + singleString: 'shadowed' + })cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), + TypeSpec(MessageTypeSpec("cel.cpp.testutil.TestJsonNames"))); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionFieldAccess) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT( + builder->AddVariable(MakeVariableDecl( + "jsonObj", + cel::MessageType(builder->descriptor_pool()->FindMessageTypeByName( + "cel.cpp.testutil.TestJsonNames")))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel( + jsonObj.int32_snake_case_json_name == 1 && + jsonObj.int64CamelCaseJsonName == 2 && + jsonObj.uint32DefaultJsonName == 3u && + jsonObj.`uint64-custom-json-name` == 4u && + jsonObj.single_string == 'shadows' && + jsonObj.singleString == 'shadowed' && + jsonObj.`cel.cpp.testutil.int32_snake_case_ext` == 5 && + jsonObj.`cel.cpp.testutil.int64CamelCaseExt` == 6 + )cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl( + "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); + + ASSERT_THAT(builder->AddLibrary({"testlib", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), + IsOk()); + EXPECT_THAT(builder->AddLibrary({"testlib", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), + StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("testlib"))); +} + +TEST(TypeCheckerBuilderTest, BuildForwardsLibraryErrors) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl( + "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); + + ASSERT_THAT(builder->AddLibrary({"", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), + IsOk()); + ASSERT_THAT(builder->AddLibrary({"", + [](TypeCheckerBuilder& b) { + return absl::InternalError("test error"); + }}), + IsOk()); + + EXPECT_THAT(builder->Build(), + StatusIs(absl::StatusCode::kInternal, "test error")); +} + +TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, MakeFunctionDecl("map", MakeMemberOverloadDecl( + "ovl_3", ListType(), ListType(), + DynType(), DynType()))); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'map' with 3 argument(s) overlaps " + "with predefined macro")); + + fn_decl.set_name("filter"); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'filter' with 3 argument(s) overlaps " + "with predefined macro")); + + fn_decl.set_name("exists"); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'exists' with 3 argument(s) overlaps " + "with predefined macro")); + + fn_decl.set_name("exists_one"); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'exists_one' with 3 argument(s) " + "overlaps with predefined macro")); + + fn_decl.set_name("all"); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'all' with 3 argument(s) overlaps " + "with predefined macro")); + + fn_decl.set_name("optMap"); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'optMap' with 3 argument(s) overlaps " + "with predefined macro")); + + fn_decl.set_name("optFlatMap"); + + EXPECT_THAT( + builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'optFlatMap' with 3 argument(s) overlaps " + "with predefined macro")); + + ASSERT_OK_AND_ASSIGN( + fn_decl, MakeFunctionDecl( + "has", MakeOverloadDecl("ovl_1", BoolType(), DynType()))); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'has' with 1 argument(s) overlaps " + "with predefined macro")); + + ASSERT_OK_AND_ASSIGN( + fn_decl, MakeFunctionDecl("map", MakeMemberOverloadDecl( + "ovl_4", ListType(), ListType(), + + DynType(), DynType(), DynType()))); + + EXPECT_THAT(builder->AddFunction(fn_decl), + StatusIs(absl::StatusCode::kInvalidArgument, + "overload for name 'map' with 4 argument(s) overlaps " + "with predefined macro")); +} + +TEST(TypeCheckerBuilderTest, AddFunctionNoOverlapWithStdMacroError) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("has", MakeMemberOverloadDecl("ovl", BoolType(), + DynType(), StringType()))); + + EXPECT_THAT(builder->AddFunction(fn_decl), IsOk()); +} + +TEST(TypeCheckerBuilderTest, ToBuilderIndependenceAndInheritance) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("addOne", + MakeOverloadDecl("addOne_int", IntType(), IntType()))); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker1, builder->Build()); + + // Exercise checker1. + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("addOne(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result1, + checker1->Check(std::move(ast))); + EXPECT_TRUE(result1.IsValid()); + } + + // Start new builder via ToBuilder. + auto builder2 = checker1->ToBuilder(); + ASSERT_THAT(builder2->AddVariable(MakeVariableDecl("y", IntType())), IsOk()); + ASSERT_THAT(builder2->AddLibrary(OptionalCheckerLibrary()), IsOk()); + builder2->SetExpectedType(IntType()); + + ASSERT_OK_AND_ASSIGN(auto checker2, builder2->Build()); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("optional.of(addOne(x)).orValue(0) + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result2, + checker2->Check(std::move(ast))); + EXPECT_TRUE(result2.IsValid()); + } + + // Demonstrate checker1 is unmodified and independent (still does not know + // about y). + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result_y_checker1_again, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result_y_checker1_again.IsValid()); + } + + // Same for optional library functions. + { + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("optional.none().orValue(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + } +} + +} // namespace +} // namespace cel diff --git a/checker/type_checker_subset_factory.cc b/checker/type_checker_subset_factory.cc new file mode 100644 index 000000000..1b146c5a5 --- /dev/null +++ b/checker/type_checker_subset_factory.cc @@ -0,0 +1,67 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker_subset_factory.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/signature.h" + +namespace cel { + +TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( + absl::flat_hash_set overload_ids) { + return [overload_ids = std::move(overload_ids)]( + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return true; + } + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + return signature.ok() && overload_ids.contains(*signature); + }; +} + +TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( + absl::Span overload_ids) { + return IncludeOverloadsByIdPredicate(absl::flat_hash_set( + overload_ids.begin(), overload_ids.end())); +} + +TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( + absl::flat_hash_set overload_ids) { + return [overload_ids = std::move(overload_ids)]( + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return false; + } + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + return !signature.ok() || !overload_ids.contains(*signature); + }; +} + +TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( + absl::Span overload_ids) { + return ExcludeOverloadsByIdPredicate(absl::flat_hash_set( + overload_ids.begin(), overload_ids.end())); +} + +} // namespace cel diff --git a/checker/type_checker_subset_factory.h b/checker/type_checker_subset_factory.h new file mode 100644 index 000000000..5db5660bd --- /dev/null +++ b/checker/type_checker_subset_factory.h @@ -0,0 +1,45 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Factory functions for creating typical type checker library subsets. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_SUBSET_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_SUBSET_FACTORY_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_checker_builder.h" + +namespace cel { + +// Subsets a type checker library to only include the given overload ids. +TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( + absl::flat_hash_set overload_ids); + +TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( + absl::Span overload_ids); + +// Subsets a type checker library to exclude the given overload ids. +TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( + absl::flat_hash_set overload_ids); + +TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( + absl::Span overload_ids); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_SUBSET_FACTORY_H_ diff --git a/checker/type_checker_subset_factory_test.cc b/checker/type_checker_subset_factory_test.cc new file mode 100644 index 000000000..5b644ec7c --- /dev/null +++ b/checker/type_checker_subset_factory_test.cc @@ -0,0 +1,149 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker_subset_factory.h" + +#include + +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "common/standard_definitions.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +using ::absl_testing::IsOk; + +namespace cel { +namespace { + +TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + absl::string_view allowlist[] = { + StandardOverloadIds::kNot, + StandardOverloadIds::kAnd, + StandardOverloadIds::kOr, + StandardOverloadIds::kConditional, + StandardOverloadIds::kEquals, + StandardOverloadIds::kNotEquals, + StandardOverloadIds::kNotStrictlyFalse, + "matches(string,string)", + "string.matches(string)", + }; + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ + "stdlib", + IncludeOverloadsByIdPredicate(allowlist), + }), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult r, + compiler->Compile( + "!true || !false && (false) ? true : false && 1 == 2 || 3.0 != 2.1")); + + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN( + r, compiler->Compile("[true, false, true, false].exists(x, x && !x)")); + + EXPECT_TRUE(r.IsValid()); + + // Allowed by signature. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); + EXPECT_TRUE(r.IsValid()); + + // Not in allowlist. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); + EXPECT_FALSE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); + EXPECT_FALSE(r.IsValid()); +} + +TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + absl::string_view exclude_list[] = { + StandardOverloadIds::kMatches, + StandardOverloadIds::kMatchesMember, + "size(string)", + "string.size()", + }; + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ + "stdlib", + ExcludeOverloadsByIdPredicate(exclude_list), + }), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult r, + compiler->Compile( + "!true || !false && (false) ? true : false && 1 == 2 || 3.0 != 2.1")); + + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN( + r, compiler->Compile("[true, false, true, false].exists(x, x && !x)")); + + EXPECT_TRUE(r.IsValid()); + + // Allowed. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); + EXPECT_TRUE(r.IsValid()); + + // Excluded by ID. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); + EXPECT_FALSE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); + EXPECT_FALSE(r.IsValid()); + + // Excluded by signature (top-level function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size('abc')")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size([1, 2, 3])")); + EXPECT_TRUE(r.IsValid()); + + // Excluded by signature (member function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc'.size()")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size member). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("[1, 2, 3].size()")); + EXPECT_TRUE(r.IsValid()); +} + +} // namespace + +} // namespace cel diff --git a/checker/validation_result.cc b/checker/validation_result.cc new file mode 100644 index 000000000..88d52932a --- /dev/null +++ b/checker/validation_result.cc @@ -0,0 +1,32 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/validation_result.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "checker/type_check_issue.h" + +namespace cel { + +std::string ValidationResult::FormatError() const { + return absl::StrJoin( + issues_, "\n", [this](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.ToDisplayString(source_.get())); + }); +} + +} // namespace cel diff --git a/checker/validation_result.h b/checker/validation_result.h new file mode 100644 index 000000000..7417e9969 --- /dev/null +++ b/checker/validation_result.h @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/source.h" +#include "common/type.h" + +namespace cel { + +// ValidationResult holds the result of type checking. +// +// Error states are captured as type check issues where possible. +class ValidationResult { + public: + using TypeMap = absl::flat_hash_map; + + ValidationResult(std::unique_ptr ast, std::vector issues) + : ast_(std::move(ast)), issues_(std::move(issues)) {} + + explicit ValidationResult(std::vector issues) + : ast_(nullptr), issues_(std::move(issues)) {} + + bool IsValid() const { return ast_ != nullptr; } + + // Returns the AST if validation was successful. + // + // This is a non-null pointer if IsValid() is true. + const Ast* absl_nullable GetAst() const { return ast_.get(); } + + absl::StatusOr> ReleaseAst() { + if (ast_ == nullptr) { + return absl::FailedPreconditionError( + "ValidationResult is empty. Check for TypeCheckIssues."); + } + return std::move(ast_); + } + + absl::Span GetIssues() const { return issues_; } + + void AddIssue(TypeCheckIssue issue) { issues_.push_back(std::move(issue)); } + + // The source expression may optionally be set if it is available. + const cel::Source* absl_nullable GetSource() const { return source_.get(); } + + void SetSource(std::unique_ptr source) { + source_ = std::move(source); + } + + absl_nullable std::unique_ptr ReleaseSource() { + return std::move(source_); + } + + // Returns the resolved type map for the AST. + // + // Only populated if the AST was checked with an explicit arena. + // + // The type entries may have storage in the arena or reference type + // information from the type checker that produced the AST. This means the map + // is only valid as long as both the type checker and the arena are valid. + const TypeMap& GetResolvedTypeMap() const { return resolved_type_map_; } + void SetResolvedTypeMap(TypeMap resolved_type_map) { + resolved_type_map_ = std::move(resolved_type_map); + } + + // Returns a string representation of the issues in the result suitable for + // display. + // + // The result is empty if no issues are present. + // + // The result is formatted similarly to CEL-Java and CEL-Go, but we do not + // give strong guarantees on the format or stability. + // + // Example: + // + // ERROR: :1:3: Issue1 + // | source.cel + // | ..^ + // INFORMATION: :-1:-1: Issue2 + std::string FormatError() const; + + private: + absl_nullable std::unique_ptr ast_; + TypeMap resolved_type_map_; + std::vector issues_; + absl_nullable std::unique_ptr source_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ diff --git a/checker/validation_result_test.cc b/checker/validation_result_test.cc new file mode 100644 index 000000000..dd9b05a4c --- /dev/null +++ b/checker/validation_result_test.cc @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/validation_result.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/type_check_issue.h" +#include "common/ast.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::_; +using ::testing::IsNull; +using ::testing::NotNull; +using ::testing::SizeIs; + +using Severity = TypeCheckIssue::Severity; + +TEST(ValidationResultTest, IsValidWithAst) { + ValidationResult result(std::make_unique(), {}); + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetAst(), NotNull()); + EXPECT_THAT(result.ReleaseAst(), IsOkAndHolds(NotNull())); +} + +TEST(ValidationResultTest, IsNotValidWithoutAst) { + ValidationResult result({}); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetAst(), IsNull()); + EXPECT_THAT(result.ReleaseAst(), + StatusIs(absl::StatusCode::kFailedPrecondition, _)); +} + +TEST(ValidationResultTest, GetIssues) { + ValidationResult result( + {TypeCheckIssue::CreateError({-1, -1}, "Issue1"), + TypeCheckIssue(Severity::kInformation, {-1, -1}, "Issue2")}); + EXPECT_FALSE(result.IsValid()); + + ASSERT_THAT(result.GetIssues(), SizeIs(2)); + + EXPECT_THAT(result.GetIssues()[0].message(), "Issue1"); + EXPECT_THAT(result.GetIssues()[0].severity(), Severity::kError); + + EXPECT_THAT(result.GetIssues()[1].message(), "Issue2"); + EXPECT_THAT(result.GetIssues()[1].severity(), Severity::kInformation); +} + +TEST(ValidationResultTest, FormatError) { + ValidationResult result( + {TypeCheckIssue::CreateError({1, 2}, "Issue1"), + TypeCheckIssue(Severity::kInformation, {-1, -1}, "Issue2")}); + EXPECT_FALSE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr source, + NewSource("source.cel", "")); + result.SetSource(std::move(source)); + + ASSERT_THAT(result.GetIssues(), SizeIs(2)); + + EXPECT_THAT(result.FormatError(), + "ERROR: :1:3: Issue1\n" + " | source.cel\n" + " | ..^\n" + "INFORMATION: :-1:-1: Issue2"); +} + +} // namespace +} // namespace cel diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 5845145b7..dec359f25 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -1,9 +1,41 @@ steps: -- name: 'gcr.io/cel-analysis/bazel:bionic-3.0.0' - entrypoint: bazel - args: ['test', '--test_output=errors', '...'] - id: bazel-test +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' + args: + - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. + - 'test' + - '...' + - '--enable_bzlmod' + - '--copt=-Wno-deprecated-declarations' + - '--compilation_mode=fastbuild' + - '--test_output=errors' + - '--show_timestamps' + - '--test_tag_filters=-benchmark,-notap' + - '--jobs=HOST_CPUS*.5' + - '--local_ram_resources=HOST_RAM*.4' + - '--remote_cache=https://storage.googleapis.com/cel-cpp-remote-cache' + - '--google_default_credentials' + id: gcc-9 + waitFor: ['-'] +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' + env: + - 'CC=clang-11' + - 'CXX=clang++-11' + args: + - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. + - 'test' + - '...' + - '--enable_bzlmod' + - '--copt=-Wno-deprecated-declarations' + - '--compilation_mode=fastbuild' + - '--test_output=errors' + - '--show_timestamps' + - '--test_tag_filters=-benchmark,-notap' + - '--jobs=HOST_CPUS*.5' + - '--local_ram_resources=HOST_RAM*.4' + - '--remote_cache=https://storage.googleapis.com/cel-cpp-remote-cache' + - '--google_default_credentials' + id: clang-11 waitFor: ['-'] timeout: 1h options: - machineType: 'N1_HIGHCPU_8' + machineType: 'E2_HIGHCPU_32' diff --git a/codelab/BUILD b/codelab/BUILD new file mode 100644 index 000000000..69c2825e2 --- /dev/null +++ b/codelab/BUILD @@ -0,0 +1,302 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +exports_files( + srcs = glob([ + "exercise*.h", + "exercise*_test.cc", + ]), + visibility = ["//codelab/solutions:__pkg__"], +) + +# Exclude tests from tap and glob runs since they start failing for the codelab. +# The solutions directory has test targets that are included to catch breaking changes. +EXERCISE_TEST_TAGS = [ + "manual", + "notap", + "norapid", +] + +cc_library( + name = "exercise1", + srcs = ["exercise1.cc"], + hdrs = ["exercise1.h"], + tags = [ + "manual", + "nobuilder", + ], + deps = [ + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:status_macros", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise1_test", + srcs = ["exercise1_test.cc"], + tags = EXERCISE_TEST_TAGS, + deps = [ + ":exercise1", + "//internal:testing", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "exercise2", + srcs = ["exercise2.cc"], + hdrs = ["exercise2.h"], + deps = [ + ":cel_compiler", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise2_test", + srcs = ["exercise2_test.cc"], + tags = EXERCISE_TEST_TAGS, + deps = [ + ":exercise2", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise3_test", + srcs = ["exercise3_test.cc"], + tags = EXERCISE_TEST_TAGS, + deps = [ + ":exercise2", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + ], +) + +cc_library( + name = "cel_compiler", + hdrs = ["cel_compiler.h"], + deps = [ + "//checker:validation_result", + "//common:ast_proto", + "//compiler", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + ], +) + +cc_test( + name = "cel_compiler_test", + srcs = ["cel_compiler_test.cc"], + deps = [ + ":cel_compiler", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//eval/public:activation", + "//eval/public:activation_bind_helper", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_function_adapter", + "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//internal:testing", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "exercise4", + srcs = ["exercise4.cc"], + hdrs = ["exercise4.h"], + deps = [ + ":cel_compiler", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//eval/public:activation", + "//eval/public:activation_bind_helper", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise4_test", + srcs = ["exercise4_test.cc"], + tags = EXERCISE_TEST_TAGS, + deps = [ + ":exercise4", + "//internal:testing", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_library( + name = "network_functions", + srcs = ["network_functions.cc"], + hdrs = ["network_functions.h"], + deps = [ + "//checker:type_checker_builder", + "//common:decl", + "//common:native_type", + "//common:type", + "//common:typeinfo", + "//common:value", + "//compiler", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime:type_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "network_functions_test", + srcs = ["network_functions_test.cc"], + deps = [ + ":network_functions", + "//common:decl", + "//common:minimal_descriptor_pool", + "//common:type", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:benchmark", + "//internal:status_macros", + "//internal:testing", + "//runtime", + "//runtime:activation", + "//runtime:constant_folding", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "exercise10", + srcs = ["exercise10.cc"], + hdrs = ["exercise10.h"], + deps = [ + ":network_functions", + "//checker:validation_result", + "//common:decl", + "//common:minimal_descriptor_pool", + "//common:type", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise10_test", + srcs = ["exercise10_test.cc"], + tags = EXERCISE_TEST_TAGS, + deps = [ + ":exercise10", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) diff --git a/codelab/Dockerfile b/codelab/Dockerfile new file mode 100644 index 000000000..c98a08f39 --- /dev/null +++ b/codelab/Dockerfile @@ -0,0 +1,19 @@ +ARG DEBIAN_IMAGE="marketplace.gcr.io/google/debian11:latest" +FROM ${DEBIAN_IMAGE} + +ARG BAZELISK_RELEASE="https://github.com/bazelbuild/bazelisk/releases/download/v1.25.0/bazelisk-amd64.deb" + +RUN apt update && apt upgrade -y && apt install -y gcc-9 g++-9 clang-13 git curl bash openjdk-11-jdk-headless + +RUN curl -L ${BAZELISK_RELEASE} > ./bazelisk.deb +RUN apt install ./bazelisk.deb + +RUN git clone https://github.com/google/cel-cpp.git + +ENV CXX=clang++-13 +ENV CC=clang-13 + +WORKDIR /cel-cpp +# not generally recommended to cache the bazel build in the image, +# but works ok for prototyping. +RUN bazelisk build ... && bazelisk test //codelab/solutions:all \ No newline at end of file diff --git a/codelab/README.md b/codelab/README.md new file mode 100644 index 000000000..96f7598ba --- /dev/null +++ b/codelab/README.md @@ -0,0 +1,328 @@ +# What is CEL? +Common Expression Language (CEL) is an expression language that’s fast, portable, and safe to execute in performance-critical applications. CEL is designed to be embedded in an application, with application-specific extensions, and is ideal for extending declarative configurations that your applications might already use. + +## What is covered in this Codelab? +This codelab is aimed at developers who would like to learn CEL to use services that already support CEL. This Codelab covers common use cases. This codelab doesn't cover how to integrate CEL into your own project. For a more in-depth look at the language, semantics, and features see the [CEL Language Definition on GitHub](https://github.com/google/cel-spec). + +Some key areas covered are: + +* [Hello, World: Using CEL to evaluate a String](#hello-world) +* [Creating variables](#creating-variables) +* [Commutative logical AND/OR](#logical-andor) +* [Adding custom functions](#custom-functions) + +### Prerequisites +This codelab builds upon a basic understanding of Protocol Buffers and C++. + +If you're not familiar with Protocol Buffers, the first exercise will give you a sense of how CEL works, but because the more advanced examples use Protocol Buffers as the input into CEL, they may be harder to understand. Consider working through one of these tutorials, first. See the devsite for [Protocol Buffers](https://protobuf.dev). + +Notes on portability: Protocol Buffers are not required to use CEL +generally, but the C++ implementation has a hard dependency on the library +and some APIs reference protobuf types directly. Automated builds test +against gcc9 and clang11 on linux. We accept requests for portability +fixes for other OSes and compilers, but don't actively maintain support at +this time. A simple Docker file is provided as a reference for a known good +environment configuration for running the codelab solutions. + +What you'll need: + +- Git +- Bazel +- C/C++ Compiler (GCC, Clang, Visual Studio). +- Optional: bazelisk is a wrapper around bazel that simplifies version + management. If using, substitute all bazel commands below with `bazelisk`. + +## GitHub Setup + +GitHub Repo: + +The code for this codelab lives in the `codelab` folder of the cel-cpp repo. The solution is available in the `codelab/solution` folder of the same repo. + +Clone and cd into the repo: + +``` +git clone git@github.com:google/cel-cpp.git +cd cel-cpp +``` + +Make sure everything is working by building the codelab: + +``` +bazel build //codelab:all +``` + +## Hello, World +In the tried and true tradition of all programming languages, let's start with "Hello, World!". + +Update exercise1.cc with the following: + +Using declarations: + +```c++ +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelExpression; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +``` + +Implementation: + +```c++ +absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr) +{ + // === Start Codelab === + // Setup a default environment for building expressions. + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + // Parse the expression. This is fine for codelabs, but this skips the type + // checking phase. It won't check that functions and variables are available + // in the environment, and it won't handle certain ambiguous identifier + // expressions (e.g. container lookup vs namespaced name, packaged function + // vs. receiver call style function). + ParsedExpr parsed_expr; + CEL_ASSIGN_OR_RETURN(parsed_expr, Parse(cel_expr)); + + // The evaluator uses a proto Arena for incidental allocations during + // evaluation. + proto2::Arena arena; + // The activation provides variables and functions that are bound into the + // expression environment. In this example, there's no context expected, so + // we just provide an empty one to the evaluator. + Activation activation; + + // Build the expression plan. This assumes that the source expression AST and + // the expression builder outlives the CelExpression object. + CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + // Actually run the expression plan. We don't support any environment + // variables at the moment so just use an empty activation. + CEL_ASSIGN_OR_RETURN(CelValue result, + expression_plan->Evaluate(activation, &arena)); + + // Convert the result to a c++ string. CelValues may reference instances from + // either the input expression, or objects allocated on the arena, so we need + // to pass ownership (in this case by copying to a new instance and returning + // that). + return ConvertResult(result); + // === End Codelab === +} +``` + +Run the following to check your work: + +``` +bazel test //codelab:exercise1_test +``` + +You can add additional test cases or experiment with different return types. + +Hello, World! Now, let's break down what's happening. + + +### Setup the Environment +CEL applications evaluate an expression against an environment. + +The standard CEL environment supports all of the types, operators, functions, and macros defined within the language spec. The environment can be customized by providing options to disable macros, declare custom variables and functions, etc. + +An ExpressionBuilder maintains C++ evaluation environment. This creates a builder with the standard environment. + +```c++ +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_options.h" +... +// Setup a default environment for building expressions. + +// Breaking behavior changes and optional features are controlled by +// InterpreterOptions. +InterpreterOptions options; + +// Environment used for planning and evaluating expressions is managed by an +// ExpressionBuilder. +std::unique_ptr builder = + CreateCelExpressionBuilder(options); + +// Add standard function bindings e.g. for +,-,==,||,&& operators. +// Custom functions (implementing the CelFunction interface) can be added to the +// registry similarly. +CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); +``` + +### Parse +After the environment is configured, you can parse and check the expressions: + +```c++ +#include "google/api/expr/syntax.proto.h" +#include "parser/parser.h" +// ... +ASSIGN_OR_RETURN(google::api::expr::ParsedExpr parsed_expr, google::api::expr::parser::Parse(cel_expr)); +``` + +The C++ parser is a stand-alone utility. It's not aware of the evaluation environment and does not perform any semantic checks on the expression. A status is returned if the input string isn't a syntactically valid CEL expression or if it exceeds the configured complexity limits (see cel::ParserOptions and default limits). + +### Evaluate +After the expressions have been parsed and checked into an AST representation, it can be converted into an evaluable program whose function bindings and evaluation modes can be customized depending on the stack you are using. +Once a CEL expression is planned, it can be evaluated against an evaluation context (an activation). The evaluation result will be either a value or an error state. +The InterpreterOptions to create the expression plan are honored at evaluation. C++ uses the proto representation of either a parsed `google.api.expr.ParsedExpr` or parsed and type-checked `google.api.expr.CheckedExpr` AST directly. +Once a CEL program is planned (represented by a `google::api::expr::runtime::CelExpression`), it can be evaluated against an `google::api::expr::runtime::Activation`. The Activation provides per-evaluation bindings for variables and functions in the expression's environment. + +```c++ +#include "third_party/protobuf/arena.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "parser/parser.h" +... +// The evaluator uses a proto Arena for incidental allocations during +// evaluation. +proto2::Arena arena; +// The activation provides variables and functions that are bound into the +// expression environment. In this example, there's no context expected, so +// we just provide an empty one to the evaluator. +Activation activation; + +// Build the expression plan. This assumes that the source expression AST and +// the expression builder outlives the CelExpression object. +CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + +// Actually run the expression plan. We don't support any environment +// variables at the moment so just use an empty activation. +CEL_ASSIGN_OR_RETURN(CelValue result, + expression_plan->Evaluate(activation, &arena)); + +// Convert the result to a C++ string. CelValues may reference instances from +// either the input expression, or objects allocated on the arena, so we need +// to pass ownership (in this case by copying to a new instance and returning +// that). +return ConvertResult(result); +``` + +## Creating variables +Most CEL applications will declare variables that can be referenced within expressions. Variables declarations specify a name and a type. A variable's type may either be a CEL builtin type, a protocol buffer well-known type, or any protobuf message type so long as its descriptor is also provided to CEL. + +At runtime, the hosting program binds instances of variables to the evaluation context (using the variable name as a key). + +For the C++ evaluator at runtime, the values are managed by the `google::api::expr::runtime::CelValue` type, a variant over the C++ representations of supported CEL types. + +Update exercise2.cc: + +```c++ +// The Variables exercise shows how to declare and use variables in expressions. +// There are two overloads for preparing an expression either granularly for +// individual variables or using a helper to bind a context proto. + +// The first overload shows manually populating individual variables in the +// evaluation environment. This allows cel_expr to reference 'bool_var'. +absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr, + bool bool_var) { + Activation activation; + proto2::Arena arena; + // === Start Codelab === + activation.InsertValue("bool_var", CelValue::CreateBool(bool_var)); + // === End Codelab === + + return ParseAndEvaluate(cel_expr, activation, &arena); +} +``` + +Run the following to check your work. You should have fixed the first two test cases in exercise2_test.cc. + +``` +bazel test //codelab:exercise2_test +``` + +The second overload uses a protocol buffer message to represent the environment variables. For this use case, there is a helper to automatically bind in fields from a top level message (see `google::api::expr::runtime::BindProtoToActivation`). In this example, we assume that unset fields should be bound to default values. + +```c++ +#include "eval/public/activation_bind_helper.h" +// ... +using ::google::api::expr::runtime::ProtoUnsetFieldOptions; +// ... +absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr, + const AttributeContext& context) { + Activation activation; + google::protobuf::Arena arena; + // === Start Codelab === + + CEL_RETURN_IF_ERROR(BindProtoToActivation( + &context, &arena, &activation, ProtoUnsetFieldOptions::kBindDefault)); + // === End Codelab === + + return ParseAndEvaluate(cel_expr, activation, &arena); +} +``` + +Note: You can experiment with unset values and the alternative bind option for BindProtoToActivation. With ProtoUnsetFieldOptions::kSkip unset values will not be bound at all, and accesses in expressions will cause errors. + +## Logical And/Or +One of CEL's more distinctive features is its use of commutative logical operators. Either side of a conditional branch can short-circuit the evaluation, even in the face of errors or partial input. +Note: If you are skipping ahead, copy the solution for exercise2 -- we'll be using it to test the behavior of some simple expressions. + +exercise3_test.cc lists truth tables for simple expressions using the 'or', 'and', and 'ternary' operators. + +Running the following should result in some failing expectations. + +``` +bazel test //codelab:exercise3_test +``` + +Open exercise3_test.cc in your editor: + +```c++ +TEST(Exercise3Var, LogicalOr) { + // Some of these expectations are incorrect. + // If a logical operation can short-circuit a branch that results in an error, + // CEL evaluation will return the logical result instead of propagating the + // error. For logical or, this means if one branch is true, the result will + // always be true, regardless of the other branch. + // Wrong + EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + // Wrong + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("true || true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("true || false"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || false"), IsOkAndHolds(false)); +} +``` + +Updating the two failing cases "true || (1 / 0 > 2)" and "(1 / 0 > 2) || true" should fix this test: + +```c++ +// ... + // Correct + EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), + IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + // Correct + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), + IsOkAndHolds(true)); +``` + +You can examine the other tests for other cases for corresponding behavior for the 'and' and ternary operators. + +CEL finds an evaluation order which gives results whenever possible, ignoring errors or even missing data that might occur in other evaluation orders. Applications like IAM conditions rely on this property to minimize the cost of evaluation, deferring the gathering of expensive inputs when a result can be reached without them. diff --git a/codelab/cel_compiler.h b/codelab/cel_compiler.h new file mode 100644 index 000000000..0ff2f699b --- /dev/null +++ b/codelab/cel_compiler.h @@ -0,0 +1,47 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CODELAB_COMPILER_H_ +#define THIRD_PARTY_CEL_CPP_CODELAB_COMPILER_H_ + +#include "cel/expr/checked.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "common/ast_proto.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" + +namespace cel_codelab { + +// Helper for compiling expression and converting to proto. +// +// Simplifies error handling for brevity in the codelab. +inline absl::StatusOr CompileToCheckedExpr( + const cel::Compiler& compiler, absl::string_view expr) { + CEL_ASSIGN_OR_RETURN(cel::ValidationResult result, compiler.Compile(expr)); + + if (!result.IsValid() || result.GetAst() == nullptr) { + return absl::InvalidArgumentError(result.FormatError()); + } + + cel::expr::CheckedExpr pb; + CEL_RETURN_IF_ERROR(cel::AstToCheckedExpr(*result.GetAst(), &pb)); + return pb; +}; + +} // namespace cel_codelab + +#endif // THIRD_PARTY_CEL_CPP_CODELAB_COMPILER_H_ diff --git a/codelab/cel_compiler_test.cc b/codelab/cel_compiler_test.cc new file mode 100644 index 000000000..635b4d54d --- /dev/null +++ b/codelab/cel_compiler_test.cc @@ -0,0 +1,146 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/cel_compiler.h" + +#include +#include + +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "eval/public/activation.h" +#include "eval/public/activation_bind_helper.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_function_adapter.h" +#include "eval/public/cel_value.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::BoolType; +using ::cel::MakeFunctionDecl; +using ::cel::MakeOverloadDecl; +using ::cel::MakeVariableDecl; +using ::cel::StringType; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::BindProtoToActivation; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::FunctionAdapter; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +using ::google::api::expr::runtime::test::IsCelBool; +using ::google::rpc::context::AttributeContext; +using ::testing::HasSubstr; + +std::unique_ptr MakeDefaultCompilerBuilder() { + google::protobuf::LinkMessageReflection(); + auto builder = + cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool()); + ABSL_CHECK_OK(builder.status()); + + ABSL_CHECK_OK((*builder)->AddLibrary(cel::StandardCompilerLibrary())); + ABSL_CHECK_OK((*builder)->GetCheckerBuilder().AddContextDeclaration( + "google.rpc.context.AttributeContext")); + + return std::move(builder).value(); +} + +TEST(DefaultCompiler, Basic) { + ASSERT_OK_AND_ASSIGN(auto compiler, MakeDefaultCompilerBuilder()->Build()); + EXPECT_THAT(compiler->Compile("1 < 2").status(), IsOk()); +} + +TEST(DefaultCompiler, AddFunctionDecl) { + auto builder = MakeDefaultCompilerBuilder(); + ASSERT_OK_AND_ASSIGN( + cel::FunctionDecl decl, + MakeFunctionDecl("IpMatch", + MakeOverloadDecl("IpMatch_string_string", BoolType(), + StringType(), StringType()))); + EXPECT_THAT(builder->GetCheckerBuilder().AddFunction(decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + EXPECT_THAT(CompileToCheckedExpr( + *compiler, "IpMatch('255.255.255.255', '255.255.255.255')") + .status(), + IsOk()); + EXPECT_THAT( + CompileToCheckedExpr(*compiler, "IpMatch('255.255.255.255', 123436)") + .status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("no matching overload"))); +} + +TEST(DefaultCompiler, EndToEnd) { + google::protobuf::Arena arena; + + auto compiler_builder = MakeDefaultCompilerBuilder(); + ASSERT_OK_AND_ASSIGN( + cel::FunctionDecl func_decl, + MakeFunctionDecl("MyFunc", MakeOverloadDecl("MyFunc", BoolType()))); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(func_decl), + IsOk()); + + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("my_var", BoolType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + auto expr, + CompileToCheckedExpr( + *compiler, + "(my_var || MyFunc()) && request.host == 'www.google.com'")); + + auto builder = + CreateCelExpressionBuilder(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); + ASSERT_THAT(FunctionAdapter::CreateAndRegister( + "MyFunc", false, [](google::protobuf::Arena*) { return true; }, + builder->GetRegistry()), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto plan, builder->CreateExpression(&expr)); + + AttributeContext context; + context.mutable_request()->set_host("www.google.com"); + Activation activation; + ASSERT_THAT(BindProtoToActivation(&context, &arena, &activation), IsOk()); + activation.InsertValue("my_var", CelValue::CreateBool(false)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + + EXPECT_THAT(result, IsCelBool(true)); +} + +} // namespace +} // namespace cel_codelab diff --git a/codelab/exercise1.cc b/codelab/exercise1.cc new file mode 100644 index 000000000..de7ccf6e0 --- /dev/null +++ b/codelab/exercise1.cc @@ -0,0 +1,84 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise1.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace cel_codelab { +namespace { + +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelValue; + +// Convert the CelResult to a C++ string if it is string typed. Otherwise, +// return invalid argument error. This takes a copy to avoid lifecycle concerns +// (the evaluator may represent strings as stringviews backed by the input +// expression). +absl::StatusOr ConvertResult(const CelValue& value) { + if (CelValue::StringHolder inner_value; value.GetValue(&inner_value)) { + return std::string(inner_value.value()); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "expected string result got '", CelValue::TypeName(value.type()), "'")); + } +} +} // namespace + +absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr) { + // === Start Codelab === + // Parse the expression using ::google::api::expr::parser::Parse; + // This will return a cel::expr::ParsedExpr message. + + // Setup a default environment for building expressions. + // std::unique_ptr builder = + // CreateCelExpressionBuilder(options); + + // Register standard functions. + // CEL_RETURN_IF_ERROR( + // RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + // The evaluator uses a proto Arena for incidental allocations during + // evaluation. + google::protobuf::Arena arena; + // The activation provides variables and functions that are bound into the + // expression environment. In this example, there's no context expected, so + // we just provide an empty one to the evaluator. + Activation activation; + + // Using the CelExpressionBuilder and the ParseExpr, create an execution plan + // (google::api::expr::runtime::CelExpression), evaluate, and return the + // result. Use the provided helper function ConvertResult to copy the value + // for return. + return absl::UnimplementedError("Not yet implemented"); + // === End Codelab === +} + +} // namespace cel_codelab diff --git a/codelab/exercise1.h b/codelab/exercise1.h new file mode 100644 index 000000000..327e7a629 --- /dev/null +++ b/codelab/exercise1.h @@ -0,0 +1,32 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ +#define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel_codelab { + +// Parse a cel expression and evaluate it. This assumes no special setup for +// the evaluation environment, and that the expression results in a string +// value. +absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr); + +} // namespace cel_codelab + +#endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ diff --git a/codelab/exercise10.cc b/codelab/exercise10.cc new file mode 100644 index 000000000..37eaa7642 --- /dev/null +++ b/codelab/exercise10.cc @@ -0,0 +1,126 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise10.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "codelab/network_functions.h" +#include "common/decl.h" +#include "common/minimal_descriptor_pool.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel_codelab { + +namespace { + +absl::StatusOr> ConfigureCompiler() { + absl::StatusOr> compiler_builder = + cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool()); + if (!compiler_builder.ok()) { + return std::move(compiler_builder).status(); + } + absl::Status s = + (*compiler_builder)->AddLibrary(cel::StandardCompilerLibrary()); + // =========================================================================== + // Codelab: Update compiler builder with functions from network_functions.h + // and add a varible for the input IP. + // =========================================================================== + if (!s.ok()) return s; + + return (*compiler_builder)->Build(); +} + +absl::StatusOr> ConfigureRuntime() { + cel::RuntimeOptions runtime_options; + // Note: this is needed to resolve net.Address as a `type` constant. + runtime_options.enable_qualified_type_identifiers = true; + absl::StatusOr runtime_builder = + cel::CreateStandardRuntimeBuilder(cel::GetMinimalDescriptorPool(), + runtime_options); + // =========================================================================== + // Codelab: Update runtime builder with functions from network_functions.h + // =========================================================================== + return std::move(runtime_builder).value().Build(); +} + +} // namespace + +absl::StatusOr CompileAndEvaluateExercise10(absl::string_view expression, + absl::string_view ip) { + absl::StatusOr> compiler = ConfigureCompiler(); + if (!compiler.ok()) { + return std::move(compiler).status(); + } + + absl::StatusOr> runtime = ConfigureRuntime(); + if (!runtime.ok()) { + return std::move(runtime).status(); + } + + absl::StatusOr checked = + (*compiler)->Compile(expression); + if (!checked.ok()) { + return std::move(checked).status(); + } + + if (!checked->IsValid() || checked->GetAst() == nullptr) { + return absl::InvalidArgumentError(checked->FormatError()); + } + + absl::StatusOr> program = + (*runtime)->CreateProgram(checked->ReleaseAst().value()); + + if (!program.ok()) { + return std::move(program).status(); + } + + cel::Activation activation; + google::protobuf::Arena arena; + activation.InsertOrAssignValue("ip", cel::StringValue::From(ip, &arena)); + absl::StatusOr result = (*program)->Evaluate(&arena, activation); + + if (!result.ok()) { + return std::move(result).status(); + } + + if (result->IsBool()) { + return result->GetBool(); + } + + if (result->IsError()) { + return result->GetError().ToStatus(); + } + + return absl::InvalidArgumentError( + absl::StrCat("unexpected result type: ", result->DebugString())); +} + +} // namespace cel_codelab diff --git a/codelab/exercise10.h b/codelab/exercise10.h new file mode 100644 index 000000000..c196441e9 --- /dev/null +++ b/codelab/exercise10.h @@ -0,0 +1,46 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE10_H_ +#define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE10_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel_codelab { + +// Exercise10 -- extension types. +// +// This function compiles an expression then evaluates, expecting a bool +// return type. +// +// Example: +// net.ParseAddressMatcher("8.8.0.0-8.8.255.255") +// .containsAddress( +// net.parseAddress(ip) +// ) +// +// Variables: +// ip - string +// +// Functions: +// net.ParseAddress(string) -> net.Address +// net.ParseAddressMatcher(string) -> net.AddressMatcher +// (net.AddressMatcher). +absl::StatusOr CompileAndEvaluateExercise10(absl::string_view expression, + absl::string_view ip); + +} // namespace cel_codelab + +#endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE10_H_ diff --git a/codelab/exercise10_test.cc b/codelab/exercise10_test.cc new file mode 100644 index 000000000..7e7044aad --- /dev/null +++ b/codelab/exercise10_test.cc @@ -0,0 +1,81 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise10.h" + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "internal/testing.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::HasSubstr; + +TEST(Exercise10, IpInRange) { + EXPECT_THAT(CompileAndEvaluateExercise10( + R"cel( + net.parseAddressMatcher("8.8.4.0-8.8.4.255") + .containsAddress( + net.parseAddress(ip) + ) + )cel", + "8.8.4.4"), + IsOkAndHolds(true)); +} + +TEST(Exercise10, IpNotInRange) { + EXPECT_THAT(CompileAndEvaluateExercise10( + R"cel( + net.parseAddressMatcher("8.8.4.0-8.8.4.255") + .containsAddress( + net.parseAddress(ip) + ) + )cel", + "8.8.8.8"), + IsOkAndHolds(false)); +} + +TEST(Exercise10, IpEqual) { + EXPECT_THAT(CompileAndEvaluateExercise10( + R"cel( + net.parseAddress("8.8.4.4") == net.parseAddress(ip) + )cel", + "8.8.4.4"), + IsOkAndHolds(true)); +} + +TEST(Exercise10, IpInequal) { + EXPECT_THAT(CompileAndEvaluateExercise10( + R"cel( + net.parseAddress("8.8.4.4") == net.parseAddress(ip) + )cel", + "8.8.8.8"), + IsOkAndHolds(false)); +} + +TEST(Exercise10, IpInvalid) { + EXPECT_THAT(CompileAndEvaluateExercise10( + R"cel( + net.parseAddress("8.8.4.4") == net.parseAddress(ip) + )cel", + "8.8"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid address"))); +} + +} // namespace +} // namespace cel_codelab diff --git a/codelab/exercise1_test.cc b/codelab/exercise1_test.cc new file mode 100644 index 000000000..fab15aed1 --- /dev/null +++ b/codelab/exercise1_test.cc @@ -0,0 +1,43 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise1.h" + +#include "absl/status/status.h" +#include "internal/testing.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; + +TEST(Exercise1, PrintHelloWorld) { + EXPECT_THAT(ParseAndEvaluate("'Hello, World!'"), + IsOkAndHolds("Hello, World!")); +} + +TEST(Exercise1, WrongTypeResultError) { + EXPECT_THAT(ParseAndEvaluate("true"), + StatusIs(absl::StatusCode::kInvalidArgument, + "expected string result got 'bool'")); +} + +TEST(Exercise1, Conditional) { + EXPECT_THAT(ParseAndEvaluate("(1 < 0)? 'Hello, World!' : '¡Hola, Mundo!'"), + IsOkAndHolds("¡Hola, Mundo!")); +} + +} // namespace +} // namespace cel_codelab diff --git a/codelab/exercise2.cc b/codelab/exercise2.cc new file mode 100644 index 000000000..373f63365 --- /dev/null +++ b/codelab/exercise2.cc @@ -0,0 +1,143 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise2.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "codelab/cel_compiler.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel_codelab { +namespace { + +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelError; +using ::google::api::expr::runtime::CelExpression; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +using ::google::rpc::context::AttributeContext; + +absl::StatusOr> MakeCelCompiler() { + // Note: we are using the generated descriptor pool here for simplicity, but + // it has the drawback of including all message types that are linked into the + // binary instead of just the ones expected for the CEL environment. + google::protobuf::LinkMessageReflection(); + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + // === Start Codelab === + // Add 'AttributeContext' as a context message to the type checker and a + // boolean variable 'bool_var'. Relevant functions are on the + // TypeCheckerBuilder class (see CompilerBuilder::GetCheckerBuilder). + // + // We're reusing the same compiler for both evaluation paths here for brevity, + // but it's likely a better fit to configure a separate compiler per use case. + // === End Codelab === + + return builder->Build(); +} + +// Parse a cel expression and evaluate it against the given activation and +// arena. +absl::StatusOr EvalCheckedExpr(const CheckedExpr& checked_expr, + const Activation& activation, + google::protobuf::Arena* arena) { + // Setup a default environment for building expressions. + InterpreterOptions options; + std::unique_ptr builder = CreateCelExpressionBuilder( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), options); + CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + // Note, the expression_plan below is reusable for different inputs, but we + // create one just in time for evaluation here. + CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, + builder->CreateExpression(&checked_expr)); + + CEL_ASSIGN_OR_RETURN(CelValue result, + expression_plan->Evaluate(activation, arena)); + + if (bool value; result.GetValue(&value)) { + return value; + } else if (const CelError * value; result.GetValue(&value)) { + return *value; + } else { + return absl::InvalidArgumentError(absl::StrCat( + "expected 'bool' result got '", result.DebugString(), "'")); + } +} +} // namespace + +absl::StatusOr CompileAndEvaluateWithBoolVar(absl::string_view cel_expr, + bool bool_var) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + MakeCelCompiler()); + + CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, + CompileToCheckedExpr(*compiler, cel_expr)); + + Activation activation; + google::protobuf::Arena arena; + // === Start Codelab === + // Update the activation to bind the bool argument to 'bool_var' + // === End Codelab === + + return EvalCheckedExpr(checked_expr, activation, &arena); +} + +absl::StatusOr CompileAndEvaluateWithContext( + absl::string_view cel_expr, const AttributeContext& context) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + MakeCelCompiler()); + + CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, + CompileToCheckedExpr(*compiler, cel_expr)); + + Activation activation; + google::protobuf::Arena arena; + // === Start Codelab === + // Update the activation to bind the AttributeContext. + // === End Codelab === + + return EvalCheckedExpr(checked_expr, activation, &arena); +} + +} // namespace cel_codelab diff --git a/codelab/exercise2.h b/codelab/exercise2.h new file mode 100644 index 000000000..d4836dc2b --- /dev/null +++ b/codelab/exercise2.h @@ -0,0 +1,40 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ +#define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ + +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel_codelab { + +// Compile a cel expression and evaluate it. Binds a simple boolean to the +// activation as 'bool_var' for use in the expression. +// +// cel_expr should result in a bool, otherwise an InvalidArgument error is +// returned. +absl::StatusOr CompileAndEvaluateWithBoolVar(absl::string_view cel_expr, + bool bool_var); + +// Compile a cel expression and evaluate it. Binds an instance of the +// AttributeContext message to the activation (binding the subfields directly). +absl::StatusOr CompileAndEvaluateWithContext( + absl::string_view cel_expr, + const google::rpc::context::AttributeContext& context); + +} // namespace cel_codelab + +#endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ diff --git a/codelab/exercise2_test.cc b/codelab/exercise2_test.cc new file mode 100644 index 000000000..ced44faaa --- /dev/null +++ b/codelab/exercise2_test.cc @@ -0,0 +1,82 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise2.h" + +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "internal/testing.h" +#include "google/protobuf/text_format.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::google::rpc::context::AttributeContext; +using ::google::protobuf::TextFormat; +using ::testing::HasSubstr; + +TEST(Exercise2Var, Simple) { + EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var", false), + IsOkAndHolds(false)); + EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var", true), + IsOkAndHolds(true)); + EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var || true", false), + IsOkAndHolds(true)); + EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var && false", true), + IsOkAndHolds(false)); +} + +TEST(Exercise2Var, WrongTypeResultError) { + EXPECT_THAT(CompileAndEvaluateWithBoolVar("'not a bool'", false), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected 'bool' result got 'string"))); +} + +TEST(Exercise2Context, Simple) { + AttributeContext context; + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + source { ip: "192.168.28.1" } + request { host: "www.example.com" } + destination { ip: "192.168.56.1" } + )pb", + &context)); + + EXPECT_THAT( + CompileAndEvaluateWithContext("source.ip == '192.168.28.1'", context), + IsOkAndHolds(true)); + EXPECT_THAT(CompileAndEvaluateWithContext("request.host == 'api.example.com'", + context), + IsOkAndHolds(false)); + EXPECT_THAT(CompileAndEvaluateWithContext("request.host == 'www.example.com'", + context), + IsOkAndHolds(true)); + EXPECT_THAT(CompileAndEvaluateWithContext("destination.ip != '192.168.56.1'", + context), + IsOkAndHolds(false)); +} + +TEST(Exercise2Context, WrongTypeResultError) { + AttributeContext context; + + // For this codelab, we expect the bind default option which will return + // proto api defaults for unset fields. + EXPECT_THAT(CompileAndEvaluateWithContext("request.host", context), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected 'bool' result got 'string"))); +} + +} // namespace +} // namespace cel_codelab diff --git a/codelab/exercise3_test.cc b/codelab/exercise3_test.cc new file mode 100644 index 000000000..e1d2d5920 --- /dev/null +++ b/codelab/exercise3_test.cc @@ -0,0 +1,115 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "codelab/exercise2.h" +#include "internal/testing.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::google::rpc::context::AttributeContext; + +// Helper for a simple CelExpression with no context. +absl::StatusOr TruthTableTest(absl::string_view statement) { + return CompileAndEvaluateWithBoolVar(statement, /*unused*/ false); +} + +TEST(Exercise3, LogicalOr) { + // Some of these expectations are incorrect. + // If a logical operation can short-circuit a branch that results in an error, + // CEL evaluation will return the logical result instead of propagating the + // error. For logical or, this means if one branch is true, the result will + // always be true, regardless of the other branch. + // Wrong + EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + // Wrong + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("true || true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("true || false"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || false"), IsOkAndHolds(false)); +} + +TEST(Exercise3, LogicalAnd) { + EXPECT_THAT(TruthTableTest("true && (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + // Wrong + EXPECT_THAT(TruthTableTest("false && (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && true"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + // Wrong + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("true && true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("true && false"), IsOkAndHolds(false)); + EXPECT_THAT(TruthTableTest("false && true"), IsOkAndHolds(false)); + EXPECT_THAT(TruthTableTest("false && false"), IsOkAndHolds(false)); +} + +TEST(Exercise3, Ternary) { + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) ? false : false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("true ? (1 / 0 > 2) : false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + // Wrong + EXPECT_THAT(TruthTableTest("false ? (1 / 0 > 2) : false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); +} + +TEST(Exercise3, BadFieldAccess) { + AttributeContext context; + + // This type of error is normally caught by the type checker, to allow + // it to surface here we use the dyn() operator to defer checking to runtime. + // typo-ed field name from 'request.host' + EXPECT_THAT( + CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' && true", context), + StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); + // Wrong + EXPECT_THAT( + CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' && false", context), + StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); + + // Wrong + EXPECT_THAT( + CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' || true", context), + StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); + EXPECT_THAT( + CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' || false", context), + StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); +} + +} // namespace +} // namespace cel_codelab diff --git a/codelab/exercise4.cc b/codelab/exercise4.cc new file mode 100644 index 000000000..cf02a88bd --- /dev/null +++ b/codelab/exercise4.cc @@ -0,0 +1,132 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise4.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "codelab/cel_compiler.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "eval/public/activation.h" +#include "eval/public/activation_bind_helper.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel_codelab { +namespace { + +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::BindProtoToActivation; +using ::google::api::expr::runtime::CelError; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +using ::google::rpc::context::AttributeContext; + +absl::StatusOr> MakeConfiguredCompiler() { + // Setup for handling for protobuf types. + // Using the generated descriptor pool is simpler to configure, but often + // adds more types than necessary. + google::protobuf::LinkMessageReflection(); + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + // Adds fields of AttributeContext as variables. + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddContextDeclaration( + AttributeContext::descriptor()->full_name())); + + // Codelab part 1: + // Add a declaration for the map.contains(string, V) function. + // Hint: use cel::MakeFunctionDecl and cel::TypeCheckerBuilder::MergeFunction. + return builder->Build(); +} + +class Evaluator { + public: + Evaluator() { + builder_ = CreateCelExpressionBuilder( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), options_); + } + + absl::Status SetupEvaluatorEnvironment() { + CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder_->GetRegistry())); + // Codelab part 2: + // Register the map.contains(string, value) function. + // Hint: use `CelFunctionAdapter::CreateAndRegister` to adapt from a free + // function ContainsExtensionFunction. + return absl::OkStatus(); + } + + absl::StatusOr Evaluate(const CheckedExpr& expr, + const AttributeContext& context) { + Activation activation; + CEL_RETURN_IF_ERROR(BindProtoToActivation(&context, &arena_, &activation)); + CEL_ASSIGN_OR_RETURN(auto plan, builder_->CreateExpression(&expr)); + CEL_ASSIGN_OR_RETURN(CelValue result, plan->Evaluate(activation, &arena_)); + + if (bool value; result.GetValue(&value)) { + return value; + } else if (const CelError * value; result.GetValue(&value)) { + return *value; + } else { + return absl::InvalidArgumentError( + absl::StrCat("unexpected return type: ", result.DebugString())); + } + } + + private: + google::protobuf::Arena arena_; + std::unique_ptr builder_; + InterpreterOptions options_; +}; + +} // namespace + +absl::StatusOr EvaluateWithExtensionFunction( + absl::string_view expr, const AttributeContext& context) { + // Prepare a checked expression. + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + MakeConfiguredCompiler()); + CEL_ASSIGN_OR_RETURN(auto checked_expr, + CompileToCheckedExpr(*compiler, expr)); + + // Prepare an evaluation environment. + Evaluator evaluator; + CEL_RETURN_IF_ERROR(evaluator.SetupEvaluatorEnvironment()); + + // Evaluate a checked expression against a particular activation + return evaluator.Evaluate(checked_expr, context); +} + +} // namespace cel_codelab diff --git a/codelab/exercise4.h b/codelab/exercise4.h new file mode 100644 index 000000000..d015cebfb --- /dev/null +++ b/codelab/exercise4.h @@ -0,0 +1,34 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE4_H_ +#define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE4_H_ + +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel_codelab { + +// Compile and evaluate an expression with google.rpc.context.AttributeContext +// as context. +// The environment includes the custom map member function +// .contains(string, string). +absl::StatusOr EvaluateWithExtensionFunction( + absl::string_view cel_expr, + const google::rpc::context::AttributeContext& context); + +} // namespace cel_codelab + +#endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE4_H_ diff --git a/codelab/exercise4_test.cc b/codelab/exercise4_test.cc new file mode 100644 index 000000000..f2f2044fa --- /dev/null +++ b/codelab/exercise4_test.cc @@ -0,0 +1,80 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise4.h" + +#include "google/protobuf/struct.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "internal/testing.h" +#include "google/protobuf/text_format.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::google::rpc::context::AttributeContext; + +TEST(EvaluateWithExtensionFunction, Baseline) { + AttributeContext context; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"(request { + path: "/" + auth { + claims { + fields { + key: "group" + value {string_value: "admin"} + } + } + } + })", + &context)); + EXPECT_THAT(EvaluateWithExtensionFunction("request.path == '/'", context), + IsOkAndHolds(true)); +} + +TEST(EvaluateWithExtensionFunction, ContainsTrue) { + AttributeContext context; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"(request { + path: "/" + auth { + claims { + fields { + key: "group" + value {string_value: "admin"} + } + } + } + })", + &context)); + EXPECT_THAT(EvaluateWithExtensionFunction( + "request.auth.claims.contains('group', 'admin')", context), + IsOkAndHolds(true)); +} + +TEST(EvaluateWithExtensionFunction, ContainsFalse) { + AttributeContext context; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"(request { + path: "/" + })", + &context)); + EXPECT_THAT(EvaluateWithExtensionFunction( + "request.auth.claims.contains('group', 'admin')", context), + IsOkAndHolds(false)); +} + +} // namespace +} // namespace cel_codelab diff --git a/codelab/network_functions.cc b/codelab/network_functions.cc new file mode 100644 index 000000000..6cc1505a9 --- /dev/null +++ b/codelab/network_functions.cc @@ -0,0 +1,541 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/network_functions.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/typeinfo.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel_codelab { +namespace { + +// TODO(uncreated-issue/86): This is how internal extensions create types, but it isn't +// a good pattern for client extensions (since they can't pool into one eternal +// arena). +google::protobuf::Arena* absl_nonnull BuiltinsArena() { + static absl::NoDestructor arena; + return arena.get(); +} + +cel::Type AddressType() { + static cel::Type kInstance( + cel::OpaqueType(BuiltinsArena(), "net.Address", {})); + return kInstance; +} + +cel::Type TypeOfAddressType() { + static cel::Type kInstance(cel::TypeType(BuiltinsArena(), AddressType())); + return kInstance; +} + +cel::Type AddressMatcherType() { + static cel::Type kInstance( + cel::OpaqueType(BuiltinsArena(), "net.AddressMatcher", {})); + return kInstance; +} + +cel::Type TypeOfAddressMatcherType() { + static cel::Type kInstance( + cel::TypeType(BuiltinsArena(), AddressMatcherType())); + return kInstance; +} + +absl::StatusOr ParseAddressImpl(absl::string_view str, + uint32_t* ipv4_out, + absl::Span ipv6_out) { + if (str.size() < 2 || str.size() > 39) { + return absl::InvalidArgumentError("unsupported address format (length)"); + } + if (absl::StrContains(str, ":")) { + if (ipv6_out.size() < 16) { + return absl::InternalError("invalid outbuffer in parse call"); + } + return absl::InvalidArgumentError("unsupported address format (ipv6)"); + } + uint32_t ipv4 = 0; + int octet = 0; + for (auto part : absl::StrSplit(str, '.')) { + if (octet >= 4) { + return absl::InvalidArgumentError( + "unsupported address format (invalid ipv4)"); + } + int octet_val; + if (!absl::SimpleAtoi(part, &octet_val) || octet_val > 255 || + octet_val < 0) { + return absl::InvalidArgumentError( + "unsupported address format (invalid ipv4)"); + } + ipv4 <<= 8; + ipv4 |= (uint32_t)octet_val; + + octet++; + } + if (octet != 4) { + return absl::InvalidArgumentError( + "unsupported address format (invalid ipv4)"); + } + *ipv4_out = ipv4; + return IpVersion::kIPv4; +} + +absl::Status ConfigureNetworkFunctions(cel::TypeCheckerBuilder& builder) { + // Type identifiers + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl("net.Address", TypeOfAddressType()))); + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl("net.AddressMatcher", TypeOfAddressMatcherType()))); + CEL_RETURN_IF_ERROR(builder.AddVariable( + MakeVariableDecl("net.addressZeroValue", AddressType()))); + + // net.parseAddress(string) -> net.Address + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl("net.parseAddress", + MakeOverloadDecl("net_parseAddress_string", + AddressType(), cel::StringType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); + // net.parseAddressOrZero(string) -> net.Address + CEL_ASSIGN_OR_RETURN( + decl, + MakeFunctionDecl("net.parseAddressOrZero", + MakeOverloadDecl("net_parseAddressOrZero_string", + AddressType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); + + // net.parseAddressMatcher(string) -> net.AddressMatcher + CEL_ASSIGN_OR_RETURN( + decl, MakeFunctionDecl( + "net.parseAddressMatcher", + MakeOverloadDecl("net_parseAddressMatcher_string", + AddressMatcherType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); + + // (net.AddressMatcher).containsAddress(net.Address) -> bool + CEL_ASSIGN_OR_RETURN( + decl, MakeFunctionDecl( + "containsAddress", + MakeMemberOverloadDecl( + "net_AddressMatcher_containsAddress_net_Address", + cel::BoolType(), AddressMatcherType(), AddressType()), + MakeMemberOverloadDecl( + "net_AddressMatcher_containsAddress_string", + cel::BoolType(), AddressMatcherType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); + + return absl::OkStatus(); +} + +// ============================================================================= +// Opaque Value type implementations for NetworkAddressRep. +// ============================================================================= + +cel::NativeTypeId NetworkAddressRepGetTypeId( + const cel::OpaqueValueDispatcher* dispatcher, + cel::OpaqueValueContent content) { + return cel::TypeId(); +} + +google::protobuf::Arena* absl_nullable NetworkAddressRepGetArena( + const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, + cel::OpaqueValueContent content) { + return nullptr; +} + +absl::string_view NetworkAddressRepGetTypeName( + const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, + cel::OpaqueValueContent content) { + return "net.Address"; +} + +std::string NetworkAddressRepDebugString( + const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, + cel::OpaqueValueContent content) { + return absl::StrCat("net.parseAddress('", + content.To().Format(), "')"); +} + +cel::OpaqueType NetworkAddressRepGetRuntimeType( + const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, + cel::OpaqueValueContent content) { + return AddressType().GetOpaque(); +} + +absl::Status NetworkAddressRepEqual( + const cel::OpaqueValueDispatcher* absl_nonnull, + cel::OpaqueValueContent content, const cel::OpaqueValue& other, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull, + cel::Value* absl_nonnull result) { + if (other.GetTypeId() != cel::TypeId()) { + *result = cel::BoolValue(false); + return absl::OkStatus(); + } + const NetworkAddressRep rep = content.To(); + std::optional other_rep = NetworkAddressRep::Unwrap(other); + ABSL_DCHECK(other_rep.has_value()); + *result = cel::BoolValue(rep.IsEqualTo(*other_rep)); + return absl::OkStatus(); +} + +cel::OpaqueValue NetworkAddressRepClone( + const cel::OpaqueValueDispatcher* absl_nonnull, + cel::OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + const NetworkAddressRep* rep = content.To(); + ABSL_DCHECK(rep != nullptr); + return NetworkAddressRep::MakeValue(*rep).GetOpaque(); +} + +// Opaque Value types can be implemented either with a shared dispatcher or +// with a subclass (using vtable dispatch). +// +// We use the shared dispatcher here since the address type has a compact +// representation and we don't need to support different implementations at +// runtime. +// +// If the data structure is more complex, benefits from runtime polymorphism, or +// doesn't have easily defined move, swap, and copy operations, it's +// recommended to use a subclass instead. +static const cel::OpaqueValueDispatcher kAddressDispatcher{ + /*.GetTypeId=*/NetworkAddressRepGetTypeId, + /*.GetArena=*/NetworkAddressRepGetArena, + /*.GetTypeName=*/NetworkAddressRepGetTypeName, + /*.DebugString=*/NetworkAddressRepDebugString, + /*.GetRuntimeType=*/NetworkAddressRepGetRuntimeType, + /*.Equal=*/NetworkAddressRepEqual, + /*.Clone=*/NetworkAddressRepClone}; + +// ============================================================================= +// Opaque Value type implementations for NetworkAddressMatcher. +// ============================================================================= + +// Implementation of the OpaqueValueInterface for NetworkAddressMatcher. +// +// This is simpler to implement, but adds an extra allocation and pointer +// indirection for every matcher. This is recommended if the data structure is +// more complex. +class NetworkAddressMatcherImpl : public cel::OpaqueValueInterface { + public: + explicit NetworkAddressMatcherImpl(NetworkAddressMatcher rep) + : rep_(std::move(rep)) {} + + const NetworkAddressMatcher& rep() const { return rep_; } + + // implement the OpaqueValueInterface + std::string DebugString() const final { + return absl::StrCat("net.ParseAddressMatcher('", "TODO(uncreated-issue/86)", "')"); + } + + absl::string_view GetTypeName() const final { return "net.AddressMatcher"; } + + cel::OpaqueType GetRuntimeType() const final { + return AddressMatcherType().GetOpaque(); + } + + absl::Status Equal(const cel::OpaqueValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + cel::Value* absl_nonnull result) const final { + if (other.GetTypeId() != cel::TypeId()) { + *result = cel::BoolValue(false); + return absl::OkStatus(); + } + const NetworkAddressMatcherImpl* other_rep = + static_cast(other.interface()); + *result = cel::BoolValue(rep_.IsEqualTo(other_rep->rep_)); + return absl::OkStatus(); + } + + cel::OpaqueValue Clone(google::protobuf::Arena* absl_nonnull arena) const final { + return NetworkAddressMatcher::MakeValue(arena, rep_).GetOpaque(); + } + + cel::NativeTypeId GetNativeTypeId() const final { + return cel::TypeId(); + } + + private: + NetworkAddressMatcher rep_; +}; + +// ============================================================================= +// Extension function implementations. +// ============================================================================= +cel::Value parseAddress( + const cel::StringValue& str, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string buf; + absl::string_view addr = str.ToStringView(&buf); + std::optional rep = NetworkAddressRep::Parse(addr); + if (!rep.has_value()) { + return cel::ErrorValue(absl::InvalidArgumentError("invalid address")); + } + return NetworkAddressRep::MakeValue(*rep); +} + +cel::Value parseAddressOrZero(const cel::StringValue& str) { + std::string buf; + absl::string_view addr = str.ToStringView(&buf); + std::optional rep = NetworkAddressRep::Parse(addr); + static const NetworkAddressRep kZero; + if (!rep.has_value()) { + return NetworkAddressRep::MakeValue(kZero); + } + return NetworkAddressRep::MakeValue(*rep); +} + +cel::Value parseAddressMatcher( + const cel::StringValue& str, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string buf; + absl::string_view addr = str.ToStringView(&buf); + std::optional rep = NetworkAddressMatcher::Parse(addr); + if (!rep.has_value()) { + return cel::ErrorValue( + absl::InvalidArgumentError("invalid address matcher")); + } + + return NetworkAddressMatcher::MakeValue(arena, std::move(rep).value()); +} + +cel::Value containsAddress(const cel::OpaqueValue& matcher, + const cel::OpaqueValue& addr) { + const auto* matcher_rep = NetworkAddressMatcher::Unwrap(matcher); + auto addr_rep = NetworkAddressRep::Unwrap(addr); + if (matcher_rep == nullptr || !addr_rep.has_value()) { + // dispatcher should catch this, but right now only distiguishes at the + // kind level. + return cel::ErrorValue(absl::InvalidArgumentError("no matching overload")); + } + return cel::BoolValue(matcher_rep->Match(*addr_rep)); +} + +} // namespace + +cel::Value NetworkAddressRep::MakeValue(const NetworkAddressRep& rep) { + return UnsafeOpaqueValue(&kAddressDispatcher, + cel::OpaqueValueContent::From(rep)); +} + +std::optional NetworkAddressRep::Unwrap( + const cel::Value& value) { + auto opaque = value.AsOpaque(); + if (!opaque.has_value() || + opaque->GetTypeId() != cel::TypeId()) { + return std::nullopt; + } + + // Note: safety depends on: + // 1) correctly implementing GetTypeId + // 2) the TypeId is unique + // 3) all calls to UnsafeOpaqueValue with the dispatcher provide the expected + // content type. + return opaque->content().To(); +} + +std::optional NetworkAddressRep::Parse( + absl::string_view str) { + uint32_t ipv4 = 0; + char ipv6[16]; + auto version = ParseAddressImpl(str, &ipv4, ipv6); + if (!version.ok()) { + return std::nullopt; + } + if (*version != IpVersion::kIPv4) { + return std::nullopt; + } + NetworkAddressRep rep; + rep.version_ = *version; + rep.addr_.v4 = ipv4; + return rep; +} + +bool NetworkAddressRep::IsEqualTo(const NetworkAddressRep& other) const { + if (version_ != other.version_) { + return false; + } + if (version_ == IpVersion::kIPv4) { + return addr_.v4 == other.addr_.v4; + } + return false; +} + +bool NetworkAddressRep::IsLessThan(const NetworkAddressRep& other) const { + if (version_ != other.version_) { + return version_ < other.version_; + } + if (version_ == IpVersion::kIPv4) { + return addr_.v4 < other.addr_.v4; + } + return false; +} + +std::optional NetworkAddressMatcher::Parse( + absl::string_view str) { + // range style addr-addr + int dash_pos = str.find('-'); + if (dash_pos == absl::string_view::npos) { + // TODO(uncreated-issue/86): CIDR style addr/prefix-length + return std::nullopt; + } + absl::string_view min_str = str.substr(0, dash_pos); + absl::string_view max_str = str.substr(dash_pos + 1); + + NetworkRangev4 v4; + NetworkRangev6 v6; + auto min_parse = ParseAddressImpl(min_str, &v4.min_incl, v6.min_incl); + if (!min_parse.ok()) { + return std::nullopt; + } + auto max_parse = ParseAddressImpl(max_str, &v4.max_incl, v6.max_incl); + if (!max_parse.ok()) { + return std::nullopt; + } + if (*min_parse != *max_parse) { + return std::nullopt; + } + NetworkAddressMatcher rep; + if (*min_parse == IpVersion::kIPv4) { + if (v4.min_incl > v4.max_incl) { + return std::nullopt; + } + rep.ranges_v4_.push_back(v4); + } else if (*min_parse == IpVersion::kIPv6) { + return std::nullopt; + } + + return rep; +} + +cel::Value NetworkAddressMatcher::MakeValue(google::protobuf::Arena* arena, + NetworkAddressMatcher rep) { + auto* iface = + google::protobuf::Arena::Create(arena, std::move(rep)); + + return cel::OpaqueValue(iface, arena); +} + +const NetworkAddressMatcher* NetworkAddressMatcher::Unwrap( + const cel::Value& value) { + auto opaque = value.AsOpaque(); + if (!opaque.has_value() || opaque->interface() == nullptr || + opaque->GetTypeId() != cel::TypeId()) { + return nullptr; + } + // Note: the safety of down casting like this depends on guaranteeing the + // GetTypeId implementation is correct and is a unique ID. The CEL runtime + // does not inspect or modify the interface type outside calling the interface + // member functions. + return &(static_cast(opaque->interface()) + ->rep()); +} + +bool NetworkAddressMatcher::Match(const NetworkAddressRep& addr) const { + if (addr.IsZeroValue()) { + return false; + } + if (addr.IsIPv4()) { + for (const auto& range : ranges_v4_) { + if (addr.GetIPv4() >= range.min_incl && + addr.GetIPv4() <= range.max_incl) { + return true; + } + } + } + + // TODO(uncreated-issue/86): ipv6 support + return false; +} + +bool NetworkAddressMatcher::IsEqualTo( + const NetworkAddressMatcher& other) const { + if (ranges_v4_.size() != other.ranges_v4_.size()) { + return false; + } + for (int i = 0; i < ranges_v4_.size(); ++i) { + if (ranges_v4_[i].min_incl != other.ranges_v4_[i].min_incl || + ranges_v4_[i].max_incl != other.ranges_v4_[i].max_incl) { + return false; + } + } + return true; +} + +cel::CompilerLibrary NetworkFunctionsCompilerLibrary() { + return cel::CompilerLibrary("cel_codelab.net", ConfigureNetworkFunctions); +} + +absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, + const cel::RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(registry.RegisterType(AddressType().GetOpaque())); + CEL_RETURN_IF_ERROR(registry.RegisterType(AddressMatcherType().GetOpaque())); + return absl::OkStatus(); +} + +absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, + const cel::RuntimeOptions& options) { + // TODO(uncreated-issue/86): remaining functions + auto s = cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("net.parseAddress", &parseAddress, registry); + s.Update(cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("net.parseAddressOrZero", + &parseAddressOrZero, registry)); + + s.Update(cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("net.parseAddressMatcher", + &parseAddressMatcher, registry)); + s.Update(cel::BinaryFunctionAdapter< + cel::Value, const cel::OpaqueValue&, + const cel::OpaqueValue&>::RegisterMemberOverload("containsAddress", + &containsAddress, + registry)); + return s; +} + +} // namespace cel_codelab diff --git a/codelab/network_functions.h b/codelab/network_functions.h new file mode 100644 index 000000000..5a90ac153 --- /dev/null +++ b/codelab/network_functions.h @@ -0,0 +1,197 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Example extension library for introducing an OpaqueValue type. +// +// The address handling is simplified for the example, and IPv6 is +// unimplemented. Do not use this as-is. + +#ifndef THIRD_PARTY_CEL_CPP_CODELAB_NETWORK_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_CODELAB_NETWORK_FUNCTIONS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" + +namespace cel_codelab { + +enum class IpVersion : uint8_t { + kUnset = 0, + kIPv4 = 4, + kIPv6 = 6, // unimplemented, but present for illustration. +}; + +// Represents a network address. To simplify the CEL type representation, this +// only supports IPv4. +// +// A the default value of 0v0 is special, and represents an invalid address, +// comparing unequal to anything except itself. For the purposes of ordering, +// compares less than any valid address. +// +// The example extension functions include a version that returns a zero value +// on error and a version that returns a CEL error. +// +// This class is stored inline in the OpaqueValue because it is compact and +// trivially copyable. +class NetworkAddressRep { + public: + // Creates a Value that wraps the given NetworkAddress. The representation is + // copied to the provided arena. + static cel::Value MakeValue(const NetworkAddressRep& rep); + + // Unwraps a Value into a NetworkAddressRep. Returns nullptr if the value is + // not a NetworkAddress. + static absl::optional Unwrap(const cel::Value& value); + + // Parses a string representation of a network address. Returns nullopt if + // the string is not a valid network address. + // + // TODO(uncreated-issue/86): error handling simplified for example, real usage should + // provide some diagnostic for the parse failure. + static absl::optional Parse(absl::string_view str); + + // Zero value for an invalid address. + NetworkAddressRep() : addr_({0}), version_(IpVersion::kUnset) {} + NetworkAddressRep(const NetworkAddressRep& other) = default; + NetworkAddressRep(NetworkAddressRep&& other) = default; + NetworkAddressRep& operator=(const NetworkAddressRep& other) = default; + NetworkAddressRep& operator=(NetworkAddressRep&& other) = default; + + IpVersion version() const { return version_; } + + bool IsZeroValue() const { return version_ == IpVersion::kUnset; } + bool IsIPv4() const { return version_ == IpVersion::kIPv4; } + bool IsIPv6() const { return false; } + + absl::optional TryGetIPv4() const { + if (version_ == IpVersion::kIPv4) { + return addr_.v4; + } + return absl::nullopt; + } + + absl::string_view TryGetIPv6() const { return absl::string_view(); } + + std::string Format() const { + if (version_ == IpVersion::kUnset) { + return "null"; + } + if (version_ == IpVersion::kIPv4) { + return absl::StrCat( + (addr_.v4 & 0xFF000000) >> 24, ".", (addr_.v4 & 0x00FF0000) >> 16, + ".", (addr_.v4 & 0x0000FF00) >> 8, ".", (addr_.v4 & 0x000000FF)); + } + return "v6 not yet implemented"; + } + + uint32_t GetIPv4() const { return addr_.v4; } + + bool IsEqualTo(const NetworkAddressRep& other) const; + bool IsLessThan(const NetworkAddressRep& other) const; + + private: + union { + uint32_t v0; // zero value + // Integer representation of an IPv4 address (system byte order) + uint32_t v4; + // TO_DO : add ipv6. this prevents storing the value inline due to size, so + // skipped here. + } addr_; + IpVersion version_; +}; + +// Represents a matcher for network addresses. +// +// Simple implementation that just stores a list of matching ranges. +// +// This is too big to store inline and has non-trivial copy and move behavior, +// so the inline representation is a pointer to an arena-allocated object. +class NetworkAddressMatcher { + public: + // Creates a Value that wraps the given NetworkAddress. + static cel::Value MakeValue(google::protobuf::Arena* arena, NetworkAddressMatcher rep); + + // Unwraps a Value into a NetworkAddressMatcher. Returns nullptr if the value + // is not a NetworkAddressMatcher. + static const NetworkAddressMatcher* Unwrap(const cel::Value& value); + + // Parses a string representation of a network address matcher. Returns + // nullopt if the string is not a valid network address matcher. + // + // TODO(uncreated-issue/86): supports a simple IPv4 range for illustration: e.g. + // 8.8.0.0-8.8.255.255 + static absl::optional Parse(absl::string_view str); + + // Default value for an empty matcher. Matches nothing. + NetworkAddressMatcher() = default; + NetworkAddressMatcher(const NetworkAddressMatcher& other) = default; + NetworkAddressMatcher(NetworkAddressMatcher&& other) = default; + NetworkAddressMatcher& operator=(const NetworkAddressMatcher& other) = + default; + NetworkAddressMatcher& operator=(NetworkAddressMatcher&& other) = default; + + bool IsEmpty() const { return ranges_v4_.empty(); } + + bool IsEqualTo(const NetworkAddressMatcher& other) const; + + bool Match(const NetworkAddressRep& addr) const; + + private: + struct NetworkRangev4 { + uint32_t min_incl; + uint32_t max_incl; + }; + + // placeholder for illustration, not implemented. + struct NetworkRangev6 { + char min_incl[16]; + char max_incl[16]; + }; + + friend void swap(NetworkAddressMatcher& lhs, NetworkAddressMatcher& rhs) { + using std::swap; + swap(lhs.ranges_v4_, rhs.ranges_v4_); + } + + // Sorted, non-overlapping ranges of matching IP addresses. + std::vector ranges_v4_; +}; + +// Returns a compiler library that adds the network functions to the type +// checker. +cel::CompilerLibrary NetworkFunctionsCompilerLibrary(); + +// Registers the network functions in a runtime for evaluation. +absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, + const cel::RuntimeOptions& options); + +// Registers the network types in a runtime for evaluation. This is needed +// for resolving the type name to a runtime type `net.Address != type('foo')`. +absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, + const cel::RuntimeOptions& options); + +} // namespace cel_codelab + +#endif // THIRD_PARTY_CEL_CPP_CODELAB_NETWORK_FUNCTIONS_H_ diff --git a/codelab/network_functions_test.cc b/codelab/network_functions_test.cc new file mode 100644 index 000000000..468221da7 --- /dev/null +++ b/codelab/network_functions_test.cc @@ -0,0 +1,347 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/network_functions.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/decl.h" +#include "common/minimal_descriptor_pool.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/benchmark.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/constant_folding.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOk; +using ::cel::Activation; +using ::cel::Compiler; +using ::cel::Program; +using ::cel::Runtime; +using ::cel::RuntimeOptions; +using ::cel::StringValue; +using ::testing::HasSubstr; + +struct TestCase { + std::string name; + std::string expr; + std::string type_check_err_substr; +}; + +class NetworkFunctionsCheckerTest : public testing::TestWithParam {}; + +TEST_P(NetworkFunctionsCheckerTest, DeclarationsTest) { + const TestCase& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(NetworkFunctionsCompilerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expr)); + + if (!test_case.type_check_err_substr.empty()) { + EXPECT_THAT(result.FormatError(), + HasSubstr(test_case.type_check_err_substr)); + return; + } + + EXPECT_TRUE(result.IsValid()) << result.FormatError(); +} + +INSTANTIATE_TEST_SUITE_P( + NetworkFunctionsCheckerTests, NetworkFunctionsCheckerTest, + testing::ValuesIn({ + {"type_identifier_addr", "net.Address != type(1)"}, + {"type_identifier_addr_2", "net.Address != list"}, + {"type_identifier_addr_matcher", "net.AddressMatcher != type(1)"}, + {"parse_address", "net.parseAddress('1.2.3.4')"}, + {"parse_address_or_zero", "net.parseAddressOrZero('1.2.3.4')"}, + {"parse_address_no_match", "net.parseAddress(1.0)", + "no matching overload for 'net.parseAddress'"}, + {"address_zero", "net.addressZeroValue"}, + {"equals", "net.parseAddress('1.2.3.4') != net.addressZeroValue"}, + {"address_matcher_parse", + "net.parseAddressMatcher('8.8.8.0-8.8.8.255')"}, + {"address_matcher_parse_invalid", + "net.parseAddressMatcher('8.8.8.0-8.8.4.255')"}, + {"address_matcher_contains", + "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress(net." + "parseAddress('8.8.8.1'))"}, + {"address_matcher_contains_string", + "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress('8.8.8." + "1')"}, + }), + [](const testing::TestParamInfo& + info) { return info.param.name; }); + +struct RuntimeTestCase { + std::string name; + std::string expr; + std::string runtime_err_substr; + bool expected_value = true; +}; + +class NetworkFunctionsRuntimeTest + : public testing::TestWithParam {}; + +TEST_P(NetworkFunctionsRuntimeTest, EvaluationTest) { + const RuntimeTestCase& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(NetworkFunctionsCompilerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + RuntimeOptions runtime_options; + runtime_options.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN(auto runtime_builder, + CreateStandardRuntimeBuilder( + cel::GetMinimalDescriptorPool(), runtime_options)); + ASSERT_THAT( + RegisterNetworkTypes(runtime_builder.type_registry(), runtime_options), + IsOk()); + ASSERT_THAT(RegisterNetworkFunctions(runtime_builder.function_registry(), + runtime_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(auto eval_result, program->Evaluate(&arena, activation)); + + if (!test_case.runtime_err_substr.empty()) { + if (!eval_result.IsError()) { + FAIL() << "Expected error, but got: " << eval_result.DebugString(); + } + EXPECT_THAT(eval_result.GetError().ToStatus().message(), + HasSubstr(test_case.runtime_err_substr)); + return; + } + + if (test_case.expected_value) { + EXPECT_TRUE(eval_result.IsBool() && eval_result.GetBool()) + << eval_result.DebugString(); + } +} + +INSTANTIATE_TEST_SUITE_P( + NetworkFunctionsRuntimeTests, NetworkFunctionsRuntimeTest, + testing::ValuesIn( + {{"type_identifier_addr", "net.Address != type(1)"}, + {"type_identifier_addr_2", "net.Address != list"}, + {"type_identifier_addr_matcher", "net.AddressMatcher != type(1)"}, + {"parse_address", + "net.parseAddress('1.2.3.4') == net.parseAddress('1.2.3.4')"}, + {"parse_address_2", + "net.parseAddress('1.2.3.4') != net.parseAddress('2.3.4.5')"}, + {"parse_address_invalid", + "net.parseAddress('256.2.3.4') != net.parseAddress('1.2.3.4')", + "invalid address"}, + {"parse_address_or_zero", + "net.parseAddressOrZero('256.2.3.4') != " + "net.parseAddressOrZero('1.2.3.4')"}, + {"parse_address_matcher", + "net.parseAddressMatcher('8.8.8.0-8.8.8.255') != " + "net.parseAddressMatcher('8.8.8.0-8.8.8.127')"}, + {"address_matcher_matches", + "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress(net." + "parseAddress('8.8.8.1'))"}}), + [](const testing::TestParamInfo& + info) { return info.param.name; }); + +class BenchmarkState { + public: + static absl::StatusOr Create(bool optimize) { + CEL_ASSIGN_OR_RETURN( + auto compiler_builder, + cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool())); + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrary(NetworkFunctionsCompilerLibrary())); + compiler_builder->GetCheckerBuilder() + .AddVariable(MakeVariableDecl("ip", cel::StringType())) + .IgnoreError(); + + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + RuntimeOptions runtime_options; + CEL_ASSIGN_OR_RETURN(auto runtime_builder, + CreateStandardRuntimeBuilder( + cel::GetMinimalDescriptorPool(), runtime_options)); + CEL_RETURN_IF_ERROR( + RegisterNetworkTypes(runtime_builder.type_registry(), runtime_options)); + CEL_RETURN_IF_ERROR(RegisterNetworkFunctions( + runtime_builder.function_registry(), runtime_options)); + + if (optimize) { + CEL_RETURN_IF_ERROR( + cel::extensions::EnableConstantFolding(runtime_builder)); + } + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + return BenchmarkState(std::move(compiler), std::move(runtime)); + } + + absl::StatusOr> MakeProgram(absl::string_view expr) { + CEL_ASSIGN_OR_RETURN(auto result, compiler_->Compile(expr)); + if (!result.IsValid()) { + return absl::InvalidArgumentError(result.FormatError()); + } + CEL_ASSIGN_OR_RETURN(auto ast, result.ReleaseAst()); + return runtime_->CreateProgram(std::move(ast)); + } + + private: + BenchmarkState(std::unique_ptr c, std::unique_ptr r) + : compiler_(std::move(c)), runtime_(std::move(r)) {} + + std::unique_ptr compiler_; + std::unique_ptr runtime_; + std::unique_ptr constants_; +}; + +void BM_ParseAddress(benchmark::State& state) { + bool optimize = state.range(0); + auto runner = BenchmarkState::Create(optimize); + + ABSL_CHECK_OK(runner.status()); + + auto program = runner->MakeProgram("net.parseAddress('1.2.3.4')"); + ABSL_CHECK_OK(program.status()); + + google::protobuf::Arena arena; + Activation activation; + for (auto s : state) { + auto result = (*program)->Evaluate(&arena, activation); + ABSL_CHECK_OK(result.status()); + } +} + +void BM_ParseAddressVar(benchmark::State& state) { + bool optimize = state.range(0); + auto runner = BenchmarkState::Create(optimize); + + ABSL_CHECK_OK(runner.status()); + + auto program = runner->MakeProgram("net.parseAddress(ip)"); + ABSL_CHECK_OK(program.status()); + + google::protobuf::Arena arena; + Activation activation; + activation.InsertOrAssignValue("ip", StringValue::From("8.8.8.8", &arena)); + for (auto s : state) { + auto result = (*program)->Evaluate(&arena, activation); + ABSL_CHECK_OK(result.status()); + } +} + +void BM_ParseAddressMatcher(benchmark::State& state) { + bool optimize = state.range(0); + auto runner = BenchmarkState::Create(optimize); + + ABSL_CHECK_OK(runner.status()); + + auto program = + runner->MakeProgram("net.parseAddressMatcher('8.8.8.0-8.8.8.255')"); + ABSL_CHECK_OK(program.status()); + + google::protobuf::Arena arena; + Activation activation; + for (auto s : state) { + auto result = (*program)->Evaluate(&arena, activation); + ABSL_CHECK_OK(result.status()); + } +} + +void BM_ParseAddressMatcherMatches(benchmark::State& state) { + bool optimize = state.range(0); + auto runner = BenchmarkState::Create(optimize); + + ABSL_CHECK_OK(runner.status()); + + auto program = runner->MakeProgram( + "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress(net." + "parseAddress('8.8.8.1'))"); + ABSL_CHECK_OK(program.status()); + + google::protobuf::Arena arena; + Activation activation; + for (auto s : state) { + auto result = (*program)->Evaluate(&arena, activation); + ABSL_CHECK_OK(result.status()); + } +} + +void BM_ParseAddressMatcherMatchesVar(benchmark::State& state) { + bool optimize = state.range(0); + auto runner = BenchmarkState::Create(optimize); + + ABSL_CHECK_OK(runner.status()); + + auto program = runner->MakeProgram( + "net.parseAddressMatcher('8.8.0.0-8.8.255.255').containsAddress(net." + "parseAddress(ip))"); + ABSL_CHECK_OK(program.status()); + + google::protobuf::Arena arena; + Activation activation; + activation.InsertOrAssignValue("ip", StringValue::From("8.8.4.4", &arena)); + for (auto s : state) { + auto result = (*program)->Evaluate(&arena, activation); + ABSL_CHECK_OK(result.status()); + } +} + +BENCHMARK(BM_ParseAddress)->Arg(0)->Arg(1); +BENCHMARK(BM_ParseAddressVar)->Arg(0)->Arg(1); +BENCHMARK(BM_ParseAddressMatcher)->Arg(0)->Arg(1); +BENCHMARK(BM_ParseAddressMatcherMatches)->Arg(0)->Arg(1); +BENCHMARK(BM_ParseAddressMatcherMatchesVar)->Arg(0)->Arg(1); + +} // namespace +} // namespace cel_codelab diff --git a/codelab/solutions/BUILD b/codelab/solutions/BUILD new file mode 100644 index 000000000..a1597e182 --- /dev/null +++ b/codelab/solutions/BUILD @@ -0,0 +1,187 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "exercise1", + srcs = ["exercise1.cc"], + hdrs = ["//codelab:exercise1.h"], + deps = [ + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:status_macros", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise1_test", + srcs = ["//codelab:exercise1_test.cc"], + deps = [ + ":exercise1", + "//internal:testing", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "exercise2", + srcs = ["exercise2.cc"], + hdrs = ["//codelab:exercise2.h"], + deps = [ + "//checker:type_checker_builder", + "//codelab:cel_compiler", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//eval/public:activation", + "//eval/public:activation_bind_helper", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise2_test", + srcs = ["//codelab:exercise2_test.cc"], + deps = [ + ":exercise2", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise3_test", + srcs = ["exercise3_test.cc"], + deps = [ + ":exercise2", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + ], +) + +cc_library( + name = "exercise4", + srcs = ["exercise4.cc"], + hdrs = ["//codelab:exercise4.h"], + deps = [ + "//codelab:cel_compiler", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//eval/public:activation", + "//eval/public:activation_bind_helper", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function_adapter", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise4_test", + srcs = ["//codelab:exercise4_test.cc"], + deps = [ + ":exercise4", + "//internal:testing", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_library( + name = "exercise10", + srcs = ["exercise10.cc"], + hdrs = ["//codelab:exercise10.h"], + deps = [ + "//checker:validation_result", + "//codelab:network_functions", + "//common:decl", + "//common:minimal_descriptor_pool", + "//common:type", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "exercise10_test", + srcs = ["//codelab:exercise10_test.cc"], + deps = [ + ":exercise10", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/codelab/solutions/exercise1.cc b/codelab/solutions/exercise1.cc new file mode 100644 index 000000000..aef6c0efe --- /dev/null +++ b/codelab/solutions/exercise1.cc @@ -0,0 +1,107 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise1.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace cel_codelab { +namespace { + +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelExpression; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; + +// Convert the CelResult to a C++ string if it is string typed. Otherwise, +// return invalid argument error. This takes a copy to avoid lifecycle concerns +// (the evaluator may represent strings as stringviews backed by the input +// expression). +absl::StatusOr ConvertResult(const CelValue& value) { + if (CelValue::StringHolder inner_value; value.GetValue(&inner_value)) { + return std::string(inner_value.value()); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "expected string result got '", CelValue::TypeName(value.type()), "'")); + } +} +} // namespace + +absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr) { + // === Start Codelab === + // Setup a default environment for building expressions. + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + // Parse the expression. This is fine for codelabs, but this skips the type + // checking phase. It won't check that functions and variables are available + // in the environment, and it won't handle certain ambiguous identifier + // expressions (e.g. container lookup vs namespaced name, packaged function + // vs. receiver call style function). + ParsedExpr parsed_expr; + CEL_ASSIGN_OR_RETURN(parsed_expr, Parse(cel_expr)); + + // The evaluator uses a proto Arena for incidental allocations during + // evaluation. + google::protobuf::Arena arena; + // The activation provides variables and functions that are bound into the + // expression environment. In this example, there's no context expected, so + // we just provide an empty one to the evaluator. + Activation activation; + + // Build the expression plan. This assumes that the source expression AST and + // the expression builder outlive the CelExpression object. + CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + // Actually run the expression plan. We don't support any environment + // variables at the moment so just use an empty activation. + CEL_ASSIGN_OR_RETURN(CelValue result, + expression_plan->Evaluate(activation, &arena)); + + // Convert the result to a c++ string. CelValues may reference instances from + // either the input expression, or objects allocated on the arena, so we need + // to pass ownership (in this case by copying to a new instance and returning + // that). + return ConvertResult(result); + // === End Codelab === +} + +} // namespace cel_codelab diff --git a/codelab/solutions/exercise10.cc b/codelab/solutions/exercise10.cc new file mode 100644 index 000000000..0d2c197d6 --- /dev/null +++ b/codelab/solutions/exercise10.cc @@ -0,0 +1,136 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise10.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "codelab/network_functions.h" +#include "common/decl.h" +#include "common/minimal_descriptor_pool.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel_codelab { + +namespace { + +absl::StatusOr> ConfigureCompiler() { + absl::StatusOr> compiler_builder = + cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool()); + if (!compiler_builder.ok()) { + return std::move(compiler_builder).status(); + } + absl::Status s = + (*compiler_builder)->AddLibrary(cel::StandardCompilerLibrary()); + // =========================================================================== + // Codelab: Update compiler builder with functions from network_functions.h + // and add a varible for the input IP. + // =========================================================================== + s.Update((*compiler_builder)->AddLibrary(NetworkFunctionsCompilerLibrary())); + s.Update((*compiler_builder) + ->GetCheckerBuilder() + .AddVariable(cel::MakeVariableDecl("ip", cel::StringType()))); + if (!s.ok()) return s; + + return (*compiler_builder)->Build(); +} + +absl::StatusOr> ConfigureRuntime() { + cel::RuntimeOptions runtime_options; + // Note: this is needed to resolve net.Address as a `type` constant. + runtime_options.enable_qualified_type_identifiers = true; + absl::StatusOr runtime_builder = + cel::CreateStandardRuntimeBuilder(cel::GetMinimalDescriptorPool(), + runtime_options); + // =========================================================================== + // Codelab: Update runtime builder with functions from network_functions.h + // =========================================================================== + absl::Status s = + RegisterNetworkTypes(runtime_builder->type_registry(), runtime_options); + s.Update(RegisterNetworkFunctions(runtime_builder->function_registry(), + runtime_options)); + if (!s.ok()) return s; + + return std::move(runtime_builder).value().Build(); +} + +} // namespace + +absl::StatusOr CompileAndEvaluateExercise10(absl::string_view expression, + absl::string_view ip) { + absl::StatusOr> compiler = ConfigureCompiler(); + if (!compiler.ok()) { + return std::move(compiler).status(); + } + + absl::StatusOr> runtime = ConfigureRuntime(); + if (!runtime.ok()) { + return std::move(runtime).status(); + } + + absl::StatusOr checked = + (*compiler)->Compile(expression); + if (!checked.ok()) { + return std::move(checked).status(); + } + + if (!checked->IsValid() || checked->GetAst() == nullptr) { + return absl::InvalidArgumentError(checked->FormatError()); + } + + absl::StatusOr> program = + (*runtime)->CreateProgram(checked->ReleaseAst().value()); + + if (!program.ok()) { + return std::move(program).status(); + } + + cel::Activation activation; + google::protobuf::Arena arena; + activation.InsertOrAssignValue("ip", cel::StringValue::From(ip, &arena)); + absl::StatusOr result = (*program)->Evaluate(&arena, activation); + + if (!result.ok()) { + return std::move(result).status(); + } + + if (result->IsBool()) { + return result->GetBool(); + } + + if (result->IsError()) { + return result->GetError().ToStatus(); + } + + return absl::InvalidArgumentError( + absl::StrCat("unexpected result type: ", result->DebugString())); +} + +} // namespace cel_codelab diff --git a/codelab/solutions/exercise2.cc b/codelab/solutions/exercise2.cc new file mode 100644 index 000000000..d07645aed --- /dev/null +++ b/codelab/solutions/exercise2.cc @@ -0,0 +1,148 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise2.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "codelab/cel_compiler.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "eval/public/activation.h" +#include "eval/public/activation_bind_helper.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel_codelab { +namespace { + +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelError; +using ::google::api::expr::runtime::CelExpression; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::ProtoUnsetFieldOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +using ::google::rpc::context::AttributeContext; + +absl::StatusOr> MakeCelCompiler() { + // Note: we are using the generated descriptor pool here for simplicity, but + // it has the drawback of including all message types that are linked into the + // binary instead of just the ones expected for the CEL environment. + google::protobuf::LinkMessageReflection(); + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + // === Start Codelab === + cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); + CEL_RETURN_IF_ERROR(checker_builder.AddVariable( + cel::MakeVariableDecl("bool_var", cel::BoolType()))); + CEL_RETURN_IF_ERROR(checker_builder.AddContextDeclaration( + AttributeContext::descriptor()->full_name())); + // === End Codelab === + + return builder->Build(); +} + +// Parse a cel expression and evaluate it against the given activation and +// arena. +absl::StatusOr EvalCheckedExpr(const CheckedExpr& checked_expr, + const Activation& activation, + google::protobuf::Arena* arena) { + // Setup a default environment for building expressions. + InterpreterOptions options; + std::unique_ptr builder = CreateCelExpressionBuilder( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), options); + CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + // Note, the expression_plan below is reusable for different inputs, but we + // create one just in time for evaluation here. + CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, + builder->CreateExpression(&checked_expr)); + + CEL_ASSIGN_OR_RETURN(CelValue result, + expression_plan->Evaluate(activation, arena)); + + if (bool value; result.GetValue(&value)) { + return value; + } else if (const CelError * value; result.GetValue(&value)) { + return *value; + } else { + return absl::InvalidArgumentError(absl::StrCat( + "expected 'bool' result got '", result.DebugString(), "'")); + } +} +} // namespace + +absl::StatusOr CompileAndEvaluateWithBoolVar(absl::string_view cel_expr, + bool bool_var) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + MakeCelCompiler()); + + CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, + CompileToCheckedExpr(*compiler, cel_expr)); + + Activation activation; + google::protobuf::Arena arena; + // === Start Codelab === + activation.InsertValue("bool_var", CelValue::CreateBool(bool_var)); + // === End Codelab === + + return EvalCheckedExpr(checked_expr, activation, &arena); +} + +absl::StatusOr CompileAndEvaluateWithContext( + absl::string_view cel_expr, const AttributeContext& context) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + MakeCelCompiler()); + + CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, + CompileToCheckedExpr(*compiler, cel_expr)); + + Activation activation; + google::protobuf::Arena arena; + // === Start Codelab === + CEL_RETURN_IF_ERROR(BindProtoToActivation( + &context, &arena, &activation, ProtoUnsetFieldOptions::kBindDefault)); + // === End Codelab === + + return EvalCheckedExpr(checked_expr, activation, &arena); +} + +} // namespace cel_codelab diff --git a/codelab/solutions/exercise3_test.cc b/codelab/solutions/exercise3_test.cc new file mode 100644 index 000000000..8cc919527 --- /dev/null +++ b/codelab/solutions/exercise3_test.cc @@ -0,0 +1,97 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "codelab/exercise2.h" +#include "internal/testing.h" + +namespace cel_codelab { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::google::rpc::context::AttributeContext; + +// Helper for a simple CelExpression with no context. +absl::StatusOr TruthTableTest(absl::string_view statement) { + return CompileAndEvaluateWithBoolVar(statement, /*unused*/ false); +} + +TEST(Exercise3, LogicalOr) { + EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("true || true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("true || false"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("false || false"), IsOkAndHolds(false)); +} + +TEST(Exercise3, LogicalAnd) { + EXPECT_THAT(TruthTableTest("true && (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("false && (1 / 0 > 2)"), IsOkAndHolds(false)); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && true"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && false"), IsOkAndHolds(false)); + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && (1 / 0 > 2)"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("true && true"), IsOkAndHolds(true)); + EXPECT_THAT(TruthTableTest("true && false"), IsOkAndHolds(false)); + EXPECT_THAT(TruthTableTest("false && true"), IsOkAndHolds(false)); + EXPECT_THAT(TruthTableTest("false && false"), IsOkAndHolds(false)); +} + +TEST(Exercise3, Ternary) { + EXPECT_THAT(TruthTableTest("(1 / 0 > 2) ? false : false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("true ? (1 / 0 > 2) : false"), + StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); + EXPECT_THAT(TruthTableTest("false ? (1 / 0 > 2) : false"), + IsOkAndHolds(false)); +} + +TEST(Exercise3Context, BadFieldAccess) { + AttributeContext context; + + // This type of error is normally caught by the type checker, to allow + // it to pass we use the dyn() operator to defer checking to runtime. + // typo-ed field name from 'request.host' + EXPECT_THAT( + CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' && true", context), + StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); + EXPECT_THAT(CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' && false", context), + IsOkAndHolds(false)); + + EXPECT_THAT(CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' || true", context), + IsOkAndHolds(true)); + EXPECT_THAT( + CompileAndEvaluateWithContext( + "dyn(request).hostname == 'localhost' || false", context), + StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); +} + +} // namespace +} // namespace cel_codelab diff --git a/codelab/solutions/exercise4.cc b/codelab/solutions/exercise4.cc new file mode 100644 index 000000000..244fdac05 --- /dev/null +++ b/codelab/solutions/exercise4.cc @@ -0,0 +1,175 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codelab/exercise4.h" + +#include + +#include "cel/expr/checked.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "codelab/cel_compiler.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "eval/public/activation.h" +#include "eval/public/activation_bind_helper.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_adapter.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel_codelab { +namespace { + +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::BindProtoToActivation; +using ::google::api::expr::runtime::CelError; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelMap; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::FunctionAdapter; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +using ::google::rpc::context::AttributeContext; + +// Handle the parametric type overload with a single generic CelValue overload. +absl::StatusOr ContainsExtensionFunction(google::protobuf::Arena* arena, + const CelMap* map, + CelValue::StringHolder key, + const CelValue& value) { + absl::optional entry = (*map)[CelValue::CreateString(key)]; + if (!entry.has_value()) { + return false; + } + if (value.IsInt64() && entry->IsInt64()) { + return value.Int64OrDie() == entry->Int64OrDie(); + } else if (value.IsString() && entry->IsString()) { + return value.StringOrDie().value() == entry->StringOrDie().value(); + } + return false; +} + +absl::StatusOr> MakeConfiguredCompiler() { + // Setup for handling for protobuf types. + // Using the generated descriptor pool is simpler to configure, but often + // adds more types than necessary. + google::protobuf::LinkMessageReflection(); + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + // Adds fields of AttributeContext as variables. + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddContextDeclaration( + AttributeContext::descriptor()->full_name())); + + // Codelab part 1: + // Add a declaration for the map.contains(string, V) function. + auto& checker_builder = builder->GetCheckerBuilder(); + // Note: we use MakeMemberOverloadDecl instead of MakeOverloadDecl + // because the function is receiver style, meaning that it is called as + // e1.f(e2) instead of f(e1, e2). + CEL_ASSIGN_OR_RETURN( + cel::FunctionDecl decl, + cel::MakeFunctionDecl( + "contains", + cel::MakeMemberOverloadDecl( + "map_contains_string_string", cel::BoolType(), + cel::MapType(checker_builder.arena(), cel::StringType(), + cel::TypeParamType("V")), + cel::StringType(), cel::TypeParamType("V")))); + // Note: we use MergeFunction instead of AddFunction because we are adding + // an overload to an already declared function with the same name. + CEL_RETURN_IF_ERROR(checker_builder.MergeFunction(decl)); + return builder->Build(); +} + +class Evaluator { + public: + Evaluator() { + builder_ = CreateCelExpressionBuilder( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), options_); + } + + absl::Status SetupEvaluatorEnvironment() { + CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder_->GetRegistry())); + // Codelab part 2: + // Register the map.contains(string, string) function. + // Hint: use `FunctionAdapter::CreateAndRegister` to adapt from a free + // function ContainsExtensionFunction. + using AdapterT = FunctionAdapter, const CelMap*, + CelValue::StringHolder, CelValue>; + CEL_RETURN_IF_ERROR(AdapterT::CreateAndRegister( + "contains", /*receiver_style=*/true, &ContainsExtensionFunction, + builder_->GetRegistry())); + return absl::OkStatus(); + } + + absl::StatusOr Evaluate(const CheckedExpr& expr, + const AttributeContext& context) { + Activation activation; + CEL_RETURN_IF_ERROR(BindProtoToActivation(&context, &arena_, &activation)); + CEL_ASSIGN_OR_RETURN(auto plan, builder_->CreateExpression(&expr)); + CEL_ASSIGN_OR_RETURN(CelValue result, plan->Evaluate(activation, &arena_)); + + if (bool value; result.GetValue(&value)) { + return value; + } else if (const CelError* value; result.GetValue(&value)) { + return *value; + } else { + return absl::InvalidArgumentError( + absl::StrCat("unexpected return type: ", result.DebugString())); + } + } + + private: + google::protobuf::Arena arena_; + std::unique_ptr builder_; + InterpreterOptions options_; +}; + +} // namespace + +absl::StatusOr EvaluateWithExtensionFunction( + absl::string_view expr, const AttributeContext& context) { + // Prepare a checked expression. + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + MakeConfiguredCompiler()); + CEL_ASSIGN_OR_RETURN(auto checked_expr, + CompileToCheckedExpr(*compiler, expr)); + + // Prepare an evaluation environment. + Evaluator evaluator; + CEL_RETURN_IF_ERROR(evaluator.SetupEvaluatorEnvironment()); + + // Evaluate a checked expression against a particular activation + return evaluator.Evaluate(checked_expr, context); +} + +} // namespace cel_codelab diff --git a/common/BUILD b/common/BUILD index 1a5342a5d..0426c0827 100644 --- a/common/BUILD +++ b/common/BUILD @@ -1,233 +1,1248 @@ -# Description -# Common cel libraries +# Copyright 2021 Google LLC # -# Uses the namespace google::api::expr::common +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) cc_library( - name = "macros", - hdrs = [ - "macros.h", + name = "ast", + srcs = ["ast.cc"], + hdrs = ["ast.h"], + deps = [ + ":expr", + ":source", + "//common/ast:metadata", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "ast_test", + srcs = ["ast_test.cc"], + deps = [ + ":ast", + ":expr", + ":source", + "//internal:testing", + "@com_google_absl//absl/container:flat_hash_map", ], ) cc_library( - name = "id", - hdrs = [ - "id.h", + name = "type_spec_resolver", + srcs = ["type_spec_resolver.cc"], + hdrs = ["type_spec_resolver.h"], + deps = [ + ":ast", + ":type", + ":type_kind", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], +) + +cc_test( + name = "type_spec_resolver_test", + srcs = ["type_spec_resolver_test.cc"], deps = [ - "//internal:cel_printer", - "//internal:handle", - "//internal:hash_util", + ":ast", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "type", - srcs = ["type.cc"], - hdrs = [ - "type.h", - ], + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], deps = [ - "//internal:handle", - "@com_google_absl//absl/memory", + ":ast", + ":type", + ":type_spec_resolver", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "signature_test", + srcs = ["signature_test.cc"], + deps = [ + ":ast", + ":signature", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "enum", - srcs = ["enum.cc"], - hdrs = [ - "enum.h", + name = "expr", + srcs = ["expr.cc"], + hdrs = ["expr.h"], + deps = [ + ":constant", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], +) + +cc_test( + name = "expr_test", + srcs = ["expr_test.cc"], deps = [ + ":expr", + "//internal:testing", + ], +) + +cc_library( + name = "navigable_ast", + srcs = ["navigable_ast.cc"], + hdrs = ["navigable_ast.h"], + deps = [ + ":ast_traverse", + ":ast_visitor", + ":ast_visitor_base", + ":expr", + "//common/ast:navigable_ast_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "navigable_ast_test", + srcs = ["navigable_ast_test.cc"], + deps = [ + ":ast", + ":expr", + ":navigable_ast", + ":source", + ":standard_definitions", + "//internal:status_macros", + "//internal:testing", + "//parser", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "decl", + srcs = ["decl.cc"], + hdrs = ["decl.h"], + deps = [ + ":constant", + ":signature", + ":type", + ":type_kind", + "//internal:status_macros", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "decl_test", + srcs = ["decl_test.cc"], + deps = [ + ":constant", + ":decl", ":type", - "//internal:cel_printer", - "//internal:ref_countable", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "error", - srcs = ["error.cc"], - hdrs = [ - "error.h", + name = "reference", + srcs = ["reference.cc"], + hdrs = ["reference.h"], + deps = [ + ":constant", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:variant", ], +) + +cc_test( + name = "reference_test", + srcs = ["reference_test.cc"], deps = [ - "//internal:cel_printer", - "//internal:hash_util", - "//internal:proto_util", - "@com_google_absl//absl/container:node_hash_set", + ":constant", + ":reference", + "//internal:testing", + ], +) + +cc_library( + name = "ast_rewrite", + srcs = ["ast_rewrite.cc"], + hdrs = ["ast_rewrite.h"], + deps = [ + ":ast_visitor", + ":constant", + ":expr", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/rpc:code_cc_proto", - "@com_google_googleapis//google/rpc:status_cc_proto", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "ast_rewrite_test", + srcs = ["ast_rewrite_test.cc"], + deps = [ + ":ast", + ":ast_rewrite", + ":ast_visitor", + ":expr", + "//common/ast:expr_proto", + "//extensions/protobuf:ast_converters", + "//internal:testing", + "//parser", + "@com_google_absl//absl/status:status_matchers", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "unknown", - srcs = ["unknown.cc"], - hdrs = [ - "unknown.h", + name = "ast_traverse", + srcs = ["ast_traverse.cc"], + hdrs = ["ast_traverse.h"], + deps = [ + ":ast_visitor", + ":constant", + ":expr", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "ast_traverse_test", + srcs = ["ast_traverse_test.cc"], + deps = [ + ":ast_traverse", + ":ast_visitor", + ":constant", + ":expr", + "//internal:testing", + ], +) + +cc_library( + name = "ast_visitor", + hdrs = ["ast_visitor.h"], + deps = [ + ":constant", + ":expr", + ], +) + +cc_library( + name = "ast_visitor_base", + hdrs = ["ast_visitor_base.h"], + deps = [ + ":ast_visitor", + ":constant", + ":expr", + ], +) + +cc_library( + name = "constant", + srcs = ["constant.cc"], + hdrs = ["constant.h"], + deps = [ + "//internal:strings", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "constant_test", + srcs = ["constant_test.cc"], + deps = [ + ":constant", + "//internal:testing", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", ], +) + +cc_library( + name = "expr_factory", + hdrs = ["expr_factory.h"], deps = [ - ":id", - ":macros", + ":constant", + ":expr", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_library( - name = "parent_ref", + name = "operators", srcs = [ - "parent_ref.cc", + "operators.cc", ], hdrs = [ - "parent_ref.h", + "operators.h", ], deps = [ - "//internal:ref_countable", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( - name = "custom_object", - srcs = [ - "custom_object.cc", + name = "any", + srcs = ["any.cc"], + hdrs = ["any.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_protobuf//:any_cc_proto", ], - hdrs = [ - "custom_object.h", +) + +cc_test( + name = "any_test", + srcs = ["any_test.cc"], + deps = [ + ":any", + "//internal:testing", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:any_cc_proto", ], +) + +cc_library( + name = "casting", + hdrs = ["casting.h"], + deps = [ + "//common/internal:casting", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "json", + hdrs = ["json.h"], +) + +cc_library( + name = "kind", + srcs = ["kind.cc"], + hdrs = ["kind.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "kind_test", + srcs = ["kind_test.cc"], + deps = [ + ":kind", + ":type_kind", + ":value_kind", + "//internal:testing", + ], +) + +cc_library( + name = "memory", + srcs = ["memory.cc"], + hdrs = ["memory.h"], + deps = [ + ":allocator", + ":arena", + ":data", + ":reference_count", + "//common/internal:metadata", + "//common/internal:reference_count", + "//internal:exceptions", + "//internal:to_address", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/numeric:bits", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "memory_test", + srcs = ["memory_test.cc"], + deps = [ + ":allocator", + ":data", + ":memory", + "//common/internal:reference_count", + "//internal:testing", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_library( + name = "memory_testing", + testonly = True, + hdrs = ["memory_testing.h"], + deps = [ + ":memory", + "//internal:testing", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_testing", + testonly = True, + hdrs = ["type_testing.h"], +) + +cc_library( + name = "value_testing", + testonly = True, + srcs = ["value_testing.cc"], + hdrs = ["value_testing.h"], deps = [ ":value", + ":value_kind", + "//internal:equals_text_proto", + "//internal:parse_text_proto", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//internal:testing_no_main", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_test( + name = "value_testing_test", + srcs = ["value_testing_test.cc"], + deps = [ + ":value", + ":value_testing", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", ], ) cc_library( - name = "operators", - srcs = [ - "operators.cc", + name = "type_kind", + hdrs = ["type_kind.h"], + deps = [ + ":kind", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", ], - hdrs = [ - "operators.h", +) + +cc_library( + name = "value_kind", + hdrs = ["value_kind.h"], + deps = [ + ":kind", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", ], +) + +cc_library( + name = "source", + srcs = ["source.cc"], + hdrs = ["source.h"], deps = [ + "//internal:unicode", + "//internal:utf8", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "source_test", + srcs = ["source_test.cc"], + deps = [ + ":source", + "//internal:testing", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) cc_library( - name = "escaping", - srcs = [ - "escaping.cc", + name = "native_type", + hdrs = ["native_type.h"], + deps = [ + ":typeinfo", ], - hdrs = [ - "escaping.h", +) + +cc_library( + name = "type", + srcs = glob( + [ + "types/*.cc", + ], + exclude = [ + "types/*_test.cc", + ], + ) + [ + "type.cc", + "type_introspector.cc", + ], + hdrs = glob( + [ + "types/*.h", + ], + exclude = [ + "types/*_test.h", + ], + ) + [ + "type.h", + "type_introspector.h", ], deps = [ + ":type_kind", + "//internal:string_pool", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/utility", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "format_type_name", + srcs = ["format_type_name.cc"], + hdrs = ["format_type_name.h"], + deps = [ + ":type", + ":type_kind", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "type_test", + srcs = glob([ + "types/*_test.cc", + ]) + [ + "type_test.cc", + ], + deps = [ + ":memory", + ":type", + ":type_kind", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_protobuf//:protobuf", ], ) cc_test( - name = "escaping_test", - srcs = ["escaping_test.cc"], + name = "format_type_name_test", + srcs = ["format_type_name_test.cc"], deps = [ - ":escaping", - "@com_google_googletest//:gtest_main", + ":format_type_name", + ":type", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", ], ) cc_library( name = "value", - srcs = [ + srcs = glob( + [ + "values/*.cc", + ], + exclude = [ + "values/*_test.cc", + ], + ) + [ + "legacy_value.cc", "value.cc", ], - hdrs = [ + hdrs = glob( + [ + "values/*.h", + ], + exclude = [ + "values/*_test.h", + ], + ) + [ + "legacy_value.h", + "type_reflector.h", "value.h", ], deps = [ - ":enum", - ":error", - ":id", - ":macros", - ":parent_ref", + ":allocator", + ":any", + ":arena", + ":casting", + ":kind", + ":memory", + ":native_type", + ":optional_ref", ":type", + ":typeinfo", ":unknown", - "//internal:cel_printer", - "//internal:hash_util", - "//internal:holder", - "//internal:ref_countable", - "//internal:status_util", - "//internal:value_internal", - "//internal:visitor_util", - "@com_google_absl//absl/memory", + ":value_kind", + "//base:attributes", + "//common/internal:byte_string", + "//common/internal:reference_count", + "//eval/internal:cel_value_equal", + "//eval/public:cel_value", + "//eval/public:message_wrapper", + "//eval/public/containers:field_backed_list_impl", + "//eval/public/containers:field_backed_map_impl", + "//eval/public/structs:cel_proto_wrap_util", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_info_apis", + "//eval/public/structs:proto_message_type_adapter", + "//extensions/protobuf/internal:map_reflection", + "//extensions/protobuf/internal:qualify", + "//internal:casts", + "//internal:empty_descriptors", + "//internal:json", + "//internal:manual", + "//internal:message_equality", + "//internal:number", + "//internal:protobuf_runtime_version", + "//internal:status_macros", + "//internal:strings", + "//internal:time", + "//internal:utf8", + "//internal:well_known_types", + "//runtime:runtime_options", + "//runtime/internal:errors", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_absl//absl/utility", - "@com_google_googleapis//google/rpc:status_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + "@com_google_protobuf//src/google/protobuf/io", ], ) cc_test( name = "value_test", - srcs = ["value_test.cc"], + srcs = glob([ + "values/*_test.cc", + ]) + [ + "type_reflector_test.cc", + "value_test.cc", + ], deps = [ - ":custom_object", + ":casting", + ":memory", + ":native_type", + ":type", ":value", - "//internal:status_util", - "//internal:types", - "//internal:value_internal", - "//testutil:util", + ":value_kind", + ":value_testing", + "//base:attributes", + "//internal:parse_text_proto", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:runtime_options", + "//runtime/internal:errors", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:cord_test_helpers", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:type_cc_proto", + "@com_google_protobuf//src/google/protobuf/io", + ], +) + +cc_library( + name = "unknown", + hdrs = ["unknown.h"], + deps = ["//base/internal:unknown_set"], +) + +alias( + name = "legacy_value", + actual = ":value", +) + +cc_library( + name = "arena", + hdrs = ["arena.h"], + deps = [ + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "reference_count", + hdrs = ["reference_count.h"], + deps = ["//common/internal:reference_count"], +) + +cc_library( + name = "allocator", + hdrs = ["allocator.h"], + deps = [ + ":arena", + ":data", + "//internal:new", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/numeric:bits", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "allocator_test", + srcs = ["allocator_test.cc"], + deps = [ + ":allocator", + "//internal:testing", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "data", + hdrs = ["data.h"], + deps = [ + "//common/internal:metadata", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "data_test", + srcs = ["data_test.cc"], + deps = [ + ":data", + "//common/internal:reference_count", + "//internal:testing", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "optional_ref", + hdrs = ["optional_ref.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/utility", + ], +) + +cc_library( + name = "arena_string", + hdrs = [ + "arena_string.h", + "arena_string_view.h", + ], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "arena_string_test", + srcs = [ + "arena_string_test.cc", + "arena_string_view_test.cc", + ], + tags = ["no_test_msvc"], + deps = [ + ":arena_string", + "//internal:testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "arena_string_pool", + hdrs = ["arena_string_pool.h"], + deps = [ + ":arena_string", + "//internal:string_pool", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "arena_string_pool_test", + srcs = ["arena_string_pool_test.cc"], + tags = ["no_test_msvc"], + deps = [ + ":arena_string_pool", + "//internal:testing", + "@com_google_absl//absl/strings:cord_test_helpers", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "minimal_descriptor_pool", + srcs = ["minimal_descriptor_pool.cc"], + hdrs = ["minimal_descriptor_pool.h"], + deps = [ + "//internal:minimal_descriptors", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "minimal_descriptor_pool_test", + srcs = ["minimal_descriptor_pool_test.cc"], + deps = [ + ":minimal_descriptor_pool", + "//internal:testing", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "minimal_descriptor_database", + srcs = ["minimal_descriptor_database.cc"], + hdrs = ["minimal_descriptor_database.h"], + deps = [ + "//internal:minimal_descriptors", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "minimal_descriptor_database_test", + srcs = ["minimal_descriptor_database_test.cc"], + deps = [ + ":minimal_descriptor_database", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "function_descriptor", + srcs = [ + "function_descriptor.cc", + ], + hdrs = [ + "function_descriptor.h", + ], + deps = [ + ":kind", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/rpc:status_cc_proto", - "@com_google_googleapis//google/type:money_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "decl_proto", + srcs = ["decl_proto.cc"], + hdrs = ["decl_proto.h"], + deps = [ + ":decl", + ":type", + ":type_proto", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "decl_proto_test", + srcs = ["decl_proto_test.cc"], + deps = [ + ":decl", + ":decl_proto", + ":decl_proto_v1alpha1", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:variant", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "converters", + name = "decl_proto_v1alpha1", + srcs = ["decl_proto_v1alpha1.cc"], + hdrs = ["decl_proto_v1alpha1.h"], + deps = [ + ":decl", + ":decl_proto", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "type_proto", + srcs = ["type_proto.cc"], + hdrs = ["type_proto.h"], + deps = [ + ":type", + ":type_kind", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_test( + name = "type_proto_test", + srcs = ["type_proto_test.cc"], + deps = [ + ":type", + ":type_kind", + ":type_proto", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "ast_proto", + srcs = ["ast_proto.cc"], + hdrs = ["ast_proto.h"], + deps = [ + ":ast", + ":constant", + ":expr", + "//base:ast", + "//common/ast:constant_proto", + "//common/ast:expr_proto", + "//common/ast:source_info_proto", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + ], +) + +cc_test( + name = "ast_proto_test", srcs = [ - "converters.cc", + "ast_proto_test.cc", + ], + deps = [ + ":ast", + ":ast_proto", + ":decl", + ":expr", + ":source", + ":type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:comprehensions_v2", + "//internal:proto_matchers", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", ], +) + +cc_library( + name = "standard_definitions", hdrs = [ - "converters.h", + "standard_definitions.h", ], deps = [ - ":parent_ref", - ":value", - "//internal:list_impl", - "//internal:map_impl", - "//internal:types", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "typeinfo", + srcs = ["typeinfo.cc"], + hdrs = ["typeinfo.h"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/strings", ], ) cc_test( - name = "converters_test", - srcs = ["converters_test.cc"], + name = "typeinfo_test", + srcs = ["typeinfo_test.cc"], deps = [ - ":converters", - ":value", - "@com_google_absl//absl/memory", - "@com_google_googletest//:gtest_main", + ":typeinfo", + "//internal:testing", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "container", + srcs = ["container.cc"], + hdrs = ["container.h"], + deps = [ + "//internal:lexis", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "container_test", + srcs = ["container_test.cc"], + deps = [ + ":container", + "//internal:testing", + "@com_google_absl//absl/status", ], ) diff --git a/common/allocator.h b/common/allocator.h new file mode 100644 index 000000000..81d56b096 --- /dev/null +++ b/common/allocator.h @@ -0,0 +1,606 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ALLOCATOR_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/log/die_if_null.h" +#include "absl/numeric/bits.h" +#include "common/arena.h" +#include "common/data.h" +#include "internal/new.h" +#include "google/protobuf/arena.h" + +namespace cel { + +enum class AllocatorKind { + kArena = 1, + kNewDelete = 2, +}; + +template +void AbslStringify(S& sink, AllocatorKind kind) { + switch (kind) { + case AllocatorKind::kArena: + sink.Append("ARENA"); + return; + case AllocatorKind::kNewDelete: + sink.Append("NEW_DELETE"); + return; + default: + sink.Append("ERROR"); + return; + } +} + +template +class NewDeleteAllocator; +template +class ArenaAllocator; +template +class Allocator; + +// `NewDeleteAllocator<>` is a type-erased vocabulary type capable of performing +// allocation/deallocation and construction/destruction using memory owned by +// `operator new`. +template <> +class NewDeleteAllocator { + public: + using size_type = size_t; + using difference_type = ptrdiff_t; + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + using is_always_equal = std::true_type; + + NewDeleteAllocator() = default; + NewDeleteAllocator(const NewDeleteAllocator&) = default; + NewDeleteAllocator& operator=(const NewDeleteAllocator&) = default; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr NewDeleteAllocator( + [[maybe_unused]] const NewDeleteAllocator& other) noexcept {} + + // Allocates at least `nbytes` bytes with a minimum alignment of `alignment` + // from the underlying memory resource. When the underlying memory resource is + // `operator new`, `deallocate_bytes` must be called at some point, otherwise + // calling `deallocate_bytes` is optional. The caller must not pass an object + // constructed in the return memory to `delete_object`, doing so is undefined + // behavior. + ABSL_MUST_USE_RESULT void* allocate_bytes( + size_type nbytes, size_type alignment = alignof(std::max_align_t)) { + ABSL_DCHECK(absl::has_single_bit(alignment)); + if (nbytes == 0) { + return nullptr; + } + return internal::AlignedNew(nbytes, + static_cast(alignment)); + } + + // Deallocates memory previously returned by `allocate_bytes`. + void deallocate_bytes( + void* p, size_type nbytes, + size_type alignment = alignof(std::max_align_t)) noexcept { + ABSL_DCHECK((p == nullptr && nbytes == 0) || (p != nullptr && nbytes != 0)); + ABSL_DCHECK(absl::has_single_bit(alignment)); + internal::SizedAlignedDelete(p, nbytes, + static_cast(alignment)); + } + + template + ABSL_MUST_USE_RESULT T* allocate_object(size_type n = 1) { + return static_cast(allocate_bytes(sizeof(T) * n, alignof(T))); + } + + template + void deallocate_object(T* p, size_type n = 1) { + deallocate_bytes(p, sizeof(T) * n, alignof(T)); + } + + // Allocates memory suitable for an object of type `T` and constructs the + // object by forwarding the provided arguments. If the underlying memory + // resource is `operator new` is false, `delete_object` must eventually be + // called. + template + ABSL_MUST_USE_RESULT T* new_object(Args&&... args) { + return new T(std::forward(args)...); + } + + // Destructs the object of type `T` located at address `p` and deallocates the + // memory, `p` must have been previously returned by `new_object`. + template + void delete_object(T* p) noexcept { + ABSL_DCHECK(p != nullptr); + delete p; + } + + void delete_object(std::nullptr_t) = delete; + + private: + template + friend class NewDeleteAllocator; +}; + +// `NewDeleteAllocator` is an extension of `NewDeleteAllocator<>` which +// adheres to the named C++ requirements for `Allocator`, allowing it to be used +// in places which accept custom STL allocators. +template +class NewDeleteAllocator : public NewDeleteAllocator { + public: + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(std::is_object_v, "T must be an object type"); + + using value_type = T; + using pointer = value_type*; + using const_pointer = const value_type*; + using reference = value_type&; + using const_reference = const value_type&; + + using NewDeleteAllocator::NewDeleteAllocator; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr NewDeleteAllocator( + [[maybe_unused]] const NewDeleteAllocator& other) noexcept {} + + pointer allocate(size_type n, const void* /*hint*/ = nullptr) { + return reinterpret_cast(internal::AlignedNew( + n * sizeof(T), static_cast(alignof(T)))); + } + +#if defined(__cpp_lib_allocate_at_least) && \ + __cpp_lib_allocate_at_least >= 202302L + std::allocation_result allocate_at_least(size_type n) { + void* addr; + size_type size; + std::tie(addr, size) = internal::SizeReturningAlignedNew( + n * sizeof(T), static_cast(alignof(T))); + std::allocation_result result; + result.ptr = reinterpret_cast(addr); + result.count = size / sizeof(T); + return result; + } +#endif + + void deallocate(pointer p, size_type n) noexcept { + internal::SizedAlignedDelete(p, n * sizeof(T), + static_cast(alignof(T))); + } + + template + void construct(U* p, Args&&... args) { + ::new (static_cast(p)) U(std::forward(args)...); + } + + template + void destroy(U* p) noexcept { + std::destroy_at(p); + } +}; + +template +inline bool operator==(NewDeleteAllocator, NewDeleteAllocator) noexcept { + return true; +} + +template +inline bool operator!=(NewDeleteAllocator lhs, + NewDeleteAllocator rhs) noexcept { + return !operator==(lhs, rhs); +} + +NewDeleteAllocator() -> NewDeleteAllocator; +template +NewDeleteAllocator(const NewDeleteAllocator&) -> NewDeleteAllocator; + +// `ArenaAllocator<>` is a type-erased vocabulary type capable of performing +// allocation/deallocation and construction/destruction using memory owned by +// `google::protobuf::Arena`. +template <> +class ArenaAllocator { + public: + using size_type = size_t; + using difference_type = ptrdiff_t; + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + + ArenaAllocator() = delete; + + ArenaAllocator(const ArenaAllocator&) = default; + ArenaAllocator& operator=(const ArenaAllocator&) = delete; + + ArenaAllocator(std::nullptr_t) = delete; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr ArenaAllocator(const ArenaAllocator& other) noexcept + : arena_(other.arena()) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + ArenaAllocator(google::protobuf::Arena* absl_nonnull arena) noexcept + : arena_(ABSL_DIE_IF_NULL(arena)) // Crash OK + {} + + constexpr google::protobuf::Arena* absl_nonnull arena() const noexcept { + ABSL_ASSUME(arena_ != nullptr); + return arena_; + } + + // Allocates at least `nbytes` bytes with a minimum alignment of `alignment` + // from the underlying memory resource. When the underlying memory resource is + // `operator new`, `deallocate_bytes` must be called at some point, otherwise + // calling `deallocate_bytes` is optional. The caller must not pass an object + // constructed in the return memory to `delete_object`, doing so is undefined + // behavior. + ABSL_MUST_USE_RESULT void* allocate_bytes( + size_type nbytes, size_type alignment = alignof(std::max_align_t)) { + ABSL_DCHECK(absl::has_single_bit(alignment)); + if (nbytes == 0) { + return nullptr; + } + return arena()->AllocateAligned(nbytes, alignment); + } + + // Deallocates memory previously returned by `allocate_bytes`. + void deallocate_bytes( + void* p, size_type nbytes, + size_type alignment = alignof(std::max_align_t)) noexcept { + ABSL_DCHECK((p == nullptr && nbytes == 0) || (p != nullptr && nbytes != 0)); + ABSL_DCHECK(absl::has_single_bit(alignment)); + } + + template + ABSL_MUST_USE_RESULT T* allocate_object(size_type n = 1) { + return static_cast(allocate_bytes(sizeof(T) * n, alignof(T))); + } + + template + void deallocate_object(T* p, size_type n = 1) { + deallocate_bytes(p, sizeof(T) * n, alignof(T)); + } + + // Allocates memory suitable for an object of type `T` and constructs the + // object by forwarding the provided arguments. If the underlying memory + // resource is `operator new` is false, `delete_object` must eventually be + // called. + template + ABSL_MUST_USE_RESULT T* new_object(Args&&... args) { + using U = std::remove_const_t; + U* object; + if constexpr (google::protobuf::Arena::is_arena_constructable::value) { + // Classes derived from `cel::Data` are manually allocated and constructed + // as those class support determining whether the destructor is skippable + // at runtime. + object = google::protobuf::Arena::Create(arena(), std::forward(args)...); + } else { + if constexpr (ArenaTraits<>::constructible()) { + object = ::new (static_cast(arena()->AllocateAligned( + sizeof(U), alignof(U)))) U(arena(), std::forward(args)...); + } else { + object = ::new (static_cast(arena()->AllocateAligned( + sizeof(U), alignof(U)))) U(std::forward(args)...); + } + if constexpr (!ArenaTraits<>::always_trivially_destructible()) { + if (!ArenaTraits<>::trivially_destructible(*object)) { + arena()->OwnDestructor(object); + } + } + } + if constexpr (google::protobuf::Arena::is_arena_constructable::value || + std::is_base_of_v) { + ABSL_DCHECK_EQ(object->GetArena(), arena()); + } + return object; + } + + // Destructs the object of type `T` located at address `p` and deallocates the + // memory, `p` must have been previously returned by `new_object`. + template + void delete_object(T* p) noexcept { + using U = std::remove_const_t; + ABSL_DCHECK(p != nullptr); + if constexpr (google::protobuf::Arena::is_arena_constructable::value || + std::is_base_of_v) { + ABSL_DCHECK_EQ(p->GetArena(), arena()); + } + } + + void delete_object(std::nullptr_t) = delete; + + private: + template + friend class ArenaAllocator; + + google::protobuf::Arena* absl_nonnull arena_; +}; + +// `ArenaAllocator` is an extension of `ArenaAllocator<>` which adheres to +// the named C++ requirements for `Allocator`, allowing it to be used in places +// which accept custom STL allocators. +template +class ArenaAllocator : public ArenaAllocator { + private: + using Base = ArenaAllocator; + + public: + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(std::is_object_v, "T must be an object type"); + + using value_type = T; + using pointer = value_type*; + using const_pointer = const value_type*; + using reference = value_type&; + using const_reference = const value_type&; + + using ArenaAllocator::ArenaAllocator; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr ArenaAllocator(const ArenaAllocator& other) noexcept + : Base(other) {} + + pointer allocate(size_type n, const void* /*hint*/ = nullptr) { + return static_cast( + arena()->AllocateAligned(n * sizeof(T), alignof(T))); + } + +#if defined(__cpp_lib_allocate_at_least) && \ + __cpp_lib_allocate_at_least >= 202302L + std::allocation_result allocate_at_least(size_type n) { + std::allocation_result result; + result.ptr = allocate(n); + result.count = n; + return result; + } +#endif + + void deallocate(pointer, size_type) noexcept {} + + template + void construct(U* p, Args&&... args) { + static_assert(!google::protobuf::Arena::is_arena_constructable::value); + ::new (static_cast(p)) U(std::forward(args)...); + } + + template + void destroy(U* p) noexcept { + static_assert(!google::protobuf::Arena::is_arena_constructable::value); + std::destroy_at(p); + } +}; + +template +inline bool operator==(ArenaAllocator lhs, ArenaAllocator rhs) noexcept { + return lhs.arena() == rhs.arena(); +} + +template +inline bool operator!=(ArenaAllocator lhs, ArenaAllocator rhs) noexcept { + return !operator==(lhs, rhs); +} + +ArenaAllocator(google::protobuf::Arena* absl_nonnull) -> ArenaAllocator; +template +ArenaAllocator(const ArenaAllocator&) -> ArenaAllocator; + +// `Allocator<>` is a type-erased vocabulary type capable of performing +// allocation/deallocation and construction/destruction using memory owned by +// `google::protobuf::Arena` or `operator new`. +template <> +class Allocator { + public: + using size_type = size_t; + using difference_type = ptrdiff_t; + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + + Allocator() = delete; + + Allocator(const Allocator&) = default; + Allocator& operator=(const Allocator&) = delete; + + Allocator(std::nullptr_t) = delete; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Allocator(const Allocator& other) noexcept + : arena_(other.arena_) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Allocator(google::protobuf::Arena* absl_nullable arena) noexcept + : arena_(arena) {} + + template + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Allocator( + [[maybe_unused]] const NewDeleteAllocator& other) noexcept + : arena_(nullptr) {} + + template + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Allocator(const ArenaAllocator& other) noexcept + : arena_(other.arena()) {} + + constexpr google::protobuf::Arena* absl_nullable arena() const noexcept { + return arena_; + } + + // Allocates at least `nbytes` bytes with a minimum alignment of `alignment` + // from the underlying memory resource. When the underlying memory resource is + // `operator new`, `deallocate_bytes` must be called at some point, otherwise + // calling `deallocate_bytes` is optional. The caller must not pass an object + // constructed in the return memory to `delete_object`, doing so is undefined + // behavior. + ABSL_MUST_USE_RESULT void* allocate_bytes( + size_type nbytes, size_type alignment = alignof(std::max_align_t)) { + return arena() != nullptr + ? ArenaAllocator(arena()).allocate_bytes(nbytes, alignment) + : NewDeleteAllocator().allocate_bytes(nbytes, alignment); + } + + // Deallocates memory previously returned by `allocate_bytes`. + void deallocate_bytes( + void* p, size_type nbytes, + size_type alignment = alignof(std::max_align_t)) noexcept { + arena() != nullptr + ? ArenaAllocator(arena()).deallocate_bytes(p, nbytes, alignment) + : NewDeleteAllocator().deallocate_bytes(p, nbytes, alignment); + } + + template + ABSL_MUST_USE_RESULT T* allocate_object(size_type n = 1) { + return arena() != nullptr + ? ArenaAllocator(arena()).allocate_object(n) + : NewDeleteAllocator().allocate_object(n); + } + + template + void deallocate_object(T* p, size_type n = 1) { + arena() != nullptr ? ArenaAllocator(arena()).deallocate_object(p, n) + : NewDeleteAllocator().deallocate_object(p, n); + } + + // Allocates memory suitable for an object of type `T` and constructs the + // object by forwarding the provided arguments. If the underlying memory + // resource is `operator new` is false, `delete_object` must eventually be + // called. + template + ABSL_MUST_USE_RESULT T* new_object(Args&&... args) { + return arena() != nullptr ? ArenaAllocator(arena()).new_object( + std::forward(args)...) + : NewDeleteAllocator().new_object( + std::forward(args)...); + } + + // Destructs the object of type `T` located at address `p` and deallocates the + // memory, `p` must have been previously returned by `new_object`. + template + void delete_object(T* p) noexcept { + arena() != nullptr ? ArenaAllocator(arena()).delete_object(p) + : NewDeleteAllocator().delete_object(p); + } + + void delete_object(std::nullptr_t) = delete; + + private: + template + friend class Allocator; + + google::protobuf::Arena* absl_nullable arena_; +}; + +// `Allocator` is an extension of `Allocator<>` which adheres to the named +// C++ requirements for `Allocator`, allowing it to be used in places which +// accept custom STL allocators. +template +class Allocator : public Allocator { + public: + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(std::is_object_v, "T must be an object type"); + + using value_type = T; + using pointer = value_type*; + using const_pointer = const value_type*; + using reference = value_type&; + using const_reference = const value_type&; + + using Allocator::Allocator; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Allocator(const Allocator& other) noexcept + : Allocator(other.arena_) {} + + pointer allocate(size_type n, const void* /*hint*/ = nullptr) { + return arena() != nullptr ? ArenaAllocator(arena()).allocate(n) + : NewDeleteAllocator().allocate(n); + } + +#if defined(__cpp_lib_allocate_at_least) && \ + __cpp_lib_allocate_at_least >= 202302L + std::allocation_result allocate_at_least(size_type n) { + return arena() != nullptr ? ArenaAllocator(arena()).allocate_at_least(n) + : NewDeleteAllocator().allocate_at_least(n); + } +#endif + + void deallocate(pointer p, size_type n) noexcept { + arena() != nullptr ? ArenaAllocator(arena()).deallocate(p, n) + : NewDeleteAllocator().deallocate(p, n); + } + + template + void construct(U* p, Args&&... args) { + arena() != nullptr + ? ArenaAllocator(arena()).construct(p, std::forward(args)...) + : NewDeleteAllocator().construct(p, std::forward(args)...); + } + + template + void destroy(U* p) noexcept { + arena() != nullptr ? ArenaAllocator(arena()).destroy(p) + : NewDeleteAllocator().destroy(p); + } +}; + +template +inline bool operator==(Allocator lhs, Allocator rhs) noexcept { + return lhs.arena() == rhs.arena(); +} + +template +inline bool operator!=(Allocator lhs, Allocator rhs) noexcept { + return !operator==(lhs, rhs); +} + +Allocator(google::protobuf::Arena* absl_nullable) -> Allocator; +template +Allocator(const Allocator&) -> Allocator; +template +Allocator(const NewDeleteAllocator&) -> Allocator; +template +Allocator(const ArenaAllocator&) -> Allocator; + +template +inline NewDeleteAllocator NewDeleteAllocatorFor() noexcept { + static_assert(!std::is_void_v); + return NewDeleteAllocator(); +} + +template +inline Allocator ArenaAllocatorFor( + google::protobuf::Arena* absl_nonnull arena) noexcept { + static_assert(!std::is_void_v); + ABSL_DCHECK(arena != nullptr); + return Allocator(arena); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ALLOCATOR_H_ diff --git a/common/allocator_test.cc b/common/allocator_test.cc new file mode 100644 index 000000000..7fa924bd4 --- /dev/null +++ b/common/allocator_test.cc @@ -0,0 +1,196 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This header contains primitives for reference counting, roughly equivalent to +// the primitives used to implement `std::shared_ptr`. These primitives should +// not be used directly in most cases, instead `cel::ManagedMemory` should be +// used instead. + +#include "common/allocator.h" + +#include + +#include "absl/strings/str_cat.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::testing::NotNull; + +TEST(AllocatorKind, AbslStringify) { + EXPECT_EQ(absl::StrCat(AllocatorKind::kArena), "ARENA"); + EXPECT_EQ(absl::StrCat(AllocatorKind::kNewDelete), "NEW_DELETE"); + EXPECT_EQ(absl::StrCat(static_cast(0)), "ERROR"); +} + +TEST(NewDeleteAllocator, Bytes) { + auto allocator = NewDeleteAllocator<>(); + void* p = allocator.allocate_bytes(17, 8); + EXPECT_THAT(p, NotNull()); + allocator.deallocate_bytes(p, 17, 8); +} + +TEST(ArenaAllocator, Bytes) { + google::protobuf::Arena arena; + auto allocator = ArenaAllocator<>(&arena); + void* p = allocator.allocate_bytes(17, 8); + EXPECT_THAT(p, NotNull()); + allocator.deallocate_bytes(p, 17, 8); +} + +struct TrivialObject { + char data[17]; +}; + +TEST(NewDeleteAllocator, NewDeleteObject) { + auto allocator = NewDeleteAllocator<>(); + auto* p = allocator.new_object(); + EXPECT_THAT(p, NotNull()); + allocator.delete_object(p); +} + +TEST(ArenaAllocator, NewDeleteObject) { + google::protobuf::Arena arena; + auto allocator = ArenaAllocator<>(&arena); + auto* p = allocator.new_object(); + EXPECT_THAT(p, NotNull()); + allocator.delete_object(p); +} + +TEST(NewDeleteAllocator, Object) { + auto allocator = NewDeleteAllocator<>(); + auto* p = allocator.allocate_object(); + EXPECT_THAT(p, NotNull()); + allocator.deallocate_object(p); +} + +TEST(ArenaAllocator, Object) { + google::protobuf::Arena arena; + auto allocator = ArenaAllocator<>(&arena); + auto* p = allocator.allocate_object(); + EXPECT_THAT(p, NotNull()); + allocator.deallocate_object(p); +} + +TEST(NewDeleteAllocator, ObjectArray) { + auto allocator = NewDeleteAllocator<>(); + auto* p = allocator.allocate_object(2); + EXPECT_THAT(p, NotNull()); + allocator.deallocate_object(p, 2); +} + +TEST(ArenaAllocator, ObjectArray) { + google::protobuf::Arena arena; + auto allocator = ArenaAllocator<>(&arena); + auto* p = allocator.allocate_object(2); + EXPECT_THAT(p, NotNull()); + allocator.deallocate_object(p, 2); +} + +TEST(NewDeleteAllocator, T) { + auto allocator = NewDeleteAllocatorFor(); + auto* p = allocator.allocate(1); + EXPECT_THAT(p, NotNull()); + allocator.construct(p); + allocator.destroy(p); + allocator.deallocate(p, 1); +} + +TEST(ArenaAllocator, T) { + google::protobuf::Arena arena; + auto allocator = ArenaAllocatorFor(&arena); + auto* p = allocator.allocate(1); + EXPECT_THAT(p, NotNull()); + allocator.construct(p); + allocator.destroy(p); + allocator.deallocate(p, 1); +} + +TEST(NewDeleteAllocator, CopyConstructible) { + EXPECT_TRUE( + (std::is_trivially_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE( + (std::is_trivially_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); +} + +TEST(ArenaAllocator, CopyConstructible) { + EXPECT_TRUE((std::is_trivially_constructible_v, + const ArenaAllocator&>)); + EXPECT_TRUE((std::is_trivially_constructible_v, + const ArenaAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const ArenaAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const ArenaAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const ArenaAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const ArenaAllocator&>)); +} + +TEST(Allocator, CopyConstructible) { + EXPECT_TRUE((std::is_trivially_constructible_v, + const Allocator&>)); + EXPECT_TRUE((std::is_trivially_constructible_v, + const Allocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const Allocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const Allocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const Allocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const Allocator&>)); + + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + EXPECT_TRUE((std::is_constructible_v, + const NewDeleteAllocator&>)); + + EXPECT_TRUE( + (std::is_constructible_v, const ArenaAllocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const ArenaAllocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const ArenaAllocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const ArenaAllocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const ArenaAllocator&>)); + EXPECT_TRUE( + (std::is_constructible_v, const ArenaAllocator&>)); +} + +} // namespace +} // namespace cel diff --git a/common/any.cc b/common/any.cc new file mode 100644 index 000000000..6ddcc5887 --- /dev/null +++ b/common/any.cc @@ -0,0 +1,38 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/any.h" + +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" + +namespace cel { + +bool ParseTypeUrl(absl::string_view type_url, + absl::string_view* absl_nullable prefix, + absl::string_view* absl_nullable type_name) { + auto pos = type_url.find_last_of('/'); + if (pos == absl::string_view::npos || pos + 1 == type_url.size()) { + return false; + } + if (prefix) { + *prefix = type_url.substr(0, pos + 1); + } + if (type_name) { + *type_name = type_url.substr(pos + 1); + } + return true; +} + +} // namespace cel diff --git a/common/any.h b/common/any.h new file mode 100644 index 000000000..cf86aa636 --- /dev/null +++ b/common/any.h @@ -0,0 +1,90 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ANY_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ANY_H_ + +#include + +#include "google/protobuf/any.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" + +namespace cel { + +inline google::protobuf::Any MakeAny(absl::string_view type_url, + const absl::Cord& value) { + google::protobuf::Any any; + any.set_type_url(type_url); + any.set_value(static_cast(value)); + return any; +} + +inline google::protobuf::Any MakeAny(absl::string_view type_url, + absl::string_view value) { + google::protobuf::Any any; + any.set_type_url(type_url); + any.set_value(value); + return any; +} + +inline absl::Cord GetAnyValueAsCord(const google::protobuf::Any& any) { + return absl::Cord(any.value()); +} + +inline std::string GetAnyValueAsString(const google::protobuf::Any& any) { + return std::string(any.value()); +} + +inline void SetAnyValueFromCord(google::protobuf::Any* absl_nonnull any, + const absl::Cord& value) { + any->set_value(static_cast(value)); +} + +inline absl::string_view GetAnyValueAsStringView( + const google::protobuf::Any& any ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return absl::string_view(any.value()); +} + +inline constexpr absl::string_view kTypeGoogleApisComPrefix = + "type.googleapis.com/"; + +inline std::string MakeTypeUrlWithPrefix(absl::string_view prefix, + absl::string_view type_name) { + return absl::StrCat(absl::StripSuffix(prefix, "/"), "/", type_name); +} + +inline std::string MakeTypeUrl(absl::string_view type_name) { + return MakeTypeUrlWithPrefix(kTypeGoogleApisComPrefix, type_name); +} + +bool ParseTypeUrl(absl::string_view type_url, + absl::string_view* absl_nullable prefix, + absl::string_view* absl_nullable type_name); +inline bool ParseTypeUrl(absl::string_view type_url, + absl::string_view* absl_nullable type_name) { + return ParseTypeUrl(type_url, nullptr, type_name); +} +inline bool ParseTypeUrl(absl::string_view type_url) { + return ParseTypeUrl(type_url, nullptr); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ANY_H_ diff --git a/common/any_test.cc b/common/any_test.cc new file mode 100644 index 000000000..ddf914150 --- /dev/null +++ b/common/any_test.cc @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/any.h" + +#include + +#include "google/protobuf/any.pb.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(Any, Value) { + google::protobuf::Any any; + std::string scratch; + SetAnyValueFromCord(&any, absl::Cord("Hello World!")); + EXPECT_EQ(GetAnyValueAsCord(any), "Hello World!"); + EXPECT_EQ(GetAnyValueAsString(any), "Hello World!"); + EXPECT_EQ(GetAnyValueAsStringView(any, scratch), "Hello World!"); +} + +TEST(MakeTypeUrlWithPrefix, Basic) { + EXPECT_EQ(MakeTypeUrlWithPrefix("foo", "bar.Baz"), "foo/bar.Baz"); + EXPECT_EQ(MakeTypeUrlWithPrefix("foo/", "bar.Baz"), "foo/bar.Baz"); +} + +TEST(MakeTypeUrl, Basic) { + EXPECT_EQ(MakeTypeUrl("bar.Baz"), "type.googleapis.com/bar.Baz"); +} + +TEST(ParseTypeUrl, Valid) { + EXPECT_TRUE(ParseTypeUrl("type.googleapis.com/bar.Baz")); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com")); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/")); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/foo/")); +} + +TEST(ParseTypeUrl, TypeName) { + absl::string_view type_name; + EXPECT_TRUE(ParseTypeUrl("type.googleapis.com/bar.Baz", &type_name)); + EXPECT_EQ(type_name, "bar.Baz"); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com", &type_name)); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/", &type_name)); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/foo/", &type_name)); +} + +TEST(ParseTypeUrl, PrefixAndTypeName) { + absl::string_view prefix; + absl::string_view type_name; + EXPECT_TRUE(ParseTypeUrl("type.googleapis.com/bar.Baz", &prefix, &type_name)); + EXPECT_EQ(prefix, "type.googleapis.com/"); + EXPECT_EQ(type_name, "bar.Baz"); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com", &prefix, &type_name)); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/", &prefix, &type_name)); + EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/foo/", &prefix, &type_name)); +} + +} // namespace +} // namespace cel diff --git a/common/arena.h b/common/arena.h new file mode 100644 index 000000000..fa2c6f67b --- /dev/null +++ b/common/arena.h @@ -0,0 +1,110 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ARENA_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "google/protobuf/arena.h" + +namespace cel { + +template +struct ArenaTraits; + +namespace common_internal { + +template +struct AssertArenaType : std::false_type { + static_assert(!std::is_void_v, "T must not be void"); + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_array_v, "T must not be an array"); +}; + +template +struct ArenaTraitsConstructible { + using type = std::false_type; +}; + +template +struct ArenaTraitsConstructible< + T, std::void_t::constructible)>> { + using type = typename ArenaTraits::constructible; +}; + +template +std::enable_if_t::value, + google::protobuf::Arena* absl_nullable> +GetArena(const T* absl_nullable ptr) { + return ptr != nullptr ? ptr->GetArena() : nullptr; +} + +template +std::enable_if_t::value, + google::protobuf::Arena* absl_nullable> +GetArena([[maybe_unused]] const T* absl_nullable ptr) { + return nullptr; +} + +template +struct HasArenaTraitsTriviallyDestructible : std::false_type {}; + +template +struct HasArenaTraitsTriviallyDestructible< + T, std::void_t::trivially_destructible( + std::declval()))>> : std::true_type {}; + +} // namespace common_internal + +template <> +struct ArenaTraits { + template + using constructible = std::disjunction< + typename common_internal::AssertArenaType::type, + typename common_internal::ArenaTraitsConstructible::type>; + + template + using always_trivially_destructible = + std::disjunction::type, + std::is_trivially_destructible>; + + template + static bool trivially_destructible(const U& obj) { + static_assert(!std::is_void_v, "T must not be void"); + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_array_v, "T must not be an array"); + + if constexpr (always_trivially_destructible()) { + return true; + } else if constexpr (google::protobuf::Arena::is_destructor_skippable::value) { + return obj.GetArena() != nullptr; + } else if constexpr (common_internal::HasArenaTraitsTriviallyDestructible< + U>::value) { + return ArenaTraits::trivially_destructible(obj); + } else { + return false; + } + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_H_ diff --git a/common/arena_string.h b/common/arena_string.h new file mode 100644 index 000000000..942600b41 --- /dev/null +++ b/common/arena_string.h @@ -0,0 +1,365 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/arena_string_view.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class ArenaStringPool; + +// Bug in current Abseil LTS. Fixed in +// https://github.com/abseil/abseil-cpp/commit/fd7713cb9a97c49096211ff40de280b6cebbb21c +// which is not yet in an LTS. +#if defined(__clang__) && (!defined(__clang_major__) || __clang_major__ >= 13) +#define CEL_ATTRIBUTE_ARENA_STRING_OWNER ABSL_ATTRIBUTE_OWNER +#else +#define CEL_ATTRIBUTE_ARENA_STRING_OWNER +#endif + +namespace common_internal { + +enum class ArenaStringKind : unsigned int { + kSmall = 0, + kLarge, +}; + +struct ArenaStringSmallRep final { + ArenaStringKind kind : 1; + uint8_t size : 7; + char data[23 - sizeof(google::protobuf::Arena*)]; + google::protobuf::Arena* absl_nullable arena; +}; + +struct ArenaStringLargeRep final { + ArenaStringKind kind : 1; + size_t size : sizeof(size_t) * 8 - 1; + const char* absl_nonnull data; + google::protobuf::Arena* absl_nullable arena; +}; + +inline constexpr size_t kArenaStringSmallCapacity = + sizeof(ArenaStringSmallRep::data); + +union ArenaStringRep final { + struct { + ArenaStringKind kind : 1; + }; + ArenaStringSmallRep small; + ArenaStringLargeRep large; +}; + +} // namespace common_internal + +// `ArenaString` is a read-only string which is either backed by a static string +// literal or owned by the `ArenaStringPool` that created it. It is compatible +// with `absl::string_view` and is implicitly convertible to it. +class CEL_ATTRIBUTE_ARENA_STRING_OWNER ArenaString final { + public: + using traits_type = std::char_traits; + using value_type = char; + using pointer = char*; + using const_pointer = const char*; + using reference = char&; + using const_reference = const char&; + using const_iterator = const_pointer; + using iterator = const_iterator; + using const_reverse_iterator = std::reverse_iterator; + using reverse_iterator = const_reverse_iterator; + using size_type = size_t; + using difference_type = ptrdiff_t; + using absl_internal_is_view = std::false_type; + + ArenaString() : ArenaString(static_cast(nullptr)) {} + + ArenaString(const ArenaString&) = default; + ArenaString& operator=(const ArenaString&) = default; + + explicit ArenaString( + google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ArenaString(absl::string_view(), arena) {} + + ArenaString(std::nullptr_t) = delete; + + ArenaString(absl::string_view string, google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + if (string.size() <= common_internal::kArenaStringSmallCapacity) { + rep_.small.kind = common_internal::ArenaStringKind::kSmall; + rep_.small.size = string.size(); + std::memcpy(rep_.small.data, string.data(), string.size()); + rep_.small.arena = arena; + } else { + rep_.large.kind = common_internal::ArenaStringKind::kLarge; + rep_.large.size = string.size(); + rep_.large.data = string.data(); + rep_.large.arena = arena; + } + } + + ArenaString(absl::string_view, std::nullptr_t) = delete; + + explicit ArenaString(ArenaStringView other) + : ArenaString(absl::implicit_cast(other), + other.arena()) {} + + google::protobuf::Arena* absl_nullable arena() const { + switch (rep_.kind) { + case common_internal::ArenaStringKind::kSmall: + return rep_.small.arena; + case common_internal::ArenaStringKind::kLarge: + return rep_.large.arena; + } + } + + size_type size() const { + switch (rep_.kind) { + case common_internal::ArenaStringKind::kSmall: + return rep_.small.size; + case common_internal::ArenaStringKind::kLarge: + return rep_.large.size; + } + } + + bool empty() const { return size() == 0; } + + size_type max_size() const { return std::numeric_limits::max() >> 1; } + + absl_nonnull const_pointer data() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + switch (rep_.kind) { + case common_internal::ArenaStringKind::kSmall: + return rep_.small.data; + case common_internal::ArenaStringKind::kLarge: + return rep_.large.data; + } + } + + const_reference front() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(!empty()); + + return data()[0]; + } + + const_reference back() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(!empty()); + + return data()[size() - 1]; + } + + const_reference operator[](size_type index) const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK_LT(index, size()); + + return data()[index]; + } + + void remove_prefix(size_type n) { + ABSL_DCHECK_LE(n, size()); + + switch (rep_.kind) { + case common_internal::ArenaStringKind::kSmall: + std::memmove(rep_.small.data, rep_.small.data + n, rep_.small.size - n); + rep_.small.size = rep_.small.size - n; + break; + case common_internal::ArenaStringKind::kLarge: + rep_.large.data += n; + rep_.large.size = rep_.large.size - n; + break; + } + } + + void remove_suffix(size_type n) { + ABSL_DCHECK_LE(n, size()); + + switch (rep_.kind) { + case common_internal::ArenaStringKind::kSmall: + rep_.small.size = rep_.small.size - n; + break; + case common_internal::ArenaStringKind::kLarge: + rep_.large.size = rep_.large.size - n; + break; + } + } + + const_iterator begin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return data(); } + + const_iterator cbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return begin(); + } + + const_iterator end() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return data() + size(); + } + + const_iterator cend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return end(); } + + const_reverse_iterator rbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::make_reverse_iterator(end()); + } + + const_reverse_iterator crbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return rbegin(); + } + + const_reverse_iterator rend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::make_reverse_iterator(begin()); + } + + const_reverse_iterator crend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return rend(); + } + + private: + friend class ArenaStringView; + + common_internal::ArenaStringRep rep_; +}; + +inline ArenaStringView::ArenaStringView( + const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (arena_string.rep_.kind) { + case common_internal::ArenaStringKind::kSmall: + string_ = absl::string_view(arena_string.rep_.small.data, + arena_string.rep_.small.size); + arena_ = arena_string.rep_.small.arena; + break; + case common_internal::ArenaStringKind::kLarge: + string_ = absl::string_view(arena_string.rep_.large.data, + arena_string.rep_.large.size); + arena_ = arena_string.rep_.large.arena; + break; + } +} + +inline ArenaStringView& ArenaStringView::operator=( + const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (arena_string.rep_.kind) { + case common_internal::ArenaStringKind::kSmall: + string_ = absl::string_view(arena_string.rep_.small.data, + arena_string.rep_.small.size); + arena_ = arena_string.rep_.small.arena; + break; + case common_internal::ArenaStringKind::kLarge: + string_ = absl::string_view(arena_string.rep_.large.data, + arena_string.rep_.large.size); + arena_ = arena_string.rep_.large.arena; + break; + } + return *this; +} + +inline bool operator==(const ArenaString& lhs, const ArenaString& rhs) { + return absl::implicit_cast(lhs) == + absl::implicit_cast(rhs); +} + +inline bool operator==(const ArenaString& lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) == rhs; +} + +inline bool operator==(absl::string_view lhs, const ArenaString& rhs) { + return lhs == absl::implicit_cast(rhs); +} + +inline bool operator!=(const ArenaString& lhs, const ArenaString& rhs) { + return absl::implicit_cast(lhs) != + absl::implicit_cast(rhs); +} + +inline bool operator!=(const ArenaString& lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) != rhs; +} + +inline bool operator!=(absl::string_view lhs, const ArenaString& rhs) { + return lhs != absl::implicit_cast(rhs); +} + +inline bool operator<(const ArenaString& lhs, const ArenaString& rhs) { + return absl::implicit_cast(lhs) < + absl::implicit_cast(rhs); +} + +inline bool operator<(const ArenaString& lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) < rhs; +} + +inline bool operator<(absl::string_view lhs, const ArenaString& rhs) { + return lhs < absl::implicit_cast(rhs); +} + +inline bool operator<=(const ArenaString& lhs, const ArenaString& rhs) { + return absl::implicit_cast(lhs) <= + absl::implicit_cast(rhs); +} + +inline bool operator<=(const ArenaString& lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) <= rhs; +} + +inline bool operator<=(absl::string_view lhs, const ArenaString& rhs) { + return lhs <= absl::implicit_cast(rhs); +} + +inline bool operator>(const ArenaString& lhs, const ArenaString& rhs) { + return absl::implicit_cast(lhs) > + absl::implicit_cast(rhs); +} + +inline bool operator>(const ArenaString& lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) > rhs; +} + +inline bool operator>(absl::string_view lhs, const ArenaString& rhs) { + return lhs > absl::implicit_cast(rhs); +} + +inline bool operator>=(const ArenaString& lhs, const ArenaString& rhs) { + return absl::implicit_cast(lhs) >= + absl::implicit_cast(rhs); +} + +inline bool operator>=(const ArenaString& lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) >= rhs; +} + +inline bool operator>=(absl::string_view lhs, const ArenaString& rhs) { + return lhs >= absl::implicit_cast(rhs); +} + +template +H AbslHashValue(H state, const ArenaString& arena_string) { + return H::combine(std::move(state), + absl::implicit_cast(arena_string)); +} + +#undef CEL_ATTRIBUTE_ARENA_STRING_OWNER + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ diff --git a/common/arena_string_pool.h b/common/arena_string_pool.h new file mode 100644 index 000000000..bddd9c8e4 --- /dev/null +++ b/common/arena_string_pool.h @@ -0,0 +1,86 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/arena_string_view.h" +#include "internal/string_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class ArenaStringPool; + +absl_nonnull std::unique_ptr NewArenaStringPool( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class ArenaStringPool final { + public: + ArenaStringPool(const ArenaStringPool&) = delete; + ArenaStringPool(ArenaStringPool&&) = delete; + ArenaStringPool& operator=(const ArenaStringPool&) = delete; + ArenaStringPool& operator=(ArenaStringPool&&) = delete; + + ArenaStringView InternString(const char* absl_nullable string) { + return ArenaStringView(strings_.InternString(string), strings_.arena()); + } + + ArenaStringView InternString(absl::string_view string) { + return ArenaStringView(strings_.InternString(string), strings_.arena()); + } + + ArenaStringView InternString(std::string&& string) { + return ArenaStringView(strings_.InternString(std::move(string)), + strings_.arena()); + } + + ArenaStringView InternString(const absl::Cord& string) { + return ArenaStringView(strings_.InternString(string), strings_.arena()); + } + + ArenaStringView InternString(ArenaStringView string) { + if (string.arena() == strings_.arena()) { + return string; + } + return InternString(absl::implicit_cast(string)); + } + + private: + friend absl_nonnull std::unique_ptr NewArenaStringPool( + google::protobuf::Arena* absl_nonnull); + + explicit ArenaStringPool(google::protobuf::Arena* absl_nonnull arena) + : strings_(arena) {} + + internal::StringPool strings_; +}; + +inline absl_nonnull std::unique_ptr NewArenaStringPool( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return std::unique_ptr(new ArenaStringPool(arena)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ diff --git a/common/arena_string_pool_test.cc b/common/arena_string_pool_test.cc new file mode 100644 index 000000000..59921ae48 --- /dev/null +++ b/common/arena_string_pool_test.cc @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/arena_string_pool.h" + +#include + +#include "absl/strings/cord_test_helpers.h" +#include "absl/strings/string_view.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +TEST(ArenaStringPool, InternCString) { + google::protobuf::Arena arena; + auto string_pool = NewArenaStringPool(&arena); + auto expected = string_pool->InternString("Hello World!"); + auto got = string_pool->InternString("Hello World!"); + EXPECT_EQ(expected.data(), got.data()); +} + +TEST(ArenaStringPool, InternStringView) { + google::protobuf::Arena arena; + auto string_pool = NewArenaStringPool(&arena); + auto expected = string_pool->InternString(absl::string_view("Hello World!")); + auto got = string_pool->InternString("Hello World!"); + EXPECT_EQ(expected.data(), got.data()); +} + +TEST(ArenaStringPool, InternStringSmall) { + google::protobuf::Arena arena; + auto string_pool = NewArenaStringPool(&arena); + auto expected = string_pool->InternString(std::string("Hello World!")); + auto got = string_pool->InternString("Hello World!"); + EXPECT_EQ(expected.data(), got.data()); +} + +TEST(ArenaStringPool, InternStringLarge) { + google::protobuf::Arena arena; + auto string_pool = NewArenaStringPool(&arena); + auto expected = string_pool->InternString( + std::string("This string is larger than std::string itself!")); + auto got = string_pool->InternString( + "This string is larger than std::string itself!"); + EXPECT_EQ(expected.data(), got.data()); +} + +TEST(ArenaStringPool, InternCord) { + google::protobuf::Arena arena; + auto string_pool = NewArenaStringPool(&arena); + auto expected = string_pool->InternString(absl::MakeFragmentedCord( + {"This string is larger", " ", "than absl::Cord itself!"})); + auto got = string_pool->InternString( + "This string is larger than absl::Cord itself!"); + EXPECT_EQ(expected.data(), got.data()); +} + +} // namespace +} // namespace cel diff --git a/common/arena_string_test.cc b/common/arena_string_test.cc new file mode 100644 index 000000000..a3135a93f --- /dev/null +++ b/common/arena_string_test.cc @@ -0,0 +1,160 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/arena_string.h" + +#include "absl/base/nullability.h" +#include "absl/hash/hash.h" +#include "absl/hash/hash_testing.h" +#include "absl/strings/string_view.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::testing::Eq; +using ::testing::Ge; +using ::testing::Gt; +using ::testing::IsEmpty; +using ::testing::Le; +using ::testing::Lt; +using ::testing::Ne; +using ::testing::Not; +using ::testing::NotNull; +using ::testing::SizeIs; + +class ArenaStringTest : public ::testing::Test { + protected: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(ArenaStringTest, Default) { + ArenaString string; + EXPECT_THAT(string, IsEmpty()); + EXPECT_THAT(string, SizeIs(0)); + EXPECT_THAT(string, Eq(ArenaString())); +} + +TEST_F(ArenaStringTest, Small) { + static constexpr absl::string_view kSmall = "Hello World!"; + + ArenaString string(kSmall, arena()); + EXPECT_THAT(string, Not(IsEmpty())); + EXPECT_THAT(string, SizeIs(kSmall.size())); + EXPECT_THAT(string.data(), NotNull()); + EXPECT_THAT(string, kSmall); +} + +TEST_F(ArenaStringTest, Large) { + static constexpr absl::string_view kLarge = + "This string is larger than the inline storage!"; + + ArenaString string(kLarge, arena()); + EXPECT_THAT(string, Not(IsEmpty())); + EXPECT_THAT(string, SizeIs(kLarge.size())); + EXPECT_THAT(string.data(), NotNull()); + EXPECT_THAT(string, kLarge); +} + +TEST_F(ArenaStringTest, Iterator) { + ArenaString string = ArenaString("Hello World!", arena()); + auto it = string.cbegin(); + EXPECT_THAT(*it++, Eq('H')); + EXPECT_THAT(*it++, Eq('e')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq(' ')); + EXPECT_THAT(*it++, Eq('W')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('r')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('d')); + EXPECT_THAT(*it++, Eq('!')); + EXPECT_THAT(it, Eq(string.cend())); +} + +TEST_F(ArenaStringTest, ReverseIterator) { + ArenaString string = ArenaString("Hello World!", arena()); + auto it = string.crbegin(); + EXPECT_THAT(*it++, Eq('!')); + EXPECT_THAT(*it++, Eq('d')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('r')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('W')); + EXPECT_THAT(*it++, Eq(' ')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('e')); + EXPECT_THAT(*it++, Eq('H')); + EXPECT_THAT(it, Eq(string.crend())); +} + +TEST_F(ArenaStringTest, RemovePrefix) { + ArenaString string = ArenaString("Hello World!", arena()); + string.remove_prefix(6); + EXPECT_EQ(string, "World!"); +} + +TEST_F(ArenaStringTest, RemoveSuffix) { + ArenaString string = ArenaString("Hello World!", arena()); + string.remove_suffix(7); + EXPECT_EQ(string, "Hello"); +} + +TEST_F(ArenaStringTest, Equal) { + EXPECT_THAT(ArenaString("1", arena()), Eq(ArenaString("1", arena()))); +} + +TEST_F(ArenaStringTest, NotEqual) { + EXPECT_THAT(ArenaString("1", arena()), Ne(ArenaString("2", arena()))); +} + +TEST_F(ArenaStringTest, Less) { + EXPECT_THAT(ArenaString("1", arena()), Lt(ArenaString("2", arena()))); +} + +TEST_F(ArenaStringTest, LessEqual) { + EXPECT_THAT(ArenaString("1", arena()), Le(ArenaString("1", arena()))); +} + +TEST_F(ArenaStringTest, Greater) { + EXPECT_THAT(ArenaString("2", arena()), Gt(ArenaString("1", arena()))); +} + +TEST_F(ArenaStringTest, GreaterEqual) { + EXPECT_THAT(ArenaString("1", arena()), Ge(ArenaString("1", arena()))); +} + +TEST_F(ArenaStringTest, ImplementsAbslHashCorrectly) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {ArenaString("", arena()), ArenaString("Hello World!", arena()), + ArenaString("How much wood could a woodchuck chuck if a " + "woodchuck could chuck wood?", + arena())})); +} + +TEST_F(ArenaStringTest, Hash) { + EXPECT_EQ(absl::HashOf(ArenaString("Hello World!", arena())), + absl::HashOf(absl::string_view("Hello World!"))); +} + +} // namespace +} // namespace cel diff --git a/common/arena_string_view.h b/common/arena_string_view.h new file mode 100644 index 000000000..2c750ba99 --- /dev/null +++ b/common/arena_string_view.h @@ -0,0 +1,239 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_VIEW_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_VIEW_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class ArenaString; + +// Bug in current Abseil LTS. Fixed in +// https://github.com/abseil/abseil-cpp/commit/fd7713cb9a97c49096211ff40de280b6cebbb21c +// which is not yet in an LTS. +#if defined(__clang__) && (!defined(__clang_major__) || __clang_major__ >= 13) +#define CEL_ATTRIBUTE_ARENA_STRING_VIEW ABSL_ATTRIBUTE_VIEW +#else +#define CEL_ATTRIBUTE_ARENA_STRING_VIEW +#endif + +class CEL_ATTRIBUTE_ARENA_STRING_VIEW ArenaStringView final { + public: + using traits_type = std::char_traits; + using value_type = char; + using pointer = char*; + using const_pointer = const char*; + using reference = char&; + using const_reference = const char&; + using const_iterator = typename absl::string_view::const_pointer; + using iterator = typename absl::string_view::const_iterator; + using const_reverse_iterator = + typename absl::string_view::const_reverse_iterator; + using reverse_iterator = typename absl::string_view::reverse_iterator; + using size_type = size_t; + using difference_type = ptrdiff_t; + using absl_internal_is_view = std::true_type; + + ArenaStringView() = default; + ArenaStringView(const ArenaStringView&) = default; + ArenaStringView& operator=(const ArenaStringView&) = default; + + // NOLINTNEXTLINE(google-explicit-constructor) + ArenaStringView( + const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // NOLINTNEXTLINE(google-explicit-constructor) + ArenaStringView& operator=( + const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND); + + ArenaStringView& operator=(ArenaString&&) = delete; + + explicit ArenaStringView( + google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : arena_(arena) {} + + ArenaStringView(std::nullptr_t) = delete; + + ArenaStringView(absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) + : string_(string), arena_(arena) {} + + ArenaStringView(absl::string_view, std::nullptr_t) = delete; + + google::protobuf::Arena* absl_nullable arena() const { return arena_; } + + size_type size() const { return string_.size(); } + + bool empty() const { return string_.empty(); } + + size_type max_size() const { return std::numeric_limits::max() >> 1; } + + absl_nonnull const_pointer data() const { return string_.data(); } + + const_reference front() const { + ABSL_DCHECK(!empty()); + + return string_.front(); + } + + const_reference back() const { + ABSL_DCHECK(!empty()); + + return string_.back(); + } + + const_reference operator[](size_type index) const { + ABSL_DCHECK_LT(index, size()); + + return string_[index]; + } + + void remove_prefix(size_type n) { + ABSL_DCHECK_LE(n, size()); + + string_.remove_prefix(n); + } + + void remove_suffix(size_type n) { + ABSL_DCHECK_LE(n, size()); + + string_.remove_suffix(n); + } + + const_iterator begin() const { return string_.begin(); } + + const_iterator cbegin() const { return string_.cbegin(); } + + const_iterator end() const { return string_.end(); } + + const_iterator cend() const { return string_.cend(); } + + const_reverse_iterator rbegin() const { return string_.rbegin(); } + + const_reverse_iterator crbegin() const { return string_.crbegin(); } + + const_reverse_iterator rend() const { return string_.rend(); } + + const_reverse_iterator crend() const { return string_.crend(); } + + // NOLINTNEXTLINE(google-explicit-constructor) + operator absl::string_view() const { return string_; } + + private: + absl::string_view string_; + google::protobuf::Arena* absl_nullable arena_ = nullptr; +}; + +inline bool operator==(ArenaStringView lhs, ArenaStringView rhs) { + return absl::implicit_cast(lhs) == + absl::implicit_cast(rhs); +} + +inline bool operator==(ArenaStringView lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) == rhs; +} + +inline bool operator==(absl::string_view lhs, ArenaStringView rhs) { + return lhs == absl::implicit_cast(rhs); +} + +inline bool operator!=(ArenaStringView lhs, ArenaStringView rhs) { + return absl::implicit_cast(lhs) != + absl::implicit_cast(rhs); +} + +inline bool operator!=(ArenaStringView lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) != rhs; +} + +inline bool operator!=(absl::string_view lhs, ArenaStringView rhs) { + return lhs != absl::implicit_cast(rhs); +} + +inline bool operator<(ArenaStringView lhs, ArenaStringView rhs) { + return absl::implicit_cast(lhs) < + absl::implicit_cast(rhs); +} + +inline bool operator<(ArenaStringView lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) < rhs; +} + +inline bool operator<(absl::string_view lhs, ArenaStringView rhs) { + return lhs < absl::implicit_cast(rhs); +} + +inline bool operator<=(ArenaStringView lhs, ArenaStringView rhs) { + return absl::implicit_cast(lhs) <= + absl::implicit_cast(rhs); +} + +inline bool operator<=(ArenaStringView lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) <= rhs; +} + +inline bool operator<=(absl::string_view lhs, ArenaStringView rhs) { + return lhs <= absl::implicit_cast(rhs); +} + +inline bool operator>(ArenaStringView lhs, ArenaStringView rhs) { + return absl::implicit_cast(lhs) > + absl::implicit_cast(rhs); +} + +inline bool operator>(ArenaStringView lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) > rhs; +} + +inline bool operator>(absl::string_view lhs, ArenaStringView rhs) { + return lhs > absl::implicit_cast(rhs); +} + +inline bool operator>=(ArenaStringView lhs, ArenaStringView rhs) { + return absl::implicit_cast(lhs) >= + absl::implicit_cast(rhs); +} + +inline bool operator>=(ArenaStringView lhs, absl::string_view rhs) { + return absl::implicit_cast(lhs) >= rhs; +} + +inline bool operator>=(absl::string_view lhs, ArenaStringView rhs) { + return lhs >= absl::implicit_cast(rhs); +} + +template +H AbslHashValue(H state, ArenaStringView arena_string_view) { + return H::combine(std::move(state), + absl::implicit_cast(arena_string_view)); +} + +#undef CEL_ATTRIBUTE_ARENA_STRING_VIEW + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_VIEW_H_ diff --git a/common/arena_string_view_test.cc b/common/arena_string_view_test.cc new file mode 100644 index 000000000..f3fa055db --- /dev/null +++ b/common/arena_string_view_test.cc @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/arena_string_view.h" + +#include "absl/base/nullability.h" +#include "absl/hash/hash.h" +#include "absl/hash/hash_testing.h" +#include "absl/strings/string_view.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::testing::Eq; +using ::testing::Ge; +using ::testing::Gt; +using ::testing::IsEmpty; +using ::testing::Le; +using ::testing::Lt; +using ::testing::Ne; +using ::testing::SizeIs; + +class ArenaStringViewTest : public ::testing::Test { + protected: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(ArenaStringViewTest, Default) { + ArenaStringView string; + EXPECT_THAT(string, IsEmpty()); + EXPECT_THAT(string, SizeIs(0)); + EXPECT_THAT(string, Eq(ArenaStringView())); +} + +TEST_F(ArenaStringViewTest, Iterator) { + ArenaStringView string = ArenaStringView("Hello World!", arena()); + auto it = string.cbegin(); + EXPECT_THAT(*it++, Eq('H')); + EXPECT_THAT(*it++, Eq('e')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq(' ')); + EXPECT_THAT(*it++, Eq('W')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('r')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('d')); + EXPECT_THAT(*it++, Eq('!')); + EXPECT_THAT(it, Eq(string.cend())); +} + +TEST_F(ArenaStringViewTest, ReverseIterator) { + ArenaStringView string = ArenaStringView("Hello World!", arena()); + auto it = string.crbegin(); + EXPECT_THAT(*it++, Eq('!')); + EXPECT_THAT(*it++, Eq('d')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('r')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('W')); + EXPECT_THAT(*it++, Eq(' ')); + EXPECT_THAT(*it++, Eq('o')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('l')); + EXPECT_THAT(*it++, Eq('e')); + EXPECT_THAT(*it++, Eq('H')); + EXPECT_THAT(it, Eq(string.crend())); +} + +TEST_F(ArenaStringViewTest, RemovePrefix) { + ArenaStringView string = ArenaStringView("Hello World!", arena()); + string.remove_prefix(6); + EXPECT_EQ(string, "World!"); +} + +TEST_F(ArenaStringViewTest, RemoveSuffix) { + ArenaStringView string = ArenaStringView("Hello World!", arena()); + string.remove_suffix(7); + EXPECT_EQ(string, "Hello"); +} + +TEST_F(ArenaStringViewTest, Equal) { + EXPECT_THAT(ArenaStringView("1", arena()), Eq(ArenaStringView("1", arena()))); +} + +TEST_F(ArenaStringViewTest, NotEqual) { + EXPECT_THAT(ArenaStringView("1", arena()), Ne(ArenaStringView("2", arena()))); +} + +TEST_F(ArenaStringViewTest, Less) { + EXPECT_THAT(ArenaStringView("1", arena()), Lt(ArenaStringView("2", arena()))); +} + +TEST_F(ArenaStringViewTest, LessEqual) { + EXPECT_THAT(ArenaStringView("1", arena()), Le(ArenaStringView("1", arena()))); +} + +TEST_F(ArenaStringViewTest, Greater) { + EXPECT_THAT(ArenaStringView("2", arena()), Gt(ArenaStringView("1", arena()))); +} + +TEST_F(ArenaStringViewTest, GreaterEqual) { + EXPECT_THAT(ArenaStringView("1", arena()), Ge(ArenaStringView("1", arena()))); +} + +TEST_F(ArenaStringViewTest, ImplementsAbslHashCorrectly) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {ArenaStringView("", arena()), ArenaStringView("Hello World!", arena()), + ArenaStringView("How much wood could a woodchuck chuck if a " + "woodchuck could chuck wood?", + arena())})); +} + +TEST_F(ArenaStringViewTest, Hash) { + EXPECT_EQ(absl::HashOf(ArenaStringView("Hello World!", arena())), + absl::HashOf(absl::string_view("Hello World!"))); +} + +} // namespace +} // namespace cel diff --git a/common/ast.cc b/common/ast.cc new file mode 100644 index 000000000..48b6f5e0b --- /dev/null +++ b/common/ast.cc @@ -0,0 +1,98 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast.h" + +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "common/ast/metadata.h" +#include "common/source.h" + +namespace cel { +namespace { + +const TypeSpec& DynSingleton() { + static absl::NoDestructor singleton{TypeSpecKind(DynTypeSpec())}; + return *singleton; +} + +} // namespace + +const TypeSpec* absl_nullable Ast::GetType(int64_t expr_id) const { + auto iter = type_map_.find(expr_id); + if (iter == type_map_.end()) { + return nullptr; + } + return &iter->second; +} + +const TypeSpec& Ast::GetTypeOrDyn(int64_t expr_id) const { + if (const TypeSpec* type = GetType(expr_id); type != nullptr) { + return *type; + } + return DynSingleton(); +} + +const TypeSpec& Ast::GetReturnType() const { + return GetTypeOrDyn(root_expr().id()); +} + +const Reference* absl_nullable Ast::GetReference(int64_t expr_id) const { + auto iter = reference_map_.find(expr_id); + if (iter == reference_map_.end()) { + return nullptr; + } + return &iter->second; +} + +SourceLocation Ast::ComputeSourceLocation(int64_t expr_id) const { + const auto& source_info = this->source_info(); + auto iter = source_info.positions().find(expr_id); + if (iter == source_info.positions().end()) { + return SourceLocation{}; + } + int32_t absolute_position = iter->second; + if (absolute_position < 0) { + return SourceLocation{}; + } + + // Find the first line offset that is greater than the absolute position. + int32_t line_idx = -1; + int32_t offset = 0; + for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { + int32_t next_offset = source_info.line_offsets()[i]; + if (next_offset <= offset) { + // Line offset is not monotonically increasing, so line information is + // invalid. + return SourceLocation{}; + } + if (absolute_position < next_offset) { + line_idx = i; + break; + } + offset = next_offset; + } + + if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { + return SourceLocation{}; + } + + int32_t rel_position = absolute_position - offset; + + return SourceLocation{line_idx + 1, rel_position}; +} + +} // namespace cel diff --git a/common/ast.h b/common/ast.h new file mode 100644 index 000000000..afd0575ad --- /dev/null +++ b/common/ast.h @@ -0,0 +1,157 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "common/ast/metadata.h" // IWYU pragma: export +#include "common/expr.h" +#include "common/source.h" + +namespace cel { + +// In memory representation of a CEL abstract syntax tree. +// +// If AST inspection or manipulation is needed, prefer to use an existing tool +// or traverse the protobuf representation rather than directly manipulating +// through this class. See `cel::NavigableAst` and `cel::AstTraverse`. +// +// Type and reference maps are only populated if the AST is checked. Any changes +// to the AST are not automatically reflected in the type or reference maps. +// +// To create a new instance from a protobuf representation, use the conversion +// utilities in `common/ast_proto.h`. +class Ast final { + public: + using ReferenceMap = absl::flat_hash_map; + using TypeMap = absl::flat_hash_map; + + Ast() : is_checked_(false) {} + + Ast(Expr expr, SourceInfo source_info) + : root_expr_(std::move(expr)), + source_info_(std::move(source_info)), + is_checked_(false) {} + + Ast(Expr expr, SourceInfo source_info, ReferenceMap reference_map, + TypeMap type_map, std::string expr_version) + : root_expr_(std::move(expr)), + source_info_(std::move(source_info)), + reference_map_(std::move(reference_map)), + type_map_(std::move(type_map)), + expr_version_(std::move(expr_version)), + is_checked_(true) {} + + Ast(const Ast& other) = default; + Ast& operator=(const Ast& other) = default; + Ast(Ast&& other) = default; + Ast& operator=(Ast&& other) = default; + + // Deprecated. Use `is_checked()` instead. + bool IsChecked() const { return is_checked_; } + + bool is_checked() const { return is_checked_; } + void set_is_checked(bool is_checked) { is_checked_ = is_checked; } + + // The root expression of the AST. + // + // This is the entry point for evaluation and determines the overall result + // of the expression given a context. + const Expr& root_expr() const { return root_expr_; } + Expr& mutable_root_expr() { return root_expr_; } + + // Metadata about the source expression. + const SourceInfo& source_info() const { return source_info_; } + SourceInfo& mutable_source_info() { return source_info_; } + + // Returns the type of the expression with the given `expr_id`. + // + // Returns `nullptr` if the expression node is not found or has dynamic type. + const TypeSpec* absl_nullable GetType(int64_t expr_id) const; + const TypeSpec& GetTypeOrDyn(int64_t expr_id) const; + const TypeSpec& GetReturnType() const; + + // Returns the resolved reference for the expression with the given `expr_id`. + // + // Returns `nullptr` if the expression node is not found or no reference was + // resolved. + const Reference* absl_nullable GetReference(int64_t expr_id) const; + + // A map from expression ids to resolved references. + // + // The following entries are in this table: + // + // - An Ident or Select expression is represented here if it resolves to a + // declaration. For instance, if `a.b.c` is represented by + // `select(select(id(a), b), c)`, and `a.b` resolves to a declaration, + // while `c` is a field selection, then the reference is attached to the + // nested select expression (but not to the id or or the outer select). + // In turn, if `a` resolves to a declaration and `b.c` are field selections, + // the reference is attached to the ident expression. + // - Every Call expression has an entry here, identifying the function being + // called. + // - Every CreateStruct expression for a message has an entry, identifying + // the message. + // + // Unpopulated if the AST is not checked. + const ReferenceMap& reference_map() const { return reference_map_; } + ReferenceMap& mutable_reference_map() { return reference_map_; } + + // A map from expression ids to types. + // + // Every expression node which has a type different than DYN has a mapping + // here. If an expression has type DYN, it is omitted from this map to save + // space. + // + // Unpopulated if the AST is not checked. + const TypeMap& type_map() const { return type_map_; } + TypeMap& mutable_type_map() { return type_map_; } + + // The expr version indicates the major / minor version number of the `expr` + // representation. + // + // The most common reason for a version change will be to indicate to the CEL + // runtimes that transformations have been performed on the expr during static + // analysis. + absl::string_view expr_version() const { return expr_version_; } + void set_expr_version(absl::string_view expr_version) { + expr_version_ = expr_version; + } + + // Computes the source location (line and column) for the given expression ID + // from the source info (which stores absolute positions). + // + // Returns a default (empty) source location if the expression ID is not found + // or the source info is not populated correctly. + SourceLocation ComputeSourceLocation(int64_t expr_id) const; + + private: + Expr root_expr_; + SourceInfo source_info_; + ReferenceMap reference_map_; + TypeMap type_map_; + std::string expr_version_; + bool is_checked_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_H_ diff --git a/common/ast/BUILD b/common/ast/BUILD new file mode 100644 index 000000000..17456566b --- /dev/null +++ b/common/ast/BUILD @@ -0,0 +1,151 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Internal AST implementation and utilities +# These are needed by various parts of the CEL-C++ library, but are not intended for public use at +# this time. +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "constant_proto", + srcs = ["constant_proto.cc"], + hdrs = ["constant_proto.h"], + deps = [ + "//common:constant", + "//internal:proto_time_encoding", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_library( + name = "expr_proto", + srcs = ["expr_proto.cc"], + hdrs = ["expr_proto.h"], + deps = [ + ":constant_proto", + "//common:constant", + "//common:expr", + "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_test( + name = "expr_proto_test", + srcs = ["expr_proto_test.cc"], + deps = [ + ":expr_proto", + "//common:expr", + "//internal:proto_matchers", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "source_info_proto", + srcs = ["source_info_proto.cc"], + hdrs = ["source_info_proto.h"], + deps = [ + ":expr_proto", + "//common:ast", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + ], +) + +cc_library( + name = "metadata", + srcs = ["metadata.cc"], + hdrs = ["metadata.h"], + deps = [ + "//common:constant", + "//common:expr", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "metadata_test", + srcs = ["metadata_test.cc"], + deps = [ + ":metadata", + "//common:expr", + "//internal:testing", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "navigable_ast_internal", + srcs = ["navigable_ast_kinds.cc"], + hdrs = [ + "navigable_ast_internal.h", + "navigable_ast_kinds.h", + ], + deps = [ + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "navigable_ast_internal_test", + srcs = ["navigable_ast_internal_test.cc"], + deps = [ + ":navigable_ast_internal", + "//internal:testing", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) diff --git a/common/ast/constant_proto.cc b/common/ast/constant_proto.cc new file mode 100644 index 000000000..1982c05b4 --- /dev/null +++ b/common/ast/constant_proto.cc @@ -0,0 +1,123 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast/constant_proto.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "common/constant.h" +#include "internal/proto_time_encoding.h" + +namespace cel::ast_internal { + +using ConstantProto = cel::expr::Constant; + +absl::Status ConstantToProto(const Constant& constant, + ConstantProto* absl_nonnull proto) { + return absl::visit(absl::Overload( + [proto](std::monostate) -> absl::Status { + proto->clear_constant_kind(); + return absl::OkStatus(); + }, + [proto](std::nullptr_t) -> absl::Status { + proto->set_null_value(google::protobuf::NULL_VALUE); + return absl::OkStatus(); + }, + [proto](bool value) -> absl::Status { + proto->set_bool_value(value); + return absl::OkStatus(); + }, + [proto](int64_t value) -> absl::Status { + proto->set_int64_value(value); + return absl::OkStatus(); + }, + [proto](uint64_t value) -> absl::Status { + proto->set_uint64_value(value); + return absl::OkStatus(); + }, + [proto](double value) -> absl::Status { + proto->set_double_value(value); + return absl::OkStatus(); + }, + [proto](const BytesConstant& value) -> absl::Status { + proto->set_bytes_value(value); + return absl::OkStatus(); + }, + [proto](const StringConstant& value) -> absl::Status { + proto->set_string_value(value); + return absl::OkStatus(); + }, + [proto](absl::Duration value) -> absl::Status { + return internal::EncodeDuration( + value, proto->mutable_duration_value()); + }, + [proto](absl::Time value) -> absl::Status { + return internal::EncodeTime( + value, proto->mutable_timestamp_value()); + }), + constant.kind()); +} + +absl::Status ConstantFromProto(const ConstantProto& proto, Constant& constant) { + switch (proto.constant_kind_case()) { + case ConstantProto::CONSTANT_KIND_NOT_SET: + constant = Constant{}; + break; + case ConstantProto::kNullValue: + constant.set_null_value(); + break; + case ConstantProto::kBoolValue: + constant.set_bool_value(proto.bool_value()); + break; + case ConstantProto::kInt64Value: + constant.set_int_value(proto.int64_value()); + break; + case ConstantProto::kUint64Value: + constant.set_uint_value(proto.uint64_value()); + break; + case ConstantProto::kDoubleValue: + constant.set_double_value(proto.double_value()); + break; + case ConstantProto::kStringValue: + constant.set_string_value(proto.string_value()); + break; + case ConstantProto::kBytesValue: + constant.set_bytes_value(proto.bytes_value()); + break; + case ConstantProto::kDurationValue: + constant.set_duration_value( + internal::DecodeDuration(proto.duration_value())); + break; + case ConstantProto::kTimestampValue: + constant.set_timestamp_value( + internal::DecodeTime(proto.timestamp_value())); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected ConstantKindCase: ", + static_cast(proto.constant_kind_case()))); + } + return absl::OkStatus(); +} + +} // namespace cel::ast_internal diff --git a/common/ast/constant_proto.h b/common/ast/constant_proto.h new file mode 100644 index 000000000..c00adbdf3 --- /dev/null +++ b/common/ast/constant_proto.h @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_CONSTANT_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_CONSTANT_PROTO_H_ + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "common/constant.h" + +namespace cel::ast_internal { + +// `ConstantToProto` converts from native `Constant` to its protocol buffer +// message equivalent. +absl::Status ConstantToProto(const Constant& constant, + cel::expr::Constant* absl_nonnull proto); + +// `ConstantToProto` converts to native `Constant` from its protocol buffer +// message equivalent. +absl::Status ConstantFromProto(const cel::expr::Constant& proto, + Constant& constant); + +} // namespace cel::ast_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_CONSTANT_PROTO_H_ diff --git a/common/ast/expr_proto.cc b/common/ast/expr_proto.cc new file mode 100644 index 000000000..d0efea567 --- /dev/null +++ b/common/ast/expr_proto.cc @@ -0,0 +1,514 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast/expr_proto.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/variant.h" +#include "common/ast/constant_proto.h" +#include "common/constant.h" +#include "common/expr.h" +#include "internal/status_macros.h" + +namespace cel::ast_internal { + +namespace { + +using ExprProto = cel::expr::Expr; +using ConstantProto = cel::expr::Constant; +using StructExprProto = cel::expr::Expr::CreateStruct; + +class ExprToProtoState final { + private: + struct Frame final { + const Expr* absl_nonnull expr; + cel::expr::Expr* absl_nonnull proto; + }; + + public: + absl::Status ExprToProto(const Expr& expr, + cel::expr::Expr* absl_nonnull proto) { + Push(expr, proto); + Frame frame; + while (Pop(frame)) { + CEL_RETURN_IF_ERROR(ExprToProtoImpl(*frame.expr, frame.proto)); + } + return absl::OkStatus(); + } + + private: + absl::Status ExprToProtoImpl(const Expr& expr, + cel::expr::Expr* absl_nonnull proto) { + return absl::visit( + absl::Overload( + [&expr, proto](const UnspecifiedExpr&) -> absl::Status { + proto->Clear(); + proto->set_id(expr.id()); + return absl::OkStatus(); + }, + [this, &expr, proto](const Constant& const_expr) -> absl::Status { + return ConstExprToProto(expr, const_expr, proto); + }, + [this, &expr, proto](const IdentExpr& ident_expr) -> absl::Status { + return IdentExprToProto(expr, ident_expr, proto); + }, + [this, &expr, + proto](const SelectExpr& select_expr) -> absl::Status { + return SelectExprToProto(expr, select_expr, proto); + }, + [this, &expr, proto](const CallExpr& call_expr) -> absl::Status { + return CallExprToProto(expr, call_expr, proto); + }, + [this, &expr, proto](const ListExpr& list_expr) -> absl::Status { + return ListExprToProto(expr, list_expr, proto); + }, + [this, &expr, + proto](const StructExpr& struct_expr) -> absl::Status { + return StructExprToProto(expr, struct_expr, proto); + }, + [this, &expr, proto](const MapExpr& map_expr) -> absl::Status { + return MapExprToProto(expr, map_expr, proto); + }, + [this, &expr, proto]( + const ComprehensionExpr& comprehension_expr) -> absl::Status { + return ComprehensionExprToProto(expr, comprehension_expr, proto); + }), + expr.kind()); + } + + absl::Status ConstExprToProto(const Expr& expr, const Constant& const_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + proto->set_id(expr.id()); + return ConstantToProto(const_expr, proto->mutable_const_expr()); + } + + absl::Status IdentExprToProto(const Expr& expr, const IdentExpr& ident_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + auto* ident_proto = proto->mutable_ident_expr(); + proto->set_id(expr.id()); + ident_proto->set_name(ident_expr.name()); + return absl::OkStatus(); + } + + absl::Status SelectExprToProto(const Expr& expr, + const SelectExpr& select_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + auto* select_proto = proto->mutable_select_expr(); + proto->set_id(expr.id()); + if (select_expr.has_operand()) { + Push(select_expr.operand(), select_proto->mutable_operand()); + } + select_proto->set_field(select_expr.field()); + select_proto->set_test_only(select_expr.test_only()); + return absl::OkStatus(); + } + + absl::Status CallExprToProto(const Expr& expr, const CallExpr& call_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + auto* call_proto = proto->mutable_call_expr(); + proto->set_id(expr.id()); + if (call_expr.has_target()) { + Push(call_expr.target(), call_proto->mutable_target()); + } + call_proto->set_function(call_expr.function()); + if (!call_expr.args().empty()) { + call_proto->mutable_args()->Reserve( + static_cast(call_expr.args().size())); + for (const auto& argument : call_expr.args()) { + Push(argument, call_proto->add_args()); + } + } + return absl::OkStatus(); + } + + absl::Status ListExprToProto(const Expr& expr, const ListExpr& list_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + auto* list_proto = proto->mutable_list_expr(); + proto->set_id(expr.id()); + if (!list_expr.elements().empty()) { + list_proto->mutable_elements()->Reserve( + static_cast(list_expr.elements().size())); + for (size_t i = 0; i < list_expr.elements().size(); ++i) { + const auto& element_expr = list_expr.elements()[i]; + auto* element_proto = list_proto->add_elements(); + if (element_expr.has_expr()) { + Push(element_expr.expr(), element_proto); + } + if (element_expr.optional()) { + list_proto->add_optional_indices(static_cast(i)); + } + } + } + return absl::OkStatus(); + } + + absl::Status StructExprToProto(const Expr& expr, + const StructExpr& struct_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + auto* struct_proto = proto->mutable_struct_expr(); + proto->set_id(expr.id()); + struct_proto->set_message_name(struct_expr.name()); + if (!struct_expr.fields().empty()) { + struct_proto->mutable_entries()->Reserve( + static_cast(struct_expr.fields().size())); + for (const auto& field_expr : struct_expr.fields()) { + auto* field_proto = struct_proto->add_entries(); + field_proto->set_id(field_expr.id()); + field_proto->set_field_key(field_expr.name()); + if (field_expr.has_value()) { + Push(field_expr.value(), field_proto->mutable_value()); + } + if (field_expr.optional()) { + field_proto->set_optional_entry(true); + } + } + } + return absl::OkStatus(); + } + + absl::Status MapExprToProto(const Expr& expr, const MapExpr& map_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + auto* map_proto = proto->mutable_struct_expr(); + proto->set_id(expr.id()); + if (!map_expr.entries().empty()) { + map_proto->mutable_entries()->Reserve( + static_cast(map_expr.entries().size())); + for (const auto& entry_expr : map_expr.entries()) { + auto* entry_proto = map_proto->add_entries(); + entry_proto->set_id(entry_expr.id()); + if (entry_expr.has_key()) { + Push(entry_expr.key(), entry_proto->mutable_map_key()); + } + if (entry_expr.has_value()) { + Push(entry_expr.value(), entry_proto->mutable_value()); + } + if (entry_expr.optional()) { + entry_proto->set_optional_entry(true); + } + } + } + return absl::OkStatus(); + } + + absl::Status ComprehensionExprToProto( + const Expr& expr, const ComprehensionExpr& comprehension_expr, + ExprProto* absl_nonnull proto) { + proto->Clear(); + auto* comprehension_proto = proto->mutable_comprehension_expr(); + proto->set_id(expr.id()); + comprehension_proto->set_iter_var(comprehension_expr.iter_var()); + comprehension_proto->set_iter_var2(comprehension_expr.iter_var2()); + if (comprehension_expr.has_iter_range()) { + Push(comprehension_expr.iter_range(), + comprehension_proto->mutable_iter_range()); + } + comprehension_proto->set_accu_var(comprehension_expr.accu_var()); + if (comprehension_expr.has_accu_init()) { + Push(comprehension_expr.accu_init(), + comprehension_proto->mutable_accu_init()); + } + if (comprehension_expr.has_loop_condition()) { + Push(comprehension_expr.loop_condition(), + comprehension_proto->mutable_loop_condition()); + } + if (comprehension_expr.has_loop_step()) { + Push(comprehension_expr.loop_step(), + comprehension_proto->mutable_loop_step()); + } + if (comprehension_expr.has_result()) { + Push(comprehension_expr.result(), comprehension_proto->mutable_result()); + } + return absl::OkStatus(); + } + + void Push(const Expr& expr, ExprProto* absl_nonnull proto) { + frames_.push(Frame{&expr, proto}); + } + + bool Pop(Frame& frame) { + if (frames_.empty()) { + return false; + } + frame = frames_.top(); + frames_.pop(); + return true; + } + + std::stack> frames_; +}; + +class ExprFromProtoState final { + private: + struct Frame final { + const ExprProto* absl_nonnull proto; + Expr* absl_nonnull expr; + }; + + public: + absl::Status ExprFromProto(const ExprProto& proto, Expr& expr) { + Push(proto, expr); + Frame frame; + while (Pop(frame)) { + CEL_RETURN_IF_ERROR(ExprFromProtoImpl(*frame.proto, *frame.expr)); + } + return absl::OkStatus(); + } + + private: + absl::Status ExprFromProtoImpl(const ExprProto& proto, Expr& expr) { + switch (proto.expr_kind_case()) { + case ExprProto::EXPR_KIND_NOT_SET: + expr.Clear(); + expr.set_id(proto.id()); + return absl::OkStatus(); + case ExprProto::kConstExpr: + return ConstExprFromProto(proto, proto.const_expr(), expr); + case ExprProto::kIdentExpr: + return IdentExprFromProto(proto, proto.ident_expr(), expr); + case ExprProto::kSelectExpr: + return SelectExprFromProto(proto, proto.select_expr(), expr); + case ExprProto::kCallExpr: + return CallExprFromProto(proto, proto.call_expr(), expr); + case ExprProto::kListExpr: + return ListExprFromProto(proto, proto.list_expr(), expr); + case ExprProto::kStructExpr: + if (proto.struct_expr().message_name().empty()) { + return MapExprFromProto(proto, proto.struct_expr(), expr); + } + return StructExprFromProto(proto, proto.struct_expr(), expr); + case ExprProto::kComprehensionExpr: + return ComprehensionExprFromProto(proto, proto.comprehension_expr(), + expr); + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected ExprKindCase: ", + static_cast(proto.expr_kind_case()))); + } + } + + absl::Status ConstExprFromProto(const ExprProto& proto, + const ConstantProto& const_proto, + Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + return ConstantFromProto(const_proto, expr.mutable_const_expr()); + } + + absl::Status IdentExprFromProto(const ExprProto& proto, + const ExprProto::Ident& ident_proto, + Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + auto& ident_expr = expr.mutable_ident_expr(); + ident_expr.set_name(ident_proto.name()); + return absl::OkStatus(); + } + + absl::Status SelectExprFromProto(const ExprProto& proto, + const ExprProto::Select& select_proto, + Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + auto& select_expr = expr.mutable_select_expr(); + if (select_proto.has_operand()) { + Push(select_proto.operand(), select_expr.mutable_operand()); + } + select_expr.set_field(select_proto.field()); + select_expr.set_test_only(select_proto.test_only()); + return absl::OkStatus(); + } + + absl::Status CallExprFromProto(const ExprProto& proto, + const ExprProto::Call& call_proto, + Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + auto& call_expr = expr.mutable_call_expr(); + call_expr.set_function(call_proto.function()); + if (call_proto.has_target()) { + Push(call_proto.target(), call_expr.mutable_target()); + } + call_expr.mutable_args().reserve( + static_cast(call_proto.args().size())); + for (const auto& argument_proto : call_proto.args()) { + Push(argument_proto, call_expr.add_args()); + } + return absl::OkStatus(); + } + + absl::Status ListExprFromProto(const ExprProto& proto, + const ExprProto::CreateList& list_proto, + Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + auto& list_expr = expr.mutable_list_expr(); + list_expr.mutable_elements().reserve( + static_cast(list_proto.elements().size())); + for (int i = 0; i < list_proto.elements().size(); ++i) { + const auto& element_proto = list_proto.elements()[i]; + auto& element_expr = list_expr.add_elements(); + Push(element_proto, element_expr.mutable_expr()); + const auto& optional_indicies_proto = list_proto.optional_indices(); + element_expr.set_optional(std::find(optional_indicies_proto.begin(), + optional_indicies_proto.end(), + i) != optional_indicies_proto.end()); + } + return absl::OkStatus(); + } + + absl::Status StructExprFromProto(const ExprProto& proto, + const StructExprProto& struct_proto, + Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + auto& struct_expr = expr.mutable_struct_expr(); + struct_expr.set_name(struct_proto.message_name()); + struct_expr.mutable_fields().reserve( + static_cast(struct_proto.entries().size())); + for (const auto& field_proto : struct_proto.entries()) { + switch (field_proto.key_kind_case()) { + case StructExprProto::Entry::KEY_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case StructExprProto::Entry::kFieldKey: + break; + case StructExprProto::Entry::kMapKey: + return absl::InvalidArgumentError("encountered map entry in struct"); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected struct field kind: ", field_proto.key_kind_case())); + } + auto& field_expr = struct_expr.add_fields(); + field_expr.set_id(field_proto.id()); + field_expr.set_name(field_proto.field_key()); + if (field_proto.has_value()) { + Push(field_proto.value(), field_expr.mutable_value()); + } + field_expr.set_optional(field_proto.optional_entry()); + } + return absl::OkStatus(); + } + + absl::Status MapExprFromProto(const ExprProto& proto, + const ExprProto::CreateStruct& map_proto, + Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + auto& map_expr = expr.mutable_map_expr(); + map_expr.mutable_entries().reserve( + static_cast(map_proto.entries().size())); + for (const auto& entry_proto : map_proto.entries()) { + switch (entry_proto.key_kind_case()) { + case StructExprProto::Entry::KEY_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case StructExprProto::Entry::kMapKey: + break; + case StructExprProto::Entry::kFieldKey: + return absl::InvalidArgumentError("encountered struct field in map"); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected map entry kind: ", entry_proto.key_kind_case())); + } + auto& entry_expr = map_expr.add_entries(); + entry_expr.set_id(entry_proto.id()); + if (entry_proto.has_map_key()) { + Push(entry_proto.map_key(), entry_expr.mutable_key()); + } + if (entry_proto.has_value()) { + Push(entry_proto.value(), entry_expr.mutable_value()); + } + entry_expr.set_optional(entry_proto.optional_entry()); + } + return absl::OkStatus(); + } + + absl::Status ComprehensionExprFromProto( + const ExprProto& proto, + const ExprProto::Comprehension& comprehension_proto, Expr& expr) { + expr.Clear(); + expr.set_id(proto.id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var(comprehension_proto.iter_var()); + comprehension_expr.set_iter_var2(comprehension_proto.iter_var2()); + comprehension_expr.set_accu_var(comprehension_proto.accu_var()); + if (comprehension_proto.has_iter_range()) { + Push(comprehension_proto.iter_range(), + comprehension_expr.mutable_iter_range()); + } + if (comprehension_proto.has_accu_init()) { + Push(comprehension_proto.accu_init(), + comprehension_expr.mutable_accu_init()); + } + if (comprehension_proto.has_loop_condition()) { + Push(comprehension_proto.loop_condition(), + comprehension_expr.mutable_loop_condition()); + } + if (comprehension_proto.has_loop_step()) { + Push(comprehension_proto.loop_step(), + comprehension_expr.mutable_loop_step()); + } + if (comprehension_proto.has_result()) { + Push(comprehension_proto.result(), comprehension_expr.mutable_result()); + } + return absl::OkStatus(); + } + + void Push(const ExprProto& proto, Expr& expr) { + frames_.push(Frame{&proto, &expr}); + } + + bool Pop(Frame& frame) { + if (frames_.empty()) { + return false; + } + frame = frames_.top(); + frames_.pop(); + return true; + } + + std::stack> frames_; +}; + +} // namespace + +absl::Status ExprToProto(const Expr& expr, + cel::expr::Expr* absl_nonnull proto) { + ExprToProtoState state; + return state.ExprToProto(expr, proto); +} + +absl::Status ExprFromProto(const cel::expr::Expr& proto, Expr& expr) { + ExprFromProtoState state; + return state.ExprFromProto(proto, expr); +} + +} // namespace cel::ast_internal diff --git a/common/ast/expr_proto.h b/common/ast/expr_proto.h new file mode 100644 index 000000000..ebb071dfe --- /dev/null +++ b/common/ast/expr_proto.h @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_EXPR_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_EXPR_PROTO_H_ + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "common/expr.h" + +namespace cel::ast_internal { + +absl::Status ExprToProto(const Expr& expr, + cel::expr::Expr* absl_nonnull proto); + +absl::Status ExprFromProto(const cel::expr::Expr& proto, Expr& expr); + +} // namespace cel::ast_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_EXPR_PROTO_H_ diff --git a/common/ast/expr_proto_test.cc b/common/ast/expr_proto_test.cc new file mode 100644 index 000000000..54379eb30 --- /dev/null +++ b/common/ast/expr_proto_test.cc @@ -0,0 +1,303 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast/expr_proto.h" + +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/expr.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "google/protobuf/text_format.h" + +namespace cel::ast_internal { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::test::EqualsProto; + +using ExprProto = cel::expr::Expr; + +struct ExprRoundtripTestCase { + std::string input; +}; + +using ExprRoundTripTest = ::testing::TestWithParam; + +TEST_P(ExprRoundTripTest, RoundTrip) { + const auto& test_case = GetParam(); + ExprProto original_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(test_case.input, &original_proto)); + Expr expr; + ASSERT_THAT(ExprFromProto(original_proto, expr), IsOk()); + ExprProto proto; + ASSERT_THAT(ExprToProto(expr, &proto), IsOk()); + EXPECT_THAT(proto, EqualsProto(original_proto)); +} + +INSTANTIATE_TEST_SUITE_P( + ExprRoundTripTest, ExprRoundTripTest, + ::testing::ValuesIn({ + {R"pb( + )pb"}, + {R"pb( + id: 1 + )pb"}, + {R"pb( + id: 1 + const_expr {} + )pb"}, + {R"pb( + id: 1 + const_expr { null_value: NULL_VALUE } + )pb"}, + {R"pb( + id: 1 + const_expr { bool_value: true } + )pb"}, + {R"pb( + id: 1 + const_expr { int64_value: 1 } + )pb"}, + {R"pb( + id: 1 + const_expr { uint64_value: 1 } + )pb"}, + {R"pb( + id: 1 + const_expr { double_value: 1 } + )pb"}, + {R"pb( + id: 1 + const_expr { string_value: "foo" } + )pb"}, + {R"pb( + id: 1 + const_expr { bytes_value: "foo" } + )pb"}, + {R"pb( + id: 1 + const_expr { duration_value { seconds: 1 nanos: 1 } } + )pb"}, + {R"pb( + id: 1 + const_expr { timestamp_value { seconds: 1 nanos: 1 } } + )pb"}, + {R"pb( + id: 1 + ident_expr { name: "foo" } + )pb"}, + {R"pb( + id: 1 + select_expr { + operand { + id: 2 + ident_expr { name: "bar" } + } + field: "foo" + test_only: true + } + )pb"}, + {R"pb( + id: 1 + call_expr { + target { + id: 2 + ident_expr { name: "bar" } + } + function: "foo" + args { + id: 3 + ident_expr { name: "baz" } + } + } + )pb"}, + {R"pb( + id: 1 + list_expr { + elements { + id: 2 + ident_expr { name: "bar" } + } + elements { + id: 3 + ident_expr { name: "baz" } + } + optional_indices: 0 + } + )pb"}, + {R"pb( + id: 1 + struct_expr { + message_name: "google.type.Expr" + entries { + id: 2 + field_key: "description" + value { + id: 3 + const_expr { string_value: "foo" } + } + optional_entry: true + } + entries { + id: 4 + field_key: "expr" + value { + id: 5 + const_expr { string_value: "bar" } + } + } + } + )pb"}, + {R"pb( + id: 1 + struct_expr { + entries { + id: 2 + map_key { + id: 3 + const_expr { string_value: "description" } + } + value { + id: 4 + const_expr { string_value: "foo" } + } + optional_entry: true + } + entries { + id: 5 + map_key { + id: 6 + const_expr { string_value: "expr" } + } + value { + id: 7 + const_expr { string_value: "foo" } + } + optional_entry: true + } + } + )pb"}, + {R"pb( + id: 1 + comprehension_expr { + iter_var: "foo" + iter_range { + id: 2 + list_expr {} + } + accu_var: "bar" + accu_init { + id: 3 + list_expr {} + } + loop_condition { + id: 4 + const_expr { bool_value: true } + } + loop_step { + id: 4 + ident_expr { name: "bar" } + } + result { + id: 5 + ident_expr { name: "foo" } + } + } + )pb"}, + {R"pb( + id: 1 + comprehension_expr { + iter_var: "foo" + iter_var2: "baz" + iter_range { + id: 2 + list_expr {} + } + accu_var: "bar" + accu_init { + id: 3 + list_expr {} + } + loop_condition { + id: 4 + const_expr { bool_value: true } + } + loop_step { + id: 4 + ident_expr { name: "bar" } + } + result { + id: 5 + ident_expr { name: "foo" } + } + } + )pb"}, + })); + +TEST(ExprFromProto, StructFieldInMap) { + ExprProto original_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + id: 1 + struct_expr: { + entries: { + id: 2 + field_key: "foo" + value: { + id: 3 + ident_expr: { name: "bar" } + } + } + } + )pb", + &original_proto)); + Expr expr; + ASSERT_THAT(ExprFromProto(original_proto, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExprFromProto, MapEntryInStruct) { + ExprProto original_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + id: 1 + struct_expr: { + message_name: "some.Message" + entries: { + id: 2 + map_key: { + id: 3 + ident_expr: { name: "foo" } + } + value: { + id: 4 + ident_expr: { name: "bar" } + } + } + } + )pb", + &original_proto)); + Expr expr; + ASSERT_THAT(ExprFromProto(original_proto, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel::ast_internal diff --git a/common/ast/metadata.cc b/common/ast/metadata.cc new file mode 100644 index 000000000..38f7ef610 --- /dev/null +++ b/common/ast/metadata.cc @@ -0,0 +1,262 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast/metadata.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/types/variant.h" + +namespace cel { + +namespace { + +const TypeSpec& DefaultTypeSpec() { + static absl::NoDestructor type(TypeSpecKind{UnsetTypeSpec()}); + return *type; +} + +std::string FormatPrimitive(PrimitiveType t) { + switch (t) { + case PrimitiveType::kBool: + return "bool"; + case PrimitiveType::kInt64: + return "int"; + case PrimitiveType::kUint64: + return "uint"; + case PrimitiveType::kDouble: + return "double"; + case PrimitiveType::kString: + return "string"; + case PrimitiveType::kBytes: + return "bytes"; + default: + return "*unspecified primitive*"; + } +} + +std::string FormatWellKnown(WellKnownTypeSpec t) { + switch (t) { + case WellKnownTypeSpec::kAny: + return "google.protobuf.Any"; + case WellKnownTypeSpec::kDuration: + return "google.protobuf.Duration"; + case WellKnownTypeSpec::kTimestamp: + return "google.protobuf.Timestamp"; + default: + return "*unspecified well known*"; + } +} + +using FormatIns = std::variant; +using FormatStack = std::vector; + +void HandleFormatTypeSpec(const TypeSpec& t, FormatStack& stack, + std::string* out) { + if (t.has_dyn()) { + absl::StrAppend(out, "dyn"); + } else if (t.has_null()) { + absl::StrAppend(out, "null"); + } else if (t.has_primitive()) { + absl::StrAppend(out, FormatPrimitive(t.primitive())); + } else if (t.has_wrapper()) { + absl::StrAppend(out, "wrapper(", FormatPrimitive(t.wrapper()), ")"); + } else if (t.has_well_known()) { + absl::StrAppend(out, FormatWellKnown(t.well_known())); + return; + } else if (t.has_abstract_type()) { + const auto& abs_type = t.abstract_type(); + if (abs_type.parameter_types().empty()) { + absl::StrAppend(out, abs_type.name()); + return; + } + absl::StrAppend(out, abs_type.name(), "("); + stack.push_back(")"); + for (size_t i = abs_type.parameter_types().size(); i > 0; --i) { + stack.push_back(&abs_type.parameter_types()[i - 1]); + if (i > 1) { + stack.push_back(", "); + } + } + + } else if (t.has_type()) { + if (t.type() == TypeSpec()) { + absl::StrAppend(out, "type"); + return; + } + absl::StrAppend(out, "type("); + stack.push_back(")"); + stack.push_back(&t.type()); + } else if (t.has_message_type()) { + absl::StrAppend(out, t.message_type().type()); + } else if (t.has_type_param()) { + absl::StrAppend(out, t.type_param().type()); + } else if (t.has_list_type()) { + absl::StrAppend(out, "list("); + stack.push_back(")"); + stack.push_back(&t.list_type().elem_type()); + } else if (t.has_map_type()) { + absl::StrAppend(out, "map("); + stack.push_back(")"); + stack.push_back(&t.map_type().value_type()); + stack.push_back(", "); + stack.push_back(&t.map_type().key_type()); + } else { + absl::StrAppend(out, "*error*"); + } +} + +TypeSpecKind CopyImpl(const TypeSpecKind& other) { + return absl::visit( + absl::Overload( + [](const std::unique_ptr& other) -> TypeSpecKind { + if (other == nullptr) { + return std::make_unique(); + } + return std::make_unique(*other); + }, + [](const auto& other) -> TypeSpecKind { + // Other variants define copy ctor. + return other; + }), + other); +} + +} // namespace + +const ExtensionSpec::Version& ExtensionSpec::Version::DefaultInstance() { + static absl::NoDestructor instance; + return *instance; +} + +const ExtensionSpec& ExtensionSpec::DefaultInstance() { + static absl::NoDestructor instance; + return *instance; +} + +ExtensionSpec::ExtensionSpec(const ExtensionSpec& other) + : id_(other.id_), + affected_components_(other.affected_components_), + version_(other.version_ == nullptr + ? nullptr + : std::make_unique(*other.version_)) {} + +ExtensionSpec& ExtensionSpec::operator=(const ExtensionSpec& other) { + id_ = other.id_; + affected_components_ = other.affected_components_; + if (other.version_ != nullptr) { + version_ = std::make_unique(other.version()); + } else { + version_ = nullptr; + } + return *this; +} + +const TypeSpec& ListTypeSpec::elem_type() const { + if (elem_type_ != nullptr) { + return *elem_type_; + } + return DefaultTypeSpec(); +} + +bool ListTypeSpec::operator==(const ListTypeSpec& other) const { + return elem_type() == other.elem_type(); +} + +const TypeSpec& MapTypeSpec::key_type() const { + if (key_type_ != nullptr) { + return *key_type_; + } + return DefaultTypeSpec(); +} + +const TypeSpec& MapTypeSpec::value_type() const { + if (value_type_ != nullptr) { + return *value_type_; + } + return DefaultTypeSpec(); +} + +bool MapTypeSpec::operator==(const MapTypeSpec& other) const { + return key_type() == other.key_type() && value_type() == other.value_type(); +} + +const TypeSpec& FunctionTypeSpec::result_type() const { + if (result_type_ != nullptr) { + return *result_type_; + } + return DefaultTypeSpec(); +} + +bool FunctionTypeSpec::operator==(const FunctionTypeSpec& other) const { + return result_type() == other.result_type() && arg_types_ == other.arg_types_; +} + +const TypeSpec& TypeSpec::type() const { + auto* value = absl::get_if>(&type_kind_); + if (value != nullptr) { + if (*value != nullptr) return **value; + } + return DefaultTypeSpec(); +} + +TypeSpec::TypeSpec(const TypeSpec& other) + : type_kind_(CopyImpl(other.type_kind_)) {} + +TypeSpec& TypeSpec::operator=(const TypeSpec& other) { + type_kind_ = CopyImpl(other.type_kind_); + return *this; +} + +FunctionTypeSpec::FunctionTypeSpec(const FunctionTypeSpec& other) + : result_type_(std::make_unique(other.result_type())), + arg_types_(other.arg_types()) {} + +FunctionTypeSpec& FunctionTypeSpec::operator=(const FunctionTypeSpec& other) { + result_type_ = std::make_unique(other.result_type()); + arg_types_ = other.arg_types(); + return *this; +} + +std::string FormatTypeSpec(const TypeSpec& t) { + // Use a stack to avoid recursion. + // Probably overly defensive, but fuzzers will often notice the recursion + // and try to trigger it. + std::string out; + FormatStack seq; + seq.push_back(&t); + while (!seq.empty()) { + FormatIns ins = std::move(seq.back()); + seq.pop_back(); + if (std::holds_alternative(ins)) { + absl::StrAppend(&out, std::get(ins)); + continue; + } + ABSL_DCHECK(std::holds_alternative(ins)); + HandleFormatTypeSpec(*std::get(ins), seq, &out); + } + return out; +} + +} // namespace cel diff --git a/common/ast/metadata.h b/common/ast/metadata.h new file mode 100644 index 000000000..1a69b5b50 --- /dev/null +++ b/common/ast/metadata.h @@ -0,0 +1,920 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Type definitions for auxiliary structures in the AST. +// +// These are more direct equivalents to the public protobuf definitions. +// +// IWYU pragma: private, include "common/ast.h" +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +// An extension that was requested for the source expression. +class ExtensionSpec { + public: + // Version + class Version { + public: + Version() : major_(0), minor_(0) {} + Version(int64_t major, int64_t minor) : major_(major), minor_(minor) {} + + Version(const Version& other) = default; + Version(Version&& other) = default; + Version& operator=(const Version& other) = default; + Version& operator=(Version&& other) = default; + + static const Version& DefaultInstance(); + + // Major version changes indicate different required support level from + // the required components. + int64_t major() const { return major_; } + void set_major(int64_t val) { major_ = val; } + + // Minor version changes must not change the observed behavior from + // existing implementations, but may be provided informationally. + int64_t minor() const { return minor_; } + void set_minor(int64_t val) { minor_ = val; } + + bool operator==(const Version& other) const { + return major_ == other.major_ && minor_ == other.minor_; + } + + bool operator!=(const Version& other) const { return !operator==(other); } + + private: + int64_t major_; + int64_t minor_; + }; + + // CEL component specifier. + enum class Component { + // Unspecified, default. + kUnspecified, + // Parser. Converts a CEL string to an AST. + kParser, + // Type checker. Checks that references in an AST are defined and types + // agree. + kTypeChecker, + // Runtime. Evaluates a parsed and optionally checked CEL AST against a + // context. + kRuntime + }; + + static const ExtensionSpec& DefaultInstance(); + + ExtensionSpec() = default; + ExtensionSpec(std::string id, std::unique_ptr version, + std::vector affected_components) + : id_(std::move(id)), + affected_components_(std::move(affected_components)), + version_(std::move(version)) {} + + ExtensionSpec(const ExtensionSpec& other); + ExtensionSpec(ExtensionSpec&& other) = default; + ExtensionSpec& operator=(const ExtensionSpec& other); + ExtensionSpec& operator=(ExtensionSpec&& other) = default; + + // Identifier for the extension. Example: constant_folding + const std::string& id() const { return id_; } + void set_id(std::string id) { id_ = std::move(id); } + + // If set, the listed components must understand the extension for the + // expression to evaluate correctly. + // + // This field has set semantics, repeated values should be deduplicated. + const std::vector& affected_components() const { + return affected_components_; + } + + std::vector& mutable_affected_components() { + return affected_components_; + } + + // Version info. May be skipped if it isn't meaningful for the extension. + // (for example constant_folding might always be v0.0). + const Version& version() const { + if (version_ == nullptr) { + return Version::DefaultInstance(); + } + return *version_; + } + + Version& mutable_version() { + if (version_ == nullptr) { + version_ = std::make_unique(); + } + return *version_; + } + + void set_version(std::unique_ptr version) { + version_ = std::move(version); + } + + bool operator==(const ExtensionSpec& other) const { + return id_ == other.id_ && + affected_components_ == other.affected_components_ && + version() == other.version(); + } + + bool operator!=(const ExtensionSpec& other) const { + return !operator==(other); + } + + private: + std::string id_; + std::vector affected_components_; + std::unique_ptr version_; +}; + +// Source information collected at parse time. +class SourceInfo { + public: + SourceInfo() = default; + SourceInfo(std::string syntax_version, std::string location, + std::vector line_offsets, + absl::flat_hash_map positions, + absl::flat_hash_map macro_calls, + std::vector extensions) + : syntax_version_(std::move(syntax_version)), + location_(std::move(location)), + line_offsets_(std::move(line_offsets)), + positions_(std::move(positions)), + macro_calls_(std::move(macro_calls)), + extensions_(std::move(extensions)) {} + + SourceInfo(const SourceInfo& other) = default; + SourceInfo(SourceInfo&& other) = default; + SourceInfo& operator=(const SourceInfo& other) = default; + SourceInfo& operator=(SourceInfo&& other) = default; + + void set_syntax_version(std::string syntax_version) { + syntax_version_ = std::move(syntax_version); + } + + void set_location(std::string location) { location_ = std::move(location); } + + void set_line_offsets(std::vector line_offsets) { + line_offsets_ = std::move(line_offsets); + } + + void set_positions(absl::flat_hash_map positions) { + positions_ = std::move(positions); + } + + void set_macro_calls(absl::flat_hash_map macro_calls) { + macro_calls_ = std::move(macro_calls); + } + + const std::string& syntax_version() const { return syntax_version_; } + + const std::string& location() const { return location_; } + + const std::vector& line_offsets() const { return line_offsets_; } + + std::vector& mutable_line_offsets() { return line_offsets_; } + + const absl::flat_hash_map& positions() const { + return positions_; + } + + absl::flat_hash_map& mutable_positions() { + return positions_; + } + + const absl::flat_hash_map& macro_calls() const { + return macro_calls_; + } + + absl::flat_hash_map& mutable_macro_calls() { + return macro_calls_; + } + + bool operator==(const SourceInfo& other) const { + return syntax_version_ == other.syntax_version_ && + location_ == other.location_ && + line_offsets_ == other.line_offsets_ && + positions_ == other.positions_ && + macro_calls_ == other.macro_calls_ && + extensions_ == other.extensions_; + } + + bool operator!=(const SourceInfo& other) const { return !operator==(other); } + + const std::vector& extensions() const { return extensions_; } + + std::vector& mutable_extensions() { return extensions_; } + + private: + // The syntax version of the source, e.g. `cel1`. + std::string syntax_version_; + + // The location name. All position information attached to an expression is + // relative to this location. + // + // The location could be a file, UI element, or similar. For example, + // `acme/app/AnvilPolicy.cel`. + std::string location_; + + // Monotonically increasing list of code point offsets where newlines + // `\n` appear. + // + // The line number of a given position is the index `i` where for a given + // `id` the `line_offsets[i] < id_positions[id] < line_offsets[i+1]`. The + // column may be derivd from `id_positions[id] - line_offsets[i]`. + // + // TODO(uncreated-issue/14): clarify this documentation + std::vector line_offsets_; + + // A map from the parse node id (e.g. `Expr.id`) to the code point offset + // within source. + absl::flat_hash_map positions_; + + // A map from the parse node id where a macro replacement was made to the + // call `Expr` that resulted in a macro expansion. + // + // For example, `has(value.field)` is a function call that is replaced by a + // `test_only` field selection in the AST. Likewise, the call + // `list.exists(e, e > 10)` translates to a comprehension expression. The key + // in the map corresponds to the expression id of the expanded macro, and the + // value is the call `Expr` that was replaced. + absl::flat_hash_map macro_calls_; + + // A list of tags for extensions that were used while parsing or type checking + // the source expression. For example, optimizations that require special + // runtime support may be specified. + // + // These are used to check feature support between components in separate + // implementations. This can be used to either skip redundant work or + // report an error if the extension is unsupported. + std::vector extensions_; +}; + +// CEL primitive types. +enum class PrimitiveType { + // Unspecified type. + kPrimitiveTypeUnspecified = 0, + // Boolean type. + kBool = 1, + // Int64 type. + // + // Proto-based integer values are widened to int64. + kInt64 = 2, + // Uint64 type. + // + // Proto-based unsigned integer values are widened to uint64. + kUint64 = 3, + // Double type. + // + // Proto-based float values are widened to double values. + kDouble = 4, + // String type. + kString = 5, + // Bytes type. + kBytes = 6, +}; + +// Well-known protobuf types treated with first-class support in CEL. +// +// TODO(uncreated-issue/15): represent well-known via abstract types (or however) +// they will be named. +enum class WellKnownTypeSpec { + // Unspecified type. + kWellKnownTypeUnspecified = 0, + // Well-known protobuf.Any type. + // + // Any types are a polymorphic message type. During type-checking they are + // treated like `DYN` types, but at runtime they are resolved to a specific + // message type specified at evaluation time. + kAny = 1, + // Well-known protobuf.Timestamp type, internally referenced as `timestamp`. + kTimestamp = 2, + // Well-known protobuf.Duration type, internally referenced as `duration`. + kDuration = 3, +}; + +// forward declare for recursive types. +class TypeSpec; + +// List type with typed elements, e.g. `list`. +class ListTypeSpec { + public: + ListTypeSpec() = default; + + ListTypeSpec(const ListTypeSpec& rhs); + ListTypeSpec& operator=(const ListTypeSpec& rhs); + ListTypeSpec(ListTypeSpec&& rhs) = default; + ListTypeSpec& operator=(ListTypeSpec&& rhs) = default; + + explicit ListTypeSpec(std::unique_ptr elem_type); + + void set_elem_type(std::unique_ptr elem_type); + + bool has_elem_type() const { return elem_type_ != nullptr; } + + const TypeSpec& elem_type() const; + + TypeSpec& mutable_elem_type(); + + bool operator==(const ListTypeSpec& other) const; + + private: + std::unique_ptr elem_type_; +}; + +// Map type specifier with parameterized key and value types, e.g. +// `map`. +class MapTypeSpec { + public: + MapTypeSpec() = default; + MapTypeSpec(std::unique_ptr key_type, + std::unique_ptr value_type); + + MapTypeSpec(const MapTypeSpec& rhs); + MapTypeSpec& operator=(const MapTypeSpec& rhs); + MapTypeSpec(MapTypeSpec&& rhs) = default; + MapTypeSpec& operator=(MapTypeSpec&& rhs) = default; + + void set_key_type(std::unique_ptr key_type); + + void set_value_type(std::unique_ptr value_type); + + bool has_key_type() const { return key_type_ != nullptr; } + + bool has_value_type() const { return value_type_ != nullptr; } + + const TypeSpec& key_type() const; + + const TypeSpec& value_type() const; + + bool operator==(const MapTypeSpec& other) const; + + TypeSpec& mutable_key_type(); + + TypeSpec& mutable_value_type(); + + private: + // The type of the key. + std::unique_ptr key_type_; + + // The type of the value. + std::unique_ptr value_type_; +}; + +// Function type specifiers with result and arg types. +// +// NOTE: function type represents a lambda-style argument to another function. +// Supported through macros, but not yet a first-class concept in CEL. +class FunctionTypeSpec { + public: + FunctionTypeSpec() = default; + FunctionTypeSpec(std::unique_ptr result_type, + std::vector arg_types); + + FunctionTypeSpec(const FunctionTypeSpec& other); + FunctionTypeSpec& operator=(const FunctionTypeSpec& other); + FunctionTypeSpec(FunctionTypeSpec&&) = default; + FunctionTypeSpec& operator=(FunctionTypeSpec&&) = default; + + void set_result_type(std::unique_ptr result_type); + + void set_arg_types(std::vector arg_types); + + bool has_result_type() const { return result_type_ != nullptr; } + + const TypeSpec& result_type() const; + + TypeSpec& mutable_result_type(); + + const std::vector& arg_types() const { return arg_types_; } + + std::vector& mutable_arg_types() { return arg_types_; } + + bool operator==(const FunctionTypeSpec& other) const; + + private: + // Result type of the function. + std::unique_ptr result_type_; + + // Argument types of the function. + std::vector arg_types_; +}; + +// Application defined abstract type. +// +// Abstract types provide a name as an identifier for the application, and +// optionally one or more type parameters. +// +// For cel::Type representation, see OpaqueType. +class AbstractType { + public: + AbstractType() = default; + AbstractType(std::string name, std::vector parameter_types); + + void set_name(std::string name) { name_ = std::move(name); } + + void set_parameter_types(std::vector parameter_types); + + const std::string& name() const { return name_; } + + const std::vector& parameter_types() const { + return parameter_types_; + } + + std::vector& mutable_parameter_types() { return parameter_types_; } + + bool operator==(const AbstractType& other) const; + + private: + // The fully qualified name of this abstract type. + std::string name_; + + // Parameter types for this abstract type. + std::vector parameter_types_; +}; + +// Wrapper of a primitive type, e.g. `google.protobuf.Int64Value`. +class PrimitiveTypeWrapper { + public: + explicit PrimitiveTypeWrapper(PrimitiveType type) : type_(std::move(type)) {} + + void set_type(PrimitiveType type) { type_ = std::move(type); } + + const PrimitiveType& type() const { return type_; } + + PrimitiveType& mutable_type() { return type_; } + + bool operator==(const PrimitiveTypeWrapper& other) const { + return type_ == other.type_; + } + + private: + PrimitiveType type_; +}; + +// Protocol buffer message type specifier. +// +// The `message_type` string specifies the qualified message type name. For +// example, `google.plus.Profile`. This must be mapped to a google::protobuf::Descriptor +// for type checking. +class MessageTypeSpec { + public: + MessageTypeSpec() = default; + explicit MessageTypeSpec(std::string type) : type_(std::move(type)) {} + + void set_type(std::string type) { type_ = std::move(type); } + + const std::string& type() const { return type_; } + + bool operator==(const MessageTypeSpec& other) const { + return type_ == other.type_; + } + + private: + std::string type_; +}; + +// TypeSpec param type. +// +// The `type_param` string specifies the type parameter name, e.g. `list` +// would be a `list_type` whose element type was a `type_param` type +// named `E`. +class ParamTypeSpec { + public: + ParamTypeSpec() = default; + explicit ParamTypeSpec(std::string type) : type_(std::move(type)) {} + + void set_type(std::string type) { type_ = std::move(type); } + + const std::string& type() const { return type_; } + + bool operator==(const ParamTypeSpec& other) const { + return type_ == other.type_; + } + + private: + std::string type_; +}; + +// Error type specifier. +// +// During type-checking if an expression is an error, its type is propagated +// as the `ERROR` type. This permits the type-checker to discover other +// errors present in the expression. +enum class ErrorTypeSpec { kValue = 0 }; + +using UnsetTypeSpec = absl::monostate; + +struct DynTypeSpec {}; + +inline bool operator==(const DynTypeSpec&, const DynTypeSpec&) { return true; } +inline bool operator!=(const DynTypeSpec&, const DynTypeSpec&) { return false; } + +struct NullTypeSpec {}; +inline bool operator==(const NullTypeSpec&, const NullTypeSpec&) { + return true; +} +inline bool operator!=(const NullTypeSpec&, const NullTypeSpec&) { + return false; +} + +using TypeSpecKind = + absl::variant, ErrorTypeSpec, + AbstractType>; + +// Analogous to cel::expr::Type. +// Represents a CEL type. +// +// TODO(uncreated-issue/15): align with value.proto +class TypeSpec { + public: + TypeSpec() = default; + explicit TypeSpec(TypeSpecKind type_kind) + : type_kind_(std::move(type_kind)) {} + + TypeSpec(const TypeSpec& other); + TypeSpec& operator=(const TypeSpec& other); + TypeSpec(TypeSpec&&) = default; + TypeSpec& operator=(TypeSpec&&) = default; + + void set_type_kind(TypeSpecKind type_kind) { + type_kind_ = std::move(type_kind); + } + + const TypeSpecKind& type_kind() const { return type_kind_; } + + TypeSpecKind& mutable_type_kind() { return type_kind_; } + + bool is_specified() const { + return !absl::holds_alternative(type_kind_); + } + + bool has_dyn() const { + return absl::holds_alternative(type_kind_); + } + + bool has_null() const { + return absl::holds_alternative(type_kind_); + } + + bool has_primitive() const { + return absl::holds_alternative(type_kind_); + } + + bool has_wrapper() const { + return absl::holds_alternative(type_kind_); + } + + bool has_well_known() const { + return absl::holds_alternative(type_kind_); + } + + bool has_list_type() const { + return absl::holds_alternative(type_kind_); + } + + bool has_map_type() const { + return absl::holds_alternative(type_kind_); + } + + bool has_function() const { + return absl::holds_alternative(type_kind_); + } + + bool has_message_type() const { + return absl::holds_alternative(type_kind_); + } + + bool has_type_param() const { + return absl::holds_alternative(type_kind_); + } + + bool has_type() const { + return absl::holds_alternative>(type_kind_); + } + + bool has_error() const { + return absl::holds_alternative(type_kind_); + } + + bool has_abstract_type() const { + return absl::holds_alternative(type_kind_); + } + + NullTypeSpec null() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return {}; + } + + PrimitiveType primitive() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return PrimitiveType::kPrimitiveTypeUnspecified; + } + + PrimitiveType wrapper() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return value->type(); + } + return PrimitiveType::kPrimitiveTypeUnspecified; + } + + WellKnownTypeSpec well_known() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return WellKnownTypeSpec::kWellKnownTypeUnspecified; + } + + const ListTypeSpec& list_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const ListTypeSpec* default_list_type = new ListTypeSpec(); + return *default_list_type; + } + + const MapTypeSpec& map_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const MapTypeSpec* default_map_type = new MapTypeSpec(); + return *default_map_type; + } + + const FunctionTypeSpec& function() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const FunctionTypeSpec* default_function_type = + new FunctionTypeSpec(); + return *default_function_type; + } + + const MessageTypeSpec& message_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const MessageTypeSpec* default_message_type = new MessageTypeSpec(); + return *default_message_type; + } + + const ParamTypeSpec& type_param() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const ParamTypeSpec* default_param_type = new ParamTypeSpec(); + return *default_param_type; + } + + const TypeSpec& type() const; + + ErrorTypeSpec error_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return ErrorTypeSpec::kValue; + } + + const AbstractType& abstract_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const AbstractType* default_abstract_type = new AbstractType(); + return *default_abstract_type; + } + + bool operator==(const TypeSpec& other) const { + if (absl::holds_alternative>(type_kind_) && + absl::holds_alternative>(other.type_kind_)) { + const auto& self_type = absl::get>(type_kind_); + const auto& other_type = + absl::get>(other.type_kind_); + if (self_type == nullptr || other_type == nullptr) { + return self_type == other_type; + } + return *self_type == *other_type; + } + return type_kind_ == other.type_kind_; + } + + private: + TypeSpecKind type_kind_; +}; + +// Returns a string representation of the given TypeSpec. +std::string FormatTypeSpec(const TypeSpec& t); + +// Describes a resolved reference to a declaration. +class Reference { + public: + Reference() = default; + + Reference(std::string name, std::vector overload_id, + Constant value) + : name_(std::move(name)), + overload_id_(std::move(overload_id)), + value_(std::move(value)) {} + + Reference(const Reference& other) = default; + Reference& operator=(const Reference& other) = default; + Reference(Reference&&) = default; + Reference& operator=(Reference&&) = default; + + void set_name(std::string name) { name_ = std::move(name); } + + void set_overload_id(std::vector overload_id) { + overload_id_ = std::move(overload_id); + } + + void set_value(Constant value) { value_ = std::move(value); } + + const std::string& name() const { return name_; } + + const std::vector& overload_id() const { return overload_id_; } + + const Constant& value() const { + if (value_.has_value()) { + return value_.value(); + } + static const Constant* default_constant = new Constant; + return *default_constant; + } + + std::vector& mutable_overload_id() { return overload_id_; } + + Constant& mutable_value() { + if (!value_.has_value()) { + value_.emplace(); + } + return *value_; + } + + bool has_value() const { return value_.has_value(); } + + bool operator==(const Reference& other) const { + return name_ == other.name_ && overload_id_ == other.overload_id_ && + value() == other.value(); + } + + private: + // The fully qualified name of the declaration. + std::string name_; + // For references to functions, this is a list of `Overload.overload_id` + // values which match according to typing rules. + // + // If the list has more than one element, overload resolution among the + // presented candidates must happen at runtime because of dynamic types. The + // type checker attempts to narrow down this list as much as possible. + // + // Empty if this is not a reference to a [Decl.FunctionDecl][]. + std::vector overload_id_; + // For references to constants, this may contain the value of the + // constant if known at compile time. + absl::optional value_; +}; + +//////////////////////////////////////////////////////////////////////// +// Out-of-line method declarations +//////////////////////////////////////////////////////////////////////// + +inline ListTypeSpec::ListTypeSpec(const ListTypeSpec& rhs) + : elem_type_(std::make_unique(rhs.elem_type())) {} + +inline ListTypeSpec& ListTypeSpec::operator=(const ListTypeSpec& rhs) { + elem_type_ = std::make_unique(rhs.elem_type()); + return *this; +} + +inline ListTypeSpec::ListTypeSpec(std::unique_ptr elem_type) + : elem_type_(std::move(elem_type)) {} + +inline void ListTypeSpec::set_elem_type(std::unique_ptr elem_type) { + elem_type_ = std::move(elem_type); +} + +inline TypeSpec& ListTypeSpec::mutable_elem_type() { + if (elem_type_ == nullptr) { + elem_type_ = std::make_unique(); + } + return *elem_type_; +} + +inline MapTypeSpec::MapTypeSpec(std::unique_ptr key_type, + std::unique_ptr value_type) + : key_type_(std::move(key_type)), value_type_(std::move(value_type)) {} + +inline MapTypeSpec::MapTypeSpec(const MapTypeSpec& rhs) + : key_type_(std::make_unique(rhs.key_type())), + value_type_(std::make_unique(rhs.value_type())) {} + +inline MapTypeSpec& MapTypeSpec::operator=(const MapTypeSpec& rhs) { + key_type_ = std::make_unique(rhs.key_type()); + value_type_ = std::make_unique(rhs.value_type()); + return *this; +} + +inline void MapTypeSpec::set_key_type(std::unique_ptr key_type) { + key_type_ = std::move(key_type); +} + +inline void MapTypeSpec::set_value_type(std::unique_ptr value_type) { + value_type_ = std::move(value_type); +} + +inline TypeSpec& MapTypeSpec::mutable_key_type() { + if (key_type_ == nullptr) { + key_type_ = std::make_unique(); + } + return *key_type_; +} + +inline TypeSpec& MapTypeSpec::mutable_value_type() { + if (value_type_ == nullptr) { + value_type_ = std::make_unique(); + } + return *value_type_; +} + +inline void FunctionTypeSpec::set_result_type( + std::unique_ptr result_type) { + result_type_ = std::move(result_type); +} + +inline TypeSpec& FunctionTypeSpec::mutable_result_type() { + if (result_type_ == nullptr) { + result_type_ = std::make_unique(); + } + return *result_type_; +} + +//////////////////////////////////////////////////////////////////////// +// Implementation details +//////////////////////////////////////////////////////////////////////// + +inline FunctionTypeSpec::FunctionTypeSpec(std::unique_ptr result_type, + std::vector arg_types) + : result_type_(std::move(result_type)), arg_types_(std::move(arg_types)) {} + +inline void FunctionTypeSpec::set_arg_types(std::vector arg_types) { + arg_types_ = std::move(arg_types); +} + +inline AbstractType::AbstractType(std::string name, + std::vector parameter_types) + : name_(std::move(name)), parameter_types_(std::move(parameter_types)) {} + +inline void AbstractType::set_parameter_types( + std::vector parameter_types) { + parameter_types_ = std::move(parameter_types); +} + +inline bool AbstractType::operator==(const AbstractType& other) const { + return name_ == other.name_ && parameter_types_ == other.parameter_types_; +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ diff --git a/common/ast/metadata_test.cc b/common/ast/metadata_test.cc new file mode 100644 index 000000000..5553f4c8f --- /dev/null +++ b/common/ast/metadata_test.cc @@ -0,0 +1,299 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast/metadata.h" + +#include +#include +#include + +#include "absl/types/variant.h" +#include "common/expr.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; + +TEST(AstTest, ListTypeSpecMutableConstruction) { + ListTypeSpec type; + type.mutable_elem_type() = TypeSpec(PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.elem_type().type_kind()), + PrimitiveType::kBool); +} + +TEST(AstTest, MapTypeSpecMutableConstruction) { + MapTypeSpec type; + type.mutable_key_type() = TypeSpec(PrimitiveType::kBool); + type.mutable_value_type() = TypeSpec(PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.key_type().type_kind()), + PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.value_type().type_kind()), + PrimitiveType::kBool); +} + +TEST(AstTest, MapTypeSpecComparatorKeyType) { + MapTypeSpec type; + type.mutable_key_type() = TypeSpec(PrimitiveType::kBool); + EXPECT_FALSE(type == MapTypeSpec()); +} + +TEST(AstTest, MapTypeSpecComparatorValueType) { + MapTypeSpec type; + type.mutable_value_type() = TypeSpec(PrimitiveType::kBool); + EXPECT_FALSE(type == MapTypeSpec()); +} + +TEST(AstTest, FunctionTypeSpecMutableConstruction) { + FunctionTypeSpec type; + type.mutable_result_type() = TypeSpec(PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.result_type().type_kind()), + PrimitiveType::kBool); +} + +TEST(AstTest, FunctionTypeSpecComparatorArgTypes) { + FunctionTypeSpec type; + type.mutable_arg_types().emplace_back(TypeSpec()); + EXPECT_FALSE(type == FunctionTypeSpec()); +} + +TEST(AstTest, ListTypeSpecDefaults) { + EXPECT_EQ(ListTypeSpec().elem_type(), TypeSpec()); +} + +TEST(AstTest, MapTypeSpecDefaults) { + EXPECT_EQ(MapTypeSpec().key_type(), TypeSpec()); + EXPECT_EQ(MapTypeSpec().value_type(), TypeSpec()); +} + +TEST(AstTest, FunctionTypeSpecDefaults) { + EXPECT_EQ(FunctionTypeSpec().result_type(), TypeSpec()); +} + +TEST(AstTest, TypeDefaults) { + EXPECT_EQ(TypeSpec().null(), NullTypeSpec()); + EXPECT_EQ(TypeSpec().primitive(), PrimitiveType::kPrimitiveTypeUnspecified); + EXPECT_EQ(TypeSpec().wrapper(), PrimitiveType::kPrimitiveTypeUnspecified); + EXPECT_EQ(TypeSpec().well_known(), + WellKnownTypeSpec::kWellKnownTypeUnspecified); + EXPECT_EQ(TypeSpec().list_type(), ListTypeSpec()); + EXPECT_EQ(TypeSpec().map_type(), MapTypeSpec()); + EXPECT_EQ(TypeSpec().function(), FunctionTypeSpec()); + EXPECT_EQ(TypeSpec().message_type(), MessageTypeSpec()); + EXPECT_EQ(TypeSpec().type_param(), ParamTypeSpec()); + EXPECT_EQ(TypeSpec().type(), TypeSpec()); + EXPECT_EQ(TypeSpec().error_type(), ErrorTypeSpec()); + EXPECT_EQ(TypeSpec().abstract_type(), AbstractType()); +} + +TEST(AstTest, TypeComparatorTest) { + TypeSpec type; + type.set_type_kind(std::make_unique(PrimitiveType::kBool)); + + EXPECT_TRUE(type == + TypeSpec(std::make_unique(PrimitiveType::kBool))); + EXPECT_FALSE(type == TypeSpec(PrimitiveType::kBool)); + EXPECT_FALSE(type == TypeSpec(std::unique_ptr())); + EXPECT_FALSE(type == + TypeSpec(std::make_unique(PrimitiveType::kInt64))); +} + +TEST(AstTest, ExprMutableConstruction) { + Expr expr; + expr.mutable_const_expr().set_bool_value(true); + ASSERT_TRUE(expr.has_const_expr()); + EXPECT_TRUE(expr.const_expr().bool_value()); + expr.mutable_ident_expr().set_name("expr"); + ASSERT_TRUE(expr.has_ident_expr()); + EXPECT_FALSE(expr.has_const_expr()); + EXPECT_EQ(expr.ident_expr().name(), "expr"); + expr.mutable_select_expr().set_field("field"); + ASSERT_TRUE(expr.has_select_expr()); + EXPECT_FALSE(expr.has_ident_expr()); + EXPECT_EQ(expr.select_expr().field(), "field"); + expr.mutable_call_expr().set_function("function"); + ASSERT_TRUE(expr.has_call_expr()); + EXPECT_FALSE(expr.has_select_expr()); + EXPECT_EQ(expr.call_expr().function(), "function"); + expr.mutable_list_expr(); + EXPECT_TRUE(expr.has_list_expr()); + EXPECT_FALSE(expr.has_call_expr()); + expr.mutable_struct_expr().set_name("name"); + ASSERT_TRUE(expr.has_struct_expr()); + EXPECT_EQ(expr.struct_expr().name(), "name"); + EXPECT_FALSE(expr.has_list_expr()); + expr.mutable_comprehension_expr().set_accu_var("accu_var"); + ASSERT_TRUE(expr.has_comprehension_expr()); + EXPECT_FALSE(expr.has_list_expr()); + EXPECT_EQ(expr.comprehension_expr().accu_var(), "accu_var"); +} + +TEST(AstTest, ReferenceConstantDefaultValue) { + Reference reference; + EXPECT_EQ(reference.value(), Constant()); +} + +TEST(AstTest, TypeCopyable) { + TypeSpec type = TypeSpec(PrimitiveType::kBool); + TypeSpec type2 = type; + EXPECT_TRUE(type2.has_primitive()); + EXPECT_EQ(type2, type); + + type = + TypeSpec(ListTypeSpec(std::make_unique(PrimitiveType::kBool))); + type2 = type; + EXPECT_TRUE(type2.has_list_type()); + EXPECT_EQ(type2, type); + + type = + TypeSpec(MapTypeSpec(std::make_unique(PrimitiveType::kBool), + std::make_unique(PrimitiveType::kBool))); + type2 = type; + EXPECT_TRUE(type2.has_map_type()); + EXPECT_EQ(type2, type); + + type = TypeSpec( + FunctionTypeSpec(std::make_unique(PrimitiveType::kBool), {})); + type2 = type; + EXPECT_TRUE(type2.has_function()); + EXPECT_EQ(type2, type); + + type = TypeSpec(AbstractType("optional", {TypeSpec(PrimitiveType::kBool)})); + type2 = type; + EXPECT_TRUE(type2.has_abstract_type()); + EXPECT_EQ(type2, type); +} + +TEST(AstTest, TypeMoveable) { + TypeSpec type = TypeSpec(PrimitiveType::kBool); + TypeSpec type2 = type; + TypeSpec type3 = std::move(type); + EXPECT_TRUE(type2.has_primitive()); + EXPECT_EQ(type2, type3); + + type = + TypeSpec(ListTypeSpec(std::make_unique(PrimitiveType::kBool))); + type2 = type; + type3 = std::move(type); + EXPECT_TRUE(type2.has_list_type()); + EXPECT_EQ(type2, type3); + + type = + TypeSpec(MapTypeSpec(std::make_unique(PrimitiveType::kBool), + std::make_unique(PrimitiveType::kBool))); + type2 = type; + type3 = std::move(type); + EXPECT_TRUE(type2.has_map_type()); + EXPECT_EQ(type2, type3); + + type = TypeSpec( + FunctionTypeSpec(std::make_unique(PrimitiveType::kBool), {})); + type2 = type; + type3 = std::move(type); + EXPECT_TRUE(type2.has_function()); + EXPECT_EQ(type2, type3); + + type = TypeSpec(AbstractType("optional", {TypeSpec(PrimitiveType::kBool)})); + type2 = type; + type3 = std::move(type); + EXPECT_TRUE(type2.has_abstract_type()); + EXPECT_EQ(type2, type3); +} + +TEST(AstTest, NestedTypeKindCopyAssignable) { + ListTypeSpec list_type(std::make_unique(PrimitiveType::kBool)); + ListTypeSpec list_type2; + list_type2 = list_type; + + EXPECT_EQ(list_type2, list_type); + + MapTypeSpec map_type(std::make_unique(PrimitiveType::kBool), + std::make_unique(PrimitiveType::kBool)); + MapTypeSpec map_type2; + map_type2 = map_type; + + AbstractType abstract_type("abstract", {TypeSpec(PrimitiveType::kBool), + TypeSpec(PrimitiveType::kBool)}); + AbstractType abstract_type2; + abstract_type2 = abstract_type; + + EXPECT_EQ(abstract_type2, abstract_type); + + FunctionTypeSpec function_type( + std::make_unique(PrimitiveType::kBool), + {TypeSpec(PrimitiveType::kBool), TypeSpec(PrimitiveType::kBool)}); + FunctionTypeSpec function_type2; + function_type2 = function_type; + + EXPECT_EQ(function_type2, function_type); +} + +TEST(AstTest, ExtensionSupported) { + SourceInfo source_info; + + source_info.mutable_extensions().push_back( + ExtensionSpec("constant_folding", nullptr, {})); + + EXPECT_EQ(source_info.extensions()[0], + ExtensionSpec("constant_folding", nullptr, {})); +} + +TEST(AstTest, ExtensionSpecEquality) { + ExtensionSpec extension1("constant_folding", nullptr, {}); + + EXPECT_EQ(extension1, ExtensionSpec("constant_folding", nullptr, {})); + + EXPECT_NE(extension1, + ExtensionSpec("constant_folding", + std::make_unique(1, 0), {})); + EXPECT_NE(extension1, ExtensionSpec("constant_folding", nullptr, + {ExtensionSpec::Component::kRuntime})); + + EXPECT_EQ(extension1, + ExtensionSpec("constant_folding", + std::make_unique(0, 0), {})); +} + +TEST(AstTest, ExtensionCopyMove) { + ExtensionSpec a("constant_folding", nullptr, {}); + a.mutable_version().set_major(1); + a.mutable_version().set_minor(2); + a.mutable_affected_components().push_back(ExtensionSpec::Component::kRuntime); + + ExtensionSpec b(a); + + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 1); + EXPECT_EQ(b.version().minor(), 2); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + ExtensionSpec c(std::move(b)); + EXPECT_EQ(c, a); + + a.set_version(nullptr); + b = a; + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 0); + EXPECT_EQ(b.version().minor(), 0); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + c = std::move(b); + EXPECT_EQ(c, a); +} + +} // namespace +} // namespace cel diff --git a/common/ast/navigable_ast_internal.h b/common/ast/navigable_ast_internal.h new file mode 100644 index 000000000..6759212a1 --- /dev/null +++ b/common/ast/navigable_ast_internal.h @@ -0,0 +1,311 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/types/span.h" +#include "common/ast/navigable_ast_kinds.h" // IWYU pragma: keep + +namespace cel::common_internal { + +// Implementation for range used for traversals backed by an absl::Span. +// +// This is intended to abstract the metadata layout from clients using the +// traversal methods in navigable_expr.h +// +// RangeTraits provide type info needed to construct the span and adapt to the +// range element type. +template +class NavigableAstRange { + private: + using UnderlyingType = typename RangeTraits::UnderlyingType; + using PtrType = const UnderlyingType*; + using SpanType = absl::Span; + + public: + class Iterator { + public: + using difference_type = ptrdiff_t; + using value_type = decltype(RangeTraits::Adapt(*PtrType())); + using iterator_category = std::bidirectional_iterator_tag; + + Iterator() : ptr_(nullptr), span_() {} + Iterator(SpanType span, size_t i) : ptr_(span.data() + i), span_(span) {} + + value_type operator*() const { + ABSL_DCHECK(ptr_ != nullptr); + ABSL_DCHECK(span_.data() != nullptr); + ABSL_DCHECK_GE(ptr_, span_.data()); + ABSL_DCHECK_LT(ptr_, span_.data() + span_.size()); + return RangeTraits::Adapt(*ptr_); + } + + template + std::enable_if_t::value, + std::add_pointer_t>> + operator->() const { + return &operator*(); + } + + Iterator& operator++() { + ++ptr_; + return *this; + } + + Iterator operator++(int) { + Iterator tmp = *this; + ++ptr_; + return tmp; + } + + Iterator& operator--() { + --ptr_; + return *this; + } + + Iterator operator--(int) { + Iterator tmp = *this; + --ptr_; + return tmp; + } + + bool operator==(const Iterator& other) const { + return ptr_ == other.ptr_ && span_ == other.span_; + } + + bool operator!=(const Iterator& other) const { return !(*this == other); } + + private: + PtrType ptr_; + SpanType span_; + }; + + explicit NavigableAstRange(SpanType span) : span_(span) {} + + Iterator begin() const { return Iterator(span_, 0); } + Iterator end() const { return Iterator(span_, span_.size()); } + + explicit operator bool() const { return !span_.empty(); } + + private: + SpanType span_; +}; + +template +struct NavigableAstMetadata; + +// Internal implementation for data-structures handling cross-referencing nodes. +// +// This is exposed separately to allow building up the AST relationships +// without exposing too much mutable state on the client facing classes. +template +struct NavigableAstNodeData { + typename AstTraits::NodeType* parent; + const typename AstTraits::ExprType* expr; + ChildKind parent_relation; + NodeKind node_kind; + const NavigableAstMetadata* absl_nonnull metadata; + size_t index; + size_t tree_size; + size_t height; + int child_index; + std::vector children; +}; + +template +struct NavigableAstMetadata { + // The nodes in the AST in preorder. + // + // unique_ptr is used to guarantee pointer stability in the other tables. + std::vector> nodes; + std::vector postorder; + absl::flat_hash_map + id_to_node; + absl::flat_hash_map + expr_to_node; +}; + +template +struct PostorderTraits { + using UnderlyingType = const AstNode*; + + static const AstNode& Adapt(const AstNode* const node) { return *node; } +}; + +template +struct PreorderTraits { + using UnderlyingType = std::unique_ptr; + static const AstNode& Adapt(const std::unique_ptr& node) { + return *node; + } +}; + +// Base class for NavigableAstNode and NavigableProtoAstNode. +template +class NavigableAstNodeBase { + private: + using MetadataType = NavigableAstMetadata; + using NodeDataType = NavigableAstNodeData; + using Derived = typename AstTraits::NodeType; + using ExprType = typename AstTraits::ExprType; + + public: + using PreorderRange = NavigableAstRange>; + using PostorderRange = NavigableAstRange>; + + // The parent of this node or nullptr if it is a root. + const Derived* absl_nullable parent() const { return data_.parent; } + + const ExprType* absl_nonnull expr() const { return data_.expr; } + + // The index of this node in the parent's children. -1 if this is a root. + int child_index() const { return data_.child_index; } + + // The type of traversal from parent to this node. + ChildKind parent_relation() const { return data_.parent_relation; } + + // The type of this node, analogous to Expr::ExprKindCase. + NodeKind node_kind() const { return data_.node_kind; } + + // The number of nodes in the tree rooted at this node (including self). + size_t tree_size() const { return data_.tree_size; } + + // The height of this node in the tree (the number of descendants including + // self on the longest path). + size_t height() const { return data_.height; } + + absl::Span children() const { + return absl::MakeConstSpan(data_.children); + } + + // Range over the descendants of this node (including self) using preorder + // semantics. Each node is visited immediately before all of its descendants. + PreorderRange DescendantsPreorder() const { + return PreorderRange(absl::MakeConstSpan(data_.metadata->nodes) + .subspan(data_.index, data_.tree_size)); + } + + // Range over the descendants of this node (including self) using postorder + // semantics. Each node is visited immediately after all of its descendants. + PostorderRange DescendantsPostorder() const { + return PostorderRange(absl::MakeConstSpan(data_.metadata->postorder) + .subspan(data_.index, data_.tree_size)); + } + + private: + friend Derived; + + NavigableAstNodeBase() = default; + NavigableAstNodeBase(const NavigableAstNodeBase&) = delete; + NavigableAstNodeBase& operator=(const NavigableAstNodeBase&) = delete; + + protected: + NodeDataType data_; +}; + +// Shared implementation for NavigableAst and NavigableProtoAst. +// +// AstTraits provides type info for the derived classes that implement building +// the traversal metadata. It provides the following types: +// +// ExprType is the expression node type of the source AST. +// +// AstType is the subclass of NavigableAstBase for the implementation. +// +// NodeType is the subclass of NavigableAstNodeBase for the implementation. +template +class NavigableAstBase { + private: + using MetadataType = NavigableAstMetadata; + using Derived = typename AstTraits::AstType; + using NodeType = typename AstTraits::NodeType; + using ExprType = typename AstTraits::ExprType; + + public: + NavigableAstBase(const NavigableAstBase&) = delete; + NavigableAstBase& operator=(const NavigableAstBase&) = delete; + NavigableAstBase(NavigableAstBase&&) = default; + NavigableAstBase& operator=(NavigableAstBase&&) = default; + + // Return ptr to the AST node with id if present. Otherwise returns nullptr. + // + // If ids are non-unique, the first pre-order node encountered with id is + // returned. + const NodeType* absl_nullable FindId(int64_t id) const { + auto it = metadata_->id_to_node.find(id); + if (it == metadata_->id_to_node.end()) { + return nullptr; + } + return it->second; + } + + // Return ptr to the AST node representing the given Expr protobuf node. + const NodeType* absl_nullable FindExpr( + const ExprType* absl_nonnull expr) const { + auto it = metadata_->expr_to_node.find(expr); + if (it == metadata_->expr_to_node.end()) { + return nullptr; + } + return it->second; + } + + // The root of the AST. + const NodeType& Root() const { return *metadata_->nodes[0]; } + + // Check whether the source AST used unique IDs for each node. + // + // This is typically the case, but older versions of the parsers didn't + // guarantee uniqueness for nodes generated by some macros and ASTs modified + // outside of CEL's parse/type check may not have unique IDs. + bool IdsAreUnique() const { + return metadata_->id_to_node.size() == metadata_->nodes.size(); + } + + // Equality operators test for identity. They are intended to distinguish + // moved-from or uninitialized instances from initialized. + bool operator==(const NavigableAstBase& other) const { + return metadata_ == other.metadata_; + } + + bool operator!=(const NavigableAstBase& other) const { + return metadata_ != other.metadata_; + } + + // Return true if this instance is initialized. + explicit operator bool() const { return metadata_ != nullptr; } + + private: + friend Derived; + + NavigableAstBase() = default; + explicit NavigableAstBase(std::unique_ptr metadata) + : metadata_(std::move(metadata)) {} + + std::unique_ptr metadata_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_ diff --git a/common/ast/navigable_ast_internal_test.cc b/common/ast/navigable_ast_internal_test.cc new file mode 100644 index 000000000..c05d5afb7 --- /dev/null +++ b/common/ast/navigable_ast_internal_test.cc @@ -0,0 +1,91 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "common/ast/navigable_ast_internal.h" + +#include +#include + +#include "absl/base/casts.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "common/ast/navigable_ast_kinds.h" +#include "internal/testing.h" + +namespace cel::common_internal { +namespace { + +struct TestRangeTraits { + using UnderlyingType = int; + static double Adapt(const UnderlyingType& value) { + return static_cast(value) + 0.5; + } +}; + +TEST(NavigableAstRangeTest, BasicIteration) { + std::vector values{1, 2, 3}; + NavigableAstRange range(absl::MakeConstSpan(values)); + absl::Span span(values); + auto it = range.begin(); + EXPECT_EQ(*it, 1.5); + EXPECT_EQ(*++it, 2.5); + EXPECT_EQ(*++it, 3.5); + EXPECT_EQ(++it, range.end()); + EXPECT_EQ(*--it, 3.5); + EXPECT_EQ(*--it, 2.5); + EXPECT_EQ(*--it, 1.5); + EXPECT_EQ(it, range.begin()); +} + +TEST(NodeKind, Stringify) { + // Note: the specific values are not important or guaranteed to be stable, + // they are only intended to make test outputs clearer. + EXPECT_EQ(absl::StrCat(NodeKind::kConstant), "Constant"); + EXPECT_EQ(absl::StrCat(NodeKind::kIdent), "Ident"); + EXPECT_EQ(absl::StrCat(NodeKind::kSelect), "Select"); + EXPECT_EQ(absl::StrCat(NodeKind::kCall), "Call"); + EXPECT_EQ(absl::StrCat(NodeKind::kList), "List"); + EXPECT_EQ(absl::StrCat(NodeKind::kMap), "Map"); + EXPECT_EQ(absl::StrCat(NodeKind::kStruct), "Struct"); + EXPECT_EQ(absl::StrCat(NodeKind::kComprehension), "Comprehension"); + EXPECT_EQ(absl::StrCat(NodeKind::kUnspecified), "Unspecified"); + + EXPECT_EQ(absl::StrCat(absl::bit_cast(255)), + "Unknown NodeKind 255"); +} + +TEST(ChildKind, Stringify) { + // Note: the specific values are not important or guaranteed to be stable, + // they are only intended to make test outputs clearer. + EXPECT_EQ(absl::StrCat(ChildKind::kSelectOperand), "SelectOperand"); + EXPECT_EQ(absl::StrCat(ChildKind::kCallReceiver), "CallReceiver"); + EXPECT_EQ(absl::StrCat(ChildKind::kCallArg), "CallArg"); + EXPECT_EQ(absl::StrCat(ChildKind::kListElem), "ListElem"); + EXPECT_EQ(absl::StrCat(ChildKind::kMapKey), "MapKey"); + EXPECT_EQ(absl::StrCat(ChildKind::kMapValue), "MapValue"); + EXPECT_EQ(absl::StrCat(ChildKind::kStructValue), "StructValue"); + EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionRange), "ComprehensionRange"); + EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionInit), "ComprehensionInit"); + EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionCondition), + "ComprehensionCondition"); + EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionLoopStep), + "ComprehensionLoopStep"); + EXPECT_EQ(absl::StrCat(ChildKind::kComprensionResult), "ComprehensionResult"); + EXPECT_EQ(absl::StrCat(ChildKind::kUnspecified), "Unspecified"); + + EXPECT_EQ(absl::StrCat(absl::bit_cast(255)), + "Unknown ChildKind 255"); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/ast/navigable_ast_kinds.cc b/common/ast/navigable_ast_kinds.cc new file mode 100644 index 000000000..4ef2da731 --- /dev/null +++ b/common/ast/navigable_ast_kinds.cc @@ -0,0 +1,80 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "common/ast/navigable_ast_kinds.h" + +#include + +#include "absl/strings/str_cat.h" + +namespace cel { + +std::string ChildKindName(ChildKind kind) { + switch (kind) { + case ChildKind::kUnspecified: + return "Unspecified"; + case ChildKind::kSelectOperand: + return "SelectOperand"; + case ChildKind::kCallReceiver: + return "CallReceiver"; + case ChildKind::kCallArg: + return "CallArg"; + case ChildKind::kListElem: + return "ListElem"; + case ChildKind::kMapKey: + return "MapKey"; + case ChildKind::kMapValue: + return "MapValue"; + case ChildKind::kStructValue: + return "StructValue"; + case ChildKind::kComprehensionRange: + return "ComprehensionRange"; + case ChildKind::kComprehensionInit: + return "ComprehensionInit"; + case ChildKind::kComprehensionCondition: + return "ComprehensionCondition"; + case ChildKind::kComprehensionLoopStep: + return "ComprehensionLoopStep"; + case ChildKind::kComprensionResult: + return "ComprehensionResult"; + default: + return absl::StrCat("Unknown ChildKind ", static_cast(kind)); + } +} + +std::string NodeKindName(NodeKind kind) { + switch (kind) { + case NodeKind::kUnspecified: + return "Unspecified"; + case NodeKind::kConstant: + return "Constant"; + case NodeKind::kIdent: + return "Ident"; + case NodeKind::kSelect: + return "Select"; + case NodeKind::kCall: + return "Call"; + case NodeKind::kList: + return "List"; + case NodeKind::kMap: + return "Map"; + case NodeKind::kStruct: + return "Struct"; + case NodeKind::kComprehension: + return "Comprehension"; + default: + return absl::StrCat("Unknown NodeKind ", static_cast(kind)); + } +} + +} // namespace cel diff --git a/common/ast/navigable_ast_kinds.h b/common/ast/navigable_ast_kinds.h new file mode 100644 index 000000000..ac8c2d4be --- /dev/null +++ b/common/ast/navigable_ast_kinds.h @@ -0,0 +1,74 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// IWYU pragma: private +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_KINDS_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_KINDS_H_ + +#include + +#include "absl/strings/str_format.h" + +namespace cel { + +// The traversal relationship from parent to the given node in a NavigableAst. +enum class ChildKind { + kUnspecified, + kSelectOperand, + kCallReceiver, + kCallArg, + kListElem, + kMapKey, + kMapValue, + kStructValue, + kComprehensionRange, + kComprehensionInit, + kComprehensionCondition, + kComprehensionLoopStep, + kComprensionResult +}; + +// The type of the node in a NavigableAst. +enum class NodeKind { + kUnspecified, + kConstant, + kIdent, + kSelect, + kCall, + kList, + kMap, + kStruct, + kComprehension, +}; + +// Human readable ChildKind name. Provided for test readability -- do not depend +// on the specific values. +std::string ChildKindName(ChildKind kind); + +template +void AbslStringify(Sink& sink, ChildKind kind) { + absl::Format(&sink, "%s", ChildKindName(kind)); +} + +// Human readable NodeKind name. Provided for test readability -- do not depend +// on the specific values. +std::string NodeKindName(NodeKind kind); + +template +void AbslStringify(Sink& sink, NodeKind kind) { + absl::Format(&sink, "%s", NodeKindName(kind)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_KINDS_H_ diff --git a/common/ast/source_info_proto.cc b/common/ast/source_info_proto.cc new file mode 100644 index 000000000..ae1803fbb --- /dev/null +++ b/common/ast/source_info_proto.cc @@ -0,0 +1,90 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast/source_info_proto.h" + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/status/status.h" +#include "common/ast.h" +#include "common/ast/expr_proto.h" +#include "internal/status_macros.h" + +namespace cel::ast_internal { + +using ::cel::ast_internal::ExprToProto; + +using ExprPb = cel::expr::Expr; +using ParsedExprPb = cel::expr::ParsedExpr; +using CheckedExprPb = cel::expr::CheckedExpr; +using ExtensionPb = cel::expr::SourceInfo::Extension; + +absl::Status SourceInfoToProto(const cel::SourceInfo& source_info, + cel::expr::SourceInfo* out) { + cel::expr::SourceInfo& result = *out; + result.set_syntax_version(source_info.syntax_version()); + result.set_location(source_info.location()); + + for (int32_t line_offset : source_info.line_offsets()) { + result.add_line_offsets(line_offset); + } + + for (auto pos_iter = source_info.positions().begin(); + pos_iter != source_info.positions().end(); ++pos_iter) { + (*result.mutable_positions())[pos_iter->first] = pos_iter->second; + } + + for (auto macro_iter = source_info.macro_calls().begin(); + macro_iter != source_info.macro_calls().end(); ++macro_iter) { + ExprPb& dest_macro = (*result.mutable_macro_calls())[macro_iter->first]; + CEL_RETURN_IF_ERROR(ExprToProto(macro_iter->second, &dest_macro)); + } + + for (const auto& extension : source_info.extensions()) { + auto* extension_pb = result.add_extensions(); + extension_pb->set_id(extension.id()); + auto* version_pb = extension_pb->mutable_version(); + version_pb->set_major(extension.version().major()); + version_pb->set_minor(extension.version().minor()); + + for (auto component : extension.affected_components()) { + switch (component) { + case cel::ExtensionSpec::Component::kParser: + extension_pb->add_affected_components(ExtensionPb::COMPONENT_PARSER); + break; + case cel::ExtensionSpec::Component::kTypeChecker: + extension_pb->add_affected_components( + ExtensionPb::COMPONENT_TYPE_CHECKER); + break; + case cel::ExtensionSpec::Component::kRuntime: + extension_pb->add_affected_components(ExtensionPb::COMPONENT_RUNTIME); + break; + default: + extension_pb->add_affected_components( + ExtensionPb::COMPONENT_UNSPECIFIED); + break; + } + } + } + + return absl::OkStatus(); +} + +} // namespace cel::ast_internal diff --git a/common/ast/source_info_proto.h b/common/ast/source_info_proto.h new file mode 100644 index 000000000..c44bb2a73 --- /dev/null +++ b/common/ast/source_info_proto.h @@ -0,0 +1,32 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_SOURCE_INFO_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_SOURCE_INFO_PROTO_H_ + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "common/ast.h" + +namespace cel::ast_internal { + +// Conversion utility for the CEL-C++ source info representation to the protobuf +// representation. +absl::Status SourceInfoToProto(const SourceInfo& source_info, + cel::expr::SourceInfo* absl_nonnull out); + +} // namespace cel::ast_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_SOURCE_INFO_PROTO_H_ diff --git a/common/ast_proto.cc b/common/ast_proto.cc new file mode 100644 index 000000000..ee990f0e5 --- /dev/null +++ b/common/ast_proto.cc @@ -0,0 +1,547 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast_proto.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/variant.h" +#include "common/ast.h" +#include "common/ast/constant_proto.h" +#include "common/ast/expr_proto.h" +#include "common/ast/source_info_proto.h" +#include "common/constant.h" +#include "common/expr.h" +#include "internal/status_macros.h" + +namespace cel { +namespace { + +using ::cel::ast_internal::ConstantFromProto; +using ::cel::ast_internal::ConstantToProto; +using ::cel::ast_internal::ExprFromProto; +using ::cel::ast_internal::ExprToProto; + +using ExprPb = cel::expr::Expr; +using ParsedExprPb = cel::expr::ParsedExpr; +using CheckedExprPb = cel::expr::CheckedExpr; +using SourceInfoPb = cel::expr::SourceInfo; +using ExtensionPb = cel::expr::SourceInfo::Extension; +using ReferencePb = cel::expr::Reference; +using TypePb = cel::expr::Type; +using ExtensionPb = cel::expr::SourceInfo::Extension; + +absl::StatusOr ExprValueFromProto(const ExprPb& expr) { + Expr result; + CEL_RETURN_IF_ERROR(ExprFromProto(expr, result)); + return result; +} + +absl::StatusOr ConvertProtoSourceInfoToNative( + const cel::expr::SourceInfo& source_info) { + absl::flat_hash_map macro_calls; + for (const auto& pair : source_info.macro_calls()) { + auto native_expr = ExprValueFromProto(pair.second); + if (!native_expr.ok()) { + return native_expr.status(); + } + macro_calls.emplace(pair.first, *(std::move(native_expr))); + } + std::vector extensions; + extensions.reserve(source_info.extensions_size()); + for (const auto& extension : source_info.extensions()) { + std::vector components; + components.reserve(extension.affected_components().size()); + for (const auto& component : extension.affected_components()) { + switch (component) { + case ExtensionPb::COMPONENT_PARSER: + components.push_back(ExtensionSpec::Component::kParser); + break; + case ExtensionPb::COMPONENT_TYPE_CHECKER: + components.push_back(ExtensionSpec::Component::kTypeChecker); + break; + case ExtensionPb::COMPONENT_RUNTIME: + components.push_back(ExtensionSpec::Component::kRuntime); + break; + default: + components.push_back(ExtensionSpec::Component::kUnspecified); + break; + } + } + extensions.push_back(ExtensionSpec( + extension.id(), + std::make_unique(extension.version().major(), + extension.version().minor()), + std::move(components))); + } + return SourceInfo( + source_info.syntax_version(), source_info.location(), + std::vector(source_info.line_offsets().begin(), + source_info.line_offsets().end()), + absl::flat_hash_map(source_info.positions().begin(), + source_info.positions().end()), + std::move(macro_calls), std::move(extensions)); +} + +absl::StatusOr ConvertProtoTypeToNative( + const cel::expr::Type& type); + +absl::StatusOr ToNative( + cel::expr::Type::PrimitiveType primitive_type) { + switch (primitive_type) { + case cel::expr::Type::PRIMITIVE_TYPE_UNSPECIFIED: + return PrimitiveType::kPrimitiveTypeUnspecified; + case cel::expr::Type::BOOL: + return PrimitiveType::kBool; + case cel::expr::Type::INT64: + return PrimitiveType::kInt64; + case cel::expr::Type::UINT64: + return PrimitiveType::kUint64; + case cel::expr::Type::DOUBLE: + return PrimitiveType::kDouble; + case cel::expr::Type::STRING: + return PrimitiveType::kString; + case cel::expr::Type::BYTES: + return PrimitiveType::kBytes; + default: + return absl::InvalidArgumentError( + "Illegal type specified for " + "cel::expr::Type::PrimitiveType."); + } +} + +absl::StatusOr ToNative( + cel::expr::Type::WellKnownType well_known_type) { + switch (well_known_type) { + case cel::expr::Type::WELL_KNOWN_TYPE_UNSPECIFIED: + return WellKnownTypeSpec::kWellKnownTypeUnspecified; + case cel::expr::Type::ANY: + return WellKnownTypeSpec::kAny; + case cel::expr::Type::TIMESTAMP: + return WellKnownTypeSpec::kTimestamp; + case cel::expr::Type::DURATION: + return WellKnownTypeSpec::kDuration; + default: + return absl::InvalidArgumentError( + "Illegal type specified for " + "cel::expr::Type::WellKnownType."); + } +} + +absl::StatusOr ToNative( + const cel::expr::Type::ListType& list_type) { + auto native_elem_type = ConvertProtoTypeToNative(list_type.elem_type()); + if (!native_elem_type.ok()) { + return native_elem_type.status(); + } + return ListTypeSpec( + std::make_unique(*(std::move(native_elem_type)))); +} + +absl::StatusOr ToNative( + const cel::expr::Type::MapType& map_type) { + auto native_key_type = ConvertProtoTypeToNative(map_type.key_type()); + if (!native_key_type.ok()) { + return native_key_type.status(); + } + auto native_value_type = ConvertProtoTypeToNative(map_type.value_type()); + if (!native_value_type.ok()) { + return native_value_type.status(); + } + return MapTypeSpec( + std::make_unique(*(std::move(native_key_type))), + std::make_unique(*(std::move(native_value_type)))); +} + +absl::StatusOr ToNative( + const cel::expr::Type::FunctionType& function_type) { + std::vector arg_types; + arg_types.reserve(function_type.arg_types_size()); + for (const auto& arg_type : function_type.arg_types()) { + auto native_arg = ConvertProtoTypeToNative(arg_type); + if (!native_arg.ok()) { + return native_arg.status(); + } + arg_types.emplace_back(*(std::move(native_arg))); + } + auto native_result = ConvertProtoTypeToNative(function_type.result_type()); + if (!native_result.ok()) { + return native_result.status(); + } + return FunctionTypeSpec( + std::make_unique(*(std::move(native_result))), + std::move(arg_types)); +} + +absl::StatusOr ToNative( + const cel::expr::Type::AbstractType& abstract_type) { + std::vector parameter_types; + for (const auto& parameter_type : abstract_type.parameter_types()) { + auto native_parameter_type = ConvertProtoTypeToNative(parameter_type); + if (!native_parameter_type.ok()) { + return native_parameter_type.status(); + } + parameter_types.emplace_back(*(std::move(native_parameter_type))); + } + return AbstractType(abstract_type.name(), std::move(parameter_types)); +} + +absl::StatusOr ConvertProtoTypeToNative( + const cel::expr::Type& type) { + switch (type.type_kind_case()) { + case cel::expr::Type::kDyn: + return TypeSpec(DynTypeSpec()); + case cel::expr::Type::kNull: + return TypeSpec(NullTypeSpec()); + case cel::expr::Type::kPrimitive: { + auto native_primitive = ToNative(type.primitive()); + if (!native_primitive.ok()) { + return native_primitive.status(); + } + return TypeSpec(*(std::move(native_primitive))); + } + case cel::expr::Type::kWrapper: { + auto native_wrapper = ToNative(type.wrapper()); + if (!native_wrapper.ok()) { + return native_wrapper.status(); + } + return TypeSpec(PrimitiveTypeWrapper(*(std::move(native_wrapper)))); + } + case cel::expr::Type::kWellKnown: { + auto native_well_known = ToNative(type.well_known()); + if (!native_well_known.ok()) { + return native_well_known.status(); + } + return TypeSpec(*std::move(native_well_known)); + } + case cel::expr::Type::kListType: { + auto native_list_type = ToNative(type.list_type()); + if (!native_list_type.ok()) { + return native_list_type.status(); + } + return TypeSpec(*(std::move(native_list_type))); + } + case cel::expr::Type::kMapType: { + auto native_map_type = ToNative(type.map_type()); + if (!native_map_type.ok()) { + return native_map_type.status(); + } + return TypeSpec(*(std::move(native_map_type))); + } + case cel::expr::Type::kFunction: { + auto native_function = ToNative(type.function()); + if (!native_function.ok()) { + return native_function.status(); + } + return TypeSpec(*(std::move(native_function))); + } + case cel::expr::Type::kMessageType: + return TypeSpec(MessageTypeSpec(type.message_type())); + case cel::expr::Type::kTypeParam: + return TypeSpec(ParamTypeSpec(type.type_param())); + case cel::expr::Type::kType: { + if (type.type().type_kind_case() == + cel::expr::Type::TypeKindCase::TYPE_KIND_NOT_SET) { + return TypeSpec(std::unique_ptr()); + } + auto native_type = ConvertProtoTypeToNative(type.type()); + if (!native_type.ok()) { + return native_type.status(); + } + return TypeSpec(std::make_unique(*std::move(native_type))); + } + case cel::expr::Type::kError: + return TypeSpec(ErrorTypeSpec::kValue); + case cel::expr::Type::kAbstractType: { + auto native_abstract = ToNative(type.abstract_type()); + if (!native_abstract.ok()) { + return native_abstract.status(); + } + return TypeSpec(*(std::move(native_abstract))); + } + case cel::expr::Type::TYPE_KIND_NOT_SET: + return TypeSpec(UnsetTypeSpec()); + default: + return absl::InvalidArgumentError( + "Illegal type specified for cel::expr::Type."); + } +} + +absl::StatusOr ConvertProtoReferenceToNative( + const cel::expr::Reference& reference) { + Reference ret_val; + ret_val.set_name(reference.name()); + ret_val.mutable_overload_id().reserve(reference.overload_id_size()); + for (const auto& elem : reference.overload_id()) { + ret_val.mutable_overload_id().emplace_back(elem); + } + if (reference.has_value()) { + CEL_RETURN_IF_ERROR( + ConstantFromProto(reference.value(), ret_val.mutable_value())); + } + return ret_val; +} + +absl::StatusOr ReferenceToProto(const Reference& reference) { + ReferencePb result; + + result.set_name(reference.name()); + + for (const auto& overload_id : reference.overload_id()) { + result.add_overload_id(overload_id); + } + + if (reference.has_value()) { + CEL_RETURN_IF_ERROR( + ConstantToProto(reference.value(), result.mutable_value())); + } + + return result; +} + +absl::Status TypeToProto(const TypeSpec& type, TypePb* result); + +struct TypeKindToProtoVisitor { + absl::Status operator()(PrimitiveType primitive) { + switch (primitive) { + case PrimitiveType::kPrimitiveTypeUnspecified: + result->set_primitive(TypePb::PRIMITIVE_TYPE_UNSPECIFIED); + return absl::OkStatus(); + case PrimitiveType::kBool: + result->set_primitive(TypePb::BOOL); + return absl::OkStatus(); + case PrimitiveType::kInt64: + result->set_primitive(TypePb::INT64); + return absl::OkStatus(); + case PrimitiveType::kUint64: + result->set_primitive(TypePb::UINT64); + return absl::OkStatus(); + case PrimitiveType::kDouble: + result->set_primitive(TypePb::DOUBLE); + return absl::OkStatus(); + case PrimitiveType::kString: + result->set_primitive(TypePb::STRING); + return absl::OkStatus(); + case PrimitiveType::kBytes: + result->set_primitive(TypePb::BYTES); + return absl::OkStatus(); + default: + break; + } + return absl::InvalidArgumentError("Unsupported primitive type"); + } + + absl::Status operator()(PrimitiveTypeWrapper wrapper) { + CEL_RETURN_IF_ERROR(this->operator()(wrapper.type())); + auto wrapped = result->primitive(); + result->set_wrapper(wrapped); + return absl::OkStatus(); + } + + absl::Status operator()(UnsetTypeSpec) { + result->clear_type_kind(); + return absl::OkStatus(); + } + + absl::Status operator()(DynTypeSpec) { + result->mutable_dyn(); + return absl::OkStatus(); + } + + absl::Status operator()(ErrorTypeSpec) { + result->mutable_error(); + return absl::OkStatus(); + } + + absl::Status operator()(NullTypeSpec) { + result->set_null(google::protobuf::NULL_VALUE); + return absl::OkStatus(); + } + + absl::Status operator()(const ListTypeSpec& list_type) { + return TypeToProto(list_type.elem_type(), + result->mutable_list_type()->mutable_elem_type()); + } + + absl::Status operator()(const MapTypeSpec& map_type) { + CEL_RETURN_IF_ERROR(TypeToProto( + map_type.key_type(), result->mutable_map_type()->mutable_key_type())); + return TypeToProto(map_type.value_type(), + result->mutable_map_type()->mutable_value_type()); + } + + absl::Status operator()(const MessageTypeSpec& message_type) { + result->set_message_type(message_type.type()); + return absl::OkStatus(); + } + + absl::Status operator()(const WellKnownTypeSpec& well_known_type) { + switch (well_known_type) { + case WellKnownTypeSpec::kWellKnownTypeUnspecified: + result->set_well_known(TypePb::WELL_KNOWN_TYPE_UNSPECIFIED); + return absl::OkStatus(); + case WellKnownTypeSpec::kAny: + result->set_well_known(TypePb::ANY); + return absl::OkStatus(); + + case WellKnownTypeSpec::kDuration: + result->set_well_known(TypePb::DURATION); + return absl::OkStatus(); + case WellKnownTypeSpec::kTimestamp: + result->set_well_known(TypePb::TIMESTAMP); + return absl::OkStatus(); + default: + break; + } + return absl::InvalidArgumentError("Unsupported well-known type"); + } + + absl::Status operator()(const FunctionTypeSpec& function_type) { + CEL_RETURN_IF_ERROR( + TypeToProto(function_type.result_type(), + result->mutable_function()->mutable_result_type())); + + for (const TypeSpec& arg_type : function_type.arg_types()) { + CEL_RETURN_IF_ERROR( + TypeToProto(arg_type, result->mutable_function()->add_arg_types())); + } + return absl::OkStatus(); + } + + absl::Status operator()(const AbstractType& type) { + auto* abstract_type_pb = result->mutable_abstract_type(); + abstract_type_pb->set_name(type.name()); + for (const TypeSpec& type_param : type.parameter_types()) { + CEL_RETURN_IF_ERROR( + TypeToProto(type_param, abstract_type_pb->add_parameter_types())); + } + return absl::OkStatus(); + } + + absl::Status operator()(const std::unique_ptr& type_type) { + return TypeToProto((type_type != nullptr) ? *type_type : TypeSpec(), + result->mutable_type()); + } + + absl::Status operator()(const ParamTypeSpec& param_type) { + result->set_type_param(param_type.type()); + return absl::OkStatus(); + } + + TypePb* result; +}; + +absl::Status TypeToProto(const TypeSpec& type, TypePb* result) { + return absl::visit(TypeKindToProtoVisitor{result}, type.type_kind()); +} + +} // namespace + +absl::StatusOr> CreateAstFromParsedExpr( + const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info) { + CEL_ASSIGN_OR_RETURN(auto runtime_expr, ExprValueFromProto(expr)); + SourceInfo runtime_source_info; + if (source_info != nullptr) { + CEL_ASSIGN_OR_RETURN(runtime_source_info, + ConvertProtoSourceInfoToNative(*source_info)); + } + return std::make_unique(std::move(runtime_expr), + std::move(runtime_source_info)); +} + +absl::StatusOr> CreateAstFromParsedExpr( + const ParsedExprPb& parsed_expr) { + return CreateAstFromParsedExpr(parsed_expr.expr(), + &parsed_expr.source_info()); +} + +absl::Status AstToParsedExpr(const Ast& ast, + cel::expr::ParsedExpr* absl_nonnull out) { + ParsedExprPb& parsed_expr = *out; + CEL_RETURN_IF_ERROR(ExprToProto(ast.root_expr(), parsed_expr.mutable_expr())); + CEL_RETURN_IF_ERROR(ast_internal::SourceInfoToProto( + ast.source_info(), parsed_expr.mutable_source_info())); + + return absl::OkStatus(); +} + +absl::StatusOr> CreateAstFromCheckedExpr( + const CheckedExprPb& checked_expr) { + CEL_ASSIGN_OR_RETURN(Expr expr, ExprValueFromProto(checked_expr.expr())); + CEL_ASSIGN_OR_RETURN(SourceInfo source_info, ConvertProtoSourceInfoToNative( + checked_expr.source_info())); + + Ast::ReferenceMap reference_map; + for (const auto& pair : checked_expr.reference_map()) { + auto native_reference = ConvertProtoReferenceToNative(pair.second); + if (!native_reference.ok()) { + return native_reference.status(); + } + reference_map.emplace(pair.first, *(std::move(native_reference))); + } + Ast::TypeMap type_map; + for (const auto& pair : checked_expr.type_map()) { + auto native_type = ConvertProtoTypeToNative(pair.second); + if (!native_type.ok()) { + return native_type.status(); + } + type_map.emplace(pair.first, *(std::move(native_type))); + } + + return std::make_unique(std::move(expr), std::move(source_info), + std::move(reference_map), std::move(type_map), + checked_expr.expr_version()); +} + +absl::Status AstToCheckedExpr( + const Ast& ast, cel::expr::CheckedExpr* absl_nonnull out) { + if (!ast.is_checked()) { + return absl::InvalidArgumentError("AST is not type-checked"); + } + CheckedExprPb& checked_expr = *out; + checked_expr.set_expr_version(ast.expr_version()); + CEL_RETURN_IF_ERROR( + ExprToProto(ast.root_expr(), checked_expr.mutable_expr())); + CEL_RETURN_IF_ERROR(ast_internal::SourceInfoToProto( + ast.source_info(), checked_expr.mutable_source_info())); + for (auto it = ast.reference_map().begin(); it != ast.reference_map().end(); + ++it) { + ReferencePb& dest_reference = + (*checked_expr.mutable_reference_map())[it->first]; + CEL_ASSIGN_OR_RETURN(dest_reference, ReferenceToProto(it->second)); + } + + for (auto it = ast.type_map().begin(); it != ast.type_map().end(); ++it) { + TypePb& dest_type = (*checked_expr.mutable_type_map())[it->first]; + CEL_RETURN_IF_ERROR(TypeToProto(it->second, &dest_type)); + } + + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/ast_proto.h b/common/ast_proto.h new file mode 100644 index 000000000..e8dce81c3 --- /dev/null +++ b/common/ast_proto.h @@ -0,0 +1,52 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_PROTO_H_ + +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/ast.h" + +namespace cel { + +// Creates a runtime AST from a parsed-only protobuf AST. +// May return a non-ok Status if the AST is malformed (e.g. unset required +// fields). +absl::StatusOr> CreateAstFromParsedExpr( + const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info = nullptr); +absl::StatusOr> CreateAstFromParsedExpr( + const cel::expr::ParsedExpr& parsed_expr); + +absl::Status AstToParsedExpr(const Ast& ast, + cel::expr::ParsedExpr* absl_nonnull out); + +// Creates a runtime AST from a checked protobuf AST. +// May return a non-ok Status if the AST is malformed (e.g. unset required +// fields). +absl::StatusOr> CreateAstFromCheckedExpr( + const cel::expr::CheckedExpr& checked_expr); + +absl::Status AstToCheckedExpr(const Ast& ast, + cel::expr::CheckedExpr* absl_nonnull out); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_PROTO_H_ diff --git a/common/ast_proto_test.cc b/common/ast_proto_test.cc new file mode 100644 index 000000000..ddaa4191a --- /dev/null +++ b/common/ast_proto_test.cc @@ -0,0 +1,959 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "common/ast_proto.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/comprehensions_v2.h" +#include "internal/proto_matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::PrimitiveType; +using ::cel::WellKnownTypeSpec; +using ::cel::internal::test::EqualsProto; +using ::cel::expr::CheckedExpr; +using ::cel::expr::ParsedExpr; +using ::testing::HasSubstr; + +using TypePb = cel::expr::Type; + +absl::StatusOr ConvertProtoTypeToNative( + const cel::expr::Type& type) { + CheckedExpr checked_expr; + checked_expr.mutable_expr()->mutable_ident_expr()->set_name("foo"); + + (*checked_expr.mutable_type_map())[1] = type; + + CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromCheckedExpr(checked_expr)); + + const auto& type_map = ast->type_map(); + auto iter = type_map.find(1); + if (iter != type_map.end()) { + return iter->second; + } + return absl::InternalError("conversion failed but reported success"); +} + +TEST(AstConvertersTest, PrimitiveTypeUnspecifiedToNative) { + cel::expr::Type type; + type.set_primitive(cel::expr::Type::PRIMITIVE_TYPE_UNSPECIFIED); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kPrimitiveTypeUnspecified); +} + +TEST(AstConvertersTest, PrimitiveTypeBoolToNative) { + cel::expr::Type type; + type.set_primitive(cel::expr::Type::BOOL); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kBool); +} + +TEST(AstConvertersTest, PrimitiveTypeInt64ToNative) { + cel::expr::Type type; + type.set_primitive(cel::expr::Type::INT64); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kInt64); +} + +TEST(AstConvertersTest, PrimitiveTypeUint64ToNative) { + cel::expr::Type type; + type.set_primitive(cel::expr::Type::UINT64); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kUint64); +} + +TEST(AstConvertersTest, PrimitiveTypeDoubleToNative) { + cel::expr::Type type; + type.set_primitive(cel::expr::Type::DOUBLE); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kDouble); +} + +TEST(AstConvertersTest, PrimitiveTypeStringToNative) { + cel::expr::Type type; + type.set_primitive(cel::expr::Type::STRING); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kString); +} + +TEST(AstConvertersTest, PrimitiveTypeBytesToNative) { + cel::expr::Type type; + type.set_primitive(cel::expr::Type::BYTES); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kBytes); +} + +TEST(AstConvertersTest, PrimitiveTypeError) { + cel::expr::Type type; + type.set_primitive(::cel::expr::Type_PrimitiveType(7)); + + auto native_type = ConvertProtoTypeToNative(type); + + EXPECT_EQ(native_type.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(native_type.status().message(), + ::testing::HasSubstr("Illegal type specified for " + "cel::expr::Type::PrimitiveType.")); +} + +TEST(AstConvertersTest, WellKnownTypeUnspecifiedToNative) { + cel::expr::Type type; + type.set_well_known(cel::expr::Type::WELL_KNOWN_TYPE_UNSPECIFIED); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), + WellKnownTypeSpec::kWellKnownTypeUnspecified); +} + +TEST(AstConvertersTest, WellKnownTypeAnyToNative) { + cel::expr::Type type; + type.set_well_known(cel::expr::Type::ANY); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), WellKnownTypeSpec::kAny); +} + +TEST(AstConvertersTest, WellKnownTypeTimestampToNative) { + cel::expr::Type type; + type.set_well_known(cel::expr::Type::TIMESTAMP); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), WellKnownTypeSpec::kTimestamp); +} + +TEST(AstConvertersTest, WellKnownTypeDuraionToNative) { + cel::expr::Type type; + type.set_well_known(cel::expr::Type::DURATION); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), WellKnownTypeSpec::kDuration); +} + +TEST(AstConvertersTest, WellKnownTypeError) { + cel::expr::Type type; + type.set_well_known(::cel::expr::Type_WellKnownType(4)); + + auto native_type = ConvertProtoTypeToNative(type); + + EXPECT_EQ(native_type.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(native_type.status().message(), + ::testing::HasSubstr("Illegal type specified for " + "cel::expr::Type::WellKnownType.")); +} + +TEST(AstConvertersTest, ListTypeToNative) { + cel::expr::Type type; + type.mutable_list_type()->mutable_elem_type()->set_primitive( + cel::expr::Type::BOOL); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_list_type()); + auto& native_list_type = native_type->list_type(); + ASSERT_TRUE(native_list_type.elem_type().has_primitive()); + EXPECT_EQ(native_list_type.elem_type().primitive(), PrimitiveType::kBool); +} + +TEST(AstConvertersTest, MapTypeToNative) { + cel::expr::Type type; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + map_type { + key_type { primitive: BOOL } + value_type { primitive: DOUBLE } + } + )pb", + &type)); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_map_type()); + auto& native_map_type = native_type->map_type(); + ASSERT_TRUE(native_map_type.key_type().has_primitive()); + EXPECT_EQ(native_map_type.key_type().primitive(), PrimitiveType::kBool); + ASSERT_TRUE(native_map_type.value_type().has_primitive()); + EXPECT_EQ(native_map_type.value_type().primitive(), PrimitiveType::kDouble); +} + +TEST(AstConvertersTest, FunctionTypeToNative) { + cel::expr::Type type; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + function { + result_type { primitive: BOOL } + arg_types { primitive: DOUBLE } + arg_types { primitive: STRING } + } + )pb", + &type)); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_function()); + auto& native_function_type = native_type->function(); + ASSERT_TRUE(native_function_type.result_type().has_primitive()); + EXPECT_EQ(native_function_type.result_type().primitive(), + PrimitiveType::kBool); + ASSERT_TRUE(native_function_type.arg_types().at(0).has_primitive()); + EXPECT_EQ(native_function_type.arg_types().at(0).primitive(), + PrimitiveType::kDouble); + ASSERT_TRUE(native_function_type.arg_types().at(1).has_primitive()); + EXPECT_EQ(native_function_type.arg_types().at(1).primitive(), + PrimitiveType::kString); +} + +TEST(AstConvertersTest, AbstractTypeToNative) { + cel::expr::Type type; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + abstract_type { + name: "name" + parameter_types { primitive: DOUBLE } + parameter_types { primitive: STRING } + } + )pb", + &type)); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_abstract_type()); + auto& native_abstract_type = native_type->abstract_type(); + EXPECT_EQ(native_abstract_type.name(), "name"); + ASSERT_TRUE(native_abstract_type.parameter_types().at(0).has_primitive()); + EXPECT_EQ(native_abstract_type.parameter_types().at(0).primitive(), + PrimitiveType::kDouble); + ASSERT_TRUE(native_abstract_type.parameter_types().at(1).has_primitive()); + EXPECT_EQ(native_abstract_type.parameter_types().at(1).primitive(), + PrimitiveType::kString); +} + +TEST(AstConvertersTest, DynamicTypeToNative) { + cel::expr::Type type; + type.mutable_dyn(); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_dyn()); +} + +TEST(AstConvertersTest, NullTypeToNative) { + cel::expr::Type type; + type.set_null(google::protobuf::NULL_VALUE); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_null()); + EXPECT_EQ(native_type->null(), NullTypeSpec()); +} + +TEST(AstConvertersTest, PrimitiveTypeWrapperToNative) { + cel::expr::Type type; + type.set_wrapper(cel::expr::Type::BOOL); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_wrapper()); + EXPECT_EQ(native_type->wrapper(), PrimitiveType::kBool); +} + +TEST(AstConvertersTest, MessageTypeToNative) { + cel::expr::Type type; + type.set_message_type("message"); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_message_type()); + EXPECT_EQ(native_type->message_type().type(), "message"); +} + +TEST(AstConvertersTest, ParamTypeToNative) { + cel::expr::Type type; + type.set_type_param("param"); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_type_param()); + EXPECT_EQ(native_type->type_param().type(), "param"); +} + +TEST(AstConvertersTest, NestedTypeToNative) { + cel::expr::Type type; + type.mutable_type()->mutable_dyn(); + + auto native_type = ConvertProtoTypeToNative(type); + + ASSERT_TRUE(native_type->has_type()); + EXPECT_TRUE(native_type->type().has_dyn()); +} + +TEST(AstConvertersTest, TypeTypeDefault) { + auto native_type = ConvertProtoTypeToNative(cel::expr::Type()); + + ASSERT_THAT(native_type, IsOk()); + EXPECT_TRUE(absl::holds_alternative(native_type->type_kind())); +} + +TEST(AstConvertersTest, ReferenceToNative) { + cel::expr::CheckedExpr reference_wrapper; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + reference_map { + key: 1 + value { + name: "name" + overload_id: "id1" + overload_id: "id2" + value { bool_value: true } + } + })pb", + &reference_wrapper)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromCheckedExpr(reference_wrapper)); + const auto& native_references = ast->reference_map(); + + auto native_reference = native_references.at(1); + + EXPECT_EQ(native_reference.name(), "name"); + EXPECT_EQ(native_reference.overload_id(), + std::vector({"id1", "id2"})); + EXPECT_TRUE(native_reference.value().bool_value()); +} + +TEST(AstConvertersTest, SourceInfoToNative) { + cel::expr::ParsedExpr source_info_wrapper; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + source_info { + syntax_version: "version" + location: "location" + line_offsets: 1 + line_offsets: 2 + positions { key: 1 value: 2 } + positions { key: 3 value: 4 } + macro_calls { + key: 1 + value { ident_expr { name: "name" } } + } + })pb", + &source_info_wrapper)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(source_info_wrapper)); + const auto& native_source_info = ast->source_info(); + + EXPECT_EQ(native_source_info.syntax_version(), "version"); + EXPECT_EQ(native_source_info.location(), "location"); + EXPECT_EQ(native_source_info.line_offsets(), std::vector({1, 2})); + EXPECT_EQ(native_source_info.positions().at(1), 2); + EXPECT_EQ(native_source_info.positions().at(3), 4); + ASSERT_TRUE(native_source_info.macro_calls().at(1).has_ident_expr()); + ASSERT_EQ(native_source_info.macro_calls().at(1).ident_expr().name(), "name"); +} + +TEST(AstConvertersTest, CheckedExprToAst) { + CheckedExpr checked_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + reference_map { + key: 1 + value { + name: "name" + overload_id: "id1" + overload_id: "id2" + value { bool_value: true } + } + } + type_map { + key: 1 + value { dyn {} } + } + source_info { + syntax_version: "version" + location: "location" + line_offsets: 1 + line_offsets: 2 + positions { key: 1 value: 2 } + positions { key: 3 value: 4 } + macro_calls { + key: 1 + value { ident_expr { name: "name" } } + } + } + expr_version: "version" + expr { ident_expr { name: "expr" } } + )pb", + &checked_expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromCheckedExpr(checked_expr)); + + ASSERT_TRUE(ast->IsChecked()); +} + +TEST(AstConvertersTest, AstToCheckedExprBasic) { + Ast ast; + ast.mutable_root_expr().set_id(1); + ast.mutable_root_expr().mutable_ident_expr().set_name("expr"); + + ast.mutable_source_info().set_syntax_version("version"); + ast.mutable_source_info().set_location("location"); + ast.mutable_source_info().mutable_line_offsets().push_back(1); + ast.mutable_source_info().mutable_line_offsets().push_back(2); + ast.mutable_source_info().mutable_positions().insert({1, 2}); + ast.mutable_source_info().mutable_positions().insert({3, 4}); + + Expr macro; + macro.mutable_ident_expr().set_name("name"); + ast.mutable_source_info().mutable_macro_calls().insert({1, std::move(macro)}); + + Reference reference; + reference.set_name("name"); + reference.mutable_overload_id().push_back("id1"); + reference.mutable_overload_id().push_back("id2"); + reference.mutable_value().set_bool_value(true); + + TypeSpec type; + type.set_type_kind(DynTypeSpec()); + + ast.mutable_reference_map().insert({1, std::move(reference)}); + ast.mutable_type_map().insert({1, std::move(type)}); + + ast.set_expr_version("version"); + ast.set_is_checked(true); + + CheckedExpr checked_expr; + ASSERT_THAT(AstToCheckedExpr(ast, &checked_expr), IsOk()); + + EXPECT_THAT(checked_expr, EqualsProto(R"pb( + reference_map { + key: 1 + value { + name: "name" + overload_id: "id1" + overload_id: "id2" + value { bool_value: true } + } + } + type_map { + key: 1 + value { dyn {} } + } + source_info { + syntax_version: "version" + location: "location" + line_offsets: 1 + line_offsets: 2 + positions { key: 1 value: 2 } + positions { key: 3 value: 4 } + macro_calls { + key: 1 + value { ident_expr { name: "name" } } + } + } + expr_version: "version" + expr { + id: 1 + ident_expr { name: "expr" } + } + )pb")); +} + +constexpr absl::string_view kTypesTestCheckedExpr = + R"pb(reference_map: { + key: 1 + value: { name: "x" } + } + type_map: { + key: 1 + value: { primitive: INT64 } + } + source_info: { + location: "" + line_offsets: 2 + positions: { key: 1 value: 0 } + } + expr: { + id: 1 + ident_expr: { name: "x" } + })pb"; + +struct CheckedExprToAstTypesTestCase { + absl::string_view type; +}; + +class CheckedExprToAstTypesTest + : public testing::TestWithParam { + public: + void SetUp() override { + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kTypesTestCheckedExpr, + &checked_expr_)); + } + + protected: + CheckedExpr checked_expr_; +}; + +TEST_P(CheckedExprToAstTypesTest, CheckedExprToAstTypes) { + TypePb test_type; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(GetParam().type, &test_type)); + (*checked_expr_.mutable_type_map())[1] = test_type; + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromCheckedExpr(checked_expr_)); + + CheckedExpr checked_expr; + ASSERT_THAT(AstToCheckedExpr(*ast, &checked_expr), IsOk()); + + EXPECT_THAT(checked_expr, EqualsProto(checked_expr_)); +} + +INSTANTIATE_TEST_SUITE_P( + Types, CheckedExprToAstTypesTest, + testing::ValuesIn({ + {R"pb(list_type { elem_type { primitive: INT64 } })pb"}, + {R"pb(map_type { + key_type { primitive: STRING } + value_type { primitive: INT64 } + })pb"}, + {R"pb(message_type: "com.example.TestType")pb"}, + {R"pb(primitive: BOOL)pb"}, + {R"pb(primitive: INT64)pb"}, + {R"pb(primitive: UINT64)pb"}, + {R"pb(primitive: DOUBLE)pb"}, + {R"pb(primitive: STRING)pb"}, + {R"pb(primitive: BYTES)pb"}, + {R"pb(wrapper: BOOL)pb"}, + {R"pb(wrapper: INT64)pb"}, + {R"pb(wrapper: UINT64)pb"}, + {R"pb(wrapper: DOUBLE)pb"}, + {R"pb(wrapper: STRING)pb"}, + {R"pb(wrapper: BYTES)pb"}, + {R"pb(well_known: TIMESTAMP)pb"}, + {R"pb(well_known: DURATION)pb"}, + {R"pb(well_known: ANY)pb"}, + {R"pb(dyn {})pb"}, + {R"pb(error {})pb"}, + {R"pb(null: NULL_VALUE)pb"}, + {R"pb( + abstract_type { + name: "MyType" + parameter_types { primitive: INT64 } + } + )pb"}, + {R"pb( + type { primitive: INT64 } + )pb"}, + {R"pb( + type { type {} } + )pb"}, + {R"pb(type_param: "T")pb"}, + {R"pb( + function { + result_type { primitive: INT64 } + arg_types { primitive: INT64 } + } + )pb"}, + })); + +TEST(AstConvertersTest, ParsedExprToAst) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + source_info { + syntax_version: "version" + location: "location" + line_offsets: 1 + line_offsets: 2 + positions { key: 1 value: 2 } + positions { key: 3 value: 4 } + macro_calls { + key: 1 + value { ident_expr { name: "name" } } + } + } + expr { ident_expr { name: "expr" } } + )pb", + &parsed_expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); +} + +TEST(AstConvertersTest, AstToParsedExprBasic) { + Expr expr; + expr.set_id(1); + expr.mutable_ident_expr().set_name("expr"); + + SourceInfo source_info; + source_info.set_syntax_version("version"); + source_info.set_location("location"); + source_info.mutable_line_offsets().push_back(1); + source_info.mutable_line_offsets().push_back(2); + source_info.mutable_positions().insert({1, 2}); + source_info.mutable_positions().insert({3, 4}); + + Expr macro; + macro.mutable_ident_expr().set_name("name"); + source_info.mutable_macro_calls().insert({1, std::move(macro)}); + + Ast ast(std::move(expr), std::move(source_info)); + + ParsedExpr parsed_expr; + ASSERT_THAT(AstToParsedExpr(ast, &parsed_expr), IsOk()); + + EXPECT_THAT(parsed_expr, EqualsProto(R"pb( + source_info { + syntax_version: "version" + location: "location" + line_offsets: 1 + line_offsets: 2 + positions { key: 1 value: 2 } + positions { key: 3 value: 4 } + macro_calls { + key: 1 + value { ident_expr { name: "name" } } + } + } + expr { + id: 1 + ident_expr { name: "expr" } + } + )pb")); +} + +TEST(AstConvertersTest, ExprToAst) { + cel::expr::Expr expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + ident_expr { name: "expr" } + )pb", + &expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(expr)); +} + +TEST(AstConvertersTest, ExprAndSourceInfoToAst) { + cel::expr::Expr expr; + cel::expr::SourceInfo source_info; + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + syntax_version: "version" + location: "location" + line_offsets: 1 + line_offsets: 2 + positions { key: 1 value: 2 } + positions { key: 3 value: 4 } + macro_calls { + key: 1 + value { ident_expr { name: "name" } } + } + )pb", + &source_info)); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + ident_expr { name: "expr" } + )pb", + &expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(expr, &source_info)); +} + +TEST(AstConvertersTest, EmptyNodeRoundTrip) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + id: 1 + select_expr { + operand { + id: 2 + # no kind set. + } + field: "field" + } + } + source_info {} + )pb", + &parsed_expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); + ParsedExpr copy; + ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); + EXPECT_THAT(copy, EqualsProto(parsed_expr)); +} + +TEST(AstConvertersTest, DurationConstantRoundTrip) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + id: 1 + const_expr { + # deprecated, but support existing ASTs. + duration_value { seconds: 10 } + } + } + source_info {} + )pb", + &parsed_expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); + + ParsedExpr copy; + ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); + EXPECT_THAT(copy, EqualsProto(parsed_expr)); +} + +TEST(AstConvertersTest, TimestampConstantRoundTrip) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + id: 1 + const_expr { + # deprecated, but support existing ASTs. + timestamp_value { seconds: 10 } + } + } + source_info {} + )pb", + &parsed_expr)); + + ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); + ParsedExpr copy; + ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); + EXPECT_THAT(copy, EqualsProto(parsed_expr)); +} + +struct ConversionRoundTripCase { + absl::string_view expr; +}; + +class ConversionRoundTripTest + : public testing::TestWithParam { + public: + ConversionRoundTripTest() { + auto builder = + cel::NewCompilerBuilder(internal::GetTestingDescriptorPool()).value(); + builder->AddLibrary(cel::StandardCompilerLibrary()).IgnoreError(); + builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError(); + builder->AddLibrary(extensions::ComprehensionsV2CompilerLibrary()) + .IgnoreError(); + builder->GetCheckerBuilder().set_container("cel.expr.conformance.proto3"); + builder->GetCheckerBuilder() + .AddVariable(MakeVariableDecl("ident", IntType())) + .IgnoreError(); + builder->GetCheckerBuilder() + .AddVariable(MakeVariableDecl("map_ident", JsonMapType())) + .IgnoreError(); + compiler_ = builder->Build().value(); + } + + absl::StatusOr ParseToProto(absl::string_view expr) { + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expr)); + + CEL_ASSIGN_OR_RETURN(auto result, compiler_->GetParser().Parse(*source)); + ParsedExpr parsed_expr; + + CEL_RETURN_IF_ERROR(AstToParsedExpr(*result, &parsed_expr)); + return parsed_expr; + } + + absl::StatusOr CompileToProto(absl::string_view expr) { + CEL_ASSIGN_OR_RETURN(auto result, compiler_->Compile(expr)); + if (!result.IsValid()) { + return absl::InvalidArgumentError(absl::StrCat( + "Compilation failed: '", expr, "': ", result.FormatError())); + } + CEL_ASSIGN_OR_RETURN(auto ast, result.ReleaseAst()); + CheckedExpr checked_expr; + CEL_RETURN_IF_ERROR(AstToCheckedExpr(*ast, &checked_expr)); + return checked_expr; + } + + protected: + std::unique_ptr compiler_; +}; + +TEST_P(ConversionRoundTripTest, ParsedExprCopyable) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseToProto(GetParam().expr)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CreateAstFromParsedExpr(parsed_expr)); + + CheckedExpr expr_pb; + EXPECT_THAT(AstToCheckedExpr(*ast, &expr_pb), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("AST is not type-checked"))); + ParsedExpr proto_out; + ASSERT_THAT(AstToParsedExpr(*ast, &proto_out), IsOk()); + EXPECT_THAT(proto_out, EqualsProto(parsed_expr)); +} + +TEST_P(ConversionRoundTripTest, ExprCopyable) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseToProto(GetParam().expr)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CreateAstFromParsedExpr(parsed_expr)); + + Expr copy = ast->root_expr(); + ast->mutable_root_expr() = std::move(copy); + + ParsedExpr parsed_pb_out; + CheckedExpr checked_pb_out; + EXPECT_THAT(AstToCheckedExpr(*ast, &checked_pb_out), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("AST is not type-checked"))); + ASSERT_THAT(AstToParsedExpr(*ast, &parsed_pb_out), IsOk()); + EXPECT_THAT(parsed_pb_out, EqualsProto(parsed_expr)); +} + +TEST_P(ConversionRoundTripTest, CheckedExprRoundTrip) { + ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr, + CompileToProto(GetParam().expr)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CreateAstFromCheckedExpr(checked_expr)); + + CheckedExpr checked_pb_out; + ASSERT_THAT(AstToCheckedExpr(*ast, &checked_pb_out), IsOk()); + EXPECT_THAT(checked_pb_out, EqualsProto(checked_expr)); +} + +TEST_P(ConversionRoundTripTest, CheckedExprCopyRoundTrip) { + ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr, + CompileToProto(GetParam().expr)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CreateAstFromCheckedExpr(checked_expr)); + + Ast copy = *ast; + CheckedExpr checked_pb_out; + ASSERT_THAT(AstToCheckedExpr(copy, &checked_pb_out), IsOk()); + EXPECT_THAT(checked_pb_out, EqualsProto(checked_expr)); +} + +INSTANTIATE_TEST_SUITE_P( + ExpressionCases, ConversionRoundTripTest, + testing::ValuesIn( + {{R"cel(null == null)cel"}, + {R"cel(1 == 2)cel"}, + {R"cel(1u == 2u)cel"}, + {R"cel(1.1 == 2.1)cel"}, + {R"cel(b"1" == b"2")cel"}, + {R"cel("42" == "42")cel"}, + {R"cel("s".startsWith("s") == true)cel"}, + {R"cel([1, 2, 3] == [1, 2, 3])cel"}, + {R"cel([1, 2, 3].all(i, e, i == e - 1) == true)cel"}, + {R"cel(TestAllTypes{single_int64: 42}.single_int64 == 42)cel"}, + {R"cel([1, 2, 3].map(x, x + 2).size() == 3)cel"}, + {R"cel({"a": 1, "b": 2}["a"] == 1)cel"}, + {R"cel(ident == 42)cel"}, + {R"cel(map_ident.field == 42)cel"}, + {R"cel({?"abc": {}[?1]}.?abc.orValue(42) == 42)cel"}, + {R"cel([1, 2, ?optional.none()].size() == 2)cel"}})); + +TEST(ExtensionConversionRoundTripTest, RoundTrip) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + id: 1 + ident_expr { name: "unused" } + } + source_info { + extensions { + id: "extension" + version { major: 1 minor: 2 } + affected_components: COMPONENT_UNSPECIFIED + affected_components: COMPONENT_PARSER + affected_components: COMPONENT_TYPE_CHECKER + affected_components: COMPONENT_RUNTIME + } + } + )pb", + &parsed_expr)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CreateAstFromParsedExpr(parsed_expr)); + + CheckedExpr expr_pb; + EXPECT_THAT(AstToCheckedExpr(*ast, &expr_pb), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("AST is not type-checked"))); + ParsedExpr copy; + ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); + EXPECT_THAT(copy, EqualsProto(parsed_expr)); +} + +} // namespace +} // namespace cel diff --git a/common/ast_rewrite.cc b/common/ast_rewrite.cc new file mode 100644 index 000000000..b61e1fab6 --- /dev/null +++ b/common/ast_rewrite.cc @@ -0,0 +1,389 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast_rewrite.h" + +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "common/ast_visitor.h" +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +namespace { + +struct ArgRecord { + // Not null. + Expr* expr; + + // For records that are direct arguments to call, we need to call + // the CallArg visitor immediately after the argument is evaluated. + const Expr* calling_expr; + int call_arg; +}; + +struct ComprehensionRecord { + // Not null. + Expr* expr; + + const ComprehensionExpr* comprehension; + const Expr* comprehension_expr; + ComprehensionArg comprehension_arg; + bool use_comprehension_callbacks; +}; + +struct ExprRecord { + // Not null. + Expr* expr; +}; + +using StackRecordKind = + std::variant; + +struct StackRecord { + public: + static constexpr int kTarget = -2; + + explicit StackRecord(Expr* e) { + ExprRecord record; + record.expr = e; + record_variant = record; + } + + StackRecord(Expr* e, ComprehensionExpr* comprehension, + Expr* comprehension_expr, ComprehensionArg comprehension_arg, + bool use_comprehension_callbacks) { + if (use_comprehension_callbacks) { + ComprehensionRecord record; + record.expr = e; + record.comprehension = comprehension; + record.comprehension_expr = comprehension_expr; + record.comprehension_arg = comprehension_arg; + record.use_comprehension_callbacks = use_comprehension_callbacks; + record_variant = record; + return; + } + ArgRecord record; + record.expr = e; + record.calling_expr = comprehension_expr; + record.call_arg = comprehension_arg; + record_variant = record; + } + + StackRecord(Expr* e, const Expr* call, int argnum) { + ArgRecord record; + record.expr = e; + record.calling_expr = call; + record.call_arg = argnum; + record_variant = record; + } + + Expr* expr() const { return absl::get(record_variant).expr; } + + bool IsExprRecord() const { + return absl::holds_alternative(record_variant); + } + + StackRecordKind record_variant; + bool visited = false; +}; + +struct PreVisitor { + void operator()(const ExprRecord& record) { + struct { + AstVisitor* visitor; + const Expr* expr; + void operator()(const Constant&) { + // No pre-visit action. + } + void operator()(const IdentExpr&) { + // No pre-visit action. + } + void operator()(const SelectExpr& select) { + visitor->PreVisitSelect(*expr, select); + } + void operator()(const CallExpr& call) { + visitor->PreVisitCall(*expr, call); + } + void operator()(const ListExpr&) { + // No pre-visit action. + } + void operator()(const StructExpr&) { + // No pre-visit action. + } + void operator()(const MapExpr&) { + // No pre-visit action. + } + void operator()(const ComprehensionExpr& comprehension) { + visitor->PreVisitComprehension(*expr, comprehension); + } + void operator()(const UnspecifiedExpr&) { + // No pre-visit action. + } + } handler{visitor, record.expr}; + visitor->PreVisitExpr(*record.expr); + absl::visit(handler, record.expr->kind()); + } + + // Do nothing for Arg variant. + void operator()(const ArgRecord&) {} + + void operator()(const ComprehensionRecord& record) { + visitor->PreVisitComprehensionSubexpression(*record.comprehension_expr, + *record.comprehension, + record.comprehension_arg); + } + + AstVisitor* visitor; +}; + +void PreVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PreVisitor{visitor}, record.record_variant); +} + +struct PostVisitor { + void operator()(const ExprRecord& record) { + struct { + AstVisitor* visitor; + const Expr* expr; + void operator()(const Constant& constant) { + visitor->PostVisitConst(*expr, constant); + } + void operator()(const IdentExpr& ident) { + visitor->PostVisitIdent(*expr, ident); + } + void operator()(const SelectExpr& select) { + visitor->PostVisitSelect(*expr, select); + } + void operator()(const CallExpr& call) { + visitor->PostVisitCall(*expr, call); + } + void operator()(const ListExpr& create_list) { + visitor->PostVisitList(*expr, create_list); + } + void operator()(const StructExpr& create_struct) { + visitor->PostVisitStruct(*expr, create_struct); + } + void operator()(const MapExpr& map_expr) { + visitor->PostVisitMap(*expr, map_expr); + } + void operator()(const ComprehensionExpr& comprehension) { + visitor->PostVisitComprehension(*expr, comprehension); + } + void operator()(const UnspecifiedExpr&) { + ABSL_LOG(ERROR) << "Unsupported Expr kind"; + } + } handler{visitor, record.expr}; + absl::visit(handler, record.expr->kind()); + + visitor->PostVisitExpr(*record.expr); + } + + void operator()(const ArgRecord& record) { + if (record.call_arg == StackRecord::kTarget) { + visitor->PostVisitTarget(*record.calling_expr); + } else { + visitor->PostVisitArg(*record.calling_expr, record.call_arg); + } + } + + void operator()(const ComprehensionRecord& record) { + visitor->PostVisitComprehensionSubexpression(*record.comprehension_expr, + *record.comprehension, + record.comprehension_arg); + } + + AstVisitor* visitor; +}; + +void PostVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PostVisitor{visitor}, record.record_variant); +} + +void PushSelectDeps(SelectExpr* select_expr, std::stack* stack) { + if (select_expr->has_operand()) { + stack->push(StackRecord(&select_expr->mutable_operand())); + } +} + +void PushCallDeps(CallExpr* call_expr, Expr* expr, + std::stack* stack) { + const int arg_size = call_expr->args().size(); + // Our contract is that we visit arguments in order. To do that, we need + // to push them onto the stack in reverse order. + for (int i = arg_size - 1; i >= 0; --i) { + stack->push(StackRecord(&call_expr->mutable_args()[i], expr, i)); + } + // Are we receiver-style? + if (call_expr->has_target()) { + stack->push( + StackRecord(&call_expr->mutable_target(), expr, StackRecord::kTarget)); + } +} + +void PushListDeps(ListExpr* list_expr, std::stack* stack) { + auto& elements = list_expr->mutable_elements(); + for (auto it = elements.rbegin(); it != elements.rend(); ++it) { + auto& element = *it; + stack->push(StackRecord(&element.mutable_expr())); + } +} + +void PushStructDeps(StructExpr* struct_expr, std::stack* stack) { + auto& entries = struct_expr->mutable_fields(); + for (auto it = entries.rbegin(); it != entries.rend(); ++it) { + auto& entry = *it; + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_value()) { + stack->push(StackRecord(&entry.mutable_value())); + } + } +} + +void PushMapDeps(MapExpr* struct_expr, std::stack* stack) { + auto& entries = struct_expr->mutable_entries(); + for (auto it = entries.rbegin(); it != entries.rend(); ++it) { + auto& entry = *it; + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_value()) { + stack->push(StackRecord(&entry.mutable_value())); + } + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_key()) { + stack->push(StackRecord(&entry.mutable_key())); + } + } +} + +void PushComprehensionDeps(ComprehensionExpr* c, Expr* expr, + std::stack* stack, + bool use_comprehension_callbacks) { + StackRecord iter_range(&c->mutable_iter_range(), c, expr, ITER_RANGE, + use_comprehension_callbacks); + StackRecord accu_init(&c->mutable_accu_init(), c, expr, ACCU_INIT, + use_comprehension_callbacks); + StackRecord loop_condition(&c->mutable_loop_condition(), c, expr, + LOOP_CONDITION, use_comprehension_callbacks); + StackRecord loop_step(&c->mutable_loop_step(), c, expr, LOOP_STEP, + use_comprehension_callbacks); + StackRecord result(&c->mutable_result(), c, expr, RESULT, + use_comprehension_callbacks); + // Push them in reverse order. + stack->push(result); + stack->push(loop_step); + stack->push(loop_condition); + stack->push(accu_init); + stack->push(iter_range); +} + +struct PushDepsVisitor { + void operator()(const ExprRecord& record) { + struct { + std::stack& stack; + const RewriteTraversalOptions& options; + const ExprRecord& record; + void operator()(const Constant&) {} + void operator()(const IdentExpr&) {} + void operator()(const SelectExpr&) { + PushSelectDeps(&record.expr->mutable_select_expr(), &stack); + } + void operator()(const CallExpr&) { + PushCallDeps(&record.expr->mutable_call_expr(), record.expr, &stack); + } + void operator()(const ListExpr&) { + PushListDeps(&record.expr->mutable_list_expr(), &stack); + } + void operator()(const StructExpr&) { + PushStructDeps(&record.expr->mutable_struct_expr(), &stack); + } + void operator()(const MapExpr&) { + PushMapDeps(&record.expr->mutable_map_expr(), &stack); + } + void operator()(const ComprehensionExpr&) { + PushComprehensionDeps(&record.expr->mutable_comprehension_expr(), + record.expr, &stack, + options.use_comprehension_callbacks); + } + void operator()(const UnspecifiedExpr&) {} + } handler{stack, options, record}; + absl::visit(handler, record.expr->kind()); + } + + void operator()(const ArgRecord& record) { + stack.push(StackRecord(record.expr)); + } + + void operator()(const ComprehensionRecord& record) { + stack.push(StackRecord(record.expr)); + } + + std::stack& stack; + const RewriteTraversalOptions& options; +}; + +void PushDependencies(const StackRecord& record, std::stack& stack, + const RewriteTraversalOptions& options) { + absl::visit(PushDepsVisitor{stack, options}, record.record_variant); +} + +} // namespace + +bool AstRewrite(Expr& expr, AstRewriter& visitor, + RewriteTraversalOptions options) { + std::stack stack; + std::vector traversal_path; + + stack.push(StackRecord(&expr)); + bool rewritten = false; + + while (!stack.empty()) { + StackRecord& record = stack.top(); + if (!record.visited) { + if (record.IsExprRecord()) { + traversal_path.push_back(record.expr()); + visitor.TraversalStackUpdate(absl::MakeSpan(traversal_path)); + + if (visitor.PreVisitRewrite(*record.expr())) { + rewritten = true; + } + } + PreVisit(record, &visitor); + PushDependencies(record, stack, options); + record.visited = true; + } else { + PostVisit(record, &visitor); + if (record.IsExprRecord()) { + if (visitor.PostVisitRewrite(*record.expr())) { + rewritten = true; + } + + traversal_path.pop_back(); + visitor.TraversalStackUpdate(absl::MakeSpan(traversal_path)); + } + stack.pop(); + } + } + + return rewritten; +} + +} // namespace cel diff --git a/common/ast_rewrite.h b/common/ast_rewrite.h new file mode 100644 index 000000000..e24108ae4 --- /dev/null +++ b/common/ast_rewrite.h @@ -0,0 +1,146 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_REWRITE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_REWRITE_H_ + +#include "absl/base/nullability.h" +#include "absl/types/span.h" +#include "common/ast_visitor.h" +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +// Traversal options for AstRewrite. +struct RewriteTraversalOptions { + // If enabled, use comprehension specific callbacks instead of the general + // arguments callbacks. + bool use_comprehension_callbacks; + + RewriteTraversalOptions() : use_comprehension_callbacks(false) {} +}; + +// Interface for AST rewriters. +// Extends AstVisitor interface with update methods. +// see AstRewrite for more details on usage. +class AstRewriter : public AstVisitor { + public: + ~AstRewriter() override {} + + // Rewrite a sub expression before visiting. + // Occurs before visiting Expr. If expr is modified, it the new value will be + // visited. + virtual bool PreVisitRewrite(Expr& expr) = 0; + + // Rewrite a sub expression after visiting. + // Occurs after visiting expr and it's children. If expr is modified, the old + // sub expression is visited. + virtual bool PostVisitRewrite(Expr& expr) = 0; + + // Notify the visitor of updates to the traversal stack. + virtual void TraversalStackUpdate( + absl::Span path) = 0; +}; + +// Trivial implementation for AST rewriters. +// Virtual methods are overridden with no-op callbacks. +class AstRewriterBase : public AstRewriter { + public: + ~AstRewriterBase() override {} + + void PreVisitExpr(const Expr&) override {} + + void PostVisitExpr(const Expr&) override {} + + void PostVisitConst(const Expr&, const Constant&) override {} + + void PostVisitIdent(const Expr&, const IdentExpr&) override {} + + void PreVisitSelect(const Expr&, const SelectExpr&) override {} + + void PostVisitSelect(const Expr&, const SelectExpr&) override {} + + void PreVisitCall(const Expr&, const CallExpr&) override {} + + void PostVisitCall(const Expr&, const CallExpr&) override {} + + void PreVisitComprehension(const Expr&, const ComprehensionExpr&) override {} + + void PostVisitComprehension(const Expr&, const ComprehensionExpr&) override {} + + void PostVisitArg(const Expr&, int) override {} + + void PostVisitTarget(const Expr&) override {} + + void PostVisitList(const Expr&, const ListExpr&) override {} + + void PostVisitStruct(const Expr&, const StructExpr&) override {} + + void PostVisitMap(const Expr&, const MapExpr&) override {} + + bool PreVisitRewrite(Expr& expr) override { return false; } + + bool PostVisitRewrite(Expr& expr) override { return false; } + + void TraversalStackUpdate( + absl::Span path) override {} +}; + +// Traverses the AST representation in an expr proto. Returns true if any +// rewrites occur. +// +// Rewrites may happen before and/or after visiting an expr subtree. If a +// change happens during the pre-visit rewrite, the updated subtree will be +// visited. If a change happens during the post-visit rewrite, the old subtree +// will be visited. +// +// expr: root node of the tree. +// source_info: optional additional parse information about the expression +// visitor: the callback object that receives the visitation notifications +// options: options for traversal. see RewriteTraversalOptions. Defaults are +// used if not sepecified. +// +// Traversal order follows the pattern: +// PreVisitRewrite +// PreVisitExpr +// ..PreVisit{ExprKind} +// ....PreVisit{ArgumentIndex} +// .......PreVisitExpr (subtree) +// .......PostVisitExpr (subtree) +// ....PostVisit{ArgumentIndex} +// ..PostVisit{ExprKind} +// PostVisitExpr +// PostVisitRewrite +// +// Example callback order for fn(1, var): +// PreVisitExpr +// ..PreVisitCall(fn) +// ......PreVisitExpr +// ........PostVisitConst(1) +// ......PostVisitExpr +// ....PostVisitArg(fn, 0) +// ......PreVisitExpr +// ........PostVisitIdent(var) +// ......PostVisitExpr +// ....PostVisitArg(fn, 1) +// ..PostVisitCall(fn) +// PostVisitExpr + +bool AstRewrite(Expr& expr, AstRewriter& visitor, + RewriteTraversalOptions options = RewriteTraversalOptions()); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_REWRITE_H_ diff --git a/common/ast_rewrite_test.cc b/common/ast_rewrite_test.cc new file mode 100644 index 000000000..5417b23ac --- /dev/null +++ b/common/ast_rewrite_test.cc @@ -0,0 +1,609 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast_rewrite.h" + +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status_matchers.h" +#include "common/ast.h" +#include "common/ast/expr_proto.h" +#include "common/ast_visitor.h" +#include "common/expr.h" +#include "extensions/protobuf/ast_converters.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/text_format.h" + +namespace cel { + +namespace { + +using ::absl_testing::IsOk; +using ::cel::ast_internal::ExprFromProto; +using ::cel::extensions::CreateAstFromParsedExpr; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::InSequence; +using ::testing::Ref; + +class MockAstRewriter : public AstRewriter { + public: + // Expr handler. + MOCK_METHOD(void, PreVisitExpr, (const Expr& expr), (override)); + + // Expr handler. + MOCK_METHOD(void, PostVisitExpr, (const Expr& expr), (override)); + + MOCK_METHOD(void, PostVisitConst, + (const Expr& expr, const Constant& const_expr), (override)); + + // Ident node handler. + MOCK_METHOD(void, PostVisitIdent, + (const Expr& expr, const IdentExpr& ident_expr), (override)); + + // Select node handler group + MOCK_METHOD(void, PreVisitSelect, + (const Expr& expr, const SelectExpr& select_expr), (override)); + + MOCK_METHOD(void, PostVisitSelect, + (const Expr& expr, const SelectExpr& select_expr), (override)); + + // Call node handler group + MOCK_METHOD(void, PreVisitCall, (const Expr& expr, const CallExpr& call_expr), + (override)); + MOCK_METHOD(void, PostVisitCall, + (const Expr& expr, const CallExpr& call_expr), (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehension, + (const Expr& expr, const ComprehensionExpr& comprehension_expr), + (override)); + MOCK_METHOD(void, PostVisitComprehension, + (const Expr& expr, const ComprehensionExpr& comprehension_expr), + (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehensionSubexpression, + (const Expr& expr, const ComprehensionExpr& comprehension_expr, + ComprehensionArg comprehension_arg), + (override)); + MOCK_METHOD(void, PostVisitComprehensionSubexpression, + (const Expr& expr, const ComprehensionExpr& comprehension_expr, + ComprehensionArg comprehension_arg), + (override)); + + // We provide finer granularity for Call and Comprehension node callbacks + // to allow special handling for short-circuiting. + MOCK_METHOD(void, PostVisitTarget, (const Expr& expr), (override)); + MOCK_METHOD(void, PostVisitArg, (const Expr& expr, int arg_num), (override)); + + // List node handler group + MOCK_METHOD(void, PostVisitList, + (const Expr& expr, const ListExpr& list_expr), (override)); + + // Struct node handler group + MOCK_METHOD(void, PostVisitStruct, + (const Expr& expr, const StructExpr& struct_expr), (override)); + + // Map node handler group + MOCK_METHOD(void, PostVisitMap, (const Expr& expr, const MapExpr& map_expr), + (override)); + + MOCK_METHOD(bool, PreVisitRewrite, (Expr & expr), (override)); + + MOCK_METHOD(bool, PostVisitRewrite, (Expr & expr), (override)); + + MOCK_METHOD(void, TraversalStackUpdate, + (absl::Span path), (override)); +}; + +TEST(AstCrawlerTest, CheckCrawlConstant) { + MockAstRewriter handler; + + Expr expr; + auto& const_expr = expr.mutable_const_expr(); + + EXPECT_CALL(handler, PostVisitConst(Ref(expr), Ref(const_expr))).Times(1); + + AstRewrite(expr, handler); +} + +TEST(AstCrawlerTest, CheckCrawlIdent) { + MockAstRewriter handler; + + Expr expr; + auto& ident_expr = expr.mutable_ident_expr(); + + EXPECT_CALL(handler, PostVisitIdent(Ref(expr), Ref(ident_expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of Select node when operand is not set. +TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { + MockAstRewriter handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of Select node +TEST(AstCrawlerTest, CheckCrawlSelect) { + MockAstRewriter handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + auto& operand = select_expr.mutable_operand(); + auto& ident_expr = operand.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of Call node without receiver +TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { + MockAstRewriter handler; + + // (, ) + Expr expr; + auto& call_expr = expr.mutable_call_expr(); + call_expr.mutable_args().reserve(2); + Expr& arg0 = call_expr.mutable_args().emplace_back(); + auto& const_expr = arg0.mutable_const_expr(); + Expr& arg1 = call_expr.mutable_args().emplace_back(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); + EXPECT_CALL(handler, PostVisitTarget(_)).Times(0); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of Call node with receiver +TEST(AstCrawlerTest, CheckCrawlCallReceiver) { + MockAstRewriter handler; + + // .(, ) + Expr expr; + auto& call_expr = expr.mutable_call_expr(); + Expr& target = call_expr.mutable_target(); + auto& target_ident = target.mutable_ident_expr(); + call_expr.mutable_args().reserve(2); + Expr& arg0 = call_expr.mutable_args().emplace_back(); + auto& const_expr = arg0.mutable_const_expr(); + Expr& arg1 = call_expr.mutable_args().emplace_back(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); + + // Target + EXPECT_CALL(handler, PostVisitIdent(Ref(target), Ref(target_ident))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(target))).Times(1); + EXPECT_CALL(handler, PostVisitTarget(Ref(expr))).Times(1); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehension) { + MockAstRewriter handler; + + Expr expr; + auto& c = expr.mutable_comprehension_expr(); + auto& iter_range = c.mutable_iter_range(); + auto& iter_range_expr = iter_range.mutable_const_expr(); + auto& accu_init = c.mutable_accu_init(); + auto& accu_init_expr = accu_init.mutable_ident_expr(); + auto& loop_condition = c.mutable_loop_condition(); + auto& loop_condition_expr = loop_condition.mutable_const_expr(); + auto& loop_step = c.mutable_loop_step(); + auto& loop_step_expr = loop_step.mutable_ident_expr(); + auto& result = c.mutable_result(); + auto& result_expr = result.mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); + + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ITER_RANGE)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), + ITER_RANGE)) + .Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) + .Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) + .Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), + LOOP_CONDITION)) + .Times(1); + EXPECT_CALL(handler, + PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), + LOOP_CONDITION)) + .Times(1); + + // LOOP STEP + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) + .Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) + .Times(1); + + // RESULT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) + .Times(1); + + EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); + + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) + .Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); + + RewriteTraversalOptions opts; + opts.use_comprehension_callbacks = true; + AstRewrite(expr, handler, opts); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { + MockAstRewriter handler; + + Expr expr; + auto& c = expr.mutable_comprehension_expr(); + auto& iter_range = c.mutable_iter_range(); + auto& iter_range_expr = iter_range.mutable_const_expr(); + auto& accu_init = c.mutable_accu_init(); + auto& accu_init_expr = accu_init.mutable_ident_expr(); + auto& loop_condition = c.mutable_loop_condition(); + auto& loop_condition_expr = loop_condition.mutable_const_expr(); + auto& loop_step = c.mutable_loop_step(); + auto& loop_step_expr = loop_step.mutable_ident_expr(); + auto& result = c.mutable_result(); + auto& result_expr = result.mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); + + EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), ITER_RANGE)).Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), ACCU_INIT)).Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, + PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_CONDITION)).Times(1); + + // LOOP STEP + EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_STEP)).Times(1); + + // RESULT + EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), RESULT)).Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of List node. +TEST(AstCrawlerTest, CheckList) { + MockAstRewriter handler; + + Expr expr; + auto& list_expr = expr.mutable_list_expr(); + list_expr.mutable_elements().reserve(2); + auto& arg0 = list_expr.mutable_elements().emplace_back().mutable_expr(); + auto& const_expr = arg0.mutable_const_expr(); + auto& arg1 = list_expr.mutable_elements().emplace_back().mutable_expr(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitList(Ref(expr), Ref(list_expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of Struct node. +TEST(AstCrawlerTest, CheckStruct) { + MockAstRewriter handler; + + Expr expr; + auto& struct_expr = expr.mutable_struct_expr(); + auto& entry0 = struct_expr.mutable_fields().emplace_back(); + + auto& value = entry0.mutable_value().mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) + .Times(1); + EXPECT_CALL(handler, PostVisitStruct(Ref(expr), Ref(struct_expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test handling of Map node. +TEST(AstCrawlerTest, CheckMap) { + MockAstRewriter handler; + + Expr expr; + auto& map_expr = expr.mutable_map_expr(); + auto& entry0 = map_expr.mutable_entries().emplace_back(); + + auto& key = entry0.mutable_key().mutable_const_expr(); + auto& value = entry0.mutable_value().mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(Ref(entry0.key()), Ref(key))).Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) + .Times(1); + EXPECT_CALL(handler, PostVisitMap(Ref(expr), Ref(map_expr))).Times(1); + + AstRewrite(expr, handler); +} + +// Test generic Expr handlers. +TEST(AstCrawlerTest, CheckExprHandlers) { + MockAstRewriter handler; + + Expr expr; + auto& map_expr = expr.mutable_map_expr(); + auto& entry0 = map_expr.mutable_entries().emplace_back(); + + entry0.mutable_key().mutable_const_expr(); + entry0.mutable_value().mutable_ident_expr(); + + EXPECT_CALL(handler, PreVisitExpr(_)).Times(3); + EXPECT_CALL(handler, PostVisitExpr(_)).Times(3); + + AstRewrite(expr, handler); +} + +// Test generic Expr handlers. +TEST(AstCrawlerTest, CheckExprRewriteHandlers) { + MockAstRewriter handler; + + Expr select_expr; + select_expr.mutable_select_expr().set_field("var"); + auto& inner_select_expr = select_expr.mutable_select_expr().mutable_operand(); + inner_select_expr.mutable_select_expr().set_field("mid"); + auto& ident = inner_select_expr.mutable_select_expr().mutable_operand(); + ident.mutable_ident_expr().set_name("top"); + + { + InSequence sequence; + EXPECT_CALL(handler, + TraversalStackUpdate(testing::ElementsAre(&select_expr))); + EXPECT_CALL(handler, PreVisitRewrite(Ref(select_expr))); + + EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( + &select_expr, &inner_select_expr))); + EXPECT_CALL(handler, PreVisitRewrite(Ref(inner_select_expr))); + + EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( + &select_expr, &inner_select_expr, &ident))); + EXPECT_CALL(handler, PreVisitRewrite(Ref(ident))); + + EXPECT_CALL(handler, PostVisitRewrite(Ref(ident))); + EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( + &select_expr, &inner_select_expr))); + + EXPECT_CALL(handler, PostVisitRewrite(Ref(inner_select_expr))); + EXPECT_CALL(handler, + TraversalStackUpdate(testing::ElementsAre(&select_expr))); + + EXPECT_CALL(handler, PostVisitRewrite(Ref(select_expr))); + EXPECT_CALL(handler, TraversalStackUpdate(testing::IsEmpty())); + } + + EXPECT_FALSE(AstRewrite(select_expr, handler)); +} + +// Simple rewrite that replaces a select path with a dot-qualified identifier. +class RewriterExample : public AstRewriterBase { + public: + RewriterExample() {} + bool PostVisitRewrite(Expr& expr) override { + if (target_.has_value() && expr.id() == *target_) { + expr.mutable_ident_expr().set_name("com.google.Identifier"); + return true; + } + return false; + } + + void PostVisitIdent(const Expr& expr, const IdentExpr& ident) override { + if (path_.size() >= 3) { + if (ident.name() == "com") { + const Expr* p1 = path_.at(path_.size() - 2); + const Expr* p2 = path_.at(path_.size() - 3); + + if (p1->has_select_expr() && p1->select_expr().field() == "google" && + p2->has_select_expr() && + p2->select_expr().field() == "Identifier") { + target_ = p2->id(); + } + } + } + } + + void TraversalStackUpdate(absl::Span path) override { + path_ = path; + } + + private: + absl::Span path_; + absl::optional target_; +}; + +TEST(AstRewrite, SelectRewriteExample) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CreateAstFromParsedExpr( + google::api::expr::parser::Parse("com.google.Identifier").value())); + RewriterExample example; + ASSERT_TRUE(AstRewrite(ast->mutable_root_expr(), example)); + + cel::expr::Expr expected_expr; + google::protobuf::TextFormat::ParseFromString( + R"pb( + id: 3 + ident_expr { name: "com.google.Identifier" } + )pb", + &expected_expr); + + cel::Expr expected_native; + ASSERT_THAT(ExprFromProto(expected_expr, expected_native), IsOk()); + + EXPECT_EQ(ast->root_expr(), expected_native); +} + +// Rewrites x -> y -> z to demonstrate traversal when a node is rewritten on +// both passes. +class PreRewriterExample : public AstRewriterBase { + public: + PreRewriterExample() {} + bool PreVisitRewrite(Expr& expr) override { + if (expr.ident_expr().name() == "x") { + expr.mutable_ident_expr().set_name("y"); + return true; + } + return false; + } + + bool PostVisitRewrite(Expr& expr) override { + if (expr.ident_expr().name() == "y") { + expr.mutable_ident_expr().set_name("z"); + return true; + } + return false; + } + + void PostVisitIdent(const Expr& expr, const IdentExpr& ident) override { + visited_idents_.push_back(ident.name()); + } + + const std::vector& visited_idents() const { + return visited_idents_; + } + + private: + std::vector visited_idents_; +}; + +TEST(AstRewrite, PreAndPostVisitExpample) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CreateAstFromParsedExpr(google::api::expr::parser::Parse("x").value())); + PreRewriterExample visitor; + ASSERT_TRUE(AstRewrite(ast->mutable_root_expr(), visitor)); + + cel::expr::Expr expected_expr; + google::protobuf::TextFormat::ParseFromString( + R"pb( + id: 1 + ident_expr { name: "z" } + )pb", + &expected_expr); + cel::Expr expected_native; + ASSERT_THAT(ExprFromProto(expected_expr, expected_native), IsOk()); + + EXPECT_EQ(ast->root_expr(), expected_native); + EXPECT_THAT(visitor.visited_idents(), ElementsAre("y")); +} + +} // namespace + +} // namespace cel diff --git a/common/ast_test.cc b/common/ast_test.cc new file mode 100644 index 000000000..56e1bcd1e --- /dev/null +++ b/common/ast_test.cc @@ -0,0 +1,188 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "common/expr.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::Pointee; +using ::testing::Truly; + +TEST(AstImpl, RawExprCtor) { + // arrange + // make ast for 2 + 1 == 3 + Expr expr; + auto& call = expr.mutable_call_expr(); + expr.set_id(5); + call.set_function("_==_"); + auto& eq_lhs = call.mutable_args().emplace_back(); + eq_lhs.mutable_call_expr().set_function("_+_"); + eq_lhs.set_id(3); + auto& sum_lhs = eq_lhs.mutable_call_expr().mutable_args().emplace_back(); + sum_lhs.mutable_const_expr().set_int_value(2); + sum_lhs.set_id(1); + auto& sum_rhs = eq_lhs.mutable_call_expr().mutable_args().emplace_back(); + sum_rhs.mutable_const_expr().set_int_value(1); + sum_rhs.set_id(2); + auto& eq_rhs = call.mutable_args().emplace_back(); + eq_rhs.mutable_const_expr().set_int_value(3); + eq_rhs.set_id(4); + + SourceInfo source_info; + source_info.mutable_positions()[5] = 6; + + // act + Ast ast(std::move(expr), std::move(source_info)); + + // assert + ASSERT_FALSE(ast.is_checked()); + EXPECT_EQ(ast.GetTypeOrDyn(1), TypeSpec(DynTypeSpec())); + EXPECT_EQ(ast.GetReturnType(), TypeSpec(DynTypeSpec())); + EXPECT_EQ(ast.GetReference(1), nullptr); + EXPECT_TRUE(ast.root_expr().has_call_expr()); + EXPECT_EQ(ast.root_expr().call_expr().function(), "_==_"); + EXPECT_EQ(ast.root_expr().id(), 5); // Parser IDs leaf to root. + EXPECT_EQ(ast.source_info().positions().at(5), 6); // start pos of == +} + +TEST(AstImpl, CheckedExprCtor) { + Expr expr; + expr.mutable_ident_expr().set_name("int_value"); + expr.set_id(1); + Reference ref; + ref.set_name("com.int_value"); + Ast::ReferenceMap reference_map; + reference_map[1] = Reference(ref); + Ast::TypeMap type_map; + type_map[1] = TypeSpec(PrimitiveType::kInt64); + SourceInfo source_info; + source_info.set_syntax_version("1.0"); + + Ast ast(std::move(expr), std::move(source_info), std::move(reference_map), + std::move(type_map), "1.0"); + + ASSERT_TRUE(ast.is_checked()); + EXPECT_EQ(ast.GetTypeOrDyn(1), TypeSpec(PrimitiveType::kInt64)); + EXPECT_THAT(ast.GetReference(1), Pointee(Truly([&ref](const Reference& arg) { + return arg.name() == ref.name(); + }))); + EXPECT_EQ(ast.GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + EXPECT_TRUE(ast.root_expr().has_ident_expr()); + EXPECT_EQ(ast.root_expr().ident_expr().name(), "int_value"); + EXPECT_EQ(ast.root_expr().id(), 1); + EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); + EXPECT_EQ(ast.expr_version(), "1.0"); +} + +TEST(AstImpl, CheckedExprDeepCopy) { + Expr root; + root.set_id(3); + root.mutable_call_expr().set_function("_==_"); + root.mutable_call_expr().mutable_args().resize(2); + auto& lhs = root.mutable_call_expr().mutable_args()[0]; + auto& rhs = root.mutable_call_expr().mutable_args()[1]; + Ast::TypeMap type_map; + Ast::ReferenceMap reference_map; + SourceInfo source_info; + + type_map[3] = TypeSpec(PrimitiveType::kBool); + + lhs.mutable_ident_expr().set_name("int_value"); + lhs.set_id(1); + Reference ref; + ref.set_name("com.int_value"); + reference_map[1] = std::move(ref); + type_map[1] = TypeSpec(PrimitiveType::kInt64); + + rhs.mutable_const_expr().set_int_value(2); + rhs.set_id(2); + type_map[2] = TypeSpec(PrimitiveType::kInt64); + source_info.set_syntax_version("1.0"); + + Ast ast(std::move(root), std::move(source_info), std::move(reference_map), + std::move(type_map), "1.0"); + + ASSERT_TRUE(ast.IsChecked()); + EXPECT_EQ(ast.GetTypeOrDyn(1), TypeSpec(PrimitiveType::kInt64)); + EXPECT_THAT(ast.GetReference(1), Pointee(Truly([](const Reference& arg) { + return arg.name() == "com.int_value"; + }))); + EXPECT_EQ(ast.GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_TRUE(ast.root_expr().has_call_expr()); + EXPECT_EQ(ast.root_expr().call_expr().function(), "_==_"); + EXPECT_EQ(ast.root_expr().id(), 3); + EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); +} + +TEST(AstImpl, ComputeSourceLocation) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20, 30}); + source_info.mutable_positions()[1] = 0; // Start of first line + source_info.mutable_positions()[2] = 5; // Middle of first line + source_info.mutable_positions()[3] = 10; // ... + source_info.mutable_positions()[4] = 15; + source_info.mutable_positions()[5] = 20; + source_info.mutable_positions()[6] = 25; + + Ast ast(Expr{}, std::move(source_info)); + + EXPECT_EQ(ast.ComputeSourceLocation(1), (SourceLocation{1, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(2), (SourceLocation{1, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(3), (SourceLocation{2, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(4), (SourceLocation{2, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(5), (SourceLocation{3, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(6), (SourceLocation{3, 5})); +} + +TEST(AstImpl, ComputeSourceLocationFailures) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20}); + source_info.mutable_positions()[1] = -1; // Negative position + source_info.mutable_positions()[2] = 25; // Beyond last line offset + // ID 3 is missing + + Ast ast; + ast.mutable_source_info() = std::move(source_info); + + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(2), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(3), SourceLocation{}); +} + +TEST(AstImpl, ComputeSourceLocationInvalidLineOffsets) { + { + // Empty line offsets + Ast ast; + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } + { + // Non-monotonic + SourceInfo source_info; + source_info.set_line_offsets({10, 5}); + source_info.mutable_positions()[1] = 12; + Ast ast(Expr{}, std::move(source_info)); + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } +} + +} // namespace +} // namespace cel diff --git a/common/ast_traverse.cc b/common/ast_traverse.cc new file mode 100644 index 000000000..fb4f9731e --- /dev/null +++ b/common/ast_traverse.cc @@ -0,0 +1,380 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast_traverse.h" + +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/types/variant.h" +#include "common/ast_visitor.h" +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +namespace { + +struct ArgRecord { + // Not null. + const Expr* expr; + + // For records that are direct arguments to call, we need to call + // the CallArg visitor immediately after the argument is evaluated. + const Expr* calling_expr; + int call_arg; +}; + +struct ComprehensionRecord { + // Not null. + const Expr* expr; + + const ComprehensionExpr* comprehension; + const Expr* comprehension_expr; + ComprehensionArg comprehension_arg; + bool use_comprehension_callbacks; +}; + +struct ExprRecord { + // Not null. + const Expr* expr; +}; + +using StackRecordKind = + std::variant; + +struct StackRecord { + public: + static constexpr int kTarget = -2; + + explicit StackRecord(const Expr* e) { + ExprRecord record; + record.expr = e; + record_variant = record; + } + + StackRecord(const Expr* e, const ComprehensionExpr* comprehension, + const Expr* comprehension_expr, + ComprehensionArg comprehension_arg, + bool use_comprehension_callbacks) { + if (use_comprehension_callbacks) { + ComprehensionRecord record; + record.expr = e; + record.comprehension = comprehension; + record.comprehension_expr = comprehension_expr; + record.comprehension_arg = comprehension_arg; + record.use_comprehension_callbacks = use_comprehension_callbacks; + record_variant = record; + return; + } + ArgRecord record; + record.expr = e; + record.calling_expr = comprehension_expr; + record.call_arg = comprehension_arg; + record_variant = record; + } + + StackRecord(const Expr* e, const Expr* call, int argnum) { + ArgRecord record; + record.expr = e; + record.calling_expr = call; + record.call_arg = argnum; + record_variant = record; + } + StackRecordKind record_variant; + bool visited = false; +}; + +struct PreVisitor { + void operator()(const ExprRecord& record) { + const Expr* expr = record.expr; + visitor->PreVisitExpr(*expr); + if (expr->has_select_expr()) { + visitor->PreVisitSelect(*expr, expr->select_expr()); + } else if (expr->has_call_expr()) { + visitor->PreVisitCall(*expr, expr->call_expr()); + } else if (expr->has_comprehension_expr()) { + visitor->PreVisitComprehension(*expr, expr->comprehension_expr()); + } else { + // No pre-visit action. + } + } + + // Do nothing for Arg variant. + void operator()(const ArgRecord&) {} + + void operator()(const ComprehensionRecord& record) { + visitor->PreVisitComprehensionSubexpression(*record.comprehension_expr, + *record.comprehension, + record.comprehension_arg); + } + + AstVisitor* visitor; +}; + +void PreVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PreVisitor{visitor}, record.record_variant); +} + +struct PostVisitor { + void operator()(const ExprRecord& record) { + const Expr* expr = record.expr; + struct { + AstVisitor* visitor; + const Expr* expr; + void operator()(const Constant& constant) { + visitor->PostVisitConst(*expr, expr->const_expr()); + } + void operator()(const IdentExpr& ident) { + visitor->PostVisitIdent(*expr, expr->ident_expr()); + } + void operator()(const SelectExpr& select) { + visitor->PostVisitSelect(*expr, expr->select_expr()); + } + void operator()(const CallExpr& call) { + visitor->PostVisitCall(*expr, expr->call_expr()); + } + void operator()(const ListExpr& create_list) { + visitor->PostVisitList(*expr, expr->list_expr()); + } + void operator()(const StructExpr& create_struct) { + visitor->PostVisitStruct(*expr, expr->struct_expr()); + } + void operator()(const MapExpr& map_expr) { + visitor->PostVisitMap(*expr, expr->map_expr()); + } + void operator()(const ComprehensionExpr& comprehension) { + visitor->PostVisitComprehension(*expr, expr->comprehension_expr()); + } + void operator()(const UnspecifiedExpr&) { + ABSL_LOG(ERROR) << "Unsupported Expr kind"; + } + } handler{visitor, record.expr}; + absl::visit(handler, record.expr->kind()); + + visitor->PostVisitExpr(*expr); + } + + void operator()(const ArgRecord& record) { + if (record.call_arg == StackRecord::kTarget) { + visitor->PostVisitTarget(*record.calling_expr); + } else { + visitor->PostVisitArg(*record.calling_expr, record.call_arg); + } + } + + void operator()(const ComprehensionRecord& record) { + visitor->PostVisitComprehensionSubexpression(*record.comprehension_expr, + *record.comprehension, + record.comprehension_arg); + } + + AstVisitor* visitor; +}; + +void PostVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PostVisitor{visitor}, record.record_variant); +} + +void PushSelectDeps(const SelectExpr* select_expr, + std::stack* stack) { + if (select_expr->has_operand()) { + stack->push(StackRecord(&select_expr->operand())); + } +} + +void PushCallDeps(const CallExpr* call_expr, const Expr* expr, + std::stack* stack) { + const int arg_size = call_expr->args().size(); + // Our contract is that we visit arguments in order. To do that, we need + // to push them onto the stack in reverse order. + for (int i = arg_size - 1; i >= 0; --i) { + stack->push(StackRecord(&call_expr->args()[i], expr, i)); + } + // Are we receiver-style? + if (call_expr->has_target()) { + stack->push(StackRecord(&call_expr->target(), expr, StackRecord::kTarget)); + } +} + +void PushListDeps(const ListExpr* list_expr, std::stack* stack) { + const auto& elements = list_expr->elements(); + for (auto it = elements.rbegin(); it != elements.rend(); ++it) { + const auto& element = *it; + stack->push(StackRecord(&element.expr())); + } +} + +void PushStructDeps(const StructExpr* struct_expr, + std::stack* stack) { + const auto& entries = struct_expr->fields(); + for (auto it = entries.rbegin(); it != entries.rend(); ++it) { + const auto& entry = *it; + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_value()) { + stack->push(StackRecord(&entry.value())); + } + } +} + +void PushMapDeps(const MapExpr* map_expr, std::stack* stack) { + const auto& entries = map_expr->entries(); + for (auto it = entries.rbegin(); it != entries.rend(); ++it) { + const auto& entry = *it; + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_value()) { + stack->push(StackRecord(&entry.value())); + } + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_key()) { + stack->push(StackRecord(&entry.key())); + } + } +} + +void PushComprehensionDeps(const ComprehensionExpr* c, const Expr* expr, + std::stack* stack, + bool use_comprehension_callbacks) { + StackRecord iter_range(&c->iter_range(), c, expr, ITER_RANGE, + use_comprehension_callbacks); + StackRecord accu_init(&c->accu_init(), c, expr, ACCU_INIT, + use_comprehension_callbacks); + StackRecord loop_condition(&c->loop_condition(), c, expr, LOOP_CONDITION, + use_comprehension_callbacks); + StackRecord loop_step(&c->loop_step(), c, expr, LOOP_STEP, + use_comprehension_callbacks); + StackRecord result(&c->result(), c, expr, RESULT, + use_comprehension_callbacks); + // Push them in reverse order. + stack->push(result); + stack->push(loop_step); + stack->push(loop_condition); + stack->push(accu_init); + stack->push(iter_range); +} + +struct PushDepsVisitor { + void operator()(const ExprRecord& record) { + struct { + std::stack& stack; + const TraversalOptions& options; + const ExprRecord& record; + void operator()(const Constant& constant) {} + void operator()(const IdentExpr& ident) {} + void operator()(const SelectExpr& select) { + PushSelectDeps(&record.expr->select_expr(), &stack); + } + void operator()(const CallExpr& call) { + PushCallDeps(&record.expr->call_expr(), record.expr, &stack); + } + void operator()(const ListExpr& create_list) { + PushListDeps(&record.expr->list_expr(), &stack); + } + void operator()(const StructExpr& create_struct) { + PushStructDeps(&record.expr->struct_expr(), &stack); + } + void operator()(const MapExpr& map_expr) { + PushMapDeps(&record.expr->map_expr(), &stack); + } + void operator()(const ComprehensionExpr& comprehension) { + PushComprehensionDeps(&record.expr->comprehension_expr(), record.expr, + &stack, options.use_comprehension_callbacks); + } + void operator()(const UnspecifiedExpr&) {} + } handler{stack, options, record}; + absl::visit(handler, record.expr->kind()); + } + + void operator()(const ArgRecord& record) { + stack.push(StackRecord(record.expr)); + } + + void operator()(const ComprehensionRecord& record) { + stack.push(StackRecord(record.expr)); + } + + std::stack& stack; + const TraversalOptions& options; +}; + +void PushDependencies(const StackRecord& record, std::stack& stack, + const TraversalOptions& options) { + absl::visit(PushDepsVisitor{stack, options}, record.record_variant); +} + +} // namespace + +namespace common_internal { +struct AstTraversalState { + std::stack stack; +}; +} // namespace common_internal + +AstTraversal AstTraversal::Create(const cel::Expr& ast, + const TraversalOptions& options) { + AstTraversal instance(options); + instance.state_ = std::make_unique(); + instance.state_->stack.push(StackRecord(&ast)); + return instance; +} + +AstTraversal::AstTraversal(TraversalOptions options) : options_(options) {} + +AstTraversal::~AstTraversal() = default; + +bool AstTraversal::Step(AstVisitor& visitor) { + if (IsDone()) { + return false; + } + auto& stack = state_->stack; + StackRecord& record = stack.top(); + if (!record.visited) { + PreVisit(record, &visitor); + PushDependencies(record, stack, options_); + record.visited = true; + } else { + PostVisit(record, &visitor); + stack.pop(); + } + + return !stack.empty(); +} + +bool AstTraversal::IsDone() { + return state_ == nullptr || state_->stack.empty(); +} + +void AstTraverse(const Expr& expr, AstVisitor& visitor, + TraversalOptions options) { + std::stack stack; + stack.push(StackRecord(&expr)); + + while (!stack.empty()) { + StackRecord& record = stack.top(); + if (!record.visited) { + PreVisit(record, &visitor); + PushDependencies(record, stack, options); + record.visited = true; + } else { + PostVisit(record, &visitor); + stack.pop(); + } + } +} + +} // namespace cel diff --git a/common/ast_traverse.h b/common/ast_traverse.h new file mode 100644 index 000000000..004727e49 --- /dev/null +++ b/common/ast_traverse.h @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_TRAVERSE_NATIVE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_TRAVERSE_NATIVE_H_ + +#include + +#include "absl/base/attributes.h" +#include "common/ast_visitor.h" +#include "common/expr.h" + +namespace cel { + +namespace common_internal { +struct AstTraversalState; +} + +struct TraversalOptions { + // Enable use of the comprehension specific callbacks. + bool use_comprehension_callbacks = false; +}; + +// Helper class for managing the traversal of the AST. +// Allows caller to step through the traversal. +// +// Usage: +// +// AstTraversal traversal = AstTraversal::Create(expr); +// +// MyVisitor visitor(); +// while(!traversal.IsDone()) { +// traversal.Step(visitor); +// } +// +// This class is thread-hostile and should only be used in synchronous code. +class AstTraversal { + public: + static AstTraversal Create(const cel::Expr& ast ABSL_ATTRIBUTE_LIFETIME_BOUND, + const TraversalOptions& options = {}); + + ~AstTraversal(); + + AstTraversal(const AstTraversal&) = delete; + AstTraversal& operator=(const AstTraversal&) = delete; + AstTraversal(AstTraversal&&) = default; + AstTraversal& operator=(AstTraversal&&) = default; + + // Advances the traversal. Returns true if there is more work to do. This is a + // no-op if the traversal is done and IsDone() is true. + bool Step(AstVisitor& visitor); + + // Returns true if there is no work left to do. + bool IsDone(); + + private: + explicit AstTraversal(TraversalOptions options); + TraversalOptions options_; + std::unique_ptr state_; +}; + +// Traverses the AST representation in an expr proto. +// +// expr: root node of the tree. +// source_info: optional additional parse information about the expression +// visitor: the callback object that receives the visitation notifications +// +// Traversal order follows the pattern: +// PreVisitExpr +// ..PreVisit{ExprKind} +// ....PreVisit{ArgumentIndex} +// .......PreVisitExpr (subtree) +// .......PostVisitExpr (subtree) +// ....PostVisit{ArgumentIndex} +// ..PostVisit{ExprKind} +// PostVisitExpr +// +// Example callback order for fn(1, var): +// PreVisitExpr +// ..PreVisitCall(fn) +// ......PreVisitExpr +// ........PostVisitConst(1) +// ......PostVisitExpr +// ....PostVisitArg(fn, 0) +// ......PreVisitExpr +// ........PostVisitIdent(var) +// ......PostVisitExpr +// ....PostVisitArg(fn, 1) +// ..PostVisitCall(fn) +// PostVisitExpr +void AstTraverse(const Expr& expr, AstVisitor& visitor, + TraversalOptions options = TraversalOptions()); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_TRAVERSE_NATIVE_H_ diff --git a/common/ast_traverse_test.cc b/common/ast_traverse_test.cc new file mode 100644 index 000000000..16ee40ce0 --- /dev/null +++ b/common/ast_traverse_test.cc @@ -0,0 +1,478 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ast_traverse.h" + +#include "common/ast_visitor.h" +#include "common/constant.h" +#include "common/expr.h" +#include "internal/testing.h" + +namespace cel::ast_internal { + +namespace { + +using ::testing::_; +using ::testing::Ref; + +class MockAstVisitor : public AstVisitor { + public: + // Expr handler. + MOCK_METHOD(void, PreVisitExpr, (const Expr& expr), (override)); + + // Expr handler. + MOCK_METHOD(void, PostVisitExpr, (const Expr& expr), (override)); + + MOCK_METHOD(void, PostVisitConst, + (const Expr& expr, const Constant& const_expr), (override)); + + // Ident node handler. + MOCK_METHOD(void, PostVisitIdent, + (const Expr& expr, const IdentExpr& ident_expr), (override)); + + // Select node handler group + MOCK_METHOD(void, PreVisitSelect, + (const Expr& expr, const SelectExpr& select_expr), (override)); + + MOCK_METHOD(void, PostVisitSelect, + (const Expr& expr, const SelectExpr& select_expr), (override)); + + // Call node handler group + MOCK_METHOD(void, PreVisitCall, (const Expr& expr, const CallExpr& call_expr), + (override)); + MOCK_METHOD(void, PostVisitCall, + (const Expr& expr, const CallExpr& call_expr), (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehension, + (const Expr& expr, const ComprehensionExpr& comprehension_expr), + (override)); + MOCK_METHOD(void, PostVisitComprehension, + (const Expr& expr, const ComprehensionExpr& comprehension_expr), + (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehensionSubexpression, + (const Expr& expr, const ComprehensionExpr& comprehension_expr, + ComprehensionArg comprehension_arg), + (override)); + MOCK_METHOD(void, PostVisitComprehensionSubexpression, + (const Expr& expr, const ComprehensionExpr& comprehension_expr, + ComprehensionArg comprehension_arg), + (override)); + + // We provide finer granularity for Call and Comprehension node callbacks + // to allow special handling for short-circuiting. + MOCK_METHOD(void, PostVisitTarget, (const Expr& expr), (override)); + MOCK_METHOD(void, PostVisitArg, (const Expr& expr, int arg_num), (override)); + + // List node handler group + MOCK_METHOD(void, PostVisitList, + (const Expr& expr, const ListExpr& list_expr), (override)); + + // Struct node handler group + MOCK_METHOD(void, PostVisitStruct, + (const Expr& expr, const StructExpr& struct_expr), (override)); + + // Map node handler group + MOCK_METHOD(void, PostVisitMap, (const Expr& expr, const MapExpr& map_expr), + (override)); +}; + +TEST(AstCrawlerTest, CheckCrawlConstant) { + MockAstVisitor handler; + + Expr expr; + auto& const_expr = expr.mutable_const_expr(); + + EXPECT_CALL(handler, PostVisitConst(Ref(expr), Ref(const_expr))).Times(1); + + AstTraverse(expr, handler); +} + +TEST(AstCrawlerTest, CheckCrawlIdent) { + MockAstVisitor handler; + + Expr expr; + auto& ident_expr = expr.mutable_ident_expr(); + + EXPECT_CALL(handler, PostVisitIdent(Ref(expr), Ref(ident_expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of Select node when operand is not set. +TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { + MockAstVisitor handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of Select node +TEST(AstCrawlerTest, CheckCrawlSelect) { + MockAstVisitor handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + auto& operand = select_expr.mutable_operand(); + auto& ident_expr = operand.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of Call node without receiver +TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { + MockAstVisitor handler; + + // (, ) + Expr expr; + auto& call_expr = expr.mutable_call_expr(); + call_expr.mutable_args().reserve(2); + auto& arg0 = call_expr.mutable_args().emplace_back(); + auto& const_expr = arg0.mutable_const_expr(); + auto& arg1 = call_expr.mutable_args().emplace_back(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); + EXPECT_CALL(handler, PostVisitTarget(_)).Times(0); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of Call node with receiver +TEST(AstCrawlerTest, CheckCrawlCallReceiver) { + MockAstVisitor handler; + + // .(, ) + Expr expr; + auto& call_expr = expr.mutable_call_expr(); + auto& target = call_expr.mutable_target(); + auto& target_ident = target.mutable_ident_expr(); + call_expr.mutable_args().reserve(2); + auto& arg0 = call_expr.mutable_args().emplace_back(); + auto& const_expr = arg0.mutable_const_expr(); + auto& arg1 = call_expr.mutable_args().emplace_back(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); + + // Target + EXPECT_CALL(handler, PostVisitIdent(Ref(target), Ref(target_ident))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(target))).Times(1); + EXPECT_CALL(handler, PostVisitTarget(Ref(expr))).Times(1); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); + EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehension) { + MockAstVisitor handler; + + Expr expr; + auto& c = expr.mutable_comprehension_expr(); + auto& iter_range = c.mutable_iter_range(); + auto& iter_range_expr = iter_range.mutable_const_expr(); + auto& accu_init = c.mutable_accu_init(); + auto& accu_init_expr = accu_init.mutable_ident_expr(); + auto& loop_condition = c.mutable_loop_condition(); + auto& loop_condition_expr = loop_condition.mutable_const_expr(); + auto& loop_step = c.mutable_loop_step(); + auto& loop_step_expr = loop_step.mutable_ident_expr(); + auto& result = c.mutable_result(); + auto& result_expr = result.mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); + + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ITER_RANGE)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), + ITER_RANGE)) + .Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) + .Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) + .Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), + LOOP_CONDITION)) + .Times(1); + EXPECT_CALL(handler, + PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), + LOOP_CONDITION)) + .Times(1); + + // LOOP STEP + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) + .Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) + .Times(1); + + // RESULT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) + .Times(1); + + EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); + + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) + .Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); + + TraversalOptions opts; + opts.use_comprehension_callbacks = true; + AstTraverse(expr, handler, opts); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { + MockAstVisitor handler; + + Expr expr; + auto& c = expr.mutable_comprehension_expr(); + auto& iter_range = c.mutable_iter_range(); + auto& iter_range_expr = iter_range.mutable_const_expr(); + auto& accu_init = c.mutable_accu_init(); + auto& accu_init_expr = accu_init.mutable_ident_expr(); + auto& loop_condition = c.mutable_loop_condition(); + auto& loop_condition_expr = loop_condition.mutable_const_expr(); + auto& loop_step = c.mutable_loop_step(); + auto& loop_step_expr = loop_step.mutable_ident_expr(); + auto& result = c.mutable_result(); + auto& result_expr = result.mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); + + EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), ITER_RANGE)).Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), ACCU_INIT)).Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, + PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_CONDITION)).Times(1); + + // LOOP STEP + EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_STEP)).Times(1); + + // RESULT + EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); + EXPECT_CALL(handler, PostVisitArg(Ref(expr), RESULT)).Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of List node. +TEST(AstCrawlerTest, CheckList) { + MockAstVisitor handler; + + Expr expr; + auto& list_expr = expr.mutable_list_expr(); + list_expr.mutable_elements().reserve(2); + auto& arg0 = list_expr.mutable_elements().emplace_back().mutable_expr(); + auto& const_expr = arg0.mutable_const_expr(); + auto& arg1 = list_expr.mutable_elements().emplace_back().mutable_expr(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitList(Ref(expr), Ref(list_expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of Struct node. +TEST(AstCrawlerTest, CheckStruct) { + MockAstVisitor handler; + + Expr expr; + auto& struct_expr = expr.mutable_struct_expr(); + auto& entry0 = struct_expr.mutable_fields().emplace_back(); + + auto& value = entry0.mutable_value().mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) + .Times(1); + EXPECT_CALL(handler, PostVisitStruct(Ref(expr), Ref(struct_expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test handling of Map node. +TEST(AstCrawlerTest, CheckMap) { + MockAstVisitor handler; + + Expr expr; + auto& map_expr = expr.mutable_map_expr(); + auto& entry0 = map_expr.mutable_entries().emplace_back(); + + auto& key = entry0.mutable_key().mutable_const_expr(); + auto& value = entry0.mutable_value().mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(Ref(entry0.key()), Ref(key))).Times(1); + EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) + .Times(1); + EXPECT_CALL(handler, PostVisitMap(Ref(expr), Ref(map_expr))).Times(1); + + AstTraverse(expr, handler); +} + +// Test generic Expr handlers. +TEST(AstCrawlerTest, CheckExprHandlers) { + MockAstVisitor handler; + + Expr expr; + auto& map_expr = expr.mutable_map_expr(); + auto& entry0 = map_expr.mutable_entries().emplace_back(); + + entry0.mutable_key().mutable_const_expr(); + entry0.mutable_value().mutable_ident_expr(); + + EXPECT_CALL(handler, PreVisitExpr(_)).Times(3); + EXPECT_CALL(handler, PostVisitExpr(_)).Times(3); + + AstTraverse(expr, handler); +} + +TEST(AstTraversal, Interrupt) { + MockAstVisitor handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + auto& operand = select_expr.mutable_operand(); + auto& ident_expr = operand.mutable_ident_expr(); + + testing::InSequence seq; + + auto traversal = AstTraversal::Create(expr); + + EXPECT_CALL(handler, PreVisitExpr(_)).Times(2); + + EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(0); + + EXPECT_TRUE(traversal.Step(handler)); + EXPECT_TRUE(traversal.Step(handler)); + EXPECT_TRUE(traversal.Step(handler)); + + EXPECT_FALSE(traversal.IsDone()); +} + +TEST(AstTraversal, NoInterrupt) { + MockAstVisitor handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + auto& operand = select_expr.mutable_operand(); + auto& ident_expr = operand.mutable_ident_expr(); + + testing::InSequence seq; + + auto traversal = AstTraversal::Create(expr); + + EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); + EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); + + while (traversal.Step(handler)) continue; + EXPECT_TRUE(traversal.IsDone()); +} + +} // namespace + +} // namespace cel::ast_internal diff --git a/common/ast_visitor.h b/common/ast_visitor.h new file mode 100644 index 000000000..3e1f4929e --- /dev/null +++ b/common/ast_visitor.h @@ -0,0 +1,118 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_NATIVE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_NATIVE_H_ + +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +// ComprehensionArg specifies arg_num values passed to PostVisitArg +// for subexpressions of Comprehension. +enum ComprehensionArg { + ITER_RANGE, + ACCU_INIT, + LOOP_CONDITION, + LOOP_STEP, + RESULT, +}; + +// Callback handler class, used in conjunction with AstTraverse. +// Methods of this class are invoked when AST nodes with corresponding +// types are processed. +// +// For all types with children, the children will be visited in the natural +// order from first to last. For structs, keys are visited before values. +class AstVisitor { + public: + virtual ~AstVisitor() = default; + + // Expr node handler method. Called for all Expr nodes. + // Is invoked before child Expr nodes being processed. + virtual void PreVisitExpr(const Expr&) = 0; + + // Expr node handler method. Called for all Expr nodes. + // Is invoked after child Expr nodes are processed. + virtual void PostVisitExpr(const Expr&) = 0; + + // Const node handler. + // Invoked after child nodes are processed. + virtual void PostVisitConst(const Expr&, const Constant&) = 0; + + // Ident node handler. + // Invoked after child nodes are processed. + virtual void PostVisitIdent(const Expr&, const IdentExpr&) = 0; + + // Select node handler + // Invoked before child nodes are processed. + virtual void PreVisitSelect(const Expr&, const SelectExpr&) = 0; + + // Select node handler + // Invoked after child nodes are processed. + virtual void PostVisitSelect(const Expr&, const SelectExpr&) = 0; + + // Call node handler group + // We provide finer granularity for Call node callbacks to allow special + // handling for short-circuiting + // PreVisitCall is invoked before child nodes are processed. + virtual void PreVisitCall(const Expr&, const CallExpr&) = 0; + + // Invoked after all child nodes are processed. + virtual void PostVisitCall(const Expr&, const CallExpr&) = 0; + + // Invoked after target node is processed. + // Expr is the call expression. + virtual void PostVisitTarget(const Expr&) = 0; + + // Invoked before all child nodes are processed. + virtual void PreVisitComprehension(const Expr&, const ComprehensionExpr&) = 0; + + // Invoked before comprehension child node is processed. + virtual void PreVisitComprehensionSubexpression( + const Expr&, const ComprehensionExpr& compr, + ComprehensionArg comprehension_arg) {} + + // Invoked after comprehension child node is processed. + virtual void PostVisitComprehensionSubexpression( + const Expr&, const ComprehensionExpr& compr, + ComprehensionArg comprehension_arg) {} + + // Invoked after all child nodes are processed. + virtual void PostVisitComprehension(const Expr&, + const ComprehensionExpr&) = 0; + + // Invoked after each argument node processed. + // For Call arg_num is the index of the argument. + // For Comprehension arg_num is specified by ComprehensionArg. + // Expr is the call expression. + virtual void PostVisitArg(const Expr&, int arg_num) = 0; + + // List node handler + // Invoked after child nodes are processed. + virtual void PostVisitList(const Expr&, const ListExpr&) = 0; + + // Struct node handler + // Invoked after child nodes are processed. + virtual void PostVisitStruct(const Expr&, const StructExpr&) = 0; + + // Map node handler + // Invoked after child nodes are processed. + virtual void PostVisitMap(const Expr&, const MapExpr&) = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_NATIVE_H_ diff --git a/common/ast_visitor_base.h b/common/ast_visitor_base.h new file mode 100644 index 000000000..e78d3f46c --- /dev/null +++ b/common/ast_visitor_base.h @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_BASE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_BASE_H_ + +#include "common/ast_visitor.h" +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +// Trivial base implementation of AstVisitor. +class AstVisitorBase : public AstVisitor { + public: + AstVisitorBase() = default; + + // Non-copyable + AstVisitorBase(const AstVisitorBase&) = delete; + AstVisitorBase& operator=(AstVisitorBase const&) = delete; + + ~AstVisitorBase() override {} + + // Const node handler. + // Invoked after child nodes are processed. + void PostVisitConst(const Expr&, const Constant&) override {} + + // Ident node handler. + // Invoked after child nodes are processed. + void PostVisitIdent(const Expr&, const IdentExpr&) override {} + + void PreVisitSelect(const Expr&, const SelectExpr&) override {} + + // Select node handler + // Invoked after child nodes are processed. + void PostVisitSelect(const Expr&, const SelectExpr&) override {} + + // Call node handler group + // We provide finer granularity for Call node callbacks to allow special + // handling for short-circuiting + // PreVisitCall is invoked before child nodes are processed. + void PreVisitCall(const Expr&, const CallExpr&) override {} + + // Invoked after all child nodes are processed. + void PostVisitCall(const Expr&, const CallExpr&) override {} + + // Invoked before all child nodes are processed. + void PreVisitComprehension(const Expr&, const ComprehensionExpr&) override {} + + // Invoked after all child nodes are processed. + void PostVisitComprehension(const Expr&, const ComprehensionExpr&) override {} + + // Invoked after each argument node processed. + // For Call arg_num is the index of the argument. + // For Comprehension arg_num is specified by ComprehensionArg. + // Expr is the call expression. + void PostVisitArg(const Expr&, int) override {} + + // Invoked after target node processed. + void PostVisitTarget(const Expr&) override {} + + // List node handler + // Invoked after child nodes are processed. + void PostVisitList(const Expr&, const ListExpr&) override {} + + // Struct node handler + // Invoked after child nodes are processed. + void PostVisitStruct(const Expr&, const StructExpr&) override {} + + // Map node handler + // Invoked after child nodes are processed. + void PostVisitMap(const Expr&, const MapExpr&) override {} +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_BASE_H_ diff --git a/common/casting.h b/common/casting.h new file mode 100644 index 000000000..69074d4d9 --- /dev/null +++ b/common/casting.h @@ -0,0 +1,75 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_CASTING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_CASTING_H_ + +#include "absl/base/attributes.h" +#include "common/internal/casting.h" + +namespace cel { + +// `InstanceOf(const From&)` determines whether `From` holds or is `To`. +// +// `To` must be a plain non-union class type that is not qualified. +// +// We expose `InstanceOf` this way to avoid ADL. +// +// Example: +// +// if (InstanceOf(superclass)) { +// Cast(superclass).SomeMethod(); +// } +template +ABSL_DEPRECATED("Use Is member functions instead.") +inline constexpr common_internal::InstanceOfImpl InstanceOf{}; + +// `Cast(From)` is a "checked cast". In debug builds an assertion is emitted +// which verifies `From` is an instance-of `To`. In non-debug builds, invalid +// casts are undefined behavior. +// +// We expose `Cast` this way to avoid ADL. +// +// Example: +// +// if (InstanceOf(superclass)) { +// Cast(superclass).SomeMethod(); +// } +template +ABSL_DEPRECATED( + "Use explicit conversion functions instead through static_cast.") +inline constexpr common_internal::CastImpl Cast{}; + +// `As(From)` is a "checking cast". The result is explicitly convertible to +// `bool`, such that it can be used with `if` statements. The result can be +// accessed with `operator*` or `operator->`. The return type should be treated +// as an implementation detail, with no assumptions on the concrete type. You +// should use `auto`. +// +// `As` is analogous to the paradigm `if (InstanceOf(a)) Cast(a)`. +// +// We expose `As` this way to avoid ADL. +// +// Example: +// +// if (auto subclass = As(superclass); subclass) { +// subclass->SomeMethod(); +// } +template +ABSL_DEPRECATED("Use As member functions instead.") +inline constexpr common_internal::AsImpl As{}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INSTANCE_OF_H_ diff --git a/common/constant.cc b/common/constant.cc new file mode 100644 index 000000000..f335fb535 --- /dev/null +++ b/common/constant.cc @@ -0,0 +1,101 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/constant.h" + +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "internal/strings.h" + +namespace cel { + +const BytesConstant& BytesConstant::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const StringConstant& StringConstant::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const Constant& Constant::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +std::string FormatNullConstant() { return "null"; } + +std::string FormatBoolConstant(bool value) { + return value ? std::string("true") : std::string("false"); +} + +std::string FormatIntConstant(int64_t value) { return absl::StrCat(value); } + +std::string FormatUintConstant(uint64_t value) { + return absl::StrCat(value, "u"); +} + +std::string FormatDoubleConstant(double value) { + if (std::isfinite(value)) { + if (std::floor(value) != value) { + // The double is not representable as a whole number, so use + // absl::StrCat which will add decimal places. + return absl::StrCat(value); + } + // absl::StrCat historically would represent 0.0 as 0, and we want the + // decimal places so ZetaSQL correctly assumes the type as double + // instead of int64. + std::string stringified = absl::StrCat(value); + if (!absl::StrContains(stringified, '.')) { + absl::StrAppend(&stringified, ".0"); + } + return stringified; + } + if (std::isnan(value)) { + return "nan"; + } + if (std::signbit(value)) { + return "-infinity"; + } + return "+infinity"; +} + +std::string FormatBytesConstant(absl::string_view value) { + return internal::FormatBytesLiteral(value); +} + +std::string FormatStringConstant(absl::string_view value) { + return internal::FormatStringLiteral(value); +} + +std::string FormatDurationConstant(absl::Duration value) { + return absl::StrCat("duration(\"", absl::FormatDuration(value), "\")"); +} + +std::string FormatTimestampConstant(absl::Time value) { + return absl::StrCat( + "timestamp(\"", + absl::FormatTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, absl::UTCTimeZone()), + "\")"); +} + +} // namespace cel diff --git a/common/constant.h b/common/constant.h new file mode 100644 index 000000000..ac9a2942b --- /dev/null +++ b/common/constant.h @@ -0,0 +1,491 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_CONSTANT_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_CONSTANT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/overload.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" + +namespace cel { + +class Expr; +class Constant; +class BytesConstant; +class StringConstant; +class VariableDecl; + +class BytesConstant final : public std::string { + public: + explicit BytesConstant(std::string string) : std::string(std::move(string)) {} + + explicit BytesConstant(absl::string_view string) + : BytesConstant(std::string(string)) {} + + explicit BytesConstant(const char* string) + : BytesConstant(absl::NullSafeStringView(string)) {} + + BytesConstant() = default; + BytesConstant(const BytesConstant&) = default; + BytesConstant(BytesConstant&&) = default; + BytesConstant& operator=(const BytesConstant&) = default; + BytesConstant& operator=(BytesConstant&&) = default; + + BytesConstant(const StringConstant&) = delete; + BytesConstant(StringConstant&&) = delete; + BytesConstant& operator=(const StringConstant&) = delete; + BytesConstant& operator=(StringConstant&&) = delete; + + private: + static const BytesConstant& default_instance(); + + friend class Constant; +}; + +class StringConstant final : public std::string { + public: + explicit StringConstant(std::string string) + : std::string(std::move(string)) {} + + explicit StringConstant(absl::string_view string) + : StringConstant(std::string(string)) {} + + explicit StringConstant(const char* string) + : StringConstant(absl::NullSafeStringView(string)) {} + + StringConstant() = default; + StringConstant(const StringConstant&) = default; + StringConstant(StringConstant&&) = default; + StringConstant& operator=(const StringConstant&) = default; + StringConstant& operator=(StringConstant&&) = default; + + StringConstant(const BytesConstant&) = delete; + StringConstant(BytesConstant&&) = delete; + StringConstant& operator=(const BytesConstant&) = delete; + StringConstant& operator=(BytesConstant&&) = delete; + + private: + static const StringConstant& default_instance(); + + friend class Constant; +}; + +namespace common_internal { + +template +struct ConstantKindIndexer { + static constexpr size_t value = + std::conditional_t, + std::integral_constant, + ConstantKindIndexer>::value; +}; + +template +struct ConstantKindIndexer { + static constexpr size_t value = std::conditional_t< + std::is_same_v, std::integral_constant, + std::integral_constant>::value; +}; + +template +struct ConstantKindImpl { + using VariantType = absl::variant; + + template + static constexpr size_t IndexOf() { + return ConstantKindIndexer<0, U, Ts...>::value; + } +}; + +using ConstantKind = + ConstantKindImpl; + +static_assert(ConstantKind::IndexOf() == 0); +static_assert(ConstantKind::IndexOf() == 1); +static_assert(ConstantKind::IndexOf() == 2); +static_assert(ConstantKind::IndexOf() == 3); +static_assert(ConstantKind::IndexOf() == 4); +static_assert(ConstantKind::IndexOf() == 5); +static_assert(ConstantKind::IndexOf() == 6); +static_assert(ConstantKind::IndexOf() == 7); +static_assert(ConstantKind::IndexOf() == 8); +static_assert(ConstantKind::IndexOf() == 9); +static_assert(ConstantKind::IndexOf() == absl::variant_npos); + +} // namespace common_internal + +// Constant is a variant composed of all the literal types support by the Common +// Expression Language. +using ConstantKind = common_internal::ConstantKind::VariantType; + +enum class ConstantKindCase { + kUnspecified, + kNull, + kBool, + kInt, + kUint, + kDouble, + kBytes, + kString, + kDuration, + kTimestamp, +}; + +template +constexpr size_t ConstantKindIndexOf() { + return common_internal::ConstantKind::IndexOf(); +} + +// Returns the `null` literal. +std::string FormatNullConstant(); +inline std::string FormatNullConstant(std::nullptr_t) { + return FormatNullConstant(); +} + +// Formats `value` as a bool literal. +std::string FormatBoolConstant(bool value); + +// Formats `value` as a int literal. +std::string FormatIntConstant(int64_t value); + +// Formats `value` as a uint literal. +std::string FormatUintConstant(uint64_t value); + +// Formats `value` as a double literal-like representation. Due to Common +// Expression Language not having NaN or infinity literals, the result will not +// always be syntactically valid. +std::string FormatDoubleConstant(double value); + +// Formats `value` as a bytes literal. +std::string FormatBytesConstant(absl::string_view value); + +// Formats `value` as a string literal. +std::string FormatStringConstant(absl::string_view value); + +// Formats `value` as a duration constant. +std::string FormatDurationConstant(absl::Duration value); + +// Formats `value` as a timestamp constant. +std::string FormatTimestampConstant(absl::Time value); + +// Represents a primitive literal. +// +// This is similar as the primitives supported in the well-known type +// `google.protobuf.Value`, but richer so it can represent CEL's full range of +// primitives. +// +// Lists and structs are not included as constants as these aggregate types may +// contain [Expr][] elements which require evaluation and are thus not constant. +// +// Examples of constants include: `"hello"`, `b'bytes'`, `1u`, `4.2`, `-2`, +// `true`, `null`. +class Constant final { + public: + Constant() = default; + Constant(const Constant&) = default; + Constant(Constant&&) = default; + Constant& operator=(const Constant&) = default; + Constant& operator=(Constant&&) = default; + + explicit Constant(ConstantKind kind) : kind_(std::move(kind)) {} + + ABSL_MUST_USE_RESULT const ConstantKind& kind() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return kind_; + } + + ABSL_DEPRECATED("Use kind()") + ABSL_MUST_USE_RESULT const ConstantKind& constant_kind() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return kind(); + } + + ABSL_MUST_USE_RESULT bool has_value() const { + return !absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT bool has_null_value() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT std::nullptr_t null_value() const { return nullptr; } + + void set_null_value() { mutable_kind().emplace(); } + + void set_null_value(std::nullptr_t) { set_null_value(); } + + ABSL_MUST_USE_RESULT bool has_bool_value() const { + return absl::holds_alternative(kind()); + } + + void set_bool_value(bool value) { mutable_kind().emplace(value); } + + ABSL_MUST_USE_RESULT bool bool_value() const { return get_value(); } + + ABSL_MUST_USE_RESULT bool has_int_value() const { + return absl::holds_alternative(kind()); + } + + void set_int_value(int64_t value) { mutable_kind().emplace(value); } + + ABSL_MUST_USE_RESULT int64_t int_value() const { + return get_value(); + } + + ABSL_MUST_USE_RESULT bool has_uint_value() const { + return absl::holds_alternative(kind()); + } + + void set_uint_value(uint64_t value) { + mutable_kind().emplace(value); + } + + ABSL_MUST_USE_RESULT uint64_t uint_value() const { + return get_value(); + } + + ABSL_DEPRECATED("Use has_int_value") + ABSL_MUST_USE_RESULT bool has_int64_value() const { return has_int_value(); } + + ABSL_DEPRECATED("Use set_int_value()") + void set_int64_value(int64_t value) { set_int_value(value); } + + ABSL_DEPRECATED("Use int_value()") + ABSL_MUST_USE_RESULT int64_t int64_value() const { return int_value(); } + + ABSL_DEPRECATED("Use has_uint_value()") + ABSL_MUST_USE_RESULT bool has_uint64_value() const { + return has_uint_value(); + } + + ABSL_DEPRECATED("Use set_uint_value()") + void set_uint64_value(uint64_t value) { set_uint_value(value); } + + ABSL_DEPRECATED("Use uint_value()") + ABSL_MUST_USE_RESULT uint64_t uint64_value() const { return uint_value(); } + + ABSL_MUST_USE_RESULT bool has_double_value() const { + return absl::holds_alternative(kind()); + } + + void set_double_value(double value) { mutable_kind().emplace(value); } + + ABSL_MUST_USE_RESULT double double_value() const { + return get_value(); + } + + ABSL_MUST_USE_RESULT bool has_bytes_value() const { + return absl::holds_alternative(kind()); + } + + void set_bytes_value(BytesConstant value) { + mutable_kind().emplace(std::move(value)); + } + + void set_bytes_value(std::string value) { + set_bytes_value(BytesConstant{std::move(value)}); + } + + void set_bytes_value(absl::string_view value) { + set_bytes_value(BytesConstant{value}); + } + + void set_bytes_value(const char* value) { + set_bytes_value(absl::NullSafeStringView(value)); + } + + ABSL_MUST_USE_RESULT const std::string& bytes_value() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (const auto* alt = absl::get_if(&kind()); alt) { + return *alt; + } + return BytesConstant::default_instance(); + } + + ABSL_MUST_USE_RESULT std::string release_bytes_value() { + std::string string; + if (auto* alt = absl::get_if(&mutable_kind()); alt) { + string.swap(*alt); + } + mutable_kind().emplace(); + return string; + } + + ABSL_MUST_USE_RESULT bool has_string_value() const { + return absl::holds_alternative(kind()); + } + + void set_string_value(StringConstant value) { + mutable_kind().emplace(std::move(value)); + } + + void set_string_value(std::string value) { + set_string_value(StringConstant{std::move(value)}); + } + + void set_string_value(absl::string_view value) { + set_string_value(StringConstant{value}); + } + + void set_string_value(const char* value) { + set_string_value(absl::NullSafeStringView(value)); + } + + ABSL_MUST_USE_RESULT const std::string& string_value() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (const auto* alt = absl::get_if(&kind()); alt) { + return *alt; + } + return StringConstant::default_instance(); + } + + ABSL_MUST_USE_RESULT std::string release_string_value() { + std::string string; + if (auto* alt = absl::get_if(&mutable_kind()); alt) { + string.swap(*alt); + } + mutable_kind().emplace(); + return string; + } + + ABSL_DEPRECATED("duration is no longer considered a builtin type") + ABSL_MUST_USE_RESULT bool has_duration_value() const { + return absl::holds_alternative(kind()); + } + + ABSL_DEPRECATED("duration is no longer considered a builtin type") + void set_duration_value(absl::Duration value) { + mutable_kind().emplace(value); + } + + ABSL_DEPRECATED("duration is no longer considered a builtin type") + ABSL_MUST_USE_RESULT absl::Duration duration_value() const { + return get_value(); + } + + ABSL_DEPRECATED("timestamp is no longer considered a builtin type") + ABSL_MUST_USE_RESULT bool has_timestamp_value() const { + return absl::holds_alternative(kind()); + } + + ABSL_DEPRECATED("timestamp is no longer considered a builtin type") + void set_timestamp_value(absl::Time value) { + mutable_kind().emplace(value); + } + + ABSL_DEPRECATED("timestamp is no longer considered a builtin type") + ABSL_MUST_USE_RESULT absl::Time timestamp_value() const { + return get_value(); + } + + ABSL_DEPRECATED("Use has_timestamp_value()") + ABSL_MUST_USE_RESULT bool has_time_value() const { + return has_timestamp_value(); + } + + ABSL_DEPRECATED("Use set_timestamp_value()") + void set_time_value(absl::Time value) { set_timestamp_value(value); } + + ABSL_DEPRECATED("Use timestamp_value()") + ABSL_MUST_USE_RESULT absl::Time time_value() const { + return timestamp_value(); + } + + ConstantKindCase kind_case() const { + static_assert(absl::variant_size_v == 10); + if (kind_.index() <= 10) { + return static_cast(kind_.index()); + } + return ConstantKindCase::kUnspecified; + } + + private: + friend class Expr; + friend class VariableDecl; + + static const Constant& default_instance(); + + ABSL_MUST_USE_RESULT ConstantKind& mutable_kind() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return kind_; + } + + template + T get_value() const { + if (const auto* alt = absl::get_if(&kind()); alt) { + return *alt; + } + return T{}; + } + + ConstantKind kind_; +}; + +inline bool operator==(const Constant& lhs, const Constant& rhs) { + return lhs.kind() == rhs.kind(); +} + +inline bool operator!=(const Constant& lhs, const Constant& rhs) { + return lhs.kind() != rhs.kind(); +} + +template +void AbslStringify(Sink& sink, const Constant& constant) { + absl::visit( + absl::Overload( + [&sink](absl::monostate) -> void { sink.Append(""); }, + [&sink](std::nullptr_t value) -> void { + sink.Append(FormatNullConstant(value)); + }, + [&sink](bool value) -> void { + sink.Append(FormatBoolConstant(value)); + }, + [&sink](int64_t value) -> void { + sink.Append(FormatIntConstant(value)); + }, + [&sink](uint64_t value) -> void { + sink.Append(FormatUintConstant(value)); + }, + [&sink](double value) -> void { + sink.Append(FormatDoubleConstant(value)); + }, + [&sink](const BytesConstant& value) -> void { + sink.Append(FormatBytesConstant(value)); + }, + [&sink](const StringConstant& value) -> void { + sink.Append(FormatStringConstant(value)); + }, + [&sink](absl::Duration value) -> void { + sink.Append(FormatDurationConstant(value)); + }, + [&sink](absl::Time value) -> void { + sink.Append(FormatTimestampConstant(value)); + }), + constant.kind()); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_CONSTANT_H_ diff --git a/common/constant_test.cc b/common/constant_test.cc new file mode 100644 index 000000000..1f8448ecb --- /dev/null +++ b/common/constant_test.cc @@ -0,0 +1,286 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/constant.h" + +#include +#include +#include +#include + +#include "absl/strings/has_absl_stringify.h" +#include "absl/strings/str_format.h" +#include "absl/time/time.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::IsEmpty; +using ::testing::IsFalse; +using ::testing::IsTrue; + +TEST(Constant, NullValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_null_value(), IsFalse()); + const_expr.set_null_value(); + EXPECT_THAT(const_expr.has_null_value(), IsTrue()); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kNull); +} + +TEST(Constant, BoolValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_bool_value(), IsFalse()); + EXPECT_EQ(const_expr.bool_value(), false); + const_expr.set_bool_value(false); + EXPECT_THAT(const_expr.has_bool_value(), IsTrue()); + EXPECT_EQ(const_expr.bool_value(), false); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kBool); +} + +TEST(Constant, IntValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_int_value(), IsFalse()); + EXPECT_EQ(const_expr.int_value(), 0); + const_expr.set_int_value(0); + EXPECT_THAT(const_expr.has_int_value(), IsTrue()); + EXPECT_EQ(const_expr.int_value(), 0); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kInt); +} + +TEST(Constant, UintValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_uint_value(), IsFalse()); + EXPECT_EQ(const_expr.uint_value(), 0); + const_expr.set_uint_value(0); + EXPECT_THAT(const_expr.has_uint_value(), IsTrue()); + EXPECT_EQ(const_expr.uint_value(), 0); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kUint); +} + +TEST(Constant, DoubleValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_double_value(), IsFalse()); + EXPECT_EQ(const_expr.double_value(), 0); + const_expr.set_double_value(0); + EXPECT_THAT(const_expr.has_double_value(), IsTrue()); + EXPECT_EQ(const_expr.double_value(), 0); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kDouble); +} + +TEST(Constant, BytesValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_bytes_value(), IsFalse()); + EXPECT_THAT(const_expr.bytes_value(), IsEmpty()); + const_expr.set_bytes_value("foo"); + EXPECT_THAT(const_expr.has_bytes_value(), IsTrue()); + EXPECT_EQ(const_expr.bytes_value(), "foo"); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kBytes); +} + +TEST(Constant, StringValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_string_value(), IsFalse()); + EXPECT_THAT(const_expr.string_value(), IsEmpty()); + const_expr.set_string_value("foo"); + EXPECT_THAT(const_expr.has_string_value(), IsTrue()); + EXPECT_EQ(const_expr.string_value(), "foo"); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kString); +} + +TEST(Constant, DurationValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_duration_value(), IsFalse()); + EXPECT_EQ(const_expr.duration_value(), absl::ZeroDuration()); + const_expr.set_duration_value(absl::ZeroDuration()); + EXPECT_THAT(const_expr.has_duration_value(), IsTrue()); + EXPECT_EQ(const_expr.duration_value(), absl::ZeroDuration()); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kDuration); +} + +TEST(Constant, TimestampValue) { + Constant const_expr; + EXPECT_THAT(const_expr.has_timestamp_value(), IsFalse()); + EXPECT_EQ(const_expr.timestamp_value(), absl::UnixEpoch()); + const_expr.set_timestamp_value(absl::UnixEpoch()); + EXPECT_THAT(const_expr.has_timestamp_value(), IsTrue()); + EXPECT_EQ(const_expr.timestamp_value(), absl::UnixEpoch()); + EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kTimestamp); +} + +TEST(Constant, DefaultConstructed) { + Constant const_expr; + EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kUnspecified); +} + +TEST(Constant, Equality) { + EXPECT_EQ(Constant{}, Constant{}); + + Constant lhs_const_expr; + Constant rhs_const_expr; + + lhs_const_expr.set_null_value(); + rhs_const_expr.set_null_value(); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_bool_value(false); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_bool_value(false); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_int_value(0); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_int_value(0); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_uint_value(0); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_uint_value(0); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_double_value(0); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_double_value(0); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_bytes_value("foo"); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_bytes_value("foo"); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_string_value("foo"); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_string_value("foo"); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_duration_value(absl::ZeroDuration()); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_duration_value(absl::ZeroDuration()); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + + lhs_const_expr.set_timestamp_value(absl::UnixEpoch()); + rhs_const_expr.set_null_value(); + EXPECT_NE(lhs_const_expr, rhs_const_expr); + EXPECT_NE(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); + rhs_const_expr.set_timestamp_value(absl::UnixEpoch()); + EXPECT_EQ(lhs_const_expr, rhs_const_expr); + EXPECT_EQ(rhs_const_expr, lhs_const_expr); + EXPECT_NE(lhs_const_expr, Constant{}); + EXPECT_NE(Constant{}, rhs_const_expr); +} + +std::string Stringify(const Constant& constant) { + return absl::StrFormat("%v", constant); +} + +TEST(Constant, HasAbslStringify) { + EXPECT_TRUE(absl::HasAbslStringify::value); +} + +TEST(Constant, AbslStringify) { + Constant constant; + EXPECT_EQ(Stringify(constant), ""); + constant.set_null_value(); + EXPECT_EQ(Stringify(constant), "null"); + constant.set_bool_value(true); + EXPECT_EQ(Stringify(constant), "true"); + constant.set_int_value(1); + EXPECT_EQ(Stringify(constant), "1"); + constant.set_uint_value(1); + EXPECT_EQ(Stringify(constant), "1u"); + constant.set_double_value(1); + EXPECT_EQ(Stringify(constant), "1.0"); + constant.set_double_value(1.1); + EXPECT_EQ(Stringify(constant), "1.1"); + constant.set_double_value(NAN); + EXPECT_EQ(Stringify(constant), "nan"); + constant.set_double_value(INFINITY); + EXPECT_EQ(Stringify(constant), "+infinity"); + constant.set_double_value(-INFINITY); + EXPECT_EQ(Stringify(constant), "-infinity"); + constant.set_bytes_value("foo"); + EXPECT_EQ(Stringify(constant), "b\"foo\""); + constant.set_string_value("foo"); + EXPECT_EQ(Stringify(constant), "\"foo\""); + constant.set_duration_value(absl::Seconds(1)); + EXPECT_EQ(Stringify(constant), "duration(\"1s\")"); + constant.set_timestamp_value(absl::UnixEpoch() + absl::Seconds(1)); + EXPECT_EQ(Stringify(constant), "timestamp(\"1970-01-01T00:00:01Z\")"); +} + +} // namespace +} // namespace cel diff --git a/common/container.cc b/common/container.cc new file mode 100644 index 000000000..e1db8f86c --- /dev/null +++ b/common/container.cc @@ -0,0 +1,171 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/lexis.h" + +namespace cel { +namespace { + +bool IsValidQualifiedName(absl::string_view name) { + auto dot_pos = name.find('.'); + while (dot_pos != absl::string_view::npos) { + if (!internal::LexisIsIdentifier(name.substr(0, dot_pos))) { + return false; + } + name = name.substr(dot_pos + 1); + dot_pos = name.find('.'); + } + return internal::LexisIsIdentifier(name); +} + +bool IsValidAlias(absl::string_view alias) { + return internal::LexisIsIdentifier(alias); +} + +bool IsAbbreviationImpl(absl::string_view alias, absl::string_view name) { + auto pos = name.rfind('.'); + return pos != std::string::npos && pos > 0 && pos < name.size() - 1 && + alias == name.substr(pos + 1); +} + +} // namespace + +bool ExpressionContainer::AliasListing::IsAbbreviation() const { + return IsAbbreviationImpl(alias, name); +} + +absl::StatusOr MakeExpressionContainer( + absl::string_view name) { + ExpressionContainer container; + + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + return container; +} + +absl::Status ExpressionContainer::SetContainer(absl::string_view name) { + if (name.empty()) { + container_ = ""; + return absl::OkStatus(); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + for (const auto& entry : aliases_) { + const std::string& alias = entry.first; + if (name == alias || + (name.size() > alias.size() && + absl::string_view(name).substr(0, alias.size()) == alias && + name.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("container name collides with alias: ", alias)); + } + } + + container_ = std::string(name); + return absl::OkStatus(); +} + +absl::Status ExpressionContainer::AddAbbreviation(absl::string_view abrev) { + abrev = absl::StripAsciiWhitespace(abrev); + if (!IsValidQualifiedName(abrev)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); + } + + auto pos = abrev.rfind('.'); + if (pos == 0 || pos == absl::string_view::npos || pos == abrev.size() - 1) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); + } + + absl::string_view alias = abrev.substr(pos + 1); + return AddAlias(alias, abrev); +} + +absl::Status ExpressionContainer::AddAlias(absl::string_view alias, + absl::string_view name) { + if (!IsValidAlias(alias)) { + return absl::InvalidArgumentError(absl::StrCat( + "alias must be non-empty and simple (not qualified): ", alias)); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + if (auto it = aliases_.find(alias); it != aliases_.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "alias collides with existing reference: ", alias, " -> ", it->second)); + } + + if (container_ == alias || + (container_.size() > alias.size() && + absl::string_view(container_).substr(0, alias.size()) == alias && + container_.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("alias collides with container name: ", alias)); + } + + aliases_.insert({std::string(alias), std::string(name)}); + return absl::OkStatus(); +} + +absl::string_view ExpressionContainer::FindAlias( + absl::string_view alias) const { + auto it = aliases_.find(alias); + if (it != aliases_.end()) { + return it->second; + } + return ""; +} + +std::vector ExpressionContainer::ListAbbreviations() const { + std::vector res; + for (const auto& entry : aliases_) { + if (IsAbbreviationImpl(entry.first, entry.second)) { + res.push_back(entry.second); + } + } + return res; +} + +std::vector +ExpressionContainer::ListAliases() const { + std::vector res; + for (const auto& entry : aliases_) { + res.push_back({entry.first, entry.second}); + } + return res; +} + +} // namespace cel diff --git a/common/container.h b/common/container.h new file mode 100644 index 000000000..ad8d91c35 --- /dev/null +++ b/common/container.h @@ -0,0 +1,138 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel { + +// ExpressionContainer represents the namespace configuration for a CEL +// expression. +// +// The container defines the default resolution order for names referenced in +// the expression. It generally maps to a protobuf package and follows +// approximately the same resolution rules as protobuf or C++ namespaces. +// +// Aliases declare short names that can be referenced without resolving against +// the scopes defined by the container. An alias cannot be a prefix of the +// container name, (otherwise re-type-checking an expression could +// change the meaning). Aliases are always unqualified identifiers. +// +// An abbreviation is a special case of alias that behaves like an import or +// using declaration in other languages. (pkg.TypeName -> TypeName). +// +// For better traceability, prefer using abbreviations over aliases. +class ExpressionContainer { + public: + struct AliasListing { + std::string alias; + std::string name; + + bool IsAbbreviation() const; + }; + + ExpressionContainer() = default; + + ExpressionContainer(const ExpressionContainer&) = default; + ExpressionContainer(ExpressionContainer&&) = default; + ExpressionContainer& operator=(const ExpressionContainer&) = default; + ExpressionContainer& operator=(ExpressionContainer&&) = default; + + // Returns the full name of the container. + // + // The default value is an empty string meaning no container. + absl::string_view container() const { return container_; } + + // Sets the container name. + // + // Returns an error if the container name is malformed or conflicts with an + // existing alias. + absl::Status SetContainer(absl::string_view name); + + // Adds an abbreviation to the container. + // + // Returns an error if the abbreviation is malformed or conflicts with the + // container or an existing alias. + absl::Status AddAbbreviation(absl::string_view abrev); + + // Adds an alias to the container. + // + // Returns an error if the alias is malformed or conflicts with the container + // or an existing alias. + absl::Status AddAlias(absl::string_view alias, absl::string_view name); + + // Returns the full name of the alias or an empty string if not found. + // + // The returned string view may be invalidated by updates to the + // ExpressionContainer. + absl::string_view FindAlias(absl::string_view alias) const; + + // Utility method for listing the abbreviations in the container. + // Order is not guaranteed. + std::vector ListAbbreviations() const; + + // Utility method for listing the aliases in the container. + // Includes abbreviations. + // Order is not guaranteed. + std::vector ListAliases() const; + + // Removes all aliases and abbreviations from the container. + void clear() { + container_.clear(); + aliases_.clear(); + } + + private: + std::string container_; + + // alias -> full name. + absl::flat_hash_map aliases_; +}; + +// Factory function for creating an ExpressionContainer. +absl::StatusOr MakeExpressionContainer( + absl::string_view name); + +// Factory function for creating an ExpressionContainer with a list of +// abbreviations. +template +absl::StatusOr MakeExpressionContainer( + absl::string_view name, Args&&... abbrevs) { + ExpressionContainer container; + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + absl::string_view abbrevs_view[] = {std::forward(abbrevs)...}; + for (absl::string_view abrev : abbrevs_view) { + status.Update(container.AddAbbreviation(abrev)); + if (!status.ok()) { + return status; + } + } + + return container; +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ diff --git a/common/container_test.cc b/common/container_test.cc new file mode 100644 index 000000000..e40814f54 --- /dev/null +++ b/common/container_test.cc @@ -0,0 +1,126 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include "absl/status/status.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(ExpressionContainerTest, DefaultConstructed) { + ExpressionContainer container; + EXPECT_THAT(container.container(), IsEmpty()); + EXPECT_THAT(container.FindAlias("foo"), IsEmpty()); +} + +TEST(ExpressionContainerTest, MakeExpressionContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.container(), Eq("my.container")); + + EXPECT_THAT(MakeExpressionContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, MakeExpressionContainerWithAbbrevs) { + ASSERT_OK_AND_ASSIGN( + ExpressionContainer container, + MakeExpressionContainer("my.container", "pkg.Abbr", "qual.pkg.Abbr2")); + EXPECT_THAT(container.container(), Eq("my.container")); + EXPECT_THAT(container.FindAlias("Abbr"), Eq("pkg.Abbr")); + EXPECT_THAT(container.FindAlias("Abbr2"), Eq("qual.pkg.Abbr2")); + + EXPECT_THAT(MakeExpressionContainer("my.container", "invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, SetContainer) { + ExpressionContainer container; + EXPECT_THAT(container.SetContainer("my.container.name"), IsOk()); + EXPECT_THAT(container.container(), Eq("my.container.name")); + EXPECT_THAT(container.SetContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.SetContainer("foo.1invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, AddAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + EXPECT_THAT(container.FindAlias("foo"), Eq("bar.baz")); +} + +TEST(ExpressionContainerTest, AddAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.TypeName"), IsOk()); + EXPECT_THAT(container.FindAlias("TypeName"), Eq("qual.pkg.TypeName")); +} + +TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.Abbr"), IsOk()); + EXPECT_THAT(container.AddAlias("AliasSym", "some.long.name"), IsOk()); + + EXPECT_THAT(container.ListAbbreviations(), + UnorderedElementsAre("qual.pkg.Abbr")); + + auto aliases = container.ListAliases(); + EXPECT_THAT(aliases, SizeIs(2)); +} + +TEST(ExpressionContainerTest, InvalidAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAbbreviation(""), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation(".pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg."), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, InvalidAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAlias("", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo.bar", "baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo", ".baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, CollidesWithContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAlias("my", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel diff --git a/common/converters.cc b/common/converters.cc deleted file mode 100644 index 6501f071e..000000000 --- a/common/converters.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "common/converters.h" - -namespace google { -namespace api { -namespace expr { -namespace common {} // namespace common -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/converters.h b/common/converters.h deleted file mode 100644 index 2a1bc59a7..000000000 --- a/common/converters.h +++ /dev/null @@ -1,132 +0,0 @@ -// Converter functions from common c++ representations to Value. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_CONVERTERS_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_CONVERTERS_H_ - -#include - -#include "common/parent_ref.h" -#include "common/value.h" -#include "internal/list_impl.h" -#include "internal/map_impl.h" -#include "internal/types.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -// Converters for native c++ list types. - -// Creates a Value from the given list. -template -Value ValueFromList(T&& value); - -// Creates a Value from the given list. -template -Value ValueFromList(std::unique_ptr value); - -// Creates a Value for the given list with a reference to the given parent. -// If 'parent' is not provided, `value` must live longer than the returned -// Value. -template -Value ValueForList(T* value, ParentRef parent = NoParent()); - -// Converters for native c++ map types. - -// Creates a Value from the given map. -template -Value ValueFromMap(T&& value); - -// Creates a Value from the given map. -template -Value ValueFromMap(std::unique_ptr value); - -// Creates a Value for the given list with a reference to the given parent. -// If 'parent' is not provided, `value` must live longer than the returned -// Value. -template -Value ValueForMap(T* value, ParentRef parent = NoParent()); - -// Creates a Value from the given list. -template -Value ValueFromList(T&& value) { - static_assert(!std::is_pointer::value, "use ValueForList"); - return Value::MakeList< - internal::ListWrapper>>( - std::forward(value)); -} - -// Creates a Value from the given list. -template -Value ValueFromList(std::unique_ptr value) { - return Value::MakeList< - internal::ListWrapper>( - std::move(value)); -} - -// Creates a Value for the given map with a reference to the given parent. -// If 'parent' is not provided, `value` must live longer than the returned -// Value. -template -Value ValueForList(T* value, ParentRef parent) { - if (!parent.has_value()) { - // Parent does not support refs, so copy the value. - return ValueFromList(*value); - } - if (parent->RequiresReference()) { - // Create with reference. - return Value::MakeList>>(parent->GetRef(), - value); - } - // Parent is not provided. - return Value::MakeList< - internal::ListWrapper>(value); -} - -// Converters for native c++ map types. - -// Creates a Value from the given map. -template -Value ValueFromMap(T&& value) { - static_assert(!std::is_pointer::value, "use ValueForList"); - return Value::MakeMap< - internal::MapWrapper>>( - std::forward(value)); -} - -// Creates a Value from the given map. -template -Value ValueFromMap(std::unique_ptr value) { - return Value::MakeMap< - internal::MapWrapper>( - std::move(value)); -} - -// Creates a Value for the given list with a reference to the given parent. -// If 'parent' is not provided, `value` must live longer than the returned -// Value. -template -Value ValueForMap(T* value, ParentRef parent) { - if (!parent.has_value()) { - // Parent does not support refs, so copy the value. - return ValueFromMap(*value); - } - if (parent->RequiresReference()) { - // Create with reference. - return Value::MakeMap>>( - parent->GetRef(), value); - } - // Parent is not provided. - return Value::MakeMap< - internal::MapWrapper>(value); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_PROTO_CONVERTERS_H_ diff --git a/common/converters_test.cc b/common/converters_test.cc deleted file mode 100644 index ec832b36c..000000000 --- a/common/converters_test.cc +++ /dev/null @@ -1,112 +0,0 @@ -#include "common/converters.h" - -#include - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "common/value.h" - -namespace google { -namespace api { -namespace expr { -namespace common { -namespace { - -void TestListImpl(const Value& list, const Value& first, const Value& second, - const Value& missing) { - EXPECT_EQ(2, list.list_value().size()); - EXPECT_EQ(second, list.list_value().Get(1)); - EXPECT_EQ(Value::FalseValue(), list.list_value().Contains(missing)); - EXPECT_EQ(Value::FalseValue(), - list.list_value().Contains(Value::NullValue())); - EXPECT_EQ(Value::TrueValue(), list.list_value().Contains(first)); -} - -template -void TestList(T&& first, T&& second, T&& missing) { - std::vector list({first, second}); - std::vector values = { - ValueFromList(list), // Copy - ValueFromList( - absl::make_unique>(list)), // OwnedPtr - ValueForList(&list), // Ptr - ValueForList(&list, absl::nullopt) // Copy - }; - EXPECT_TRUE(values[0].owns_value()); - EXPECT_TRUE(values[1].owns_value()); - EXPECT_FALSE(values[2].owns_value()); - EXPECT_TRUE(values[3].owns_value()); - - for (const auto& value : values) { - TestListImpl(value, Value::From(first), - Value::From(second), - Value::From(missing)); - } -} - -TEST(ConverterTest, List_Int) { - TestList(1, 3, 5); - TestList(1, 3, 5); - TestList(1, 3, 5); -} - -TEST(ConverterTest, List_Bool) { - TestList(true, true, false); -} - -void TestMapImpl(const Value& map, const Value& k1, const Value& v1, - const Value& k2, const Value& v2) { - EXPECT_EQ(2, map.map_value().size()); - - EXPECT_EQ(v1, map.map_value().Get(k1)); - EXPECT_NE(v1, map.map_value().Get(k2)); - EXPECT_EQ(v2, map.map_value().Get(k2)); - EXPECT_NE(v2, map.map_value().Get(k1)); - - EXPECT_EQ(Value::TrueValue(), map.map_value().ContainsKey(k1)); - EXPECT_EQ(Value::TrueValue(), map.map_value().ContainsKey(k2)); - EXPECT_EQ(Value::FalseValue(), - map.map_value().ContainsKey(Value::NullValue())); - - EXPECT_EQ(Value::TrueValue(), map.map_value().ContainsValue(v1)); - EXPECT_EQ(Value::TrueValue(), map.map_value().ContainsValue(v2)); - EXPECT_EQ(Value::FalseValue(), - map.map_value().ContainsValue(Value::NullValue())); -} - -template -void TestMap(K&& k1, V&& v1, K&& k2, V&& v2) { - std::map map({{k1, v1}, {k2, v2}}); - std::vector values = { - ValueFromMap(map), // Copy - ValueFromMap( - absl::make_unique>(map)), // OwnedPtr - ValueForMap(&map), // Ptr - ValueForMap(&map, absl::nullopt) // Copy - }; - EXPECT_TRUE(values[0].owns_value()); - EXPECT_TRUE(values[1].owns_value()); - EXPECT_FALSE(values[2].owns_value()); - EXPECT_TRUE(values[3].owns_value()); - - for (const auto& value : values) { - TestMapImpl(value, Value::From(k1), Value::From(v1), - Value::From(k2), Value::From(v2)); - } -} - -TEST(ConverterTest, Map_Int) { - TestMap(1, true, 3, false); - TestMap(1, 7, 3, 8); - TestMap(1, 7, 3, 5); -} - -TEST(ConverterTest, Map_Bool) { - TestMap(true, false, false, true); -} - -} // namespace -} // namespace common -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/custom_object.cc b/common/custom_object.cc deleted file mode 100644 index 8dd81e53c..000000000 --- a/common/custom_object.cc +++ /dev/null @@ -1,22 +0,0 @@ -#include "common/custom_object.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -Value OpaqueObject::GetMember(absl::string_view name) const { - return Value::FromError( - internal::NoSuchMember(name, object_type().full_name())); -} - -google::rpc::Status OpaqueObject::ForEach( - const std::function& - call) const { - return internal::OkStatus(); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/custom_object.h b/common/custom_object.h deleted file mode 100644 index 06e68f649..000000000 --- a/common/custom_object.h +++ /dev/null @@ -1,38 +0,0 @@ -// Helper base classes and function for custom objects. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_CUSTOM_OBJECT_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_CUSTOM_OBJECT_H_ - -#include "common/value.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -// An object base class for objects that do not support direct member access. -class OpaqueObject : public Object { - public: - // Does not contain any accessible members. - Value GetMember(absl::string_view name) const final; - google::rpc::Status ForEach( - const std::function& - call) const final; - - // Assume to own value (can be overridden). - inline bool owns_value() const override { return true; } - - // Require a custom ToString function. - std::string ToString() const override = 0; - - protected: - // Require a custom hash function. - std::size_t ComputeHash() const override = 0; -}; - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_CUSTOM_OBJECT_H_ diff --git a/common/data.h b/common/data.h new file mode 100644 index 000000000..cefc21fa4 --- /dev/null +++ b/common/data.h @@ -0,0 +1,120 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_DATA_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_DATA_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "common/internal/metadata.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class Data; +template +struct Ownable; +template +struct Borrowable; + +namespace common_internal { + +class ReferenceCount; + +void SetDataReferenceCount(const Data* absl_nonnull data, + const ReferenceCount* absl_nonnull refcount); + +const ReferenceCount* absl_nullable GetDataReferenceCount( + const Data* absl_nonnull data); + +} // namespace common_internal + +// `Data` is one of the base classes of objects that can be managed by +// `MemoryManager`, the other is `google::protobuf::MessageLite`. +class Data { + public: + Data(const Data&) = default; + Data(Data&&) = default; + ~Data() = default; + Data& operator=(const Data&) = default; + Data& operator=(Data&&) = default; + + google::protobuf::Arena* absl_nullable GetArena() const { + return (owner_ & kOwnerBits) == kOwnerArenaBit + ? reinterpret_cast(owner_ & kOwnerPointerMask) + : nullptr; + } + + protected: + // At this point, the reference count has not been created. So we create it + // unowned and set the reference count after. In theory we could create the + // reference count ahead of time and then update it with the data it has to + // delete, but that is a bit counter intuitive. Doing it this way is also + // similar to how std::enable_shared_from_this works. + Data() = default; + + Data(std::nullptr_t) = delete; + + explicit Data(google::protobuf::Arena* absl_nullable arena) + : owner_(reinterpret_cast(arena) | + (arena != nullptr ? kOwnerArenaBit : kOwnerNone)) {} + + private: + static constexpr uintptr_t kOwnerNone = common_internal::kMetadataOwnerNone; + static constexpr uintptr_t kOwnerReferenceCountBit = + common_internal::kMetadataOwnerReferenceCountBit; + static constexpr uintptr_t kOwnerArenaBit = + common_internal::kMetadataOwnerArenaBit; + static constexpr uintptr_t kOwnerBits = common_internal::kMetadataOwnerBits; + static constexpr uintptr_t kOwnerPointerMask = + common_internal::kMetadataOwnerPointerMask; + + friend void common_internal::SetDataReferenceCount( + const Data* absl_nonnull data, + const common_internal::ReferenceCount* absl_nonnull refcount); + friend const common_internal::ReferenceCount* absl_nullable + common_internal::GetDataReferenceCount(const Data* absl_nonnull data); + template + friend struct Ownable; + template + friend struct Borrowable; + + mutable uintptr_t owner_ = kOwnerNone; +}; + +namespace common_internal { + +inline void SetDataReferenceCount(const Data* absl_nonnull data, + const ReferenceCount* absl_nonnull refcount) { + ABSL_DCHECK_EQ(data->owner_, Data::kOwnerNone); + data->owner_ = + reinterpret_cast(refcount) | Data::kOwnerReferenceCountBit; +} + +inline const ReferenceCount* absl_nullable GetDataReferenceCount( + const Data* absl_nonnull data) { + return (data->owner_ & Data::kOwnerBits) == Data::kOwnerReferenceCountBit + ? reinterpret_cast(data->owner_ & + Data::kOwnerPointerMask) + : nullptr; +} + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_DATA_H_ diff --git a/common/data_test.cc b/common/data_test.cc new file mode 100644 index 000000000..a6b2a788f --- /dev/null +++ b/common/data_test.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This header contains primitives for reference counting, roughly equivalent to +// the primitives used to implement `std::shared_ptr`. These primitives should +// not be used directly in most cases, instead `cel::ManagedMemory` should be +// used instead. + +#include "common/data.h" + +#include "absl/base/nullability.h" +#include "common/internal/reference_count.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::testing::IsNull; + +class DataTest final : public Data { + public: + DataTest() noexcept : Data() {} + + explicit DataTest(google::protobuf::Arena* absl_nullable arena) noexcept + : Data(arena) {} +}; + +class DataReferenceCount final : public common_internal::ReferenceCounted { + public: + explicit DataReferenceCount(const Data* data) : data_(data) {} + + private: + void Finalize() noexcept override { delete data_; } + + const Data* data_; +}; + +TEST(Data, Arena) { + google::protobuf::Arena arena; + DataTest data(&arena); + EXPECT_EQ(data.GetArena(), &arena); + EXPECT_THAT(common_internal::GetDataReferenceCount(&data), IsNull()); +} + +TEST(Data, ReferenceCount) { + auto* data = new DataTest(); + EXPECT_THAT(data->GetArena(), IsNull()); + auto* refcount = new DataReferenceCount(data); + common_internal::SetDataReferenceCount(data, refcount); + EXPECT_EQ(common_internal::GetDataReferenceCount(data), refcount); + common_internal::StrongUnref(refcount); +} + +} // namespace +} // namespace cel diff --git a/common/decl.cc b/common/decl.cc new file mode 100644 index 000000000..858e6fb49 --- /dev/null +++ b/common/decl.cc @@ -0,0 +1,221 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/decl.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/signature.h" +#include "common/type.h" +#include "common/type_kind.h" + +namespace cel { + +namespace common_internal { + +bool TypeIsAssignable(const Type& to, const Type& from) { + if (to == from) { + return true; + } + const auto to_kind = to.kind(); + if (to_kind == TypeKind::kDyn) { + return true; + } + switch (to_kind) { + case TypeKind::kBoolWrapper: + return TypeIsAssignable(NullType{}, from) || + TypeIsAssignable(BoolType{}, from); + case TypeKind::kIntWrapper: + return TypeIsAssignable(NullType{}, from) || + TypeIsAssignable(IntType{}, from); + case TypeKind::kUintWrapper: + return TypeIsAssignable(NullType{}, from) || + TypeIsAssignable(UintType{}, from); + case TypeKind::kDoubleWrapper: + return TypeIsAssignable(NullType{}, from) || + TypeIsAssignable(DoubleType{}, from); + case TypeKind::kBytesWrapper: + return TypeIsAssignable(NullType{}, from) || + TypeIsAssignable(BytesType{}, from); + case TypeKind::kStringWrapper: + return TypeIsAssignable(NullType{}, from) || + TypeIsAssignable(StringType{}, from); + default: + break; + } + const auto from_kind = from.kind(); + if (to_kind != from_kind || to.name() != from.name()) { + return false; + } + auto to_params = to.GetParameters(); + auto from_params = from.GetParameters(); + const auto params_size = to_params.size(); + if (params_size != from_params.size()) { + return false; + } + for (size_t i = 0; i < params_size; ++i) { + if (!TypeIsAssignable(to_params[i], from_params[i])) { + return false; + } + } + return true; +} + +} // namespace common_internal + +namespace { + +bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { + if (lhs.member() != rhs.member()) { + return false; + } + const auto& lhs_args = lhs.args(); + const auto& rhs_args = rhs.args(); + const auto args_size = lhs_args.size(); + if (args_size != rhs_args.size()) { + return false; + } + bool args_overlap = true; + for (size_t i = 0; i < args_size; ++i) { + args_overlap = + args_overlap && + (common_internal::TypeIsAssignable(lhs_args[i], rhs_args[i]) || + common_internal::TypeIsAssignable(rhs_args[i], lhs_args[i])); + } + return args_overlap; +} + +template +void AddOverloadInternal(std::string_view function_name, + std::vector& insertion_order, + absl::flat_hash_map& by_id, + absl::flat_hash_map& by_signature, + Overload&& overload, absl::Status& status) { + if (!status.ok()) { + return; + } + + absl::StatusOr signature = + MakeOverloadSignature(function_name, overload.args(), overload.member()); + if (!signature.ok()) { + status = signature.status(); + return; + } + + OverloadDecl mutable_overload = std::forward(overload); + + if (mutable_overload.id().empty()) { + mutable_overload.set_id(*signature); + } + + if (auto it = by_id.find(mutable_overload.id()); it != by_id.end()) { + status = absl::AlreadyExistsError( + absl::StrCat("overload exists: ", mutable_overload.id())); + return; + } + + for (const auto& existing : insertion_order) { + if (SignaturesOverlap(mutable_overload, existing)) { + status = absl::InvalidArgumentError( + absl::StrCat("overload signature collision: ", existing.id(), + " collides with ", mutable_overload.id())); + return; + } + } + + size_t index = insertion_order.size(); + by_id[mutable_overload.id()] = index; + by_signature[*signature] = index; + insertion_order.push_back(std::move(mutable_overload)); +} + +void CollectTypeParams(absl::flat_hash_set& type_params, + const Type& type) { + const auto kind = type.kind(); + switch (kind) { + case TypeKind::kList: { + const auto& list_type = type.GetList(); + CollectTypeParams(type_params, list_type.element()); + } break; + case TypeKind::kMap: { + const auto& map_type = type.GetMap(); + CollectTypeParams(type_params, map_type.key()); + CollectTypeParams(type_params, map_type.value()); + } break; + case TypeKind::kOpaque: { + const auto& opaque_type = type.GetOpaque(); + for (const auto& param : opaque_type.GetParameters()) { + CollectTypeParams(type_params, param); + } + } break; + case TypeKind::kFunction: { + const auto& function_type = type.GetFunction(); + CollectTypeParams(type_params, function_type.result()); + for (const auto& arg : function_type.args()) { + CollectTypeParams(type_params, arg); + } + } break; + case TypeKind::kTypeParam: + type_params.emplace(type.GetTypeParam().name()); + break; + default: + break; + } +} + +} // namespace + +absl::flat_hash_set OverloadDecl::GetTypeParams() const { + absl::flat_hash_set type_params; + CollectTypeParams(type_params, result()); + for (const auto& arg : args()) { + CollectTypeParams(type_params, arg); + } + return type_params; +} + +void FunctionDecl::AddOverloadImpl(const OverloadDecl& overload, + absl::Status& status) { + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, overload, status); +} + +void FunctionDecl::AddOverloadImpl(OverloadDecl&& overload, + absl::Status& status) { + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, std::move(overload), status); +} + +const OverloadDecl* FunctionDecl::FindOverloadById(absl::string_view id) const { + if (auto it = overloads_.by_id.find(id); it != overloads_.by_id.end()) { + return &overloads_.insertion_order[it->second]; + } + if (auto it = overloads_.by_signature.find(id); + it != overloads_.by_signature.end()) { + return &overloads_.insertion_order[it->second]; + } + return nullptr; +} + +} // namespace cel diff --git a/common/decl.h b/common/decl.h new file mode 100644 index 000000000..b15645236 --- /dev/null +++ b/common/decl.h @@ -0,0 +1,446 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/constant.h" +#include "common/type.h" +#include "internal/status_macros.h" + +namespace cel { + +class VariableDecl; +class OverloadDecl; +class FunctionDecl; + +// `VariableDecl` represents a declaration of a variable, composed of its name +// and type, and optionally a constant value. +class VariableDecl final { + public: + VariableDecl() = default; + VariableDecl(const VariableDecl&) = default; + VariableDecl(VariableDecl&&) = default; + VariableDecl& operator=(const VariableDecl&) = default; + VariableDecl& operator=(VariableDecl&&) = default; + + ABSL_MUST_USE_RESULT const std::string& name() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return name_; + } + + void set_name(std::string name) { name_ = std::move(name); } + + void set_name(absl::string_view name) { + name_.assign(name.data(), name.size()); + } + + void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } + + ABSL_MUST_USE_RESULT std::string release_name() { + std::string released; + released.swap(name_); + return released; + } + + ABSL_MUST_USE_RESULT const Type& type() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return type_; + } + + ABSL_MUST_USE_RESULT Type& mutable_type() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return type_; + } + + void set_type(Type type) { mutable_type() = std::move(type); } + + ABSL_MUST_USE_RESULT bool has_value() const { return value_.has_value(); } + + ABSL_MUST_USE_RESULT const Constant& value() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_value() ? *value_ : Constant::default_instance(); + } + + Constant& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_value()) { + value_.emplace(); + } + return *value_; + } + + void set_value(absl::optional value) { value_ = std::move(value); } + + void set_value(Constant value) { mutable_value() = std::move(value); } + + ABSL_MUST_USE_RESULT Constant release_value() { + absl::optional released; + released.swap(value_); + return std::move(released).value_or(Constant{}); + } + + private: + std::string name_; + Type type_ = DynType{}; + absl::optional value_; +}; + +inline VariableDecl MakeVariableDecl(absl::string_view name, Type type) { + VariableDecl variable_decl; + variable_decl.set_name(std::string(name)); + variable_decl.set_type(std::move(type)); + return variable_decl; +} + +inline VariableDecl MakeConstantVariableDecl(std::string name, Type type, + Constant value) { + VariableDecl variable_decl; + variable_decl.set_name(std::move(name)); + variable_decl.set_type(std::move(type)); + variable_decl.set_value(std::move(value)); + return variable_decl; +} + +inline bool operator==(const VariableDecl& lhs, const VariableDecl& rhs) { + return lhs.name() == rhs.name() && lhs.type() == rhs.type() && + lhs.has_value() == rhs.has_value() && lhs.value() == rhs.value(); +} + +inline bool operator!=(const VariableDecl& lhs, const VariableDecl& rhs) { + return !operator==(lhs, rhs); +} + +// `OverloadDecl` represents a single overload of `FunctionDecl`. +class OverloadDecl final { + public: + OverloadDecl() = default; + OverloadDecl(const OverloadDecl&) = default; + OverloadDecl(OverloadDecl&&) = default; + OverloadDecl& operator=(const OverloadDecl&) = default; + OverloadDecl& operator=(OverloadDecl&&) = default; + + ABSL_MUST_USE_RESULT const std::string& id() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return id_; + } + + void set_id(std::string id) { id_ = std::move(id); } + + void set_id(absl::string_view id) { id_.assign(id.data(), id.size()); } + + void set_id(const char* id) { set_id(absl::NullSafeStringView(id)); } + + ABSL_MUST_USE_RESULT std::string release_id() { + std::string released; + released.swap(id_); + return released; + } + + ABSL_MUST_USE_RESULT const std::vector& args() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return args_; + } + + ABSL_MUST_USE_RESULT std::vector& mutable_args() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return args_; + } + + ABSL_MUST_USE_RESULT std::vector release_args() { + std::vector released; + released.swap(mutable_args()); + return released; + } + + ABSL_MUST_USE_RESULT const Type& result() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return result_; + } + + ABSL_MUST_USE_RESULT Type& mutable_result() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return result_; + } + + void set_result(Type result) { mutable_result() = std::move(result); } + + ABSL_MUST_USE_RESULT bool member() const { return member_; } + + void set_member(bool member) { member_ = member; } + + absl::flat_hash_set GetTypeParams() const; + + private: + std::string id_; + std::vector args_; + Type result_ = DynType{}; + bool member_ = false; +}; + +inline bool operator==(const OverloadDecl& lhs, const OverloadDecl& rhs) { + return lhs.id() == rhs.id() && absl::c_equal(lhs.args(), rhs.args()) && + lhs.result() == rhs.result() && lhs.member() == rhs.member(); +} + +inline bool operator!=(const OverloadDecl& lhs, const OverloadDecl& rhs) { + return !operator==(lhs, rhs); +} + +template +OverloadDecl MakeOverloadDecl(Type result, Args&&... args) { + OverloadDecl overload_decl; + overload_decl.set_result(std::move(result)); + overload_decl.set_member(false); + auto& mutable_args = overload_decl.mutable_args(); + mutable_args.reserve(sizeof...(Args)); + (mutable_args.push_back(std::forward(args)), ...); + return overload_decl; +} + +// Prefer the version of `MakeOverloadDecl` that does not specify the id. +// This version is less robust than the version that automatically generates a +// descriptive overload id at the time the overload is added to the function +// declaration. +template +OverloadDecl MakeOverloadDecl(absl::string_view id, Type result, + Args&&... args) { + OverloadDecl overload_decl; + overload_decl.set_id(std::string(id)); + overload_decl.set_result(std::move(result)); + overload_decl.set_member(false); + auto& mutable_args = overload_decl.mutable_args(); + mutable_args.reserve(sizeof...(Args)); + (mutable_args.push_back(std::forward(args)), ...); + return overload_decl; +} + +template +OverloadDecl MakeMemberOverloadDecl(Type result, Args&&... args) { + OverloadDecl overload_decl; + overload_decl.set_result(std::move(result)); + overload_decl.set_member(true); + auto& mutable_args = overload_decl.mutable_args(); + mutable_args.reserve(sizeof...(Args)); + (mutable_args.push_back(std::forward(args)), ...); + return overload_decl; +} + +// Avoid this version of `MakeMemberOverloadDecl`, it is less robust than the +// version that automatically generates a descriptive overload id at the time +// the overload is added to the function declaration. +template +OverloadDecl MakeMemberOverloadDecl(absl::string_view id, Type result, + Args&&... args) { + OverloadDecl overload_decl; + overload_decl.set_id(std::string(id)); + overload_decl.set_result(std::move(result)); + overload_decl.set_member(true); + auto& mutable_args = overload_decl.mutable_args(); + mutable_args.reserve(sizeof...(Args)); + (mutable_args.push_back(std::forward(args)), ...); + return overload_decl; +} + +template +absl::StatusOr MakeFunctionDecl(std::string name, + Overloads&&... overloads); + +// `FunctionDecl` represents a function declaration. +class FunctionDecl final { + public: + FunctionDecl() = default; + FunctionDecl(const FunctionDecl&) = default; + FunctionDecl(FunctionDecl&&) = default; + FunctionDecl& operator=(const FunctionDecl&) = default; + FunctionDecl& operator=(FunctionDecl&&) = default; + + ABSL_MUST_USE_RESULT const std::string& name() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return name_; + } + + void set_name(std::string name) { name_ = std::move(name); } + + void set_name(absl::string_view name) { + name_.assign(name.data(), name.size()); + } + + void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } + + ABSL_MUST_USE_RESULT std::string release_name() { + std::string released; + released.swap(name_); + return released; + } + + absl::Status AddOverload(const OverloadDecl& overload) { + absl::Status status; + AddOverloadImpl(overload, status); + return status; + } + + absl::Status AddOverload(OverloadDecl&& overload) { + absl::Status status; + AddOverloadImpl(std::move(overload), status); + return status; + } + + ABSL_MUST_USE_RESULT absl::Span overloads() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return overloads_.insertion_order; + } + + ABSL_MUST_USE_RESULT const OverloadDecl* FindOverloadById( + absl::string_view id) const; + + std::vector release_overloads() { + std::vector released = std::move(overloads_.insertion_order); + overloads_.insertion_order.clear(); + overloads_.by_id.clear(); + overloads_.by_signature.clear(); + return released; + } + + private: + struct Overloads { + std::vector insertion_order; + absl::flat_hash_map by_id; + absl::flat_hash_map by_signature; + + void Reserve(size_t size) { + insertion_order.reserve(size); + by_id.reserve(size); + by_signature.reserve(size); + } + }; + + template + friend absl::StatusOr MakeFunctionDecl( + std::string name, Overloads&&... overloads); + + void AddOverloadImpl(const OverloadDecl& overload, absl::Status& status); + void AddOverloadImpl(OverloadDecl&& overload, absl::Status& status); + + std::string name_; + Overloads overloads_; +}; + +inline bool operator==(const FunctionDecl& lhs, const FunctionDecl& rhs) { + return lhs.name() == rhs.name() && + absl::c_equal(lhs.overloads(), rhs.overloads()); +} + +inline bool operator!=(const FunctionDecl& lhs, const FunctionDecl& rhs) { + return !operator==(lhs, rhs); +} + +template +absl::StatusOr MakeFunctionDecl(std::string name, + Overloads&&... overloads) { + FunctionDecl function_decl; + function_decl.set_name(std::move(name)); + function_decl.overloads_.Reserve(sizeof...(Overloads)); + absl::Status status; + (function_decl.AddOverloadImpl(std::forward(overloads), status), + ...); + CEL_RETURN_IF_ERROR(status); + return function_decl; +} + +namespace common_internal { + +// Checks whether `from` is assignable to `to`. +// This can probably be in a better place, it is here currently to ease testing. +bool TypeIsAssignable(const Type& to, const Type& from); + +} // namespace common_internal + +struct VariableDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::VariableDecl& lhs, + const cel::VariableDecl& rhs) const { + return lhs.name() == rhs.name(); + } + + bool operator()(const cel::VariableDecl& lhs, std::string_view rhs) const { + return lhs.name() == rhs; + } + + bool operator()(std::string_view lhs, const cel::VariableDecl& rhs) const { + return lhs == rhs.name(); + } +}; + +struct VariableDeclHash { + using is_transparent = void; + + size_t operator()(const cel::VariableDecl& decl) const { + return (*this)(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using VariableDeclSet = absl::flat_hash_set; + +struct FunctionDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::FunctionDecl& lhs, + const cel::FunctionDecl& rhs) const { + return (*this)(lhs.name(), rhs.name()); + } + + bool operator()(const cel::FunctionDecl& lhs, std::string_view rhs) const { + return (*this)(lhs.name(), rhs); + } + + bool operator()(std::string_view lhs, const cel::FunctionDecl& rhs) const { + return (*this)(lhs, rhs.name()); + } + + bool operator()(std::string_view lhs, std::string_view rhs) const { + return lhs == rhs; + } +}; + +struct FunctionDeclHash { + using is_transparent = void; + + size_t operator()(const cel::FunctionDecl& decl) const { + return absl::HashOf(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using FunctionDeclSet = absl::flat_hash_set; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ diff --git a/common/decl_proto.cc b/common/decl_proto.cc new file mode 100644 index 000000000..098c5068c --- /dev/null +++ b/common/decl_proto.cc @@ -0,0 +1,86 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/decl_proto.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_proto.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr VariableDeclFromProto( + absl::string_view name, const cel::expr::Decl::IdentDecl& variable, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(Type type, + TypeFromProto(variable.type(), descriptor_pool, arena)); + return cel::MakeVariableDecl(std::string(name), type); +} + +absl::StatusOr FunctionDeclFromProto( + absl::string_view name, + const cel::expr::Decl::FunctionDecl& function, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena) { + cel::FunctionDecl decl; + decl.set_name(name); + for (const auto& overload_pb : function.overloads()) { + cel::OverloadDecl ovl_decl; + ovl_decl.set_id(overload_pb.overload_id()); + ovl_decl.set_member(overload_pb.is_instance_function()); + CEL_ASSIGN_OR_RETURN( + cel::Type result, + TypeFromProto(overload_pb.result_type(), descriptor_pool, arena)); + ovl_decl.set_result(result); + std::vector param_types; + param_types.reserve(overload_pb.params_size()); + for (const auto& param_type_pb : overload_pb.params()) { + CEL_ASSIGN_OR_RETURN( + param_types.emplace_back(), + TypeFromProto(param_type_pb, descriptor_pool, arena)); + } + ovl_decl.mutable_args() = std::move(param_types); + CEL_RETURN_IF_ERROR(decl.AddOverload(std::move(ovl_decl))); + } + return decl; +} + +absl::StatusOr> DeclFromProto( + const cel::expr::Decl& decl, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena) { + if (decl.has_ident()) { + return VariableDeclFromProto(decl.name(), decl.ident(), descriptor_pool, + arena); + } else if (decl.has_function()) { + return FunctionDeclFromProto(decl.name(), decl.function(), descriptor_pool, + arena); + } + return absl::InvalidArgumentError("empty google.api.expr.Decl proto"); +} + +} // namespace cel diff --git a/common/decl_proto.h b/common/decl_proto.h new file mode 100644 index 000000000..3b5744e0e --- /dev/null +++ b/common/decl_proto.h @@ -0,0 +1,50 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_H_ + +#include "cel/expr/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "common/decl.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Creates a VariableDecl from a google.api.expr.Decl.IdentDecl proto. +absl::StatusOr VariableDeclFromProto( + absl::string_view name, const cel::expr::Decl::IdentDecl& variable, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena); + +// Creates a FunctionDecl from a google.api.expr.Decl.FunctionDecl proto. +absl::StatusOr FunctionDeclFromProto( + absl::string_view name, + const cel::expr::Decl::FunctionDecl& function, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena); + +// Creates a VariableDecl or FunctionDecl from a google.api.expr.Decl proto. +absl::StatusOr> DeclFromProto( + const cel::expr::Decl& decl, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_H_ diff --git a/common/decl_proto_test.cc b/common/decl_proto_test.cc new file mode 100644 index 000000000..d72d97e09 --- /dev/null +++ b/common/decl_proto_test.cc @@ -0,0 +1,147 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "common/decl_proto.h" + +#include + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/types/variant.h" +#include "common/decl.h" +#include "common/decl_proto_v1alpha1.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; + +enum class DeclType { kVariable, kFunction, kInvalid }; + +struct TestCase { + std::string proto_decl; + DeclType decl_type; +}; + +class DeclFromProtoTest : public ::testing::TestWithParam {}; + +TEST_P(DeclFromProtoTest, FromProtoWorks) { + const TestCase& test_case = GetParam(); + google::protobuf::Arena arena; + const google::protobuf::DescriptorPool* descriptor_pool = + google::protobuf::DescriptorPool::generated_pool(); + cel::expr::Decl decl_pb; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); + absl::StatusOr> decl_or = + DeclFromProto(decl_pb, descriptor_pool, &arena); + switch (test_case.decl_type) { + case DeclType::kVariable: { + ASSERT_OK_AND_ASSIGN(auto decl, decl_or); + EXPECT_TRUE(absl::holds_alternative(decl)); + break; + } + case DeclType::kFunction: { + ASSERT_OK_AND_ASSIGN(auto decl, decl_or); + EXPECT_TRUE(absl::holds_alternative(decl)); + break; + } + case DeclType::kInvalid: { + EXPECT_THAT(decl_or, StatusIs(absl::StatusCode::kInvalidArgument)); + break; + } + } +} + +// Tests that the v1alpha1 proto can be converted to the unversioned proto. +// Same underlying implementation. +TEST_P(DeclFromProtoTest, FromV1Alpha1ProtoWorks) { + const TestCase& test_case = GetParam(); + google::protobuf::Arena arena; + const google::protobuf::DescriptorPool* descriptor_pool = + google::protobuf::DescriptorPool::generated_pool(); + google::api::expr::v1alpha1::Decl decl_pb; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); + absl::StatusOr> decl_or = + DeclFromV1Alpha1Proto(decl_pb, descriptor_pool, &arena); + switch (test_case.decl_type) { + case DeclType::kVariable: { + ASSERT_OK_AND_ASSIGN(auto decl, decl_or); + EXPECT_TRUE(absl::holds_alternative(decl)); + break; + } + case DeclType::kFunction: { + ASSERT_OK_AND_ASSIGN(auto decl, decl_or); + EXPECT_TRUE(absl::holds_alternative(decl)); + break; + } + case DeclType::kInvalid: { + EXPECT_THAT(decl_or, StatusIs(absl::StatusCode::kInvalidArgument)); + break; + } + } +} + +// TODO(uncreated-issue/80): Add tests for round-trip conversion after the ToProto +// functions are implemented. + +INSTANTIATE_TEST_SUITE_P( + DeclFromProtoTest, DeclFromProtoTest, + testing::Values( + TestCase{ + R"pb( + name: "foo_var" + ident { type { primitive: BOOL } })pb", + DeclType::kVariable}, + TestCase{ + R"pb( + name: "foo_fn" + function { + overloads { + overload_id: "foo_fn_int" + params { primitive: INT64 } + result_type { primitive: BOOL } + } + overloads { + overload_id: "int_foo_fn" + is_instance_function: true + params { primitive: INT64 } + result_type { primitive: BOOL } + } + overloads { + overload_id: "foo_fn_T" + params { type_param: "T" } + type_params: "T" + result_type { primitive: BOOL } + } + + })pb", + DeclType::kFunction}, + // Need a descriptor to lookup a struct type. + TestCase{ + R"pb( + name: "foo_fn" + ident { type { message_type: "com.example.UnknownType" } })pb", + DeclType::kInvalid}, + // Empty decl is invalid. + TestCase{R"pb(name: "foo_fn")pb", DeclType::kInvalid})); + +} // namespace +} // namespace cel diff --git a/common/decl_proto_v1alpha1.cc b/common/decl_proto_v1alpha1.cc new file mode 100644 index 000000000..a8d73e5c2 --- /dev/null +++ b/common/decl_proto_v1alpha1.cc @@ -0,0 +1,67 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "common/decl_proto_v1alpha1.h" + +#include "cel/expr/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "common/decl.h" +#include "common/decl_proto.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr VariableDeclFromV1Alpha1Proto( + absl::string_view name, + const google::api::expr::v1alpha1::Decl::IdentDecl& variable, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena) { + cel::expr::Decl::IdentDecl unversioned; + if (!unversioned.MergeFromString(variable.SerializeAsString())) { + return absl::InternalError( + "failed to convert versioned to unversioned Decl proto"); + } + return VariableDeclFromProto(name, unversioned, descriptor_pool, arena); +} + +absl::StatusOr FunctionDeclFromV1Alpha1Proto( + absl::string_view name, + const google::api::expr::v1alpha1::Decl::FunctionDecl& function, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena) { + cel::expr::Decl::FunctionDecl unversioned; + if (!unversioned.MergeFromString(function.SerializeAsString())) { + return absl::InternalError( + "failed to convert versioned to unversioned Decl proto"); + } + return FunctionDeclFromProto(name, unversioned, descriptor_pool, arena); +} + +absl::StatusOr> DeclFromV1Alpha1Proto( + const google::api::expr::v1alpha1::Decl& decl, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena) { + cel::expr::Decl unversioned; + if (!unversioned.MergeFromString(decl.SerializeAsString())) { + return absl::InternalError( + "failed to convert versioned to unversioned Decl proto"); + } + return DeclFromProto(unversioned, descriptor_pool, arena); +} + +} // namespace cel diff --git a/common/decl_proto_v1alpha1.h b/common/decl_proto_v1alpha1.h new file mode 100644 index 000000000..449c921b5 --- /dev/null +++ b/common/decl_proto_v1alpha1.h @@ -0,0 +1,55 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Converters to/from versioned Decl protos to the equivalent CEL C++ types. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_V1ALPHA1_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_V1ALPHA1_H_ + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "common/decl.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Creates a VariableDecl from a google.api.expr.v1alpha1.Decl.IdentDecl proto. +absl::StatusOr VariableDeclFromV1Alpha1Proto( + absl::string_view name, + const google::api::expr::v1alpha1::Decl::IdentDecl& variable, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena); + +// Creates a FunctionDecl from a google.api.expr.v1alpha1.Decl.FunctionDecl +// proto. +absl::StatusOr FunctionDeclFromV1Alpha1Proto( + absl::string_view name, + const google::api::expr::v1alpha1::Decl::FunctionDecl& function, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena); + +// Creates a VariableDecl or FunctionDecl from a google.api.expr.v1alpha1.Decl +// proto. +absl::StatusOr> DeclFromV1Alpha1Proto( + const google::api::expr::v1alpha1::Decl& decl, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_V1ALPHA1_H_ diff --git a/common/decl_test.cc b/common/decl_test.cc new file mode 100644 index 000000000..72e7f1b93 --- /dev/null +++ b/common/decl_test.cc @@ -0,0 +1,317 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/decl.h" + +#include +#include + +#include "absl/log/die_if_null.h" // IWYU pragma: keep +#include "absl/status/status.h" +#include "common/constant.h" +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Property; +using ::testing::UnorderedElementsAre; + +TEST(VariableDecl, Name) { + VariableDecl variable_decl; + EXPECT_THAT(variable_decl.name(), IsEmpty()); + variable_decl.set_name("foo"); + EXPECT_EQ(variable_decl.name(), "foo"); + EXPECT_EQ(variable_decl.release_name(), "foo"); + EXPECT_THAT(variable_decl.name(), IsEmpty()); +} + +TEST(VariableDecl, Type) { + VariableDecl variable_decl; + EXPECT_EQ(variable_decl.type(), DynType{}); + variable_decl.set_type(StringType{}); + EXPECT_EQ(variable_decl.type(), StringType{}); +} + +TEST(VariableDecl, Value) { + VariableDecl variable_decl; + EXPECT_FALSE(variable_decl.has_value()); + EXPECT_EQ(variable_decl.value(), Constant{}); + Constant value; + value.set_bool_value(true); + variable_decl.set_value(value); + EXPECT_TRUE(variable_decl.has_value()); + EXPECT_EQ(variable_decl.value(), value); + EXPECT_EQ(variable_decl.release_value(), value); + EXPECT_EQ(variable_decl.value(), Constant{}); +} + +Constant MakeBoolConstant(bool value) { + Constant constant; + constant.set_bool_value(value); + return constant; +} + +TEST(VariableDecl, Equality) { + VariableDecl variable_decl; + EXPECT_EQ(variable_decl, VariableDecl{}); + variable_decl.mutable_value().set_bool_value(true); + EXPECT_NE(variable_decl, VariableDecl{}); + + EXPECT_EQ(MakeVariableDecl("foo", StringType{}), + MakeVariableDecl("foo", StringType{})); + EXPECT_EQ(MakeVariableDecl("foo", StringType{}), + MakeVariableDecl("foo", StringType{})); + EXPECT_EQ( + MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true)), + MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true))); + EXPECT_EQ( + MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true)), + MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true))); +} + +TEST(OverloadDecl, Id) { + OverloadDecl overload_decl; + EXPECT_THAT(overload_decl.id(), IsEmpty()); + overload_decl.set_id("foo"); + EXPECT_EQ(overload_decl.id(), "foo"); + EXPECT_EQ(overload_decl.release_id(), "foo"); + EXPECT_THAT(overload_decl.id(), IsEmpty()); +} + +TEST(OverloadDecl, Result) { + OverloadDecl overload_decl; + EXPECT_EQ(overload_decl.result(), DynType{}); + overload_decl.set_result(StringType{}); + EXPECT_EQ(overload_decl.result(), StringType{}); +} + +TEST(OverloadDecl, Args) { + OverloadDecl overload_decl; + EXPECT_THAT(overload_decl.args(), IsEmpty()); + overload_decl.mutable_args().push_back(StringType{}); + EXPECT_THAT(overload_decl.args(), ElementsAre(StringType{})); + EXPECT_THAT(overload_decl.release_args(), ElementsAre(StringType{})); + EXPECT_THAT(overload_decl.args(), IsEmpty()); +} + +TEST(OverloadDecl, Member) { + OverloadDecl overload_decl; + EXPECT_FALSE(overload_decl.member()); + overload_decl.set_member(true); + EXPECT_TRUE(overload_decl.member()); +} + +TEST(OverloadDecl, Equality) { + OverloadDecl overload_decl; + EXPECT_EQ(overload_decl, OverloadDecl{}); + overload_decl.set_member(true); + EXPECT_NE(overload_decl, OverloadDecl{}); +} + +TEST(OverloadDecl, GetTypeParams) { + google::protobuf::Arena arena; + auto overload_decl = MakeOverloadDecl( + "foo", ListType(&arena, TypeParamType("A")), + MapType(&arena, TypeParamType("B"), TypeParamType("C")), + OpaqueType(&arena, "bar", + {FunctionType(&arena, TypeParamType("D"), {})})); + EXPECT_THAT(overload_decl.GetTypeParams(), + UnorderedElementsAre("A", "B", "C", "D")); +} + +TEST(FunctionDecl, Name) { + FunctionDecl function_decl; + EXPECT_THAT(function_decl.name(), IsEmpty()); + function_decl.set_name("foo"); + EXPECT_EQ(function_decl.name(), "foo"); + EXPECT_EQ(function_decl.release_name(), "foo"); + EXPECT_THAT(function_decl.name(), IsEmpty()); +} + +TEST(FunctionDecl, Overloads) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl( + "hello", MakeOverloadDecl("foo", StringType{}, StringType{}), + MakeMemberOverloadDecl("bar", StringType{}, StringType{}), + MakeOverloadDecl("baz", IntType{}, IntType{}))); + + EXPECT_THAT(function_decl.overloads(), + ElementsAre(Property(&OverloadDecl::id, "foo"), + Property(&OverloadDecl::id, "bar"), + Property(&OverloadDecl::id, "baz"))); + + EXPECT_THAT(function_decl.AddOverload( + MakeOverloadDecl("qux", DynType{}, StringType{})), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FunctionDecl, AddOverloadInvalidSignature) { + FunctionDecl function_decl; + function_decl.set_name("foo"); + // Member overload must have at least one argument (the receiver). + // This should fail to add because signature generation fails. + EXPECT_THAT(function_decl.AddOverload(MakeMemberOverloadDecl(StringType{})), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FunctionDecl, AddOverloadDuplicateId) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl("hello", + MakeOverloadDecl("foo", StringType{}, StringType{}))); + // Adding another overload with the same ID "foo" should fail. + EXPECT_THAT( + function_decl.AddOverload(MakeOverloadDecl("foo", IntType{}, IntType{})), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(FunctionDecl, FindOverload) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl( + "hello", MakeOverloadDecl("foo", StringType{}, StringType{}), + MakeMemberOverloadDecl("bar", StringType{}, StringType{}), + MakeOverloadDecl(IntType{}, IntType{}))); + + // Find by explicit ID + const OverloadDecl* overload = function_decl.FindOverloadById("foo"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find by ID fallback to signature + overload = function_decl.FindOverloadById("hello(string)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find implicit overload (where ID == signature) + overload = function_decl.FindOverloadById("hello(int)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "hello(int)"); + + // Non-existent + EXPECT_EQ(function_decl.FindOverloadById("non_existent"), nullptr); +} + +TEST(FunctionDecl, OverloadId) { + google::protobuf::Arena arena; + const auto* descriptor = + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")); + + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl( + "hello", MakeOverloadDecl(DoubleType{}), + MakeOverloadDecl(StringType{}, StringType{}), + MakeOverloadDecl(IntType{}, IntType{}, UintType{}), + MakeOverloadDecl(IntType{}, ListType(&arena, TypeParamType("A"))), + MakeOverloadDecl(IntType{}, MapType(&arena, TypeParamType("B"), + TypeParamType("C"))), + MakeOverloadDecl( + IntType{}, + OpaqueType(&arena, "bar", + {FunctionType(&arena, TypeParamType("D"), {})})), + MakeOverloadDecl(IntType{}, AnyType{}), + MakeOverloadDecl(IntType{}, DurationType{}), + MakeOverloadDecl(IntType{}, TimestampType{}), + MakeOverloadDecl(IntType{}, IntWrapperType{}), + MakeOverloadDecl(IntType{}, MessageType(descriptor)), + MakeMemberOverloadDecl(StringType{}, StringType{}), + MakeMemberOverloadDecl(StringType{}, StringType{}, + ListType(&arena, BoolType{})), + MakeMemberOverloadDecl(StringType{}, StringType{}, BoolType{}, + DynType{}))); + + EXPECT_THAT( + function_decl.overloads(), + ElementsAre(Property(&OverloadDecl::id, "hello()"), + Property(&OverloadDecl::id, "hello(string)"), + Property(&OverloadDecl::id, "hello(int,uint)"), + Property(&OverloadDecl::id, "hello(list<~A>)"), + Property(&OverloadDecl::id, "hello(map<~B,~C>)"), + Property(&OverloadDecl::id, "hello(bar>)"), + Property(&OverloadDecl::id, "hello(any)"), + Property(&OverloadDecl::id, "hello(duration)"), + Property(&OverloadDecl::id, "hello(timestamp)"), + Property(&OverloadDecl::id, "hello(int_wrapper)"), + Property(&OverloadDecl::id, + "hello(cel.expr.conformance.proto3.TestAllTypes)"), + Property(&OverloadDecl::id, "string.hello()"), + Property(&OverloadDecl::id, "string.hello(list)"), + Property(&OverloadDecl::id, "string.hello(bool,dyn)"))); +} + +using common_internal::TypeIsAssignable; + +TEST(TypeIsAssignable, BoolWrapper) { + EXPECT_TRUE(TypeIsAssignable(BoolWrapperType{}, BoolWrapperType{})); + EXPECT_TRUE(TypeIsAssignable(BoolWrapperType{}, NullType{})); + EXPECT_TRUE(TypeIsAssignable(BoolWrapperType{}, BoolType{})); + EXPECT_FALSE(TypeIsAssignable(BoolWrapperType{}, DurationType{})); +} + +TEST(TypeIsAssignable, IntWrapper) { + EXPECT_TRUE(TypeIsAssignable(IntWrapperType{}, IntWrapperType{})); + EXPECT_TRUE(TypeIsAssignable(IntWrapperType{}, NullType{})); + EXPECT_TRUE(TypeIsAssignable(IntWrapperType{}, IntType{})); + EXPECT_FALSE(TypeIsAssignable(IntWrapperType{}, DurationType{})); +} + +TEST(TypeIsAssignable, UintWrapper) { + EXPECT_TRUE(TypeIsAssignable(UintWrapperType{}, UintWrapperType{})); + EXPECT_TRUE(TypeIsAssignable(UintWrapperType{}, NullType{})); + EXPECT_TRUE(TypeIsAssignable(UintWrapperType{}, UintType{})); + EXPECT_FALSE(TypeIsAssignable(UintWrapperType{}, DurationType{})); +} + +TEST(TypeIsAssignable, DoubleWrapper) { + EXPECT_TRUE(TypeIsAssignable(DoubleWrapperType{}, DoubleWrapperType{})); + EXPECT_TRUE(TypeIsAssignable(DoubleWrapperType{}, NullType{})); + EXPECT_TRUE(TypeIsAssignable(DoubleWrapperType{}, DoubleType{})); + EXPECT_FALSE(TypeIsAssignable(DoubleWrapperType{}, DurationType{})); +} + +TEST(TypeIsAssignable, BytesWrapper) { + EXPECT_TRUE(TypeIsAssignable(BytesWrapperType{}, BytesWrapperType{})); + EXPECT_TRUE(TypeIsAssignable(BytesWrapperType{}, NullType{})); + EXPECT_TRUE(TypeIsAssignable(BytesWrapperType{}, BytesType{})); + EXPECT_FALSE(TypeIsAssignable(BytesWrapperType{}, DurationType{})); +} + +TEST(TypeIsAssignable, StringWrapper) { + EXPECT_TRUE(TypeIsAssignable(StringWrapperType{}, StringWrapperType{})); + EXPECT_TRUE(TypeIsAssignable(StringWrapperType{}, NullType{})); + EXPECT_TRUE(TypeIsAssignable(StringWrapperType{}, StringType{})); + EXPECT_FALSE(TypeIsAssignable(StringWrapperType{}, DurationType{})); +} + +TEST(TypeIsAssignable, Complex) { + google::protobuf::Arena arena; + EXPECT_TRUE(TypeIsAssignable(OptionalType(&arena, DynType{}), + OptionalType(&arena, StringType{}))); + EXPECT_FALSE(TypeIsAssignable(OptionalType(&arena, BoolType{}), + OptionalType(&arena, StringType{}))); +} + +} // namespace +} // namespace cel diff --git a/common/enum.cc b/common/enum.cc deleted file mode 100644 index 2bf753cc6..000000000 --- a/common/enum.cc +++ /dev/null @@ -1,58 +0,0 @@ -#include "common/enum.h" - -#include "internal/cel_printer.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -namespace { -using ::google::api::expr::internal::ToCallString; - -struct ValueVisitor { - template - int32_t operator()(const T& value) { - return value.value(); - } -}; - -struct TypeVisitor { - template - EnumType operator()(const T& value) { - return value.type(); - } -}; - -struct ToStringVisitor { - template - std::string operator()(const T& value) { - return value.ToString(); - } -}; - -} // namespace - -std::string UnnamedEnumValue::ToString() const { - return internal::ToCallString(type_.value()->full_name(), value_); -} - -EnumValue::EnumValue(EnumType type, int32_t value) - : data_(NamedEnumValue(type.value()->FindValueByNumber(value))) { - if (absl::get(data_).Handle::value() == nullptr) { - data_ = UnnamedEnumValue(type, value); - } -} - -int32_t EnumValue::value() const { return absl::visit(ValueVisitor(), data_); } - -EnumType EnumValue::type() const { return absl::visit(TypeVisitor(), data_); } - -std::string EnumValue::ToString() const { - return absl::visit(ToStringVisitor(), data_); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/enum.h b/common/enum.h deleted file mode 100644 index beb6f5e84..000000000 --- a/common/enum.h +++ /dev/null @@ -1,92 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_ENUM_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_ENUM_H_ - -#include "common/type.h" -#include "internal/ref_countable.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -/** - * A recognized named enum value. - */ -class NamedEnumValue final - : public internal::Handle { - public: - constexpr NamedEnumValue(const google::protobuf::EnumValueDescriptor* desc) - : Handle(desc) {} - - inline int32_t value() const { return value_->number(); } - inline EnumType type() const { return EnumType(value_->type()); } - - inline const std::string& ToString() const { return value_->full_name(); } -}; - -/** - * An unnamed or unrecognized enum value. - * - * Constructed by 'EnumValue'. - */ -class UnnamedEnumValue final : public internal::RefCountable { - public: - ~UnnamedEnumValue() = default; - - inline int32_t value() const { return value_; } - inline EnumType type() const { return type_; } - - std::string ToString() const; - - inline bool operator==(const UnnamedEnumValue& rhs) const { - return value_ == rhs.value_ && type_ == rhs.type_; - } - inline bool operator!=(const UnnamedEnumValue& rhs) const { - return value_ != rhs.value_ || type_ != rhs.type_; - } - - std::size_t hash_code() const { return internal::Hash(value_, type_); } - - private: - friend class EnumValue; - friend class Value; - - constexpr UnnamedEnumValue(EnumType type, int32_t value) - : type_(type), value_(value) {} - - EnumType type_; - int32_t value_; -}; - -class EnumValue final { - public: - // Allow implicit conversion so visitors can overload using EnumValue or - // the explicit class. - constexpr EnumValue(NamedEnumValue value) : data_(value) {} - EnumValue(UnnamedEnumValue value) : data_(value) {} - EnumValue(EnumType type, int32_t value); - - inline bool is_named() const { return data_.index() == 0; } - - inline NamedEnumValue named_value() const { return absl::get<0>(data_); } - - inline const UnnamedEnumValue& unnamed_value() const { - return absl::get<1>(data_); - } - - int32_t value() const; - EnumType type() const; - - std::string ToString() const; - - private: - absl::variant data_; -}; - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_ENUM_H_ diff --git a/common/error.cc b/common/error.cc deleted file mode 100644 index 58ab769bf..000000000 --- a/common/error.cc +++ /dev/null @@ -1,66 +0,0 @@ -#include "common/error.h" -#include "google/rpc/code.pb.h" -#include "internal/cel_printer.h" -#include "internal/hash_util.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -Error::Error(const google::rpc::Status& error) { errors_.insert(error); } - -Error::Error(absl::Span errors) { - for (const auto* error : errors) { - errors_.insert(*error); - } -} - -Error::Error(absl::Span errors) { - for (const auto& error : errors) { - errors_.insert(error); - } -} - -std::size_t Error::hash_code() const { - std::size_t code = internal::kIntegralTypeOffset; - for (const auto& error : errors_) { - code = internal::MixHashNoOrder( - internal::Hash(error.code(), error.message()), code); - } - return code; -} - -const Error::ErrorData& Error::errors() const { return errors_; } - -bool Error::operator==(const Error& rhs) const { - if (this == &rhs) { - return true; - } - if (hash_code() != rhs.hash_code() || errors_.size() != rhs.errors_.size()) { - return false; - } - for (const auto& error : errors_) { - if (rhs.errors_.find(error) == rhs.errors_.end()) { - return false; - } - } - return true; -} - -std::string Error::ToDebugString() const { - std::multiset codes; - for (const auto& error : errors_) { - codes.emplace( - google::rpc::Code_Name(static_cast(error.code()))); - } - - internal::VarSequencePrinter printer; - return printer("Error", internal::RawString{absl::StrJoin( - codes, internal::SetJoinPolicy::kValueDelim)}); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/error.h b/common/error.h deleted file mode 100644 index 3a33a8763..000000000 --- a/common/error.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_ERROR_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_ERROR_H_ - -#include "google/rpc/status.pb.h" -#include "absl/container/node_hash_set.h" -#include "absl/types/span.h" -#include "internal/hash_util.h" -#include "internal/proto_util.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -/** A CEL Error. */ -class Error { - public: - using ErrorData = absl::node_hash_set; - - explicit Error(const google::rpc::Status& error); - explicit Error(absl::Span errors); - explicit Error(absl::Span errors); - - const ErrorData& errors() const; - - bool operator==(const Error& rhs) const; - inline bool operator!=(const Error& rhs) const { return !(*this == rhs); } - - /** The hash code for this value. */ - std::size_t hash_code() const; - - /** - * A string useful for debugging. - * - * Format may change, and computation may be expensive. - */ - std::string ToDebugString() const; - - private: - ErrorData errors_; -}; - -inline std::ostream& operator<<(std::ostream& os, const Error& value) { - return os << value.ToDebugString(); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_ERROR_H_ diff --git a/common/escaping.cc b/common/escaping.cc deleted file mode 100644 index 39aba6d0a..000000000 --- a/common/escaping.cc +++ /dev/null @@ -1,401 +0,0 @@ -#include "common/escaping.h" - -#include "absl/strings/escaping.h" -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_replace.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { - -inline std::pair unhex(char c) { - if ('0' <= c && c <= '9') { - return std::make_pair(c - '0', true); - } - if ('a' <= c && c <= 'f') { - return std::make_pair(c - 'a' + 10, true); - } - if ('A' <= c && c <= 'F') { - return std::make_pair(c - 'A' + 10, true); - } - return std::make_pair(0, false); -} - -// Write the characters from the first code point into output, which must be at -// least 4 bytes long. Return the number of bytes written. -inline int get_utf8(absl::string_view s, char* buffer) { - buffer[0] = s[0]; - if (static_cast(s[0]) < 0x80 || s.size() < 2) return 1; - buffer[1] = s[1]; - if (static_cast(s[0]) < 0xE0 || s.size() < 3) return 2; - buffer[2] = s[2]; - if (static_cast(s[0]) < 0xF0 || s.size() < 4) return 3; - buffer[3] = s[3]; - return 4; -} - -// Write UTF-8 encoding into a buffer, which must be at least 4 bytes long. -// Return the number of bytes written. -inline int encode_utf8(char* buffer, char32_t utf8_char) { - if (utf8_char <= 0x7F) { - *buffer = static_cast(utf8_char); - return 1; - } else if (utf8_char <= 0x7FF) { - buffer[1] = 0x80 | (utf8_char & 0x3F); - utf8_char >>= 6; - buffer[0] = 0xC0 | utf8_char; - return 2; - } else if (utf8_char <= 0xFFFF) { - buffer[2] = 0x80 | (utf8_char & 0x3F); - utf8_char >>= 6; - buffer[1] = 0x80 | (utf8_char & 0x3F); - utf8_char >>= 6; - buffer[0] = 0xE0 | utf8_char; - return 3; - } else { - buffer[3] = 0x80 | (utf8_char & 0x3F); - utf8_char >>= 6; - buffer[2] = 0x80 | (utf8_char & 0x3F); - utf8_char >>= 6; - buffer[1] = 0x80 | (utf8_char & 0x3F); - utf8_char >>= 6; - buffer[0] = 0xF0 | utf8_char; - return 4; - } -} - -// unescape_char takes a string input and returns the following info: -// -// value - the escaped unicode rune at the front of the string. -// encode - the value should be unicode-encoded -// tail - the remainder of the input string. -// err - error value, if the character could not be unescaped. -// -// When encode is true the return value may still fit within a single byte, -// but unicode encoding is attempted which is more expensive than when the -// value is known to self-represent as a single byte. -// -// If is_bytes is set, unescape as a bytes literal so octal and hex escapes -// represent byte values, not unicode code points. -inline std::tuple unescape_char( - absl::string_view s, bool is_bytes) { - char c = s[0]; - - // 1. Character is not an escape sequence. - if (static_cast(c) >= 0x80 && !is_bytes) { - char tmp[5]; - int len = get_utf8(s, tmp); - tmp[len] = '\0'; - return std::make_tuple(std::string(tmp), s.substr(len), ""); - } else if (c != '\\') { - char tmp[2] = {c, '\0'}; - return std::make_tuple(std::string(tmp), s.substr(1), ""); - } - - // 2. Last character is the start of an escape sequence. - if (s.size() <= 1) { - return std::make_tuple("", s, - "unable to unescape string, " - "found '\\' as last character"); - } - - c = s[1]; - s = s.substr(2); - - char32_t value; - bool encode = false; - - // 3. Common escape sequences shared with Google SQL - switch (c) { - case 'a': - value = '\a'; - break; - case 'b': - value = '\b'; - break; - case 'f': - value = '\f'; - break; - case 'n': - value = '\n'; - break; - case 'r': - value = '\r'; - break; - case 't': - value = '\t'; - break; - case 'v': - value = '\v'; - break; - case '\\': - value = '\\'; - break; - case '\'': - value = '\''; - break; - case '"': - value = '"'; - break; - case '`': - value = '`'; - break; - case '?': - value = '?'; - break; - - // 4. Unicode escape sequences, reproduced from `strconv/quote.go` - case 'x': - [[fallthrough]]; - case 'X': - [[fallthrough]]; - case 'u': - [[fallthrough]]; - case 'U': { - int n = 0; - encode = true; - switch (c) { - case 'x': - [[fallthrough]]; - case 'X': - n = 2; - encode = !is_bytes; - break; - case 'u': - n = 4; - if (is_bytes) { - return std::make_tuple("", s, - "unable to unescape string " - "(\\u in bytes)"); - } - break; - case 'U': - n = 8; - if (is_bytes) { - return std::make_tuple("", s, - "unable to unescape string " - "(\\U in bytes)"); - } - break; - } - char32_t v = 0; - if (static_cast(s.size()) < n) { - return std::make_tuple("", s, - "unable to unescape string " - "(string too short after \\xXuU)"); - } - for (int j = 0; j < n; ++j) { - auto x = unhex(s[j]); - if (!x.second) { - return std::make_tuple("", s, - "unable to unescape string " - "(invalid hex)"); - } - v = v << 4 | x.first; - } - s = s.substr(n); - if (!is_bytes && v > 0x0010FFFF) { - return std::make_tuple("", s, - "unable to unescape string" - "(value out of bounds)"); - } - value = v; - break; - } - - // 5. Octal escape sequences, must be three digits \[0-3][0-7][0-7] - case '0': - [[fallthrough]]; - case '1': - [[fallthrough]]; - case '2': - [[fallthrough]]; - case '3': { - if (s.size() < 2) { - return std::make_tuple("", s, - "unable to unescape octal sequence in string"); - } - char32_t v = c - '0'; - for (int j = 0; j < 2; ++j) { - char x = s[j]; - if (x < '0' || x > '7') { - return std::make_tuple("", s, - "unable to unescape octal sequence " - "in string"); - } - v = v * 8 + (x - '0'); - } - if (!is_bytes && v > 0x0010FFFF) { - return std::make_tuple("", s, "unable to unescape string"); - } - value = v; - s = s.substr(2); - encode = !is_bytes; - } break; - - // Unknown escape sequence. - default: - return std::make_tuple("", s, "unable to unescape string"); - } - - if (value < 0x80 || !encode) { - char tmp[2] = {(char)value, '\0'}; - return std::make_tuple(std::string(tmp), s, ""); - } else { - char tmp[5]; - int len = encode_utf8(tmp, value); - tmp[len] = '\0'; - return std::make_tuple(std::string(tmp), s, ""); - } -} - -// Unescape takes a quoted string, unquotes, and unescapes it. -absl::optional unescape(const std::string& s, bool is_bytes) { - // All strings normalize newlines to the \n representation. - std::string value = absl::StrReplaceAll(s, {{"\r\n", "\n"}, {"\r", "\n"}}); - - size_t n = value.size(); - - // Nothing to unescape / decode. - if (n < 2) { - return value; - } - - // Raw string preceded by the 'r|R' prefix. - bool is_raw_literal = false; - if (value[0] == 'r' || value[0] == 'R') { - value = value.substr(1, n - 1); - n = value.size(); - is_raw_literal = true; - } - - // Quoted string of some form, must have same first and last char. - if (value[0] != value[n - 1] || (value[0] != '"' && value[0] != '\'')) { - return absl::optional(); - } - - // Normalize the multi-line CEL string representation to a standard - // Google SQL or Go quoted string, as accepted by CEL. - if (n >= 6) { - if (absl::StartsWith(value, "'''")) { - if (!absl::EndsWith(value, "'''")) { - return absl::optional(); - } - value = "\"" + value.substr(3, n - 6) + "\""; - } else if (absl::StartsWith(value, "\"\"\"")) { - if (!absl::EndsWith(value, "\"\"\"")) { - return absl::optional(); - } - value = "\"" + value.substr(3, n - 6) + "\""; - } - n = value.size(); - } - value = value.substr(1, n - 2); - // If there is nothing to escape, then return. - if (is_raw_literal || (value.find("\\") == std::string::npos)) { - return value; - } - - if (is_bytes) { - // first convert byte values the non-UTF8 way - std::string new_value; - for (std::string::size_type i = 0; i < value.size() - 1; ++i) { - if (value[i] == '\\') { - if (value[i + 1] == 'x' || value[i + 1] == 'X') { - if (i > (std::numeric_limits::max() - 3) || - i + 3 >= value.size()) { - return absl::optional(); - } - char v = 0; - for (int j = 2; j <= 3; ++j) { - auto x = unhex(value[i + j]); - v = v << 4 | x.first; - } - i += 3; - new_value += v; - } else if (value[i + 1] == '0' || value[i + 1] == '1' || - value[i + 1] == '2' || value[i + 1] == '3') { - if (i > (std::numeric_limits::max() - 3) || - i + 3 >= value.size()) { - return absl::optional(); - } - char v = value[i + 1] - '0'; - for (int j = 1; j <= 3; ++j) { - char x = value[i + j]; - if (x < '0' || x > '7') { - return absl::optional(); - } - v = v * 8 + (x - '0'); - } - i += 3; - new_value += v; - } else { - return absl::optional(); - } - } else { - new_value += value[i]; - } - } - value = std::move(new_value); - } - - std::string unescaped; - unescaped.reserve(3 * value.size() / 2); - absl::string_view value_sv(value); - while (!value_sv.empty()) { - std::tuple c = - unescape_char(value_sv, is_bytes); - if (!std::get<2>(c).empty()) { - return absl::optional(); - } - - unescaped.append(std::get<0>(c)); - value_sv = std::get<1>(c); - } - return unescaped; -} - -std::string escapeAndQuote(absl::string_view str) { - const std::string lowerhex = "0123456789abcdef"; - - std::string s; - for (auto c : str) { - switch (c) { - case '\a': - s.append("\\a"); - break; - case '\b': - s.append("\\b"); - break; - case '\f': - s.append("\\f"); - break; - case '\n': - s.append("\\n"); - break; - case '\r': - s.append("\\r"); - break; - case '\t': - s.append("\\t"); - break; - case '\v': - s.append("\\v"); - break; - case '"': - s.append("\\\""); - break; - default: - s += c; - break; - } - } - return absl::StrFormat("\"%s\"", s); -} - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/escaping.h b/common/escaping.h deleted file mode 100644 index 86273486b..000000000 --- a/common/escaping.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_PARSER_UNESCAPE_H_ -#define THIRD_PARTY_CEL_CPP_PARSER_UNESCAPE_H_ - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { - -// Unescape takes a quoted string, unquotes, and unescapes it. -absl::optional unescape(const std::string& s, bool is_bytes); - -// Takes a string, and escapes values according to CEL and quotes -std::string escapeAndQuote(absl::string_view str); - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_PARSER_UNESCAPE_H_ diff --git a/common/escaping_test.cc b/common/escaping_test.cc deleted file mode 100644 index 8275b48ec..000000000 --- a/common/escaping_test.cc +++ /dev/null @@ -1,104 +0,0 @@ -#include "common/escaping.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { -namespace { - -using testing::Eq; -using testing::Ne; - -constexpr char EXPECT_ERROR[] = "--ERROR--"; - -struct TestInfo { - TestInfo(const std::string& I, const std::string& O, bool is_bytes = false) - : I(I), O(O), is_bytes(is_bytes) {} - - // Input string - std::string I; - - // Expected output string - std::string O; - - // Indicator whether this is a byte or text string - bool is_bytes; -}; - -std::vector test_cases = { - {"'hello'", "hello"}, - {R"("")", ""}, - {R"("\\\"")", R"(\")"}, - {R"("\\")", "\\"}, - {"'''x''x'''", "x''x"}, - {R"("""x""x""")", R"(x""x)"}, - {R"(r"")", ""}, - // Octal 303 -> Code point 195 (Ã) - // Octal 277 -> Code point 191 (¿) - {R"("\303\277")", "ÿ"}, - // Octal 377 -> Code point 255 (ÿ) - {R"("\377")", "ÿ"}, - {R"("\u263A\u263A")", "☺☺"}, - {R"("\a\b\f\n\r\t\v\'\"\\\? Legal escapes")", - "\a\b\f\n\r\t\v'\"\\? Legal escapes"}, - // Illegal escape, expect error - {R"("\a\b\f\n\r\t\v\'\\"\\\? Illegal escape \>")", EXPECT_ERROR}, - {R"("\u1")", EXPECT_ERROR}, - - // The following are interpreted as byte sequences, hence "true" - {"\"abc\"", "\x61\x62\x63", true}, - {"\"ÿ\"", "\xc3\xbf", true}, - {R"("\303\277")", "\xc3\xbf", true}, - {R"("\377")", "\xff", true}, - {R"("\xc3\xbf")", "\xc3\xbf", true}, - {R"("\xff")", "\xff", true}, - // Bytes unicode escape, expect error - {R"("\u00ff")", EXPECT_ERROR, true}, - {R"("\z")", EXPECT_ERROR, true}, - {R"("\x1")", EXPECT_ERROR, true}, - {R"("\u1")", EXPECT_ERROR, true}, -}; - -class UnescapeTest : public testing::TestWithParam {}; - -TEST_P(UnescapeTest, Unescape) { - const TestInfo& test_info = GetParam(); - /* - ::testing::internal::ColoredPrintf(::testing::internal::COLOR_GREEN, - "[ ]"); - ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, - " Input: "); - ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, "%s%s", - test_info.I.c_str(), - test_info.is_bytes ? " BYTES" : ""); - if (test_info.O != EXPECT_ERROR) { - ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, - " Expected Output: "); - ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, - "%s\n", test_info.O.c_str()); - } else { - ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, - " Expecting ERROR\n"); - } - */ - - auto result = unescape(test_info.I, test_info.is_bytes); - if (test_info.O == EXPECT_ERROR) { - EXPECT_THAT(result, Eq(absl::nullopt)); - } else { - ASSERT_THAT(result, Ne(absl::nullopt)); - EXPECT_EQ(*result, test_info.O); - } -} - -INSTANTIATE_TEST_SUITE_P(UnescapeSuite, UnescapeTest, - testing::ValuesIn(test_cases)); - -} // namespace -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/expr.cc b/common/expr.cc new file mode 100644 index 000000000..b9ee29d3b --- /dev/null +++ b/common/expr.cc @@ -0,0 +1,320 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/expr.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/types/variant.h" +#include "common/constant.h" + +namespace cel { + +namespace { + +struct CopyStackRecord { + const Expr* src; + Expr* dst; +}; + +void CopyNode(CopyStackRecord element, std::vector& stack) { + const Expr* src = element.src; + Expr* dst = element.dst; + dst->set_id(src->id()); + absl::visit( + absl::Overload( + [=](const UnspecifiedExpr&) { + dst->mutable_kind().emplace(); + }, + [=](const IdentExpr& i) { + dst->mutable_ident_expr().set_name(i.name()); + }, + [=](const Constant& c) { dst->mutable_const_expr() = c; }, + [&](const SelectExpr& s) { + dst->mutable_select_expr().set_field(s.field()); + dst->mutable_select_expr().set_test_only(s.test_only()); + + if (s.has_operand()) { + stack.push_back({&s.operand(), + &dst->mutable_select_expr().mutable_operand()}); + } + }, + [&](const CallExpr& c) { + dst->mutable_call_expr().set_function(c.function()); + if (c.has_target()) { + stack.push_back( + {&c.target(), &dst->mutable_call_expr().mutable_target()}); + } + dst->mutable_call_expr().mutable_args().resize(c.args().size()); + for (int i = 0; i < dst->mutable_call_expr().mutable_args().size(); + ++i) { + stack.push_back( + {&c.args()[i], &dst->mutable_call_expr().mutable_args()[i]}); + } + }, + [&](const ListExpr& c) { + auto& dst_list = dst->mutable_list_expr(); + dst_list.mutable_elements().resize(c.elements().size()); + for (int i = 0; i < src->list_expr().elements().size(); ++i) { + dst_list.mutable_elements()[i].set_optional( + c.elements()[i].optional()); + stack.push_back({&c.elements()[i].expr(), + &dst_list.mutable_elements()[i].mutable_expr()}); + } + }, + [&](const StructExpr& s) { + auto& dst_struct = dst->mutable_struct_expr(); + dst_struct.mutable_fields().resize(s.fields().size()); + dst_struct.set_name(s.name()); + for (int i = 0; i < s.fields().size(); ++i) { + dst_struct.mutable_fields()[i].set_optional( + s.fields()[i].optional()); + dst_struct.mutable_fields()[i].set_name(s.fields()[i].name()); + dst_struct.mutable_fields()[i].set_id(s.fields()[i].id()); + stack.push_back( + {&s.fields()[i].value(), + &dst_struct.mutable_fields()[i].mutable_value()}); + } + }, + [&](const MapExpr& c) { + auto& dst_map = dst->mutable_map_expr(); + dst_map.mutable_entries().resize(c.entries().size()); + for (int i = 0; i < c.entries().size(); ++i) { + dst_map.mutable_entries()[i].set_optional( + c.entries()[i].optional()); + dst_map.mutable_entries()[i].set_id(c.entries()[i].id()); + stack.push_back({&c.entries()[i].key(), + &dst_map.mutable_entries()[i].mutable_key()}); + stack.push_back({&c.entries()[i].value(), + &dst_map.mutable_entries()[i].mutable_value()}); + } + }, + [&](const ComprehensionExpr& c) { + auto& dst_comprehension = dst->mutable_comprehension_expr(); + dst_comprehension.set_iter_var(c.iter_var()); + dst_comprehension.set_iter_var2(c.iter_var2()); + dst_comprehension.set_accu_var(c.accu_var()); + if (c.has_accu_init()) { + stack.push_back( + {&c.accu_init(), &dst_comprehension.mutable_accu_init()}); + } + if (c.has_iter_range()) { + stack.push_back( + {&c.iter_range(), &dst_comprehension.mutable_iter_range()}); + } + if (c.has_loop_condition()) { + stack.push_back({&c.loop_condition(), + &dst_comprehension.mutable_loop_condition()}); + } + if (c.has_loop_step()) { + stack.push_back( + {&c.loop_step(), &dst_comprehension.mutable_loop_step()}); + } + if (c.has_result()) { + stack.push_back( + {&c.result(), &dst_comprehension.mutable_result()}); + } + }), + src->kind()); +} + +void CloneImpl(const Expr& expr, Expr& dst) { + std::vector stack; + stack.push_back({&expr, &dst}); + while (!stack.empty()) { + CopyStackRecord element = stack.back(); + stack.pop_back(); + CopyNode(element, stack); + } +} + +} // namespace + +const UnspecifiedExpr& UnspecifiedExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const IdentExpr& IdentExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const SelectExpr& SelectExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const CallExpr& CallExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const ListExpr& ListExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const StructExpr& StructExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const MapExpr& MapExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const ComprehensionExpr& ComprehensionExpr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const Expr& Expr::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +Expr& Expr::operator=(const Expr& other) { + if (this == &other) { + return *this; + } + Expr tmp; + CloneImpl(other, tmp); + *this = std::move(tmp); + return *this; +} + +Expr::Expr(const Expr& other) { CloneImpl(other, *this); } + +namespace common_internal { +struct ExprEraseTag {}; +} // namespace common_internal + +namespace { +void Expand(Expr** tail, Expr* cur) { + static common_internal::ExprEraseTag tag; + switch (cur->kind_case()) { + case ExprKindCase::kSelectExpr: { + SelectExpr& select = cur->mutable_select_expr(); + if (select.has_operand()) { + select.mutable_operand().SetNext(tag, *tail); + *tail = &select.mutable_operand(); + } + break; + } + case ExprKindCase::kCallExpr: { + CallExpr& call = cur->mutable_call_expr(); + if (call.has_target()) { + call.mutable_target().SetNext(tag, *tail); + *tail = &call.mutable_target(); + } + for (auto& arg : call.mutable_args()) { + arg.SetNext(tag, *tail); + *tail = &arg; + } + break; + } + case ExprKindCase::kListExpr: { + for (auto& arg : cur->mutable_list_expr().mutable_elements()) { + arg.mutable_expr().SetNext(tag, *tail); + *tail = &arg.mutable_expr(); + } + break; + } + case ExprKindCase::kStructExpr: { + for (auto& field : cur->mutable_struct_expr().mutable_fields()) { + field.mutable_value().SetNext(tag, *tail); + *tail = &field.mutable_value(); + } + break; + } + case ExprKindCase::kMapExpr: { + for (auto& entry : cur->mutable_map_expr().mutable_entries()) { + entry.mutable_key().SetNext(tag, *tail); + *tail = &entry.mutable_key(); + entry.mutable_value().SetNext(tag, *tail); + *tail = &entry.mutable_value(); + } + break; + } + case ExprKindCase::kComprehensionExpr: { + if (cur->comprehension_expr().has_accu_init()) { + cur->mutable_comprehension_expr().mutable_accu_init().SetNext(tag, + *tail); + *tail = &cur->mutable_comprehension_expr().mutable_accu_init(); + } + if (cur->comprehension_expr().has_iter_range()) { + cur->mutable_comprehension_expr().mutable_iter_range().SetNext(tag, + *tail); + *tail = &cur->mutable_comprehension_expr().mutable_iter_range(); + } + if (cur->comprehension_expr().has_loop_condition()) { + cur->mutable_comprehension_expr().mutable_loop_condition().SetNext( + tag, *tail); + *tail = &cur->mutable_comprehension_expr().mutable_loop_condition(); + } + if (cur->comprehension_expr().has_loop_step()) { + cur->mutable_comprehension_expr().mutable_loop_step().SetNext(tag, + *tail); + *tail = &cur->mutable_comprehension_expr().mutable_loop_step(); + } + if (cur->comprehension_expr().has_result()) { + cur->mutable_comprehension_expr().mutable_result().SetNext(tag, *tail); + *tail = &cur->mutable_comprehension_expr().mutable_result(); + } + break; + } + default: + // Leaf node, nothing to expand. + // Also a fallback in case we add a new node type. + // Note: already in the deleter list so we can't delete now, will be + // deleted after ordering the AST. + break; + } +} +} // namespace + +void Expr::FlattenedErase() { + // High level idea is to build a topological ordering of the AST, then erase + // leaf to root. + this->u_.next = nullptr; + Expr* prev_tail = nullptr; + Expr* tail = this; + + while (tail != prev_tail) { + Expr* next_prev_tail = tail; + Expr* expand_ptr = tail; + while (expand_ptr != prev_tail) { + ABSL_DCHECK(expand_ptr != nullptr); // Linked list is broken or changed. + Expr* next_expand_ptr = expand_ptr->u_.next; + Expand(&tail, expand_ptr); + expand_ptr = next_expand_ptr; + } + prev_tail = next_prev_tail; + } + + Expr* node = tail; + while (node != nullptr) { + Expr* next = node->u_.next; + node->Clear(); + node = next; + } +} + +} // namespace cel diff --git a/common/expr.h b/common/expr.h new file mode 100644 index 000000000..7305c2c9f --- /dev/null +++ b/common/expr.h @@ -0,0 +1,1720 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_EXPR_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_EXPR_H_ + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "common/constant.h" + +namespace cel { + +using ExprId = int64_t; + +class Expr; +class UnspecifiedExpr; +class IdentExpr; +class SelectExpr; +class CallExpr; +class ListExprElement; +class ListExpr; +class StructExprField; +class StructExpr; +class MapExprEntry; +class MapExpr; +class ComprehensionExpr; + +inline constexpr absl::string_view kAccumulatorVariableName = "@result"; +inline constexpr absl::string_view kDeprecatedAccumulatorVariableName = + "__result__"; + +bool operator==(const Expr& lhs, const Expr& rhs); + +inline bool operator!=(const Expr& lhs, const Expr& rhs) { + return !operator==(lhs, rhs); +} + +bool operator==(const ListExprElement& lhs, const ListExprElement& rhs); + +inline bool operator!=(const ListExprElement& lhs, const ListExprElement& rhs) { + return !operator==(lhs, rhs); +} + +bool operator==(const StructExprField& lhs, const StructExprField& rhs); + +inline bool operator!=(const StructExprField& lhs, const StructExprField& rhs) { + return !operator==(lhs, rhs); +} + +bool operator==(const MapExprEntry& lhs, const MapExprEntry& rhs); + +inline bool operator!=(const MapExprEntry& lhs, const MapExprEntry& rhs) { + return !operator==(lhs, rhs); +} + +// `UnspecifiedExpr` is the default alternative of `Expr`. It is used for +// default construction of `Expr` or as a placeholder for when errors occur. +class UnspecifiedExpr final { + public: + UnspecifiedExpr() = default; + UnspecifiedExpr(UnspecifiedExpr&&) = default; + UnspecifiedExpr& operator=(UnspecifiedExpr&&) = default; + + UnspecifiedExpr(const UnspecifiedExpr&) = delete; + UnspecifiedExpr& operator=(const UnspecifiedExpr&) = delete; + + void Clear() {} + + friend void swap(UnspecifiedExpr&, UnspecifiedExpr&) noexcept {} + + private: + friend class Expr; + + static const UnspecifiedExpr& default_instance(); +}; + +inline bool operator==(const UnspecifiedExpr&, const UnspecifiedExpr&) { + return true; +} + +inline bool operator!=(const UnspecifiedExpr& lhs, const UnspecifiedExpr& rhs) { + return !operator==(lhs, rhs); +} + +// `IdentExpr` is an alternative of `Expr`, representing an identifier. +class IdentExpr final { + public: + IdentExpr() = default; + IdentExpr(IdentExpr&&) = default; + IdentExpr& operator=(IdentExpr&&) = default; + + explicit IdentExpr(std::string name) { set_name(std::move(name)); } + + explicit IdentExpr(absl::string_view name) { set_name(name); } + + explicit IdentExpr(const char* name) { set_name(name); } + + IdentExpr(const IdentExpr&) = delete; + IdentExpr& operator=(const IdentExpr&) = delete; + + void Clear() { name_.clear(); } + + // Holds a single, unqualified identifier, possibly preceded by a '.'. + // + // Qualified names are represented by the [Expr.Select][] expression. + ABSL_MUST_USE_RESULT const std::string& name() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return name_; + } + + void set_name(std::string name) { name_ = std::move(name); } + + void set_name(absl::string_view name) { + name_.assign(name.data(), name.size()); + } + + void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } + + ABSL_MUST_USE_RESULT std::string release_name() { return release(name_); } + + friend void swap(IdentExpr& lhs, IdentExpr& rhs) noexcept { + using std::swap; + swap(lhs.name_, rhs.name_); + } + + private: + friend class Expr; + + static const IdentExpr& default_instance(); + + static std::string release(std::string& property) { + std::string result; + result.swap(property); + return result; + } + + std::string name_; +}; + +inline bool operator==(const IdentExpr& lhs, const IdentExpr& rhs) { + return lhs.name() == rhs.name(); +} + +inline bool operator!=(const IdentExpr& lhs, const IdentExpr& rhs) { + return !operator==(lhs, rhs); +} + +// `SelectExpr` is an alternative of `Expr`, representing field access. +class SelectExpr final { + public: + SelectExpr() = default; + SelectExpr(SelectExpr&&) = default; + SelectExpr& operator=(SelectExpr&&) = default; + + SelectExpr(const SelectExpr&) = delete; + SelectExpr& operator=(const SelectExpr&) = delete; + + void Clear(); + + ABSL_MUST_USE_RESULT bool has_operand() const { return operand_ != nullptr; } + + // The target of the selection expression. + // + // For example, in the select expression `request.auth`, the `request` + // portion of the expression is the `operand`. + ABSL_MUST_USE_RESULT const Expr& operand() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Expr& mutable_operand() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_operand(Expr operand); + + void set_operand(std::unique_ptr operand); + + ABSL_MUST_USE_RESULT std::unique_ptr release_operand(); + + // The name of the field to select. + // + // For example, in the select expression `request.auth`, the `auth` portion + // of the expression would be the `field`. + ABSL_MUST_USE_RESULT const std::string& field() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return field_; + } + + void set_field(std::string field) { field_ = std::move(field); } + + void set_field(absl::string_view field) { + field_.assign(field.data(), field.size()); + } + + void set_field(const char* field) { + set_field(absl::NullSafeStringView(field)); + } + + ABSL_MUST_USE_RESULT std::string release_field() { return release(field_); } + + // Whether the select is to be interpreted as a field presence test. + // + // This results from the macro `has(request.auth)`. + ABSL_MUST_USE_RESULT bool test_only() const { return test_only_; } + + void set_test_only(bool test_only) { test_only_ = test_only; } + + friend void swap(SelectExpr& lhs, SelectExpr& rhs) noexcept { + using std::swap; + swap(lhs.operand_, rhs.operand_); + swap(lhs.field_, rhs.field_); + swap(lhs.test_only_, rhs.test_only_); + } + + private: + friend class Expr; + + static const SelectExpr& default_instance(); + + static std::string release(std::string& property) { + std::string result; + result.swap(property); + return result; + } + + static std::unique_ptr release(std::unique_ptr& property); + + std::unique_ptr operand_; + std::string field_; + bool test_only_ = false; +}; + +inline bool operator==(const SelectExpr& lhs, const SelectExpr& rhs) { + return lhs.operand() == rhs.operand() && lhs.field() == rhs.field() && + lhs.test_only() == rhs.test_only(); +} + +inline bool operator!=(const SelectExpr& lhs, const SelectExpr& rhs) { + return !operator==(lhs, rhs); +} + +// `CallExpr` is an alternative of `Expr`, representing a function call. +class CallExpr final { + public: + CallExpr() = default; + CallExpr(CallExpr&&) = default; + CallExpr& operator=(CallExpr&&) = default; + + CallExpr(const CallExpr&) = delete; + CallExpr& operator=(const CallExpr&) = delete; + + void Clear(); + + // The name of the function or method being called. + ABSL_MUST_USE_RESULT const std::string& function() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return function_; + } + + void set_function(std::string function) { function_ = std::move(function); } + + void set_function(absl::string_view function) { + function_.assign(function.data(), function.size()); + } + + void set_function(const char* function) { + set_function(absl::NullSafeStringView(function)); + } + + ABSL_MUST_USE_RESULT std::string release_function() { + return release(function_); + } + + ABSL_MUST_USE_RESULT bool has_target() const { return target_ != nullptr; } + + // The target of an method call-style expression. For example, `x` in `x.f()`. + ABSL_MUST_USE_RESULT const Expr& target() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Expr& mutable_target() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_target(Expr target); + + void set_target(std::unique_ptr target); + + ABSL_MUST_USE_RESULT std::unique_ptr release_target(); + + // The arguments. + ABSL_MUST_USE_RESULT const std::vector& args() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return args_; + } + + ABSL_MUST_USE_RESULT std::vector& mutable_args() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return args_; + } + + void set_args(std::vector args); + + void set_args(absl::Span args); + + Expr& add_args() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT std::vector release_args(); + + friend void swap(CallExpr& lhs, CallExpr& rhs) noexcept { + using std::swap; + swap(lhs.function_, rhs.function_); + swap(lhs.target_, rhs.target_); + swap(lhs.args_, rhs.args_); + } + + private: + friend class Expr; + + static const CallExpr& default_instance(); + + static std::string release(std::string& property) { + std::string result; + result.swap(property); + return result; + } + + static std::unique_ptr release(std::unique_ptr& property); + + std::string function_; + std::unique_ptr target_; + std::vector args_; +}; + +bool operator==(const CallExpr& lhs, const CallExpr& rhs); + +inline bool operator!=(const CallExpr& lhs, const CallExpr& rhs) { + return !operator==(lhs, rhs); +} + +// `ListExprElement` represents an element in `ListExpr`. +class ListExprElement final { + public: + ListExprElement() = default; + ListExprElement(ListExprElement&&) = default; + ListExprElement& operator=(ListExprElement&&) = default; + + ListExprElement(const ListExprElement&) = delete; + ListExprElement& operator=(const ListExprElement&) = delete; + + void Clear(); + + ABSL_MUST_USE_RESULT bool has_expr() const { return expr_ != nullptr; } + + ABSL_MUST_USE_RESULT const Expr& expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT Expr& mutable_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_expr(Expr expr); + + void set_expr(std::unique_ptr expr); + + ABSL_MUST_USE_RESULT Expr release_expr(); + + ABSL_MUST_USE_RESULT bool optional() const { return optional_; } + + void set_optional(bool optional) { optional_ = optional; } + + friend void swap(ListExprElement& lhs, ListExprElement& rhs) noexcept; + + private: + static Expr release(std::unique_ptr& property); + + std::unique_ptr expr_; + bool optional_ = false; +}; + +// `ListExpr` is an alternative of `Expr`, representing a list. +class ListExpr final { + public: + ListExpr() = default; + ListExpr(ListExpr&&) = default; + ListExpr& operator=(ListExpr&&) = default; + + ListExpr(const ListExpr&) = delete; + ListExpr& operator=(const ListExpr&) = delete; + + void Clear(); + + // The elements of the list. + ABSL_MUST_USE_RESULT const std::vector& elements() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return elements_; + } + + ABSL_MUST_USE_RESULT std::vector& mutable_elements() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return elements_; + } + + void set_elements(std::vector elements); + + void set_elements(absl::Span elements); + + ListExprElement& add_elements() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT std::vector release_elements(); + + friend void swap(ListExpr& lhs, ListExpr& rhs) noexcept { + using std::swap; + swap(lhs.elements_, rhs.elements_); + } + + private: + friend class Expr; + + static const ListExpr& default_instance(); + + std::vector elements_; +}; + +bool operator==(const ListExpr& lhs, const ListExpr& rhs); + +inline bool operator!=(const ListExpr& lhs, const ListExpr& rhs) { + return !operator==(lhs, rhs); +} + +// `StructExprField` represents a field in `StructExpr`. +class StructExprField final { + public: + StructExprField() = default; + StructExprField(StructExprField&&) = default; + StructExprField& operator=(StructExprField&&) = default; + + StructExprField(const StructExprField&) = delete; + StructExprField& operator=(const StructExprField&) = delete; + + void Clear(); + + ABSL_MUST_USE_RESULT ExprId id() const { return id_; } + + void set_id(ExprId id) { id_ = id; } + + ABSL_MUST_USE_RESULT const std::string& name() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return name_; + } + + void set_name(std::string name) { name_ = std::move(name); } + + void set_name(absl::string_view name) { + name_.assign(name.data(), name.size()); + } + + void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } + + ABSL_MUST_USE_RESULT std::string release_name() { return std::move(name_); } + + ABSL_MUST_USE_RESULT bool has_value() const { return value_ != nullptr; } + + ABSL_MUST_USE_RESULT const Expr& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT Expr& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_value(Expr value); + + void set_value(std::unique_ptr value); + + ABSL_MUST_USE_RESULT Expr release_value(); + + ABSL_MUST_USE_RESULT bool optional() const { return optional_; } + + void set_optional(bool optional) { optional_ = optional; } + + friend void swap(StructExprField& lhs, StructExprField& rhs) noexcept; + + private: + static Expr release(std::unique_ptr& property); + + ExprId id_ = 0; + std::string name_; + std::unique_ptr value_; + bool optional_ = false; +}; + +// `StructExpr` is an alternative of `Expr`, representing a struct. +class StructExpr final { + public: + StructExpr() = default; + StructExpr(StructExpr&&) = default; + StructExpr& operator=(StructExpr&&) = default; + + StructExpr(const StructExpr&) = delete; + StructExpr& operator=(const StructExpr&) = delete; + + void Clear(); + + // The type name of the struct to be created. + ABSL_MUST_USE_RESULT const std::string& name() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return name_; + } + + void set_name(std::string name) { name_ = std::move(name); } + + void set_name(absl::string_view name) { + name_.assign(name.data(), name.size()); + } + + void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } + + ABSL_MUST_USE_RESULT std::string release_name() { return release(name_); } + + // The fields of the struct. + ABSL_MUST_USE_RESULT const std::vector& fields() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return fields_; + } + + ABSL_MUST_USE_RESULT std::vector& mutable_fields() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return fields_; + } + + void set_fields(std::vector fields); + + void set_fields(absl::Span fields); + + StructExprField& add_fields() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT std::vector release_fields(); + + friend void swap(StructExpr& lhs, StructExpr& rhs) noexcept { + using std::swap; + swap(lhs.name_, rhs.name_); + swap(lhs.fields_, rhs.fields_); + } + + private: + friend class Expr; + + static const StructExpr& default_instance(); + + static std::string release(std::string& property) { + std::string result; + result.swap(property); + return result; + } + + std::string name_; + std::vector fields_; +}; + +bool operator==(const StructExpr& lhs, const StructExpr& rhs); + +inline bool operator!=(const StructExpr& lhs, const StructExpr& rhs) { + return !operator==(lhs, rhs); +} + +// `MapExprEntry` represents an entry in `MapExpr`. +class MapExprEntry final { + public: + MapExprEntry() = default; + MapExprEntry(MapExprEntry&&) = default; + MapExprEntry& operator=(MapExprEntry&&) = default; + + MapExprEntry(const MapExprEntry&) = delete; + MapExprEntry& operator=(const MapExprEntry&) = delete; + + void Clear(); + + ABSL_MUST_USE_RESULT ExprId id() const { return id_; } + + void set_id(ExprId id) { id_ = id; } + + ABSL_MUST_USE_RESULT bool has_key() const { return key_ != nullptr; } + + ABSL_MUST_USE_RESULT const Expr& key() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT Expr& mutable_key() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_key(Expr key); + + void set_key(std::unique_ptr key); + + ABSL_MUST_USE_RESULT Expr release_key(); + + ABSL_MUST_USE_RESULT bool has_value() const { return value_ != nullptr; } + + ABSL_MUST_USE_RESULT const Expr& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT Expr& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_value(Expr value); + + void set_value(std::unique_ptr value); + + ABSL_MUST_USE_RESULT Expr release_value(); + + ABSL_MUST_USE_RESULT bool optional() const { return optional_; } + + void set_optional(bool optional) { optional_ = optional; } + + friend void swap(MapExprEntry& lhs, MapExprEntry& rhs) noexcept; + + private: + static Expr release(std::unique_ptr& property); + + ExprId id_ = 0; + std::unique_ptr key_; + std::unique_ptr value_; + bool optional_ = false; +}; + +// `MapExpr` is an alternative of `Expr`, representing a map. +class MapExpr final { + public: + MapExpr() = default; + MapExpr(MapExpr&&) = default; + MapExpr& operator=(MapExpr&&) = default; + + MapExpr(const MapExpr&) = delete; + MapExpr& operator=(const MapExpr&) = delete; + + void Clear(); + + // The entries of the map. + ABSL_MUST_USE_RESULT const std::vector& entries() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return entries_; + } + + ABSL_MUST_USE_RESULT std::vector& mutable_entries() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return entries_; + } + + void set_entries(std::vector entries); + + void set_entries(absl::Span entries); + + MapExprEntry& add_entries() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_MUST_USE_RESULT std::vector release_entries(); + + friend void swap(MapExpr& lhs, MapExpr& rhs) noexcept { + using std::swap; + swap(lhs.entries_, rhs.entries_); + } + + private: + friend class Expr; + + static const MapExpr& default_instance(); + + std::vector entries_; +}; + +bool operator==(const MapExpr& lhs, const MapExpr& rhs); + +inline bool operator!=(const MapExpr& lhs, const MapExpr& rhs) { + return !operator==(lhs, rhs); +} + +// `ComprehensionExpr` is an alternative of `Expr`, representing a +// comprehension. These are always synthetic as there is no way to express them +// directly in the Common Expression Language, and are created by macros. +class ComprehensionExpr final { + public: + ComprehensionExpr() = default; + ComprehensionExpr(ComprehensionExpr&&) = default; + ComprehensionExpr& operator=(ComprehensionExpr&&) = default; + + ComprehensionExpr(const ComprehensionExpr&) = delete; + ComprehensionExpr& operator=(const ComprehensionExpr&) = delete; + + void Clear(); + + ABSL_MUST_USE_RESULT const std::string& iter_var() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return iter_var_; + } + + void set_iter_var(std::string iter_var) { iter_var_ = std::move(iter_var); } + + void set_iter_var(absl::string_view iter_var) { + iter_var_.assign(iter_var.data(), iter_var.size()); + } + + void set_iter_var(const char* iter_var) { + set_iter_var(absl::NullSafeStringView(iter_var)); + } + + ABSL_MUST_USE_RESULT std::string release_iter_var() { + return release(iter_var_); + } + + ABSL_MUST_USE_RESULT const std::string& iter_var2() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return iter_var2_; + } + + void set_iter_var2(std::string iter_var2) { + iter_var2_ = std::move(iter_var2); + } + + void set_iter_var2(absl::string_view iter_var2) { + iter_var2_.assign(iter_var2.data(), iter_var2.size()); + } + + void set_iter_var2(const char* iter_var2) { + set_iter_var2(absl::NullSafeStringView(iter_var2)); + } + + ABSL_MUST_USE_RESULT std::string release_iter_var2() { + return release(iter_var2_); + } + + ABSL_MUST_USE_RESULT bool has_iter_range() const { + return iter_range_ != nullptr; + } + + ABSL_MUST_USE_RESULT const Expr& iter_range() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Expr& mutable_iter_range() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_iter_range(Expr iter_range); + + void set_iter_range(std::unique_ptr iter_range); + + ABSL_MUST_USE_RESULT std::unique_ptr release_iter_range(); + + ABSL_MUST_USE_RESULT const std::string& accu_var() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return accu_var_; + } + + void set_accu_var(std::string accu_var) { accu_var_ = std::move(accu_var); } + + void set_accu_var(absl::string_view accu_var) { + accu_var_.assign(accu_var.data(), accu_var.size()); + } + + void set_accu_var(const char* accu_var) { + set_accu_var(absl::NullSafeStringView(accu_var)); + } + + ABSL_MUST_USE_RESULT std::string release_accu_var() { + return release(accu_var_); + } + + ABSL_MUST_USE_RESULT bool has_accu_init() const { + return accu_init_ != nullptr; + } + + ABSL_MUST_USE_RESULT const Expr& accu_init() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Expr& mutable_accu_init() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_accu_init(Expr accu_init); + + void set_accu_init(std::unique_ptr accu_init); + + ABSL_MUST_USE_RESULT std::unique_ptr release_accu_init(); + + ABSL_MUST_USE_RESULT bool has_loop_condition() const { + return loop_condition_ != nullptr; + } + + ABSL_MUST_USE_RESULT const Expr& loop_condition() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Expr& mutable_loop_condition() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_loop_condition(Expr loop_condition); + + void set_loop_condition(std::unique_ptr loop_condition); + + ABSL_MUST_USE_RESULT std::unique_ptr release_loop_condition(); + + ABSL_MUST_USE_RESULT bool has_loop_step() const { + return loop_step_ != nullptr; + } + + ABSL_MUST_USE_RESULT const Expr& loop_step() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Expr& mutable_loop_step() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_loop_step(Expr loop_step); + + void set_loop_step(std::unique_ptr loop_step); + + ABSL_MUST_USE_RESULT std::unique_ptr release_loop_step(); + + ABSL_MUST_USE_RESULT bool has_result() const { return result_ != nullptr; } + + ABSL_MUST_USE_RESULT const Expr& result() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Expr& mutable_result() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + void set_result(Expr result); + + void set_result(std::unique_ptr result); + + ABSL_MUST_USE_RESULT std::unique_ptr release_result(); + + friend void swap(ComprehensionExpr& lhs, ComprehensionExpr& rhs) noexcept { + using std::swap; + swap(lhs.iter_var_, rhs.iter_var_); + swap(lhs.iter_var2_, rhs.iter_var2_); + swap(lhs.iter_range_, rhs.iter_range_); + swap(lhs.accu_var_, rhs.accu_var_); + swap(lhs.accu_init_, rhs.accu_init_); + swap(lhs.loop_condition_, rhs.loop_condition_); + swap(lhs.loop_step_, rhs.loop_step_); + swap(lhs.result_, rhs.result_); + } + + private: + friend class Expr; + + static const ComprehensionExpr& default_instance(); + + static std::string release(std::string& property) { + std::string result; + result.swap(property); + return result; + } + + static std::unique_ptr release(std::unique_ptr& property); + + std::string iter_var_; + std::string iter_var2_; + std::unique_ptr iter_range_; + std::string accu_var_; + std::unique_ptr accu_init_; + std::unique_ptr loop_condition_; + std::unique_ptr loop_step_; + std::unique_ptr result_; +}; + +inline bool operator==(const ComprehensionExpr& lhs, + const ComprehensionExpr& rhs) { + return lhs.iter_var() == rhs.iter_var() && + lhs.iter_range() == rhs.iter_range() && + lhs.accu_var() == rhs.accu_var() && + lhs.accu_init() == rhs.accu_init() && + lhs.loop_condition() == rhs.loop_condition() && + lhs.loop_step() == rhs.loop_step() && lhs.result() == rhs.result(); +} + +inline bool operator!=(const ComprehensionExpr& lhs, + const ComprehensionExpr& rhs) { + return !operator==(lhs, rhs); +} + +using ExprKind = + absl::variant; + +enum class ExprKindCase { + kUnspecifiedExpr, + kConstant, + kIdentExpr, + kSelectExpr, + kCallExpr, + kListExpr, + kStructExpr, + kMapExpr, + kComprehensionExpr, +}; + +namespace common_internal { +struct ExprEraseTag; +} // namespace common_internal + +// `Expr` is a node in the Common Expression Language's abstract syntax tree. It +// is composed of a numeric ID and a kind variant. +class Expr final { + public: + Expr() = default; + Expr(Expr&&) = default; + Expr& operator=(Expr&&); + + Expr(const Expr&); + Expr& operator=(const Expr&); + + void Clear(); + + ABSL_MUST_USE_RESULT ExprId id() const { return u_.id; } + + void set_id(ExprId id) { u_.id = id; } + + ABSL_MUST_USE_RESULT const ExprKind& kind() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return kind_; + } + + ABSL_MUST_USE_RESULT ExprKind& mutable_kind() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return kind_; + } + + void set_kind(ExprKind kind); + + ABSL_MUST_USE_RESULT ExprKind release_kind(); + + ABSL_MUST_USE_RESULT bool has_const_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const Constant& const_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + Constant& mutable_const_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_const_expr(Constant const_expr) { + try_emplace_kind() = std::move(const_expr); + } + + ABSL_MUST_USE_RESULT Constant release_const_expr() { + return release_kind(); + } + + ABSL_MUST_USE_RESULT bool has_ident_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const IdentExpr& ident_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + IdentExpr& mutable_ident_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_ident_expr(IdentExpr ident_expr) { + try_emplace_kind() = std::move(ident_expr); + } + + ABSL_MUST_USE_RESULT IdentExpr release_ident_expr() { + return release_kind(); + } + + ABSL_MUST_USE_RESULT bool has_select_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const SelectExpr& select_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + SelectExpr& mutable_select_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_select_expr(SelectExpr select_expr) { + try_emplace_kind() = std::move(select_expr); + } + + ABSL_MUST_USE_RESULT SelectExpr release_select_expr() { + return release_kind(); + } + + ABSL_MUST_USE_RESULT bool has_call_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const CallExpr& call_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + CallExpr& mutable_call_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_call_expr(CallExpr call_expr); + + ABSL_MUST_USE_RESULT CallExpr release_call_expr(); + + ABSL_MUST_USE_RESULT bool has_list_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const ListExpr& list_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + ListExpr& mutable_list_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_list_expr(ListExpr list_expr); + + ABSL_MUST_USE_RESULT ListExpr release_list_expr(); + + ABSL_MUST_USE_RESULT bool has_struct_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const StructExpr& struct_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + StructExpr& mutable_struct_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_struct_expr(StructExpr struct_expr); + + ABSL_MUST_USE_RESULT StructExpr release_struct_expr(); + + ABSL_MUST_USE_RESULT bool has_map_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const MapExpr& map_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + MapExpr& mutable_map_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_map_expr(MapExpr map_expr); + + ABSL_MUST_USE_RESULT MapExpr release_map_expr(); + + ABSL_MUST_USE_RESULT bool has_comprehension_expr() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const ComprehensionExpr& comprehension_expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get_kind(); + } + + ComprehensionExpr& mutable_comprehension_expr() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return try_emplace_kind(); + } + + void set_comprehension_expr(ComprehensionExpr comprehension_expr) { + try_emplace_kind() = std::move(comprehension_expr); + } + + ABSL_MUST_USE_RESULT ComprehensionExpr release_comprehension_expr() { + return release_kind(); + } + + ExprKindCase kind_case() const; + + friend void swap(Expr& lhs, Expr& rhs) noexcept; + + // Erases the expr in place without recursion. + void FlattenedErase(); + + inline void SetNext(common_internal::ExprEraseTag&, Expr* next); + + private: + friend class IdentExpr; + friend class SelectExpr; + friend class CallExpr; + friend class ListExpr; + friend class StructExpr; + friend class MapExpr; + friend class ComprehensionExpr; + friend class ListExprElement; + friend class StructExprField; + friend class MapExprEntry; + + static const Expr& default_instance(); + + template + ABSL_MUST_USE_RESULT T& try_emplace_kind(Args&&... args) + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (auto* alt = absl::get_if(&mutable_kind()); alt) { + return *alt; + } + return kind_.emplace(std::forward(args)...); + } + + template + ABSL_MUST_USE_RESULT const T& get_kind() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (const auto* alt = absl::get_if(&kind()); alt) { + return *alt; + } + return T::default_instance(); + } + + template + ABSL_MUST_USE_RESULT T release_kind(); + + union { + ExprId id = 0; + // Intrusive pointer to the next element in the destructor chain. + // Only set from FlattenedErase. + Expr* next; + } u_; + ExprKind kind_; +}; + +inline bool operator==(const Expr& lhs, const Expr& rhs) { + return lhs.id() == rhs.id() && lhs.kind() == rhs.kind(); +} + +inline bool operator==(const CallExpr& lhs, const CallExpr& rhs) { + return lhs.function() == rhs.function() && lhs.target() == rhs.target() && + absl::c_equal(lhs.args(), rhs.args()); +} + +inline void SelectExpr::Clear() { + operand_.reset(); + field_.clear(); + test_only_ = false; +} + +ABSL_MUST_USE_RESULT inline std::unique_ptr +SelectExpr::release_operand() { + return release(operand_); +} + +inline const Expr& SelectExpr::operand() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_operand() ? *operand_ : Expr::default_instance(); +} + +inline Expr& SelectExpr::mutable_operand() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_operand()) { + operand_ = std::make_unique(); + } + return *operand_; +} + +inline void SelectExpr::set_operand(Expr operand) { + mutable_operand() = std::move(operand); +} + +inline void SelectExpr::set_operand(std::unique_ptr operand) { + operand_ = std::move(operand); +} + +inline std::unique_ptr SelectExpr::release( + std::unique_ptr& property) { + std::unique_ptr result; + result.swap(property); + return result; +} + +inline void ComprehensionExpr::Clear() { + iter_var_.clear(); + iter_range_.reset(); + accu_var_.clear(); + accu_init_.reset(); + loop_condition_.reset(); + loop_step_.reset(); + result_.reset(); +} + +inline const Expr& ComprehensionExpr::iter_range() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_iter_range() ? *iter_range_ : Expr::default_instance(); +} + +inline Expr& ComprehensionExpr::mutable_iter_range() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_iter_range()) { + iter_range_ = std::make_unique(); + } + return *iter_range_; +} + +inline void ComprehensionExpr::set_iter_range(Expr iter_range) { + mutable_iter_range() = std::move(iter_range); +} + +inline void ComprehensionExpr::set_iter_range( + std::unique_ptr iter_range) { + iter_range_ = std::move(iter_range); +} + +ABSL_MUST_USE_RESULT inline std::unique_ptr +ComprehensionExpr::release_iter_range() { + return release(iter_range_); +} + +inline const Expr& ComprehensionExpr::accu_init() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_accu_init() ? *accu_init_ : Expr::default_instance(); +} + +ABSL_MUST_USE_RESULT inline std::unique_ptr +ComprehensionExpr::release_accu_init() { + return release(accu_init_); +} + +inline Expr& ComprehensionExpr::mutable_accu_init() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_accu_init()) { + accu_init_ = std::make_unique(); + } + return *accu_init_; +} + +inline void ComprehensionExpr::set_accu_init(Expr accu_init) { + mutable_accu_init() = std::move(accu_init); +} + +inline void ComprehensionExpr::set_accu_init(std::unique_ptr accu_init) { + accu_init_ = std::move(accu_init); +} + +inline const Expr& ComprehensionExpr::loop_step() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_loop_step() ? *loop_step_ : Expr::default_instance(); +} + +inline Expr& ComprehensionExpr::mutable_loop_step() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_loop_step()) { + loop_step_ = std::make_unique(); + } + return *loop_step_; +} + +inline void ComprehensionExpr::set_loop_step(Expr loop_step) { + mutable_loop_step() = std::move(loop_step); +} + +inline void ComprehensionExpr::set_loop_step(std::unique_ptr loop_step) { + loop_step_ = std::move(loop_step); +} + +ABSL_MUST_USE_RESULT inline std::unique_ptr +ComprehensionExpr::release_loop_step() { + return release(loop_step_); +} + +inline const Expr& ComprehensionExpr::loop_condition() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_loop_condition() ? *loop_condition_ : Expr::default_instance(); +} + +ABSL_MUST_USE_RESULT inline std::unique_ptr +ComprehensionExpr::release_loop_condition() { + return release(loop_condition_); +} + +inline Expr& ComprehensionExpr::mutable_loop_condition() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_loop_condition()) { + loop_condition_ = std::make_unique(); + } + return *loop_condition_; +} + +inline void ComprehensionExpr::set_loop_condition(Expr loop_condition) { + mutable_loop_condition() = std::move(loop_condition); +} + +inline void ComprehensionExpr::set_loop_condition( + std::unique_ptr loop_condition) { + loop_condition_ = std::move(loop_condition); +} + +inline const Expr& ComprehensionExpr::result() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_result() ? *result_ : Expr::default_instance(); +} + +inline Expr& ComprehensionExpr::mutable_result() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_result()) { + result_ = std::make_unique(); + } + return *result_; +} + +inline void ComprehensionExpr::set_result(Expr result) { + mutable_result() = std::move(result); +} + +inline void ComprehensionExpr::set_result(std::unique_ptr result) { + result_ = std::move(result); +} + +ABSL_MUST_USE_RESULT inline std::unique_ptr +ComprehensionExpr::release_result() { + return release(result_); +} + +inline std::unique_ptr ComprehensionExpr::release( + std::unique_ptr& property) { + std::unique_ptr result; + result.swap(property); + return result; +} + +inline bool operator==(const ListExprElement& lhs, const ListExprElement& rhs) { + return lhs.expr() == rhs.expr() && lhs.optional() == rhs.optional(); +} + +inline bool operator==(const ListExpr& lhs, const ListExpr& rhs) { + return absl::c_equal(lhs.elements(), rhs.elements()); +} + +inline bool operator==(const StructExprField& lhs, const StructExprField& rhs) { + return lhs.id() == rhs.id() && lhs.name() == rhs.name() && + lhs.value() == rhs.value() && lhs.optional() == rhs.optional(); +} + +inline bool operator==(const StructExpr& lhs, const StructExpr& rhs) { + return lhs.name() == rhs.name() && absl::c_equal(lhs.fields(), rhs.fields()); +} + +inline bool operator==(const MapExprEntry& lhs, const MapExprEntry& rhs) { + return lhs.id() == rhs.id() && lhs.key() == rhs.key() && + lhs.value() == rhs.value() && lhs.optional() == rhs.optional(); +} + +inline bool operator==(const MapExpr& lhs, const MapExpr& rhs) { + return absl::c_equal(lhs.entries(), rhs.entries()); +} + +inline void MapExpr::Clear() { entries_.clear(); } + +inline void MapExpr::set_entries(std::vector entries) { + entries_ = std::move(entries); +} + +inline void MapExpr::set_entries(absl::Span entries) { + entries_.clear(); + entries_.reserve(entries.size()); + for (auto& entry : entries) { + entries_.push_back(std::move(entry)); + } +} + +inline MapExprEntry& MapExpr::add_entries() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return mutable_entries().emplace_back(); +} + +inline std::vector MapExpr::release_entries() { + std::vector entries; + entries.swap(entries_); + return entries; +} + +inline void Expr::Clear() { + u_.id = 0; + mutable_kind().emplace(); +} + +inline Expr& Expr::operator=(Expr&&) = default; + +inline void Expr::set_kind(ExprKind kind) { kind_ = std::move(kind); } + +inline ABSL_MUST_USE_RESULT ExprKind Expr::release_kind() { + ExprKind kind = std::move(kind_); + kind_.emplace(); + return kind; +} + +inline void Expr::set_call_expr(CallExpr call_expr) { + try_emplace_kind() = std::move(call_expr); +} + +inline ABSL_MUST_USE_RESULT CallExpr Expr::release_call_expr() { + return release_kind(); +} + +inline void Expr::set_list_expr(ListExpr list_expr) { + try_emplace_kind() = std::move(list_expr); +} + +inline ListExpr Expr::release_list_expr() { return release_kind(); } + +inline void Expr::set_struct_expr(StructExpr struct_expr) { + try_emplace_kind() = std::move(struct_expr); +} + +inline StructExpr Expr::release_struct_expr() { + return release_kind(); +} + +inline void Expr::set_map_expr(MapExpr map_expr) { + try_emplace_kind() = std::move(map_expr); +} + +inline MapExpr Expr::release_map_expr() { return release_kind(); } + +template +ABSL_MUST_USE_RESULT T Expr::release_kind() { + T result; + if (auto* alt = absl::get_if(&mutable_kind()); alt) { + result = std::move(*alt); + } + kind_.emplace(); + return result; +} + +inline ExprKindCase Expr::kind_case() const { + static_assert(absl::variant_size_v == 9); + if (kind_.index() <= 9) { + return static_cast(kind_.index()); + } + return ExprKindCase::kUnspecifiedExpr; +} + +inline void swap(Expr& lhs, Expr& rhs) noexcept { + using std::swap; + swap(lhs.u_, rhs.u_); + swap(lhs.kind_, rhs.kind_); +} + +inline void CallExpr::Clear() { + function_.clear(); + target_.reset(); + args_.clear(); +} + +inline const Expr& CallExpr::target() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_target() ? *target_ : Expr::default_instance(); +} + +inline Expr& CallExpr::mutable_target() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_target()) { + target_ = std::make_unique(); + } + return *target_; +} + +inline void CallExpr::set_target(Expr target) { + mutable_target() = std::move(target); +} + +inline void CallExpr::set_target(std::unique_ptr target) { + target_ = std::move(target); +} + +ABSL_MUST_USE_RESULT inline std::unique_ptr CallExpr::release_target() { + return release(target_); +} + +inline void CallExpr::set_args(std::vector args) { + args_ = std::move(args); +} + +inline void CallExpr::set_args(absl::Span args) { + args_.clear(); + args_.reserve(args.size()); + for (auto& arg : args) { + args_.push_back(std::move(arg)); + } +} + +inline Expr& CallExpr::add_args() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return mutable_args().emplace_back(); +} + +inline std::vector CallExpr::release_args() { + std::vector args; + args.swap(args_); + return args; +} + +inline std::unique_ptr CallExpr::release( + std::unique_ptr& property) { + std::unique_ptr result; + result.swap(property); + return result; +} + +inline void ListExprElement::Clear() { + expr_.reset(); + optional_ = false; +} + +inline ABSL_MUST_USE_RESULT const Expr& ListExprElement::expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_expr() ? *expr_ : Expr::default_instance(); +} + +inline ABSL_MUST_USE_RESULT Expr& ListExprElement::mutable_expr() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_expr()) { + expr_ = std::make_unique(); + } + return *expr_; +} + +inline void ListExprElement::set_expr(Expr expr) { + mutable_expr() = std::move(expr); +} + +inline void ListExprElement::set_expr(std::unique_ptr expr) { + expr_ = std::move(expr); +} + +inline ABSL_MUST_USE_RESULT Expr ListExprElement::release_expr() { + return release(expr_); +} + +inline void swap(ListExprElement& lhs, ListExprElement& rhs) noexcept { + using std::swap; + swap(lhs.expr_, rhs.expr_); + swap(lhs.optional_, rhs.optional_); +} + +inline Expr ListExprElement::release(std::unique_ptr& property) { + std::unique_ptr result; + result.swap(property); + if (result != nullptr) { + return std::move(*result); + } + return Expr{}; +} + +inline void ListExpr::Clear() { elements_.clear(); } + +inline void ListExpr::set_elements(std::vector elements) { + elements_ = std::move(elements); +} + +inline void ListExpr::set_elements(absl::Span elements) { + elements_.clear(); + elements_.reserve(elements.size()); + for (auto& element : elements) { + elements_.push_back(std::move(element)); + } +} + +inline ListExprElement& ListExpr::add_elements() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return mutable_elements().emplace_back(); +} + +inline std::vector ListExpr::release_elements() { + std::vector elements; + elements.swap(elements_); + return elements; +} + +inline void StructExprField::Clear() { + id_ = 0; + name_.clear(); + value_.reset(); + optional_ = false; +} + +inline ABSL_MUST_USE_RESULT const Expr& StructExprField::value() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_value() ? *value_ : Expr::default_instance(); +} + +inline ABSL_MUST_USE_RESULT Expr& StructExprField::mutable_value() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_value()) { + value_ = std::make_unique(); + } + return *value_; +} + +inline void StructExprField::set_value(Expr value) { + mutable_value() = std::move(value); +} + +inline void StructExprField::set_value(std::unique_ptr value) { + value_ = std::move(value); +} + +inline ABSL_MUST_USE_RESULT Expr StructExprField::release_value() { + return release(value_); +} + +inline void swap(StructExprField& lhs, StructExprField& rhs) noexcept { + using std::swap; + swap(lhs.id_, rhs.id_); + swap(lhs.name_, rhs.name_); + swap(lhs.value_, rhs.value_); + swap(lhs.optional_, rhs.optional_); +} + +inline Expr StructExprField::release(std::unique_ptr& property) { + std::unique_ptr result; + result.swap(property); + if (result != nullptr) { + return std::move(*result); + } + return Expr{}; +} + +inline void StructExpr::Clear() { + name_.clear(); + fields_.clear(); +} + +inline void StructExpr::set_fields(std::vector fields) { + fields_ = std::move(fields); +} + +inline void StructExpr::set_fields(absl::Span fields) { + fields_.clear(); + fields_.reserve(fields.size()); + for (auto& field : fields) { + fields_.push_back(std::move(field)); + } +} + +inline StructExprField& StructExpr::add_fields() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return mutable_fields().emplace_back(); +} + +inline std::vector StructExpr::release_fields() { + std::vector fields; + fields.swap(fields_); + return fields; +} + +inline void MapExprEntry::Clear() { + id_ = 0; + key_.reset(); + value_.reset(); + optional_ = false; +} + +inline ABSL_MUST_USE_RESULT const Expr& MapExprEntry::key() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_key() ? *key_ : Expr::default_instance(); +} + +inline ABSL_MUST_USE_RESULT Expr& MapExprEntry::mutable_key() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_key()) { + key_ = std::make_unique(); + } + return *key_; +} + +inline void MapExprEntry::set_key(Expr key) { mutable_key() = std::move(key); } + +inline void MapExprEntry::set_key(std::unique_ptr key) { + key_ = std::move(key); +} + +inline ABSL_MUST_USE_RESULT Expr MapExprEntry::release_key() { + return release(key_); +} + +inline ABSL_MUST_USE_RESULT const Expr& MapExprEntry::value() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return has_value() ? *value_ : Expr::default_instance(); +} + +inline ABSL_MUST_USE_RESULT Expr& MapExprEntry::mutable_value() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_value()) { + value_ = std::make_unique(); + } + return *value_; +} + +inline void MapExprEntry::set_value(Expr value) { + mutable_value() = std::move(value); +} + +inline void MapExprEntry::set_value(std::unique_ptr value) { + value_ = std::move(value); +} + +inline ABSL_MUST_USE_RESULT Expr MapExprEntry::release_value() { + return release(value_); +} + +inline void swap(MapExprEntry& lhs, MapExprEntry& rhs) noexcept { + using std::swap; + swap(lhs.id_, rhs.id_); + swap(lhs.key_, rhs.key_); + swap(lhs.value_, rhs.value_); + swap(lhs.optional_, rhs.optional_); +} + +inline Expr MapExprEntry::release(std::unique_ptr& property) { + std::unique_ptr result; + result.swap(property); + if (result != nullptr) { + return std::move(*result); + } + return Expr{}; +} + +inline void Expr::SetNext(common_internal::ExprEraseTag&, Expr* next) { + u_.next = next; +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_EXPR_H_ diff --git a/common/expr_factory.h b/common/expr_factory.h new file mode 100644 index 000000000..757318545 --- /dev/null +++ b/common/expr_factory.h @@ -0,0 +1,396 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_EXPR_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_EXPR_FACTORY_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +class MacroExprFactory; +class ParserMacroExprFactory; +class OptimizerExprFactory; + +namespace tools { +class ProtoToPredicateBuilder; +} + +class ExprFactory { + protected: + // `IsExprLike` determines whether `T` is some `Expr`. Currently that means + // either `Expr` or `std::unique_ptr`. This allows us to make the + // factory functions generic and avoid redefining them for every argument + // combination. + template + struct IsExprLike + : std::bool_constant, std::is_same>>> {}; + + // `IsStringLike` determines whether `T` is something that looks like a + // string. Currently that means `const char*`, `std::string`, or + // `absl::string_view`. This allows us to make the factory functions generic + // and avoid redefining them for every argument combination. This is necessary + // to avoid copies if possible. + template + struct IsStringLike + : std::bool_constant, std::is_same, + std::is_same, std::is_same>> { + }; + + template + struct IsStringLike : std::true_type {}; + + // `IsArrayLike` determines whether `T` is something that looks like an array + // or span of some element. + template + struct IsArrayLike : std::false_type {}; + + template + struct IsArrayLike> : std::true_type {}; + + template + struct IsArrayLike> : std::true_type {}; + + public: + ExprFactory(const ExprFactory&) = delete; + ExprFactory(ExprFactory&&) = delete; + ExprFactory& operator=(const ExprFactory&) = delete; + ExprFactory& operator=(ExprFactory&&) = delete; + + virtual ~ExprFactory() = default; + + Expr NewUnspecified(ExprId id) { + Expr expr; + expr.set_id(id); + return expr; + } + + Expr NewConst(ExprId id, Constant value) { + Expr expr; + expr.set_id(id); + expr.mutable_const_expr() = std::move(value); + return expr; + } + + Expr NewNullConst(ExprId id) { + Constant constant; + constant.set_null_value(); + return NewConst(id, std::move(constant)); + } + + Expr NewBoolConst(ExprId id, bool value) { + Constant constant; + constant.set_bool_value(value); + return NewConst(id, std::move(constant)); + } + + Expr NewIntConst(ExprId id, int64_t value) { + Constant constant; + constant.set_int_value(value); + return NewConst(id, std::move(constant)); + } + + Expr NewUintConst(ExprId id, uint64_t value) { + Constant constant; + constant.set_uint_value(value); + return NewConst(id, std::move(constant)); + } + + Expr NewDoubleConst(ExprId id, double value) { + Constant constant; + constant.set_double_value(value); + return NewConst(id, std::move(constant)); + } + + Expr NewBytesConst(ExprId id, BytesConstant value) { + Constant constant; + constant.set_bytes_value(std::move(value)); + return NewConst(id, std::move(constant)); + } + + Expr NewBytesConst(ExprId id, std::string value) { + Constant constant; + constant.set_bytes_value(std::move(value)); + return NewConst(id, std::move(constant)); + } + + Expr NewBytesConst(ExprId id, absl::string_view value) { + Constant constant; + constant.set_bytes_value(value); + return NewConst(id, std::move(constant)); + } + + Expr NewBytesConst(ExprId id, const char* value) { + Constant constant; + constant.set_bytes_value(value); + return NewConst(id, std::move(constant)); + } + + Expr NewStringConst(ExprId id, StringConstant value) { + Constant constant; + constant.set_string_value(std::move(value)); + return NewConst(id, std::move(constant)); + } + + Expr NewStringConst(ExprId id, std::string value) { + Constant constant; + constant.set_string_value(std::move(value)); + return NewConst(id, std::move(constant)); + } + + Expr NewStringConst(ExprId id, absl::string_view value) { + Constant constant; + constant.set_string_value(value); + return NewConst(id, std::move(constant)); + } + + Expr NewStringConst(ExprId id, const char* value) { + Constant constant; + constant.set_string_value(value); + return NewConst(id, std::move(constant)); + } + + template ::value>> + Expr NewIdent(ExprId id, Name name) { + Expr expr; + expr.set_id(id); + auto& ident_expr = expr.mutable_ident_expr(); + ident_expr.set_name(std::move(name)); + return expr; + } + + absl::string_view AccuVarName() { return accu_var_; } + + Expr NewAccuIdent(ExprId id) { return NewIdent(id, AccuVarName()); } + + template ::value>, + typename = std::enable_if_t::value>> + Expr NewSelect(ExprId id, Operand operand, Field field) { + Expr expr; + expr.set_id(id); + auto& select_expr = expr.mutable_select_expr(); + select_expr.set_operand(std::move(operand)); + select_expr.set_field(std::move(field)); + select_expr.set_test_only(false); + return expr; + } + + template ::value>, + typename = std::enable_if_t::value>> + Expr NewPresenceTest(ExprId id, Operand operand, Field field) { + Expr expr; + expr.set_id(id); + auto& select_expr = expr.mutable_select_expr(); + select_expr.set_operand(std::move(operand)); + select_expr.set_field(std::move(field)); + select_expr.set_test_only(true); + return expr; + } + + template ::value>, + typename = std::enable_if_t::value>> + Expr NewCall(ExprId id, Function function, Args args) { + Expr expr; + expr.set_id(id); + auto& call_expr = expr.mutable_call_expr(); + call_expr.set_function(std::move(function)); + call_expr.set_args(std::move(args)); + return expr; + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewMemberCall(ExprId id, Function function, Target target, Args args) { + Expr expr; + expr.set_id(id); + auto& call_expr = expr.mutable_call_expr(); + call_expr.set_function(std::move(function)); + call_expr.set_target(std::move(target)); + call_expr.set_args(std::move(args)); + return expr; + } + + template ::value>> + ListExprElement NewListElement(Expr expr, bool optional = false) { + ListExprElement element; + element.set_expr(std::move(expr)); + element.set_optional(optional); + return element; + } + + template ::value>> + Expr NewList(ExprId id, Elements elements) { + Expr expr; + expr.set_id(id); + auto& list_expr = expr.mutable_list_expr(); + list_expr.set_elements(std::move(elements)); + return expr; + } + + template ::value>, + typename = std::enable_if_t::value>> + StructExprField NewStructField(ExprId id, Name name, Value value, + bool optional = false) { + StructExprField field; + field.set_id(id); + field.set_name(std::move(name)); + field.set_value(std::move(value)); + field.set_optional(optional); + return field; + } + + template < + typename Name, typename Fields, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewStruct(ExprId id, Name name, Fields fields) { + Expr expr; + expr.set_id(id); + auto& struct_expr = expr.mutable_struct_expr(); + struct_expr.set_name(std::move(name)); + struct_expr.set_fields(std::move(fields)); + return expr; + } + + template ::value>, + typename = std::enable_if_t::value>> + MapExprEntry NewMapEntry(ExprId id, Key key, Value value, + bool optional = false) { + MapExprEntry entry; + entry.set_id(id); + entry.set_key(std::move(key)); + entry.set_value(std::move(value)); + entry.set_optional(optional); + return entry; + } + + template ::value>> + Expr NewMap(ExprId id, Entries entries) { + Expr expr; + expr.set_id(id); + auto& map_expr = expr.mutable_map_expr(); + map_expr.set_entries(std::move(entries)); + return expr; + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewComprehension(ExprId id, IterVar iter_var, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, + LoopCondition loop_condition, LoopStep loop_step, + Result result) { + return NewComprehension(id, std::move(iter_var), "", std::move(iter_range), + std::move(accu_var), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewComprehension(ExprId id, IterVar iter_var, IterVar2 iter_var2, + IterRange iter_range, AccuVar accu_var, + AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + Expr expr; + expr.set_id(id); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var(std::move(iter_var)); + comprehension_expr.set_iter_var2(std::move(iter_var2)); + comprehension_expr.set_iter_range(std::move(iter_range)); + comprehension_expr.set_accu_var(std::move(accu_var)); + comprehension_expr.set_accu_init(std::move(accu_init)); + comprehension_expr.set_loop_condition(std::move(loop_condition)); + comprehension_expr.set_loop_step(std::move(loop_step)); + comprehension_expr.set_result(std::move(result)); + return expr; + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewBind(NextIdFunc next_id, BindVar bind_var, BindExpr bind_expr, + RestExpr rest_expr) { + Expr expr; + expr.set_id(next_id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var("#unused"); + comprehension_expr.set_iter_range( + NewList(next_id(), std::vector{})); + comprehension_expr.set_accu_var(bind_var); + comprehension_expr.set_accu_init(std::move(bind_expr)); + comprehension_expr.set_loop_condition(NewBoolConst(next_id(), false)); + comprehension_expr.set_loop_step(NewIdent(next_id(), bind_var)); + comprehension_expr.set_result(std::move(rest_expr)); + return expr; + } + + private: + friend class MacroExprFactory; + friend class ParserMacroExprFactory; + friend class OptimizerExprFactory; + friend class tools::ProtoToPredicateBuilder; + + ExprFactory() : accu_var_(kAccumulatorVariableName) {} + + std::string accu_var_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_EXPR_FACTORY_H_ diff --git a/common/expr_test.cc b/common/expr_test.cc new file mode 100644 index 000000000..4f30dbd6f --- /dev/null +++ b/common/expr_test.cc @@ -0,0 +1,674 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/expr.h" + +#include + +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::_; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::IsFalse; +using ::testing::IsTrue; +using ::testing::SizeIs; +using ::testing::VariantWith; + +Expr MakeUnspecifiedExpr(ExprId id) { + Expr expr; + expr.set_id(id); + return expr; +} + +ListExprElement MakeListExprElement(Expr expr, bool optional = false) { + ListExprElement element; + element.set_expr(std::move(expr)); + element.set_optional(optional); + return element; +} + +StructExprField MakeStructExprField(ExprId id, const char* name, Expr value, + bool optional = false) { + StructExprField field; + field.set_id(id); + field.set_name(name); + field.set_value(std::move(value)); + field.set_optional(optional); + return field; +} + +MapExprEntry MakeMapExprEntry(ExprId id, Expr key, Expr value, + bool optional = false) { + MapExprEntry entry; + entry.set_id(id); + entry.set_key(std::move(key)); + entry.set_value(std::move(value)); + entry.set_optional(optional); + return entry; +} + +TEST(UnspecifiedExpr, Equality) { + EXPECT_EQ(UnspecifiedExpr{}, UnspecifiedExpr{}); +} + +TEST(IdentExpr, Name) { + IdentExpr ident_expr; + EXPECT_THAT(ident_expr.name(), IsEmpty()); + ident_expr.set_name("foo"); + EXPECT_THAT(ident_expr.name(), Eq("foo")); + auto name = ident_expr.release_name(); + EXPECT_THAT(name, Eq("foo")); + EXPECT_THAT(ident_expr.name(), IsEmpty()); +} + +TEST(IdentExpr, Equality) { + EXPECT_EQ(IdentExpr{}, IdentExpr{}); + IdentExpr ident_expr; + ident_expr.set_name("foo"); + EXPECT_NE(IdentExpr{}, ident_expr); +} + +TEST(SelectExpr, Operand) { + SelectExpr select_expr; + EXPECT_THAT(select_expr.has_operand(), IsFalse()); + EXPECT_EQ(select_expr.operand(), Expr{}); + select_expr.set_operand(MakeUnspecifiedExpr(1)); + EXPECT_THAT(select_expr.has_operand(), IsTrue()); + EXPECT_EQ(select_expr.operand(), MakeUnspecifiedExpr(1)); + auto operand = select_expr.release_operand(); + EXPECT_THAT(select_expr.has_operand(), IsFalse()); + EXPECT_EQ(select_expr.operand(), Expr{}); +} + +TEST(SelectExpr, Field) { + SelectExpr select_expr; + EXPECT_THAT(select_expr.field(), IsEmpty()); + select_expr.set_field("foo"); + EXPECT_THAT(select_expr.field(), Eq("foo")); + auto field = select_expr.release_field(); + EXPECT_THAT(field, Eq("foo")); + EXPECT_THAT(select_expr.field(), IsEmpty()); +} + +TEST(SelectExpr, TestOnly) { + SelectExpr select_expr; + EXPECT_THAT(select_expr.test_only(), IsFalse()); + select_expr.set_test_only(true); + EXPECT_THAT(select_expr.test_only(), IsTrue()); +} + +TEST(SelectExpr, Equality) { + EXPECT_EQ(SelectExpr{}, SelectExpr{}); + SelectExpr select_expr; + select_expr.set_test_only(true); + EXPECT_NE(SelectExpr{}, select_expr); +} + +TEST(CallExpr, Function) { + CallExpr call_expr; + EXPECT_THAT(call_expr.function(), IsEmpty()); + call_expr.set_function("foo"); + EXPECT_THAT(call_expr.function(), Eq("foo")); + auto function = call_expr.release_function(); + EXPECT_THAT(function, Eq("foo")); + EXPECT_THAT(call_expr.function(), IsEmpty()); +} + +TEST(CallExpr, Target) { + CallExpr call_expr; + EXPECT_THAT(call_expr.has_target(), IsFalse()); + EXPECT_EQ(call_expr.target(), Expr{}); + call_expr.set_target(MakeUnspecifiedExpr(1)); + EXPECT_THAT(call_expr.has_target(), IsTrue()); + EXPECT_EQ(call_expr.target(), MakeUnspecifiedExpr(1)); + auto operand = call_expr.release_target(); + EXPECT_THAT(call_expr.has_target(), IsFalse()); + EXPECT_EQ(call_expr.target(), Expr{}); +} + +TEST(CallExpr, Args) { + CallExpr call_expr; + EXPECT_THAT(call_expr.args(), IsEmpty()); + call_expr.mutable_args().push_back(MakeUnspecifiedExpr(1)); + ASSERT_THAT(call_expr.args(), SizeIs(1)); + EXPECT_EQ(call_expr.args()[0], MakeUnspecifiedExpr(1)); + auto args = call_expr.release_args(); + static_cast(args); + EXPECT_THAT(call_expr.args(), IsEmpty()); +} + +TEST(CallExpr, Equality) { + EXPECT_EQ(CallExpr{}, CallExpr{}); + CallExpr call_expr; + call_expr.mutable_args().push_back(MakeUnspecifiedExpr(1)); + EXPECT_NE(CallExpr{}, call_expr); +} + +TEST(ListExprElement, Expr) { + ListExprElement element; + EXPECT_THAT(element.has_expr(), IsFalse()); + EXPECT_EQ(element.expr(), Expr{}); + element.set_expr(MakeUnspecifiedExpr(1)); + EXPECT_THAT(element.has_expr(), IsTrue()); + EXPECT_EQ(element.expr(), MakeUnspecifiedExpr(1)); + auto operand = element.release_expr(); + EXPECT_THAT(element.has_expr(), IsFalse()); + EXPECT_EQ(element.expr(), Expr{}); +} + +TEST(ListExprElement, Optional) { + ListExprElement element; + EXPECT_THAT(element.optional(), IsFalse()); + element.set_optional(true); + EXPECT_THAT(element.optional(), IsTrue()); +} + +TEST(ListExprElement, Equality) { + EXPECT_EQ(ListExprElement{}, ListExprElement{}); + ListExprElement element; + element.set_optional(true); + EXPECT_NE(ListExprElement{}, element); +} + +TEST(ListExpr, Elements) { + ListExpr list_expr; + EXPECT_THAT(list_expr.elements(), IsEmpty()); + list_expr.mutable_elements().push_back( + MakeListExprElement(MakeUnspecifiedExpr(1))); + ASSERT_THAT(list_expr.elements(), SizeIs(1)); + EXPECT_EQ(list_expr.elements()[0], + MakeListExprElement(MakeUnspecifiedExpr(1))); + auto elements = list_expr.release_elements(); + static_cast(elements); + EXPECT_THAT(list_expr.elements(), IsEmpty()); +} + +TEST(ListExpr, Equality) { + EXPECT_EQ(ListExpr{}, ListExpr{}); + ListExpr list_expr; + list_expr.mutable_elements().push_back( + MakeListExprElement(MakeUnspecifiedExpr(0), true)); + EXPECT_NE(ListExpr{}, list_expr); +} + +TEST(StructExprField, Id) { + StructExprField field; + EXPECT_THAT(field.id(), Eq(0)); + field.set_id(1); + EXPECT_THAT(field.id(), Eq(1)); +} + +TEST(StructExprField, Name) { + StructExprField field; + EXPECT_THAT(field.name(), IsEmpty()); + field.set_name("foo"); + EXPECT_THAT(field.name(), Eq("foo")); + auto name = field.release_name(); + EXPECT_THAT(name, Eq("foo")); + EXPECT_THAT(field.name(), IsEmpty()); +} + +TEST(StructExprField, Value) { + StructExprField field; + EXPECT_THAT(field.has_value(), IsFalse()); + EXPECT_EQ(field.value(), Expr{}); + field.set_value(MakeUnspecifiedExpr(1)); + EXPECT_THAT(field.has_value(), IsTrue()); + EXPECT_EQ(field.value(), MakeUnspecifiedExpr(1)); + auto value = field.release_value(); + EXPECT_THAT(field.has_value(), IsFalse()); + EXPECT_EQ(field.value(), Expr{}); +} + +TEST(StructExprField, Optional) { + StructExprField field; + EXPECT_THAT(field.optional(), IsFalse()); + field.set_optional(true); + EXPECT_THAT(field.optional(), IsTrue()); +} + +TEST(StructExprField, Equality) { + EXPECT_EQ(StructExprField{}, StructExprField{}); + StructExprField field; + field.set_optional(true); + EXPECT_NE(StructExprField{}, field); +} + +TEST(StructExpr, Name) { + StructExpr struct_expr; + EXPECT_THAT(struct_expr.name(), IsEmpty()); + struct_expr.set_name("foo"); + EXPECT_THAT(struct_expr.name(), Eq("foo")); + auto name = struct_expr.release_name(); + EXPECT_THAT(name, Eq("foo")); + EXPECT_THAT(struct_expr.name(), IsEmpty()); +} + +TEST(StructExpr, Fields) { + StructExpr struct_expr; + EXPECT_THAT(struct_expr.fields(), IsEmpty()); + struct_expr.mutable_fields().push_back( + MakeStructExprField(1, "foo", MakeUnspecifiedExpr(1))); + ASSERT_THAT(struct_expr.fields(), SizeIs(1)); + EXPECT_EQ(struct_expr.fields()[0], + MakeStructExprField(1, "foo", MakeUnspecifiedExpr(1))); + auto fields = struct_expr.release_fields(); + static_cast(fields); + EXPECT_THAT(struct_expr.fields(), IsEmpty()); +} + +TEST(StructExpr, Equality) { + EXPECT_EQ(StructExpr{}, StructExpr{}); + StructExpr struct_expr; + struct_expr.mutable_fields().push_back( + MakeStructExprField(0, "", MakeUnspecifiedExpr(0), true)); + EXPECT_NE(StructExpr{}, struct_expr); +} + +TEST(MapExprEntry, Id) { + MapExprEntry entry; + EXPECT_THAT(entry.id(), Eq(0)); + entry.set_id(1); + EXPECT_THAT(entry.id(), Eq(1)); +} + +TEST(MapExprEntry, Key) { + MapExprEntry entry; + EXPECT_THAT(entry.has_key(), IsFalse()); + EXPECT_EQ(entry.key(), Expr{}); + entry.set_key(MakeUnspecifiedExpr(1)); + EXPECT_THAT(entry.has_key(), IsTrue()); + EXPECT_EQ(entry.key(), MakeUnspecifiedExpr(1)); + auto key = entry.release_key(); + static_cast(key); + EXPECT_THAT(entry.has_key(), IsFalse()); + EXPECT_EQ(entry.key(), Expr{}); +} + +TEST(MapExprEntry, Value) { + MapExprEntry entry; + EXPECT_THAT(entry.has_value(), IsFalse()); + EXPECT_EQ(entry.value(), Expr{}); + entry.set_value(MakeUnspecifiedExpr(1)); + EXPECT_THAT(entry.has_value(), IsTrue()); + EXPECT_EQ(entry.value(), MakeUnspecifiedExpr(1)); + auto value = entry.release_value(); + static_cast(value); + EXPECT_THAT(entry.has_value(), IsFalse()); + EXPECT_EQ(entry.value(), Expr{}); +} + +TEST(MapExprEntry, Optional) { + MapExprEntry entry; + EXPECT_THAT(entry.optional(), IsFalse()); + entry.set_optional(true); + EXPECT_THAT(entry.optional(), IsTrue()); +} + +TEST(MapExprEntry, Equality) { + EXPECT_EQ(StructExprField{}, StructExprField{}); + StructExprField field; + field.set_optional(true); + EXPECT_NE(StructExprField{}, field); +} + +TEST(MapExpr, Entries) { + MapExpr map_expr; + EXPECT_THAT(map_expr.entries(), IsEmpty()); + map_expr.mutable_entries().push_back( + MakeMapExprEntry(1, MakeUnspecifiedExpr(1), MakeUnspecifiedExpr(1))); + ASSERT_THAT(map_expr.entries(), SizeIs(1)); + EXPECT_EQ(map_expr.entries()[0], MakeMapExprEntry(1, MakeUnspecifiedExpr(1), + MakeUnspecifiedExpr(1))); + auto entries = map_expr.release_entries(); + static_cast(entries); + EXPECT_THAT(map_expr.entries(), IsEmpty()); +} + +TEST(MapExpr, Equality) { + EXPECT_EQ(MapExpr{}, MapExpr{}); + MapExpr map_expr; + map_expr.mutable_entries().push_back(MakeMapExprEntry( + 0, MakeUnspecifiedExpr(0), MakeUnspecifiedExpr(0), true)); + EXPECT_NE(MapExpr{}, map_expr); +} + +TEST(ComprehensionExpr, IterVar) { + ComprehensionExpr comprehension_expr; + EXPECT_THAT(comprehension_expr.iter_var(), IsEmpty()); + comprehension_expr.set_iter_var("foo"); + EXPECT_THAT(comprehension_expr.iter_var(), Eq("foo")); + auto iter_var = comprehension_expr.release_iter_var(); + EXPECT_THAT(iter_var, Eq("foo")); + EXPECT_THAT(comprehension_expr.iter_var(), IsEmpty()); +} + +TEST(ComprehensionExpr, IterRange) { + ComprehensionExpr comprehension_expr; + EXPECT_THAT(comprehension_expr.has_iter_range(), IsFalse()); + EXPECT_EQ(comprehension_expr.iter_range(), Expr{}); + comprehension_expr.set_iter_range(MakeUnspecifiedExpr(1)); + EXPECT_THAT(comprehension_expr.has_iter_range(), IsTrue()); + EXPECT_EQ(comprehension_expr.iter_range(), MakeUnspecifiedExpr(1)); + auto operand = comprehension_expr.release_iter_range(); + EXPECT_THAT(comprehension_expr.has_iter_range(), IsFalse()); + EXPECT_EQ(comprehension_expr.iter_range(), Expr{}); +} + +TEST(ComprehensionExpr, AccuVar) { + ComprehensionExpr comprehension_expr; + EXPECT_THAT(comprehension_expr.accu_var(), IsEmpty()); + comprehension_expr.set_accu_var("foo"); + EXPECT_THAT(comprehension_expr.accu_var(), Eq("foo")); + auto accu_var = comprehension_expr.release_accu_var(); + EXPECT_THAT(accu_var, Eq("foo")); + EXPECT_THAT(comprehension_expr.accu_var(), IsEmpty()); +} + +TEST(ComprehensionExpr, AccuInit) { + ComprehensionExpr comprehension_expr; + EXPECT_THAT(comprehension_expr.has_accu_init(), IsFalse()); + EXPECT_EQ(comprehension_expr.accu_init(), Expr{}); + comprehension_expr.set_accu_init(MakeUnspecifiedExpr(1)); + EXPECT_THAT(comprehension_expr.has_accu_init(), IsTrue()); + EXPECT_EQ(comprehension_expr.accu_init(), MakeUnspecifiedExpr(1)); + auto operand = comprehension_expr.release_accu_init(); + EXPECT_THAT(comprehension_expr.has_accu_init(), IsFalse()); + EXPECT_EQ(comprehension_expr.accu_init(), Expr{}); +} + +TEST(ComprehensionExpr, LoopCondition) { + ComprehensionExpr comprehension_expr; + EXPECT_THAT(comprehension_expr.has_loop_condition(), IsFalse()); + EXPECT_EQ(comprehension_expr.loop_condition(), Expr{}); + comprehension_expr.set_loop_condition(MakeUnspecifiedExpr(1)); + EXPECT_THAT(comprehension_expr.has_loop_condition(), IsTrue()); + EXPECT_EQ(comprehension_expr.loop_condition(), MakeUnspecifiedExpr(1)); + auto operand = comprehension_expr.release_loop_condition(); + EXPECT_THAT(comprehension_expr.has_loop_condition(), IsFalse()); + EXPECT_EQ(comprehension_expr.loop_condition(), Expr{}); +} + +TEST(ComprehensionExpr, LoopStep) { + ComprehensionExpr comprehension_expr; + EXPECT_THAT(comprehension_expr.has_loop_step(), IsFalse()); + EXPECT_EQ(comprehension_expr.loop_step(), Expr{}); + comprehension_expr.set_loop_step(MakeUnspecifiedExpr(1)); + EXPECT_THAT(comprehension_expr.has_loop_step(), IsTrue()); + EXPECT_EQ(comprehension_expr.loop_step(), MakeUnspecifiedExpr(1)); + auto operand = comprehension_expr.release_loop_step(); + EXPECT_THAT(comprehension_expr.has_loop_step(), IsFalse()); + EXPECT_EQ(comprehension_expr.loop_step(), Expr{}); +} + +TEST(ComprehensionExpr, Result) { + ComprehensionExpr comprehension_expr; + EXPECT_THAT(comprehension_expr.has_result(), IsFalse()); + EXPECT_EQ(comprehension_expr.result(), Expr{}); + comprehension_expr.set_result(MakeUnspecifiedExpr(1)); + EXPECT_THAT(comprehension_expr.has_result(), IsTrue()); + EXPECT_EQ(comprehension_expr.result(), MakeUnspecifiedExpr(1)); + auto operand = comprehension_expr.release_result(); + EXPECT_THAT(comprehension_expr.has_result(), IsFalse()); + EXPECT_EQ(comprehension_expr.result(), Expr{}); +} + +TEST(ComprehensionExpr, Equality) { + EXPECT_EQ(ComprehensionExpr{}, ComprehensionExpr{}); + ComprehensionExpr comprehension_expr; + comprehension_expr.set_result(MakeUnspecifiedExpr(1)); + EXPECT_NE(ComprehensionExpr{}, comprehension_expr); +} + +TEST(Expr, Unspecified) { + Expr expr; + EXPECT_THAT(expr.id(), Eq(ExprId{0})); + EXPECT_THAT(expr.kind(), VariantWith(_)); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, Ident) { + Expr expr; + EXPECT_THAT(expr.has_ident_expr(), IsFalse()); + EXPECT_EQ(expr.ident_expr(), IdentExpr{}); + auto& ident_expr = expr.mutable_ident_expr(); + EXPECT_THAT(expr.has_ident_expr(), IsTrue()); + EXPECT_NE(expr, Expr{}); + ident_expr.set_name("foo"); + EXPECT_NE(expr.ident_expr(), IdentExpr{}); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kIdentExpr); + static_cast(expr.release_ident_expr()); + EXPECT_THAT(expr.has_ident_expr(), IsFalse()); + EXPECT_EQ(expr.ident_expr(), IdentExpr{}); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, Select) { + Expr expr; + EXPECT_THAT(expr.has_select_expr(), IsFalse()); + EXPECT_EQ(expr.select_expr(), SelectExpr{}); + auto& select_expr = expr.mutable_select_expr(); + EXPECT_THAT(expr.has_select_expr(), IsTrue()); + EXPECT_NE(expr, Expr{}); + select_expr.set_field("foo"); + EXPECT_NE(expr.select_expr(), SelectExpr{}); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kSelectExpr); + static_cast(expr.release_select_expr()); + EXPECT_THAT(expr.has_select_expr(), IsFalse()); + EXPECT_EQ(expr.select_expr(), SelectExpr{}); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, Call) { + Expr expr; + EXPECT_THAT(expr.has_call_expr(), IsFalse()); + EXPECT_EQ(expr.call_expr(), CallExpr{}); + auto& call_expr = expr.mutable_call_expr(); + EXPECT_THAT(expr.has_call_expr(), IsTrue()); + EXPECT_NE(expr, Expr{}); + call_expr.set_function("foo"); + EXPECT_NE(expr.call_expr(), CallExpr{}); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kCallExpr); + static_cast(expr.release_call_expr()); + EXPECT_THAT(expr.has_call_expr(), IsFalse()); + EXPECT_EQ(expr.call_expr(), CallExpr{}); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, List) { + Expr expr; + EXPECT_THAT(expr.has_list_expr(), IsFalse()); + EXPECT_EQ(expr.list_expr(), ListExpr{}); + auto& list_expr = expr.mutable_list_expr(); + EXPECT_THAT(expr.has_list_expr(), IsTrue()); + EXPECT_NE(expr, Expr{}); + list_expr.mutable_elements().push_back(MakeListExprElement(Expr{}, true)); + EXPECT_NE(expr.list_expr(), ListExpr{}); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kListExpr); + static_cast(expr.release_list_expr()); + EXPECT_THAT(expr.has_list_expr(), IsFalse()); + EXPECT_EQ(expr.list_expr(), ListExpr{}); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, Struct) { + Expr expr; + EXPECT_THAT(expr.has_struct_expr(), IsFalse()); + EXPECT_EQ(expr.struct_expr(), StructExpr{}); + auto& struct_expr = expr.mutable_struct_expr(); + EXPECT_THAT(expr.has_struct_expr(), IsTrue()); + EXPECT_NE(expr, Expr{}); + struct_expr.set_name("foo"); + EXPECT_NE(expr.struct_expr(), StructExpr{}); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kStructExpr); + static_cast(expr.release_struct_expr()); + EXPECT_THAT(expr.has_struct_expr(), IsFalse()); + EXPECT_EQ(expr.struct_expr(), StructExpr{}); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, Map) { + Expr expr; + EXPECT_THAT(expr.has_map_expr(), IsFalse()); + EXPECT_EQ(expr.map_expr(), MapExpr{}); + auto& map_expr = expr.mutable_map_expr(); + EXPECT_THAT(expr.has_map_expr(), IsTrue()); + EXPECT_NE(expr, Expr{}); + map_expr.mutable_entries().push_back(MakeMapExprEntry(1, Expr{}, Expr{})); + EXPECT_NE(expr.map_expr(), MapExpr{}); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kMapExpr); + static_cast(expr.release_map_expr()); + EXPECT_THAT(expr.has_map_expr(), IsFalse()); + EXPECT_EQ(expr.map_expr(), MapExpr{}); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, Comprehension) { + Expr expr; + EXPECT_THAT(expr.has_comprehension_expr(), IsFalse()); + EXPECT_EQ(expr.comprehension_expr(), ComprehensionExpr{}); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + EXPECT_THAT(expr.has_comprehension_expr(), IsTrue()); + EXPECT_NE(expr, Expr{}); + comprehension_expr.set_iter_var("foo"); + EXPECT_NE(expr.comprehension_expr(), ComprehensionExpr{}); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kComprehensionExpr); + static_cast(expr.release_comprehension_expr()); + EXPECT_THAT(expr.has_comprehension_expr(), IsFalse()); + EXPECT_EQ(expr.comprehension_expr(), ComprehensionExpr{}); + EXPECT_EQ(expr, Expr{}); +} + +TEST(Expr, CopyCtor) { + Expr expr; + expr.mutable_select_expr().set_field("foo"); + Expr& operand = expr.mutable_select_expr().mutable_operand(); + operand.mutable_ident_expr().set_name("bar"); + Expr expr_copy = expr; + EXPECT_EQ(expr_copy.select_expr().field(), "foo"); + EXPECT_EQ(expr_copy.select_expr().operand().ident_expr().name(), "bar"); +} + +TEST(Expr, CopyAssignChildReference) { + Expr expr; + expr.mutable_select_expr().set_field("foo"); + Expr& operand = expr.mutable_select_expr().mutable_operand(); + operand.mutable_call_expr().set_function("bar"); + auto& args = operand.mutable_call_expr().mutable_args(); + args.emplace_back().mutable_ident_expr().set_name("baz"); + args.emplace_back().mutable_ident_expr().set_name("qux"); + expr = expr.mutable_select_expr().mutable_operand(); + EXPECT_EQ(expr.call_expr().function(), "bar"); + EXPECT_EQ(expr.call_expr().args().size(), 2); + EXPECT_EQ(expr.call_expr().args()[0].ident_expr().name(), "baz"); + EXPECT_EQ(expr.call_expr().args()[1].ident_expr().name(), "qux"); +} + +TEST(Expr, FlattenedErase) { + Expr expr; + auto& list_expr = expr.mutable_list_expr(); + list_expr.mutable_elements() + .emplace_back() + .mutable_expr() + .mutable_ident_expr() + .set_name("foo"); + + list_expr.mutable_elements() + .emplace_back() + .mutable_expr() + .mutable_select_expr() + .mutable_operand() + .mutable_ident_expr() + .set_name("foo"); + + auto& call_expr = list_expr.mutable_elements() + .emplace_back() + .mutable_expr() + .mutable_call_expr(); + call_expr.set_function("foo"); + call_expr.mutable_target().mutable_ident_expr().set_name("bar"); + call_expr.mutable_args().emplace_back().mutable_ident_expr().set_name("baz"); + + auto& struct_expr = list_expr.mutable_elements() + .emplace_back() + .mutable_expr() + .mutable_struct_expr(); + struct_expr.set_name("foo"); + auto& field = struct_expr.mutable_fields().emplace_back(); + field.set_name("bar"); + field.mutable_value().mutable_ident_expr().set_name("baz"); + + auto& map_expr = list_expr.mutable_elements() + .emplace_back() + .mutable_expr() + .mutable_map_expr(); + auto& map_entry = map_expr.mutable_entries().emplace_back(); + map_entry.mutable_key().mutable_const_expr().set_string_value("foo"); + map_entry.mutable_value().mutable_ident_expr().set_name("bar"); + + auto& comprehension_expr = list_expr.mutable_elements() + .emplace_back() + .mutable_expr() + .mutable_comprehension_expr(); + comprehension_expr.set_iter_var("foo"); + comprehension_expr.set_accu_var("bar"); + comprehension_expr.set_iter_range(Expr{}); + comprehension_expr.set_accu_init(Expr{}); + comprehension_expr.set_loop_condition(Expr{}); + comprehension_expr.set_loop_step(Expr{}); + comprehension_expr.set_result(Expr{}); + + expr.FlattenedErase(); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kUnspecifiedExpr); +} + +Expr MakeNestedList(int size) { + Expr e; + Expr* node = &e; + e.set_id(1); + for (int i = 0; i < size; ++i) { + node = &node->mutable_list_expr() + .mutable_elements() + .emplace_back() + .mutable_expr(); + node->set_id(i + 2); + } + return e; +} + +TEST(Expr, FlattenedErase256k) { + // Large expr to ensure we're not recursing. Would likely hit stack limits + // with default destructor. + constexpr int size = 256 * 1024; + + Expr expr = MakeNestedList(size); + + expr.FlattenedErase(); + EXPECT_EQ(expr.kind_case(), ExprKindCase::kUnspecifiedExpr); +} + +TEST(Expr, Id) { + Expr expr; + EXPECT_THAT(expr.id(), Eq(0)); + expr.set_id(1); + EXPECT_THAT(expr.id(), Eq(1)); +} + +} // namespace +} // namespace cel diff --git a/common/format_type_name.cc b/common/format_type_name.cc new file mode 100644 index 000000000..4bd6c2e61 --- /dev/null +++ b/common/format_type_name.cc @@ -0,0 +1,180 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "common/format_type_name.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "common/type_kind.h" + +namespace cel { + +namespace { +struct FormatImplRecord { + Type type; + int offset; +}; + +// Parameterized types can be arbitrarily nested, so we use a vector as +// a stack to avoid overflow. Practically, we don't expect nesting +// to ever be very deep, but fuzzers and pathological inputs can easily +// trigger stack overflow with a recursive implementation. +void FormatImpl(const Type& cur, int offset, + std::vector& stack, std::string* out) { + switch (cur.kind()) { + case TypeKind::kDyn: + absl::StrAppend(out, "dyn"); + return; + case TypeKind::kAny: + absl::StrAppend(out, "any"); + return; + case TypeKind::kBool: + absl::StrAppend(out, "bool"); + return; + case TypeKind::kBoolWrapper: + absl::StrAppend(out, "wrapper(bool)"); + return; + case TypeKind::kBytes: + absl::StrAppend(out, "bytes"); + return; + case TypeKind::kBytesWrapper: + absl::StrAppend(out, "wrapper(bytes)"); + return; + case TypeKind::kDouble: + absl::StrAppend(out, "double"); + return; + case TypeKind::kDoubleWrapper: + absl::StrAppend(out, "wrapper(double)"); + return; + case TypeKind::kDuration: + absl::StrAppend(out, "google.protobuf.Duration"); + return; + case TypeKind::kEnum: + absl::StrAppend(out, "int"); + return; + case TypeKind::kInt: + absl::StrAppend(out, "int"); + return; + case TypeKind::kIntWrapper: + absl::StrAppend(out, "wrapper(int)"); + return; + case TypeKind::kList: + if (offset == 0) { + absl::StrAppend(out, "list("); + stack.push_back({cur, 1}); + stack.push_back({cur.AsList()->GetElement(), 0}); + } else { + absl::StrAppend(out, ")"); + } + return; + case TypeKind::kMap: + if (offset == 0) { + absl::StrAppend(out, "map("); + stack.push_back({cur, 1}); + stack.push_back({cur.AsMap()->GetKey(), 0}); + return; + } + if (offset == 1) { + absl::StrAppend(out, ", "); + stack.push_back({cur, 2}); + stack.push_back({cur.AsMap()->GetValue(), 0}); + return; + } + absl::StrAppend(out, ")"); + return; + case TypeKind::kNull: + absl::StrAppend(out, "null_type"); + return; + case TypeKind::kOpaque: { + OpaqueType opaque = *cur.AsOpaque(); + if (offset == 0) { + absl::StrAppend(out, cur.AsOpaque()->name()); + if (!opaque.GetParameters().empty()) { + absl::StrAppend(out, "("); + stack.push_back({cur, 1}); + stack.push_back({cur.AsOpaque()->GetParameters()[0], 0}); + } + return; + } + if (offset >= opaque.GetParameters().size()) { + absl::StrAppend(out, ")"); + return; + } + absl::StrAppend(out, ", "); + stack.push_back({cur, offset + 1}); + stack.push_back({cur.AsOpaque()->GetParameters()[offset], 0}); + return; + } + case TypeKind::kString: + absl::StrAppend(out, "string"); + return; + case TypeKind::kStringWrapper: + absl::StrAppend(out, "wrapper(string)"); + return; + case TypeKind::kStruct: + absl::StrAppend(out, cur.AsStruct()->name()); + return; + case TypeKind::kTimestamp: + absl::StrAppend(out, "google.protobuf.Timestamp"); + return; + case TypeKind::kType: { + TypeType type_type = *cur.AsType(); + if (offset == 0) { + absl::StrAppend(out, type_type.name()); + if (!type_type.GetParameters().empty()) { + absl::StrAppend(out, "("); + stack.push_back({cur, 1}); + stack.push_back({cur.AsType()->GetParameters()[0], 0}); + } + return; + } + absl::StrAppend(out, ")"); + return; + } + case TypeKind::kTypeParam: + absl::StrAppend(out, cur.AsTypeParam()->name()); + return; + case TypeKind::kUint: + absl::StrAppend(out, "uint"); + return; + case TypeKind::kUintWrapper: + absl::StrAppend(out, "wrapper(uint)"); + return; + case TypeKind::kUnknown: + absl::StrAppend(out, "*unknown*"); + return; + case TypeKind::kError: + case TypeKind::kFunction: + default: + absl::StrAppend(out, "*error*"); + return; + } +} +} // namespace + +std::string FormatTypeName(const Type& type) { + std::vector stack; + std::string out; + stack.push_back({type, 0}); + while (!stack.empty()) { + auto [type, offset] = stack.back(); + stack.pop_back(); + FormatImpl(type, offset, stack, &out); + } + return out; +} + +} // namespace cel diff --git a/common/format_type_name.h b/common/format_type_name.h new file mode 100644 index 000000000..723ac20fd --- /dev/null +++ b/common/format_type_name.h @@ -0,0 +1,30 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ + +#include + +#include "common/type.h" + +namespace cel { + +// Format the type name for presentation in error messages. Matches the +// formatting used in github.com/cel-spec. +std::string FormatTypeName(const Type& type); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ diff --git a/common/format_type_name_test.cc b/common/format_type_name_test.cc new file mode 100644 index 000000000..ca63f60b0 --- /dev/null +++ b/common/format_type_name_test.cc @@ -0,0 +1,118 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/format_type_name.h" + +#include "common/type.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::cel::expr::conformance::proto2::GlobalEnum_descriptor; +using ::cel::expr::conformance::proto2::TestAllTypes; +using ::testing::MatchesRegex; + +TEST(FormatTypeNameTest, PrimitiveTypes) { + EXPECT_EQ(FormatTypeName(IntType()), "int"); + EXPECT_EQ(FormatTypeName(UintType()), "uint"); + EXPECT_EQ(FormatTypeName(DoubleType()), "double"); + EXPECT_EQ(FormatTypeName(StringType()), "string"); + EXPECT_EQ(FormatTypeName(BytesType()), "bytes"); + EXPECT_EQ(FormatTypeName(BoolType()), "bool"); + EXPECT_EQ(FormatTypeName(NullType()), "null_type"); + EXPECT_EQ(FormatTypeName(DynType()), "dyn"); +} + +TEST(FormatTypeNameTest, SpecialTypes) { + EXPECT_EQ(FormatTypeName(ErrorType()), "*error*"); + EXPECT_EQ(FormatTypeName(UnknownType()), "*unknown*"); + EXPECT_EQ(FormatTypeName(FunctionType()), "*error*"); +} + +TEST(FormatTypeNameTest, WellKnownTypes) { + EXPECT_EQ(FormatTypeName(AnyType()), "any"); + EXPECT_EQ(FormatTypeName(DurationType()), "google.protobuf.Duration"); + EXPECT_EQ(FormatTypeName(TimestampType()), "google.protobuf.Timestamp"); +} + +TEST(FormatTypeNameTest, Wrappers) { + EXPECT_EQ(FormatTypeName(IntWrapperType()), "wrapper(int)"); + EXPECT_EQ(FormatTypeName(UintWrapperType()), "wrapper(uint)"); + EXPECT_EQ(FormatTypeName(DoubleWrapperType()), "wrapper(double)"); + EXPECT_EQ(FormatTypeName(StringWrapperType()), "wrapper(string)"); + EXPECT_EQ(FormatTypeName(BytesWrapperType()), "wrapper(bytes)"); + EXPECT_EQ(FormatTypeName(BoolWrapperType()), "wrapper(bool)"); +} + +TEST(FormatTypeNameTest, ProtobufTypes) { + EXPECT_EQ(FormatTypeName(MessageType(TestAllTypes::descriptor())), + "cel.expr.conformance.proto2.TestAllTypes"); + EXPECT_EQ(FormatTypeName(EnumType(GlobalEnum_descriptor())), "int"); +} + +TEST(FormatTypeNameTest, Type) { + google::protobuf::Arena arena; + EXPECT_EQ(FormatTypeName(TypeType()), "type"); + EXPECT_EQ(FormatTypeName(TypeType(&arena, IntType())), "type(int)"); + EXPECT_EQ(FormatTypeName(TypeType(&arena, TypeType(&arena, IntType()))), + "type(type(int))"); + EXPECT_EQ(FormatTypeName(TypeType(&arena, TypeParamType("T"))), "type(T)"); +} + +TEST(FormatTypeNameTest, List) { + google::protobuf::Arena arena; + EXPECT_EQ(FormatTypeName(ListType()), "list(dyn)"); + EXPECT_EQ(FormatTypeName(ListType(&arena, IntType())), "list(int)"); + EXPECT_EQ(FormatTypeName(ListType(&arena, ListType(&arena, IntType()))), + "list(list(int))"); +} + +TEST(FormatTypeNameTest, Map) { + google::protobuf::Arena arena; + EXPECT_EQ(FormatTypeName(MapType()), "map(dyn, dyn)"); + EXPECT_EQ(FormatTypeName(MapType(&arena, IntType(), IntType())), + "map(int, int)"); + EXPECT_EQ(FormatTypeName(MapType(&arena, IntType(), + MapType(&arena, IntType(), IntType()))), + "map(int, map(int, int))"); +} + +TEST(FormatTypeNameTest, Opaque) { + google::protobuf::Arena arena; + EXPECT_EQ(FormatTypeName(OpaqueType(&arena, "opaque", {})), "opaque"); + Type two_tuple_type = OpaqueType(&arena, "tuple", {IntType(), IntType()}); + Type three_tuple_type = OpaqueType( + &arena, "tuple", {two_tuple_type, two_tuple_type, two_tuple_type}); + EXPECT_EQ(FormatTypeName(three_tuple_type), + "tuple(tuple(int, int), tuple(int, int), tuple(int, int))"); +} + +#ifndef __APPLE__ +TEST(FormatTypeNameTest, ArbitraryNesting) { + google::protobuf::Arena arena; + Type type = IntType(); + for (int i = 0; i < 1000; ++i) { + type = OpaqueType(&arena, "ptype", {type}); + } + + EXPECT_THAT(FormatTypeName(type), + MatchesRegex(R"(^(ptype\(){1000}int(\)){1000})")); +} +#endif + +} // namespace +} // namespace cel diff --git a/common/function_descriptor.cc b/common/function_descriptor.cc new file mode 100644 index 000000000..be32e8616 --- /dev/null +++ b/common/function_descriptor.cc @@ -0,0 +1,98 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/function_descriptor.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/types/span.h" +#include "common/kind.h" + +namespace cel { + +bool FunctionDescriptor::ShapeMatches(bool receiver_style, + absl::Span types) const { + if (this->receiver_style() != receiver_style) { + return false; + } + + if (this->types().size() != types.size()) { + return false; + } + + for (size_t i = 0; i < this->types().size(); i++) { + Kind this_type = this->types()[i]; + Kind other_type = types[i]; + if (this_type != Kind::kAny && other_type != Kind::kAny && + this_type != other_type) { + return false; + } + } + return true; +} + +bool FunctionDescriptor::operator==(const FunctionDescriptor& other) const { + return impl_.get() == other.impl_.get() || + (name() == other.name() && + receiver_style() == other.receiver_style() && + types().size() == other.types().size() && + std::equal(types().begin(), types().end(), other.types().begin())); +} + +bool FunctionDescriptor::operator<(const FunctionDescriptor& other) const { + if (impl_.get() == other.impl_.get()) { + return false; + } + if (name() < other.name()) { + return true; + } + if (name() != other.name()) { + return false; + } + if (receiver_style() < other.receiver_style()) { + return true; + } + if (receiver_style() != other.receiver_style()) { + return false; + } + auto lhs_begin = types().begin(); + auto lhs_end = types().end(); + auto rhs_begin = other.types().begin(); + auto rhs_end = other.types().end(); + while (lhs_begin != lhs_end && rhs_begin != rhs_end) { + if (*lhs_begin < *rhs_begin) { + return true; + } + if (!(*lhs_begin == *rhs_begin)) { + return false; + } + lhs_begin++; + rhs_begin++; + } + if (lhs_begin == lhs_end && rhs_begin == rhs_end) { + // Neither has any elements left, they are equal. + return false; + } + if (lhs_begin == lhs_end) { + // Left has no more elements. Right is greater. + return true; + } + // Right has no more elements. Left is greater. + ABSL_ASSERT(rhs_begin == rhs_end); + return false; +} + +} // namespace cel diff --git a/common/function_descriptor.h b/common/function_descriptor.h new file mode 100644 index 000000000..75c61e13a --- /dev/null +++ b/common/function_descriptor.h @@ -0,0 +1,124 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_DESCRIPTOR_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_DESCRIPTOR_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/kind.h" + +namespace cel { + +struct FunctionDescriptorOptions { + // If true (strict, default), error or unknown arguments are propagated + // instead of calling the function. if false (non-strict), the function may + // receive error or unknown values as arguments. + bool is_strict = true; + + // Whether the function is impure or context-sensitive. + // + // Impure functions depend on state other than the arguments received during + // the CEL expression evaluation or have visible side effects. This breaks + // some of the assumptions of the CEL evaluation model. This flag is used as a + // hint to the planner that some optimizations are not safe or not effective. + bool is_contextual = false; +}; + +// Coarsely describes a function for the purpose of runtime resolution of +// overloads. +class FunctionDescriptor final { + public: + FunctionDescriptor(absl::string_view name, bool receiver_style, + std::vector types, bool is_strict) + : impl_(std::make_shared( + name, std::move(types), receiver_style, + FunctionDescriptorOptions{is_strict, + /*is_contextual=*/false})) {} + + FunctionDescriptor(absl::string_view name, bool receiver_style, + std::vector types, bool is_strict, + bool is_contextual) + : impl_(std::make_shared( + name, std::move(types), receiver_style, + FunctionDescriptorOptions{is_strict, is_contextual})) {} + + FunctionDescriptor(absl::string_view name, bool is_receiver_style, + std::vector types, + FunctionDescriptorOptions options = {}) + : impl_(std::make_shared(name, std::move(types), is_receiver_style, + options)) {} + + // Function name. + const std::string& name() const { return impl_->name; } + + // Whether function is receiver style i.e. true means arg0.name(args[1:]...). + bool receiver_style() const { return impl_->is_receiver_style; } + + // The argument types the function accepts. + // + // TODO(uncreated-issue/17): make this kinds + const std::vector& types() const { return impl_->types; } + + // if true (strict, default), error or unknown arguments are propagated + // instead of calling the function. if false (non-strict), the function may + // receive error or unknown values as arguments. + bool is_strict() const { return impl_->options.is_strict; } + + // Whether the function is contextual (impure). + // + // Contextual functions depend on state other than the arguments received in + // the CEL expression evaluation or have visible side effects. This breaks + // some of the assumptions of CEL. This flag is used as a hint to the planner + // that some optimizations are not safe or not effective. + bool is_contextual() const { return impl_->options.is_contextual; } + + // Helper for matching a descriptor. This tests that the shape is the same -- + // |other| accepts the same number and types of arguments and is the same call + // style). + bool ShapeMatches(const FunctionDescriptor& other) const { + return ShapeMatches(other.receiver_style(), other.types()); + } + bool ShapeMatches(bool receiver_style, absl::Span types) const; + + bool operator==(const FunctionDescriptor& other) const; + + bool operator<(const FunctionDescriptor& other) const; + + private: + struct Impl final { + Impl(absl::string_view name, std::vector types, + bool is_receiver_style, FunctionDescriptorOptions options) + : name(name), + types(std::move(types)), + is_receiver_style(is_receiver_style), + options(options) {} + + std::string name; + std::vector types; + bool is_receiver_style; + FunctionDescriptorOptions options; + }; + + std::shared_ptr impl_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_DESCRIPTOR_H_ diff --git a/common/id.h b/common/id.h deleted file mode 100644 index 9f12aa2aa..000000000 --- a/common/id.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_ID_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_ID_H_ - -#include "internal/cel_printer.h" -#include "internal/handle.h" -#include "internal/hash_util.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -// A expression, statement, or variable id. -class Id : public internal::Handle { - public: - constexpr explicit Id(int32_t value) : Handle(value) {} - - inline std::string ToDebugString() const { - return internal::ToCallString("Id", value_); - } -}; - -inline std::ostream& operator<<(std::ostream& os, Id arg) { - return os << arg.ToDebugString(); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google - -namespace std { - -template <> -struct hash - : google::api::expr::common::Id::Hasher {}; - -} // namespace std - -#endif // THIRD_PARTY_CEL_CPP_COMMON_ID_H_ diff --git a/common/internal/BUILD b/common/internal/BUILD new file mode 100644 index 000000000..3be350754 --- /dev/null +++ b/common/internal/BUILD @@ -0,0 +1,137 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "casting", + hdrs = ["casting.h"], + deps = [ + "//internal:casts", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "reference_count", + srcs = ["reference_count.cc"], + hdrs = ["reference_count.h"], + deps = [ + "//common:data", + "//internal:new", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "reference_count_test", + srcs = ["reference_count_test.cc"], + deps = [ + ":reference_count", + "//common:data", + "//internal:testing", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_library( + name = "metadata", + hdrs = ["metadata.h"], + deps = ["@com_google_protobuf//:protobuf"], +) + +cc_library( + name = "byte_string", + srcs = ["byte_string.cc"], + hdrs = ["byte_string.h"], + deps = [ + ":metadata", + ":reference_count", + "//common:allocator", + "//common:arena", + "//common:memory", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "byte_string_test", + srcs = ["byte_string_test.cc"], + deps = [ + ":byte_string", + ":reference_count", + "//common:allocator", + "//common:memory", + "//internal:testing", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:cord_test_helpers", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "value_conversion", + srcs = ["value_conversion.cc"], + hdrs = ["value_conversion.h"], + deps = [ + "//common:any", + "//common:value", + "//common:value_kind", + "//extensions/protobuf:value", + "//internal:proto_time_encoding", + "//internal:status_macros", + "//internal:time", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//src/google/protobuf/io", + ], +) diff --git a/common/internal/byte_string.cc b/common/internal/byte_string.cc new file mode 100644 index 000000000..304104a87 --- /dev/null +++ b/common/internal/byte_string.cc @@ -0,0 +1,1074 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/byte_string.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/hash/hash.h" +#include "absl/log/absl_check.h" +#include "absl/strings/cord.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/internal/metadata.h" +#include "common/internal/reference_count.h" +#include "common/memory.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +namespace { + +char* CopyCordToArray(const absl::Cord& cord, char* data) { + for (auto chunk : cord.Chunks()) { + std::memcpy(data, chunk.data(), chunk.size()); + data += chunk.size(); + } + return data; +} + +template +T ConsumeAndDestroy(T& object) { + T consumed = std::move(object); + object.~T(); // NOLINT(bugprone-use-after-move) + return consumed; +} + +} // namespace + +ByteString ByteString::Concat(const ByteString& lhs, const ByteString& rhs, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + + if (lhs.empty()) { + return rhs; + } + if (rhs.empty()) { + return lhs; + } + + if (lhs.GetKind() == ByteStringKind::kLarge || + rhs.GetKind() == ByteStringKind::kLarge) { + // If either the left or right are absl::Cord, use absl::Cord. + absl::Cord result; + result.Append(lhs.ToCord()); + result.Append(rhs.ToCord()); + return ByteString(std::move(result)); + } + + const size_t lhs_size = lhs.size(); + const size_t rhs_size = rhs.size(); + const size_t result_size = lhs_size + rhs_size; + ByteString result; + if (result_size <= kSmallByteStringCapacity) { + // If the resulting string fits in inline storage, do it. + result.rep_.small.size = result_size; + result.rep_.small.arena = arena; + lhs.CopyToArray(result.rep_.small.data); + rhs.CopyToArray(result.rep_.small.data + lhs_size); + } else { + // Otherwise allocate on the arena. + char* result_data = + reinterpret_cast(arena->AllocateAligned(result_size)); + lhs.CopyToArray(result_data); + rhs.CopyToArray(result_data + lhs_size); + result.rep_.medium.data = result_data; + result.rep_.medium.size = result_size; + result.rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; + result.rep_.header.kind = ByteStringKind::kMedium; + } + return result; +} + +ByteString::ByteString(Allocator<> allocator, absl::string_view string) { + ABSL_DCHECK_LE(string.size(), max_size()); + auto* arena = allocator.arena(); + if (string.size() <= kSmallByteStringCapacity) { + SetSmall(arena, string); + } else { + SetMedium(arena, string); + } +} + +ByteString::ByteString(Allocator<> allocator, const std::string& string) { + ABSL_DCHECK_LE(string.size(), max_size()); + auto* arena = allocator.arena(); + if (string.size() <= kSmallByteStringCapacity) { + SetSmall(arena, string); + } else { + SetMedium(arena, string); + } +} + +ByteString::ByteString(Allocator<> allocator, std::string&& string) { + ABSL_DCHECK_LE(string.size(), max_size()); + auto* arena = allocator.arena(); + if (string.size() <= kSmallByteStringCapacity) { + SetSmall(arena, string); + } else { + SetMedium(arena, std::move(string)); + } +} + +ByteString::ByteString(Allocator<> allocator, const absl::Cord& cord) { + ABSL_DCHECK_LE(cord.size(), max_size()); + auto* arena = allocator.arena(); + if (cord.size() <= kSmallByteStringCapacity) { + SetSmall(arena, cord); + } else if (arena != nullptr) { + SetMedium(arena, cord); + } else { + SetLarge(cord); + } +} + +ByteString ByteString::Borrowed(Borrower borrower, absl::string_view string) { + ABSL_DCHECK(borrower != Borrower::None()) << "Borrowing from Owner::None()"; + auto* arena = borrower.arena(); + if (string.size() <= kSmallByteStringCapacity || arena != nullptr) { + return ByteString(arena, string); + } + const auto* refcount = BorrowerRelease(borrower); + // A nullptr refcount indicates somebody called us to borrow something that + // has no owner. If this is the case, we fallback to assuming operator + // new/delete and convert it to a reference count. + if (refcount == nullptr) { + std::tie(refcount, string) = MakeReferenceCountedString(string); + } else { + StrongRef(*refcount); + } + return ByteString(refcount, string); +} + +ByteString ByteString::Borrowed(Borrower borrower, const absl::Cord& cord) { + ABSL_DCHECK(borrower != Borrower::None()) << "Borrowing from Owner::None()"; + return ByteString(borrower.arena(), cord); +} + +ByteString::ByteString(const ReferenceCount* absl_nonnull refcount, + absl::string_view string) { + ABSL_DCHECK_LE(string.size(), max_size()); + SetMedium(string, reinterpret_cast(refcount) | + kMetadataOwnerReferenceCountBit); +} + +ByteString::ByteString(ByteString::ExternalStringTag, + absl::string_view string) { + if (string.size() <= kSmallByteStringCapacity) { + SetSmall(nullptr, string); + } else { + SetExternalMedium(string); + } +} + +ByteString ByteString::FromExternal(absl::string_view string) { + return ByteString(ExternalStringTag{}, string); +} + +google::protobuf::Arena* absl_nullable ByteString::GetArena() const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmallArena(); + case ByteStringKind::kMedium: + return GetMediumArena(); + case ByteStringKind::kLarge: + return nullptr; + } +} + +bool ByteString::empty() const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return rep_.small.size == 0; + case ByteStringKind::kMedium: + return rep_.medium.size == 0; + case ByteStringKind::kLarge: + return GetLarge().empty(); + } +} + +size_t ByteString::size() const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return rep_.small.size; + case ByteStringKind::kMedium: + return rep_.medium.size; + case ByteStringKind::kLarge: + return GetLarge().size(); + } +} + +absl::string_view ByteString::Flatten() { + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmall(); + case ByteStringKind::kMedium: + return GetMedium(); + case ByteStringKind::kLarge: + return GetLarge().Flatten(); + } +} + +absl::optional ByteString::TryFlat() const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmall(); + case ByteStringKind::kMedium: + return GetMedium(); + case ByteStringKind::kLarge: + return GetLarge().TryFlat(); + } +} + +bool ByteString::Equals(absl::string_view rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> bool { return lhs == rhs; }, + [&rhs](const absl::Cord& lhs) -> bool { return lhs == rhs; })); +} + +bool ByteString::Equals(const absl::Cord& rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> bool { return lhs == rhs; }, + [&rhs](const absl::Cord& lhs) -> bool { return lhs == rhs; })); +} + +int ByteString::Compare(absl::string_view rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> int { return lhs.compare(rhs); }, + [&rhs](const absl::Cord& lhs) -> int { return lhs.Compare(rhs); })); +} + +int ByteString::Compare(const absl::Cord& rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> int { return -rhs.Compare(lhs); }, + [&rhs](const absl::Cord& lhs) -> int { return lhs.Compare(rhs); })); +} + +bool ByteString::StartsWith(absl::string_view rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> bool { + return absl::StartsWith(lhs, rhs); + }, + [&rhs](const absl::Cord& lhs) -> bool { return lhs.StartsWith(rhs); })); +} + +bool ByteString::StartsWith(const absl::Cord& rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> bool { + return lhs.size() >= rhs.size() && lhs.substr(0, rhs.size()) == rhs; + }, + [&rhs](const absl::Cord& lhs) -> bool { return lhs.StartsWith(rhs); })); +} + +bool ByteString::EndsWith(absl::string_view rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> bool { + return absl::EndsWith(lhs, rhs); + }, + [&rhs](const absl::Cord& lhs) -> bool { return lhs.EndsWith(rhs); })); +} + +bool ByteString::EndsWith(const absl::Cord& rhs) const { + return Visit(absl::Overload( + [&rhs](absl::string_view lhs) -> bool { + return lhs.size() >= rhs.size() && + lhs.substr(lhs.size() - rhs.size()) == rhs; + }, + [&rhs](const absl::Cord& lhs) -> bool { return lhs.EndsWith(rhs); })); +} + +absl::optional ByteString::Find(absl::string_view needle, + size_t pos) const { + ABSL_DCHECK_LE(pos, size()); + + return Visit(absl::Overload( + [&needle, pos](absl::string_view lhs) -> absl::optional { + absl::string_view::size_type i = lhs.find(needle, pos); + if (i == absl::string_view::npos) { + return absl::nullopt; + } + return i; + }, + [&needle, pos](const absl::Cord& lhs) -> absl::optional { + absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos); + absl::Cord::CharIterator it = cord.Find(needle); + if (it == cord.char_end()) { + return absl::nullopt; + } + return pos + + static_cast(absl::Cord::Distance(cord.char_begin(), it)); + })); +} + +absl::optional ByteString::Find(const absl::Cord& needle, + size_t pos) const { + ABSL_DCHECK_LE(pos, size()); + + return Visit(absl::Overload( + [&needle, pos](absl::string_view lhs) -> absl::optional { + if (auto flat_needle = needle.TryFlat(); flat_needle) { + absl::string_view::size_type i = lhs.find(*flat_needle, pos); + if (i == absl::string_view::npos) { + return absl::nullopt; + } + return i; + } + // Needle is fragmented, we have to do a linear scan. + const size_t needle_size = needle.size(); + if (pos + needle_size > lhs.size()) { + return absl::nullopt; + } + if (ABSL_PREDICT_FALSE(needle_size == 0)) { + return pos; + } + // Optimization: find the first chunk of the needle, then compare the + // rest. If the first chunk is empty, `lhs.find` will return + // `current_pos`, which correctly degrades to a linear scan. + absl::string_view first_chunk = *needle.Chunks().begin(); + absl::Cord rest_of_needle = needle.Subcord( + first_chunk.size(), needle_size - first_chunk.size()); + size_t current_pos = pos; + while (true) { + size_t found_pos = lhs.find(first_chunk, current_pos); + if (found_pos == absl::string_view::npos || + found_pos > lhs.size() - needle_size) { + return absl::nullopt; + } + if (lhs.substr(found_pos + first_chunk.size(), + rest_of_needle.size()) == rest_of_needle) { + return found_pos; + } + current_pos = found_pos + 1; + } + }, + [&needle, pos](const absl::Cord& lhs) -> absl::optional { + absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos); + absl::Cord::CharIterator it = cord.Find(needle); + if (it == cord.char_end()) { + return absl::nullopt; + } + return pos + + static_cast(absl::Cord::Distance(cord.char_begin(), it)); + })); +} + +ByteString ByteString::Substring(size_t pos, size_t npos) const { + ABSL_DCHECK_LE(npos, size()); + ABSL_DCHECK_LE(pos, npos); + + switch (GetKind()) { + case ByteStringKind::kSmall: { + ByteString result; + result.rep_.header.kind = ByteStringKind::kSmall; + result.rep_.small.size = npos - pos; + std::memcpy(result.rep_.small.data, rep_.small.data + pos, + result.rep_.small.size); + result.rep_.small.arena = GetSmallArena(); + return result; + } + case ByteStringKind::kMedium: { + ByteString result(*this); + result.rep_.medium.data += pos; + result.rep_.medium.size = npos - pos; + return result; + } + case ByteStringKind::kLarge: + return ByteString(GetLarge().Subcord(pos, npos - pos)); + } +} + +void ByteString::RemovePrefix(size_t n) { + ABSL_DCHECK_LE(n, size()); + if (n == 0) { + return; + } + switch (GetKind()) { + case ByteStringKind::kSmall: + std::memmove(rep_.small.data, rep_.small.data + n, rep_.small.size - n); + rep_.small.size -= n; + break; + case ByteStringKind::kMedium: + rep_.medium.data += n; + rep_.medium.size -= n; + if (rep_.medium.size <= kSmallByteStringCapacity) { + const auto* refcount = GetMediumReferenceCount(); + SetSmall(GetMediumArena(), GetMedium()); + StrongUnref(refcount); + } + break; + case ByteStringKind::kLarge: { + auto& large = GetLarge(); + const auto large_size = large.size(); + const auto new_large_pos = n; + const auto new_large_size = large_size - n; + large = large.Subcord(new_large_pos, new_large_size); + if (new_large_size <= kSmallByteStringCapacity) { + auto large_copy = std::move(large); + DestroyLarge(); + SetSmall(nullptr, large_copy); + } + } break; + } +} + +void ByteString::RemoveSuffix(size_t n) { + ABSL_DCHECK_LE(n, size()); + if (n == 0) { + return; + } + switch (GetKind()) { + case ByteStringKind::kSmall: + rep_.small.size -= n; + break; + case ByteStringKind::kMedium: + rep_.medium.size -= n; + if (rep_.medium.size <= kSmallByteStringCapacity) { + const auto* refcount = GetMediumReferenceCount(); + SetSmall(GetMediumArena(), GetMedium()); + StrongUnref(refcount); + } + break; + case ByteStringKind::kLarge: { + auto& large = GetLarge(); + const auto large_size = large.size(); + const auto new_large_pos = 0; + const auto new_large_size = large_size - n; + large = large.Subcord(new_large_pos, new_large_size); + if (new_large_size <= kSmallByteStringCapacity) { + auto large_copy = std::move(large); + DestroyLarge(); + SetSmall(nullptr, large_copy); + } + } break; + } +} + +void ByteString::CopyToArray(char* absl_nonnull out) const { + ABSL_DCHECK(out != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: { + absl::string_view small = GetSmall(); + std::memcpy(out, small.data(), small.size()); + } break; + case ByteStringKind::kMedium: { + absl::string_view medium = GetMedium(); + std::memcpy(out, medium.data(), medium.size()); + } break; + case ByteStringKind::kLarge: { + const absl::Cord& large = GetLarge(); + (CopyCordToArray)(large, out); + } break; + } +} + +std::string ByteString::ToString() const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return std::string(GetSmall()); + case ByteStringKind::kMedium: + return std::string(GetMedium()); + case ByteStringKind::kLarge: + return static_cast(GetLarge()); + } +} + +void ByteString::CopyToString(std::string* absl_nonnull out) const { + ABSL_DCHECK(out != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: + out->assign(GetSmall()); + break; + case ByteStringKind::kMedium: + out->assign(GetMedium()); + break; + case ByteStringKind::kLarge: + absl::CopyCordToString(GetLarge(), out); + break; + } +} + +void ByteString::AppendToString(std::string* absl_nonnull out) const { + ABSL_DCHECK(out != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: + out->append(GetSmall()); + break; + case ByteStringKind::kMedium: + out->append(GetMedium()); + break; + case ByteStringKind::kLarge: + absl::AppendCordToString(GetLarge(), out); + break; + } +} + +namespace { + +struct ReferenceCountReleaser { + const ReferenceCount* absl_nonnull refcount; + + void operator()() const { StrongUnref(*refcount); } +}; + +} // namespace + +absl::Cord ByteString::ToCord() const& { + switch (GetKind()) { + case ByteStringKind::kSmall: + return absl::Cord(GetSmall()); + case ByteStringKind::kMedium: { + const auto* refcount = GetMediumReferenceCount(); + if (refcount != nullptr) { + StrongRef(*refcount); + return absl::MakeCordFromExternal(GetMedium(), + ReferenceCountReleaser{refcount}); + } + return absl::Cord(GetMedium()); + } + case ByteStringKind::kLarge: + return GetLarge(); + } +} + +absl::Cord ByteString::ToCord() && { + switch (GetKind()) { + case ByteStringKind::kSmall: + return absl::Cord(GetSmall()); + case ByteStringKind::kMedium: { + const auto* refcount = GetMediumReferenceCount(); + if (refcount != nullptr) { + auto medium = GetMedium(); + SetSmallEmpty(nullptr); + return absl::MakeCordFromExternal(medium, + ReferenceCountReleaser{refcount}); + } + return absl::Cord(GetMedium()); + } + case ByteStringKind::kLarge: + return GetLarge(); + } +} + +void ByteString::CopyToCord(absl::Cord* absl_nonnull out) const { + ABSL_DCHECK(out != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: + *out = absl::Cord(GetSmall()); + break; + case ByteStringKind::kMedium: { + const auto* refcount = GetMediumReferenceCount(); + if (refcount != nullptr) { + StrongRef(*refcount); + *out = absl::MakeCordFromExternal(GetMedium(), + ReferenceCountReleaser{refcount}); + } else { + *out = absl::Cord(GetMedium()); + } + } break; + case ByteStringKind::kLarge: + *out = GetLarge(); + break; + } +} + +void ByteString::AppendToCord(absl::Cord* absl_nonnull out) const { + ABSL_DCHECK(out != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: + out->Append(GetSmall()); + break; + case ByteStringKind::kMedium: { + const auto* refcount = GetMediumReferenceCount(); + if (refcount != nullptr) { + StrongRef(*refcount); + out->Append(absl::MakeCordFromExternal( + GetMedium(), ReferenceCountReleaser{refcount})); + } else { + out->Append(GetMedium()); + } + } break; + case ByteStringKind::kLarge: + out->Append(GetLarge()); + break; + } +} + +absl::string_view ByteString::ToStringView( + std::string* absl_nonnull scratch) const { + ABSL_DCHECK(scratch != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmall(); + case ByteStringKind::kMedium: + return GetMedium(); + case ByteStringKind::kLarge: + if (auto flat = GetLarge().TryFlat(); flat) { + return *flat; + } + absl::CopyCordToString(GetLarge(), scratch); + return absl::string_view(*scratch); + } +} + +absl::string_view ByteString::AsStringView() const { + const ByteStringKind kind = GetKind(); + ABSL_CHECK(kind == ByteStringKind::kSmall || // Crash OK + kind == ByteStringKind::kMedium); + switch (kind) { + case ByteStringKind::kSmall: + return GetSmall(); + case ByteStringKind::kMedium: + return GetMedium(); + case ByteStringKind::kLarge: + ABSL_UNREACHABLE(); + } +} + +google::protobuf::Arena* absl_nullable ByteString::GetMediumArena( + const MediumByteStringRep& rep) { + if ((rep.owner & kMetadataOwnerBits) == kMetadataOwnerArenaBit) { + return reinterpret_cast(rep.owner & + kMetadataOwnerPointerMask); + } + return nullptr; +} + +const ReferenceCount* absl_nullable ByteString::GetMediumReferenceCount( + const MediumByteStringRep& rep) { + if ((rep.owner & kMetadataOwnerBits) == kMetadataOwnerReferenceCountBit) { + return reinterpret_cast(rep.owner & + kMetadataOwnerPointerMask); + } + return nullptr; +} + +void ByteString::Construct(const ByteString& other, + absl::optional> allocator) { + switch (other.GetKind()) { + case ByteStringKind::kSmall: + rep_.small = other.rep_.small; + if (allocator.has_value()) { + rep_.small.arena = allocator->arena(); + } + break; + case ByteStringKind::kMedium: + if (allocator.has_value() && + allocator->arena() != other.GetMediumArena()) { + SetMedium(allocator->arena(), other.GetMedium()); + } else { + rep_.medium = other.rep_.medium; + StrongRef(GetMediumReferenceCount()); + } + break; + case ByteStringKind::kLarge: + if (allocator.has_value() && allocator->arena() != nullptr) { + SetMedium(allocator->arena(), other.GetLarge()); + } else { + SetLarge(other.GetLarge()); + } + break; + } +} + +void ByteString::Construct(ByteString& other, + absl::optional> allocator) { + switch (other.GetKind()) { + case ByteStringKind::kSmall: + rep_.small = other.rep_.small; + if (allocator.has_value()) { + rep_.small.arena = allocator->arena(); + } + break; + case ByteStringKind::kMedium: + if (allocator.has_value() && + allocator->arena() != other.GetMediumArena()) { + SetMedium(allocator->arena(), other.GetMedium()); + } else { + rep_.medium = other.rep_.medium; + other.rep_.medium.owner = 0; + } + break; + case ByteStringKind::kLarge: + if (allocator.has_value() && allocator->arena() != nullptr) { + SetMedium(allocator->arena(), other.GetLarge()); + } else { + SetLarge(std::move(other.GetLarge())); + } + break; + } +} + +void ByteString::CopyFrom(const ByteString& other) { + ABSL_DCHECK_NE(&other, this); + + switch (other.GetKind()) { + case ByteStringKind::kSmall: + switch (GetKind()) { + case ByteStringKind::kSmall: + break; + case ByteStringKind::kMedium: + DestroyMedium(); + break; + case ByteStringKind::kLarge: + DestroyLarge(); + break; + } + rep_.small = other.rep_.small; + break; + case ByteStringKind::kMedium: + switch (GetKind()) { + case ByteStringKind::kSmall: + rep_.medium = other.rep_.medium; + StrongRef(GetMediumReferenceCount()); + break; + case ByteStringKind::kMedium: + StrongRef(other.GetMediumReferenceCount()); + DestroyMedium(); + rep_.medium = other.rep_.medium; + break; + case ByteStringKind::kLarge: + DestroyLarge(); + rep_.medium = other.rep_.medium; + StrongRef(GetMediumReferenceCount()); + break; + } + break; + case ByteStringKind::kLarge: + switch (GetKind()) { + case ByteStringKind::kSmall: + SetLarge(other.GetLarge()); + break; + case ByteStringKind::kMedium: + DestroyMedium(); + SetLarge(other.GetLarge()); + break; + case ByteStringKind::kLarge: + GetLarge() = other.GetLarge(); + break; + } + break; + } +} + +void ByteString::MoveFrom(ByteString& other) { + ABSL_DCHECK_NE(&other, this); + + switch (other.GetKind()) { + case ByteStringKind::kSmall: + switch (GetKind()) { + case ByteStringKind::kSmall: + break; + case ByteStringKind::kMedium: + DestroyMedium(); + break; + case ByteStringKind::kLarge: + DestroyLarge(); + break; + } + rep_.small = other.rep_.small; + break; + case ByteStringKind::kMedium: + switch (GetKind()) { + case ByteStringKind::kSmall: + rep_.medium = other.rep_.medium; + break; + case ByteStringKind::kMedium: + DestroyMedium(); + rep_.medium = other.rep_.medium; + break; + case ByteStringKind::kLarge: + DestroyLarge(); + rep_.medium = other.rep_.medium; + break; + } + other.rep_.medium.owner = 0; + break; + case ByteStringKind::kLarge: + switch (GetKind()) { + case ByteStringKind::kSmall: + SetLarge(std::move(other.GetLarge())); + break; + case ByteStringKind::kMedium: + DestroyMedium(); + SetLarge(std::move(other.GetLarge())); + break; + case ByteStringKind::kLarge: + GetLarge() = std::move(other.GetLarge()); + break; + } + break; + } +} + +ByteString ByteString::Clone(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: + return ByteString(arena, GetSmall()); + case ByteStringKind::kMedium: { + google::protobuf::Arena* absl_nullable other_arena = GetMediumArena(); + if (arena != nullptr) { + if (arena == other_arena) { + return *this; + } + return ByteString(arena, GetMedium()); + } + if (other_arena != nullptr) { + return ByteString(arena, GetMedium()); + } + return *this; + } + case ByteStringKind::kLarge: + return ByteString(arena, GetLarge()); + } +} + +void ByteString::HashValue(absl::HashState state) const { + switch (GetKind()) { + case ByteStringKind::kSmall: + absl::HashState::combine(std::move(state), GetSmall()); + break; + case ByteStringKind::kMedium: + absl::HashState::combine(std::move(state), GetMedium()); + break; + case ByteStringKind::kLarge: + absl::HashState::combine(std::move(state), GetLarge()); + break; + } +} + +void ByteString::Swap(ByteString& other) { + ABSL_DCHECK_NE(&other, this); + using std::swap; + + switch (other.GetKind()) { + case ByteStringKind::kSmall: + switch (GetKind()) { + case ByteStringKind::kSmall: + // small <=> small + swap(rep_.small, other.rep_.small); + break; + case ByteStringKind::kMedium: + // medium <=> small + swap(rep_, other.rep_); + break; + case ByteStringKind::kLarge: { + absl::Cord cord = std::move(GetLarge()); + DestroyLarge(); + rep_ = other.rep_; + other.SetLarge(std::move(cord)); + } break; + } + break; + case ByteStringKind::kMedium: + switch (GetKind()) { + case ByteStringKind::kSmall: + swap(rep_, other.rep_); + break; + case ByteStringKind::kMedium: + swap(rep_.medium, other.rep_.medium); + break; + case ByteStringKind::kLarge: { + absl::Cord cord = std::move(GetLarge()); + DestroyLarge(); + rep_ = other.rep_; + other.SetLarge(std::move(cord)); + } break; + } + break; + case ByteStringKind::kLarge: + switch (GetKind()) { + case ByteStringKind::kSmall: { + absl::Cord cord = std::move(other.GetLarge()); + other.DestroyLarge(); + other.rep_.small = rep_.small; + SetLarge(std::move(cord)); + } break; + case ByteStringKind::kMedium: { + absl::Cord cord = std::move(other.GetLarge()); + other.DestroyLarge(); + other.rep_.medium = rep_.medium; + SetLarge(std::move(cord)); + } break; + case ByteStringKind::kLarge: + swap(GetLarge(), other.GetLarge()); + break; + } + break; + } +} + +void ByteString::Destroy() { + switch (GetKind()) { + case ByteStringKind::kSmall: + break; + case ByteStringKind::kMedium: + DestroyMedium(); + break; + case ByteStringKind::kLarge: + DestroyLarge(); + break; + } +} + +void ByteString::SetSmall(google::protobuf::Arena* absl_nullable arena, + absl::string_view string) { + ABSL_DCHECK_LE(string.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kSmall; + rep_.small.size = string.size(); + rep_.small.arena = arena; + std::memcpy(rep_.small.data, string.data(), rep_.small.size); +} + +void ByteString::SetSmall(google::protobuf::Arena* absl_nullable arena, + const absl::Cord& cord) { + ABSL_DCHECK_LE(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kSmall; + rep_.small.size = cord.size(); + rep_.small.arena = arena; + (CopyCordToArray)(cord, rep_.small.data); +} + +void ByteString::SetMedium(google::protobuf::Arena* absl_nullable arena, + absl::string_view string) { + ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = string.size(); + if (arena != nullptr) { + char* data = static_cast( + arena->AllocateAligned(rep_.medium.size, alignof(char))); + std::memcpy(data, string.data(), rep_.medium.size); + rep_.medium.data = data; + rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; + } else { + auto pair = MakeReferenceCountedString(string); + rep_.medium.data = pair.second.data(); + rep_.medium.owner = reinterpret_cast(pair.first) | + kMetadataOwnerReferenceCountBit; + } +} + +void ByteString::SetExternalMedium(absl::string_view string) { + ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = string.size(); + rep_.medium.data = string.data(); + rep_.medium.owner = 0; +} + +void ByteString::SetMedium(google::protobuf::Arena* absl_nullable arena, + std::string&& string) { + ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = string.size(); + if (arena != nullptr) { + auto* data = google::protobuf::Arena::Create(arena, std::move(string)); + rep_.medium.data = data->data(); + rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; + } else { + auto pair = MakeReferenceCountedString(std::move(string)); + rep_.medium.data = pair.second.data(); + rep_.medium.owner = reinterpret_cast(pair.first) | + kMetadataOwnerReferenceCountBit; + } +} + +void ByteString::SetMedium(google::protobuf::Arena* absl_nonnull arena, + const absl::Cord& cord) { + ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = cord.size(); + char* data = static_cast( + arena->AllocateAligned(rep_.medium.size, alignof(char))); + (CopyCordToArray)(cord, data); + rep_.medium.data = data; + rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; +} + +void ByteString::SetMedium(absl::string_view string, uintptr_t owner) { + ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); + ABSL_DCHECK_NE(owner, 0); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = string.size(); + rep_.medium.data = string.data(); + rep_.medium.owner = owner; +} + +void ByteString::SetLarge(const absl::Cord& cord) { + ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kLarge; + ::new (static_cast(&rep_.large.data[0])) absl::Cord(cord); +} + +void ByteString::SetLarge(absl::Cord&& cord) { + ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kLarge; + ::new (static_cast(&rep_.large.data[0])) absl::Cord(std::move(cord)); +} + +absl::string_view LegacyByteString(const ByteString& string, bool stable, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + if (string.empty()) { + return absl::string_view(); + } + const ByteStringKind kind = string.GetKind(); + if (kind == ByteStringKind::kMedium && string.GetMediumArena() == arena) { + google::protobuf::Arena* absl_nullable other_arena = string.GetMediumArena(); + if (other_arena == arena || other_arena == nullptr) { + // Legacy values do not preserve arena. For speed, we assume the arena is + // compatible. + return string.GetMedium(); + } + } + if (stable && kind == ByteStringKind::kSmall) { + return string.GetSmall(); + } + std::string* absl_nonnull result = google::protobuf::Arena::Create(arena); + switch (kind) { + case ByteStringKind::kSmall: + result->assign(string.GetSmall()); + break; + case ByteStringKind::kMedium: + result->assign(string.GetMedium()); + break; + case ByteStringKind::kLarge: + absl::CopyCordToString(string.GetLarge(), result); + break; + } + return absl::string_view(*result); +} + +} // namespace cel::common_internal diff --git a/common/internal/byte_string.h b/common/internal/byte_string.h new file mode 100644 index 000000000..c576e5634 --- /dev/null +++ b/common/internal/byte_string.h @@ -0,0 +1,688 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/hash/hash.h" +#include "absl/log/absl_check.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/arena.h" +#include "common/internal/reference_count.h" +#include "common/memory.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class BytesValueInputStream; +class BytesValueOutputStream; +class StringValue; + +namespace common_internal { + +// absl::Cord is trivially relocatable IFF we are not using ASan or MSan. When +// using ASan or MSan absl::Cord will poison/unpoison its inline storage. +#if defined(ABSL_HAVE_ADDRESS_SANITIZER) || defined(ABSL_HAVE_MEMORY_SANITIZER) +#define CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI +#else +#define CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ABSL_ATTRIBUTE_TRIVIAL_ABI +#endif + +class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI [[nodiscard]] ByteString; + +struct ByteStringTestFriend; + +enum class ByteStringKind : unsigned int { + kSmall = 0, + kMedium, + kLarge, +}; + +inline std::ostream& operator<<(std::ostream& out, ByteStringKind kind) { + switch (kind) { + case ByteStringKind::kSmall: + return out << "SMALL"; + case ByteStringKind::kMedium: + return out << "MEDIUM"; + case ByteStringKind::kLarge: + return out << "LARGE"; + } +} + +// Representation of small strings in ByteString, which are stored in place. +struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI SmallByteStringRep final { +#ifdef _MSC_VER +#pragma pack(push, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + std::uint8_t kind : 2; + std::uint8_t size : 6; + }; +#ifdef _MSC_VER +#pragma pack(pop) +#endif + char data[23 - sizeof(google::protobuf::Arena*)]; + google::protobuf::Arena* absl_nullable arena; +}; + +inline constexpr size_t kSmallByteStringCapacity = + sizeof(SmallByteStringRep::data); + +inline constexpr size_t kMediumByteStringSizeBits = sizeof(size_t) * 8 - 2; +inline constexpr size_t kMediumByteStringMaxSize = + (size_t{1} << kMediumByteStringSizeBits) - 1; + +inline constexpr size_t kByteStringViewSizeBits = sizeof(size_t) * 8 - 1; +inline constexpr size_t kByteStringViewMaxSize = + (size_t{1} << kByteStringViewSizeBits) - 1; + +// Representation of medium strings in ByteString. These are either owned by an +// arena or managed by a reference count. This is encoded in `owner` following +// the same semantics as `cel::Owner`. +struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI MediumByteStringRep final { +#ifdef _MSC_VER +#pragma pack(push, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + size_t kind : 2; + size_t size : kMediumByteStringSizeBits; + }; +#ifdef _MSC_VER +#pragma pack(pop) +#endif + const char* data; + uintptr_t owner; +}; + +// Representation of large strings in ByteString. These are stored as +// `absl::Cord` and never owned by an arena. +struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI LargeByteStringRep final { +#ifdef _MSC_VER +#pragma pack(push, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + size_t kind : 2; + size_t padding : kMediumByteStringSizeBits; + }; +#ifdef _MSC_VER +#pragma pack(pop) +#endif + alignas(absl::Cord) std::byte data[sizeof(absl::Cord)]; +}; + +// Representation of ByteString. +union CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ByteStringRep final { +#ifdef _MSC_VER +#pragma pack(push, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + ByteStringKind kind : 2; + } header; +#ifdef _MSC_VER +#pragma pack(pop) +#endif + SmallByteStringRep small; + MediumByteStringRep medium; + LargeByteStringRep large; +}; + +// Returns a `absl::string_view` from `ByteString`, using `arena` to make memory +// allocations if necessary. `stable` indicates whether `cel::Value` is in a +// location where it will not be moved, so that inline string/bytes storage can +// be referenced. +absl::string_view LegacyByteString(const ByteString& string, bool stable, + google::protobuf::Arena* absl_nonnull arena); + +// `ByteString` is a vocabulary type capable of representing copy-on-write +// strings efficiently for arenas and reference counting. The contents of the +// byte string are owned by an arena or managed by a reference count. All byte +// strings have an associated allocator specified at construction, once the byte +// string is constructed the allocator will not and cannot change. Copying and +// moving between different allocators is supported and dealt with +// transparently by copying. +class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI [[nodiscard]] +ByteString final { + public: + static ByteString Concat(const ByteString& lhs, const ByteString& rhs, + google::protobuf::Arena* absl_nonnull arena); + + ByteString() : ByteString(NewDeleteAllocator()) {} + + explicit ByteString(const char* absl_nullable string) + : ByteString(NewDeleteAllocator(), string) {} + + explicit ByteString(absl::string_view string) + : ByteString(NewDeleteAllocator(), string) {} + + explicit ByteString(const std::string& string) + : ByteString(NewDeleteAllocator(), string) {} + + explicit ByteString(std::string&& string) + : ByteString(NewDeleteAllocator(), std::move(string)) {} + + explicit ByteString(const absl::Cord& cord) + : ByteString(NewDeleteAllocator(), cord) {} + + ByteString(const ByteString& other) noexcept { + Construct(other, /*allocator=*/absl::nullopt); + } + + ByteString(ByteString&& other) noexcept { + Construct(other, /*allocator=*/absl::nullopt); + } + + explicit ByteString(Allocator<> allocator) { + SetSmallEmpty(allocator.arena()); + } + + ByteString(Allocator<> allocator, const char* absl_nullable string) + : ByteString(allocator, absl::NullSafeStringView(string)) {} + + ByteString(Allocator<> allocator, absl::string_view string); + + ByteString(Allocator<> allocator, const std::string& string); + + ByteString(Allocator<> allocator, std::string&& string); + + ByteString(Allocator<> allocator, const absl::Cord& cord); + + ByteString(Allocator<> allocator, const ByteString& other) { + Construct(other, allocator); + } + + ByteString(Allocator<> allocator, ByteString&& other) { + Construct(other, allocator); + } + + ByteString(Borrower borrower, + const char* absl_nullable string ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ByteString(borrower, absl::NullSafeStringView(string)) {} + + ByteString(Borrower borrower, + absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ByteString(Borrowed(borrower, string)) {} + + ByteString(Borrower borrower, + const absl::Cord& cord ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ByteString(Borrowed(borrower, cord)) {} + + // Creates a medium byte string that is backed by an external string. Should + // only be called from explicit 'Unsafe' factories. + static ByteString FromExternal(absl::string_view string); + + ~ByteString() { Destroy(); } + + ByteString& operator=(const ByteString& other) noexcept { + if (ABSL_PREDICT_TRUE(this != &other)) { + CopyFrom(other); + } + return *this; + } + + ByteString& operator=(ByteString&& other) noexcept { + if (ABSL_PREDICT_TRUE(this != &other)) { + MoveFrom(other); + } + return *this; + } + + bool empty() const; + + size_t size() const; + + size_t max_size() const { return kByteStringViewMaxSize; } + + absl::string_view Flatten() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::optional TryFlat() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + bool Equals(absl::string_view rhs) const; + bool Equals(const absl::Cord& rhs) const; + bool Equals(const ByteString& rhs) const; + + int Compare(absl::string_view rhs) const; + int Compare(const absl::Cord& rhs) const; + int Compare(const ByteString& rhs) const; + + bool StartsWith(absl::string_view rhs) const; + bool StartsWith(const absl::Cord& rhs) const; + bool StartsWith(const ByteString& rhs) const; + + bool EndsWith(absl::string_view rhs) const; + bool EndsWith(const absl::Cord& rhs) const; + bool EndsWith(const ByteString& rhs) const; + + // Finds the first occurrence of `needle` in this object, starting at byte + // position `pos`. Returns `absl::nullopt` if `needle` is not found. + // Note: Positions are byte-based, not code point based as in + // `cel::StringValue`. + absl::optional Find(absl::string_view needle, size_t pos = 0) const; + absl::optional Find(const absl::Cord& needle, size_t pos = 0) const; + absl::optional Find(const ByteString& needle, size_t pos = 0) const; + + // Returns a new `ByteString` that is a substring of this object, starting at + // byte position `pos` and with a length of `npos` bytes. + // Note: Positions are byte-based, not code point based as in + // `cel::StringValue`. + ByteString Substring(size_t pos, size_t npos) const; + ByteString Substring(size_t pos) const { + ABSL_DCHECK_LE(pos, size()); + return Substring(pos, size()); + } + + void RemovePrefix(size_t n); + + void RemoveSuffix(size_t n); + + std::string ToString() const; + + void CopyToString(std::string* absl_nonnull out) const; + + void AppendToString(std::string* absl_nonnull out) const; + + absl::Cord ToCord() const&; + + absl::Cord ToCord() &&; + + void CopyToCord(absl::Cord* absl_nonnull out) const; + + void AppendToCord(absl::Cord* absl_nonnull out) const; + + absl::string_view ToStringView( + std::string* absl_nonnull scratch + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::string_view AsStringView() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + google::protobuf::Arena* absl_nullable GetArena() const; + + ByteString Clone(google::protobuf::Arena* absl_nonnull arena) const; + + void HashValue(absl::HashState state) const; + + template + decltype(auto) Visit(Visitor&& visitor) const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return std::forward(visitor)(GetSmall()); + case ByteStringKind::kMedium: + return std::forward(visitor)(GetMedium()); + case ByteStringKind::kLarge: + return std::forward(visitor)(GetLarge()); + } + } + + friend void swap(ByteString& lhs, ByteString& rhs) { + if (&lhs != &rhs) { + lhs.Swap(rhs); + } + } + + template + friend H AbslHashValue(H state, const ByteString& byte_string) { + byte_string.HashValue(absl::HashState::Create(&state)); + return state; + } + + private: + friend class ByteStringView; + friend struct ByteStringTestFriend; + friend class cel::BytesValueInputStream; + friend class cel::BytesValueOutputStream; + friend class cel::StringValue; + friend absl::string_view LegacyByteString(const ByteString& string, + bool stable, + google::protobuf::Arena* absl_nonnull arena); + friend struct cel::ArenaTraits; + + struct ExternalStringTag {}; + + static ByteString Borrowed(Borrower borrower, + absl::string_view string + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + static ByteString Borrowed( + Borrower borrower, const absl::Cord& cord ABSL_ATTRIBUTE_LIFETIME_BOUND); + + ByteString(const ReferenceCount* absl_nonnull refcount, + absl::string_view string); + + ByteString(ExternalStringTag, absl::string_view string); + + constexpr ByteStringKind GetKind() const { return rep_.header.kind; } + + absl::string_view GetSmall() const { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + return GetSmall(rep_.small); + } + + static absl::string_view GetSmall(const SmallByteStringRep& rep) { + return absl::string_view(rep.data, rep.size); + } + + absl::string_view GetMedium() const { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return GetMedium(rep_.medium); + } + + static absl::string_view GetMedium(const MediumByteStringRep& rep) { + return absl::string_view(rep.data, rep.size); + } + + google::protobuf::Arena* absl_nullable GetSmallArena() const { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + return GetSmallArena(rep_.small); + } + + static google::protobuf::Arena* absl_nullable GetSmallArena( + const SmallByteStringRep& rep) { + return rep.arena; + } + + google::protobuf::Arena* absl_nullable GetMediumArena() const { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return GetMediumArena(rep_.medium); + } + + static google::protobuf::Arena* absl_nullable GetMediumArena( + const MediumByteStringRep& rep); + + const ReferenceCount* absl_nullable GetMediumReferenceCount() const { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return GetMediumReferenceCount(rep_.medium); + } + + static const ReferenceCount* absl_nullable GetMediumReferenceCount( + const MediumByteStringRep& rep); + + uintptr_t GetMediumOwner() const { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return rep_.medium.owner; + } + + absl::Cord& GetLarge() ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + return GetLarge(rep_.large); + } + + static absl::Cord& GetLarge( + LargeByteStringRep& rep ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return *std::launder(reinterpret_cast(&rep.data[0])); + } + + const absl::Cord& GetLarge() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + return GetLarge(rep_.large); + } + + static const absl::Cord& GetLarge( + const LargeByteStringRep& rep ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return *std::launder(reinterpret_cast(&rep.data[0])); + } + + void SetSmallEmpty(google::protobuf::Arena* absl_nullable arena) { + rep_.header.kind = ByteStringKind::kSmall; + rep_.small.size = 0; + rep_.small.arena = arena; + } + + void SetSmall(google::protobuf::Arena* absl_nullable arena, absl::string_view string); + + void SetSmall(google::protobuf::Arena* absl_nullable arena, const absl::Cord& cord); + + void SetMedium(google::protobuf::Arena* absl_nullable arena, absl::string_view string); + + // This is used to create a medium byte string that is backed by an external + // string. Should only be called from explicit 'Unsafe' factories. + void SetExternalMedium(absl::string_view string); + + void SetMedium(google::protobuf::Arena* absl_nullable arena, std::string&& string); + + void SetMedium(google::protobuf::Arena* absl_nonnull arena, const absl::Cord& cord); + + void SetMedium(absl::string_view string, uintptr_t owner); + + void SetLarge(const absl::Cord& cord); + + void SetLarge(absl::Cord&& cord); + + void Swap(ByteString& other); + + void Construct(const ByteString& other, + absl::optional> allocator); + + void Construct(ByteString& other, absl::optional> allocator); + + void CopyFrom(const ByteString& other); + + void MoveFrom(ByteString& other); + + void Destroy(); + + void DestroyMedium() { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + DestroyMedium(rep_.medium); + } + + static void DestroyMedium(const MediumByteStringRep& rep) { + StrongUnref(GetMediumReferenceCount(rep)); + } + + void DestroyLarge() { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + DestroyLarge(rep_.large); + } + + static void DestroyLarge(LargeByteStringRep& rep) { GetLarge(rep).~Cord(); } + + void CopyToArray(char* absl_nonnull out) const; + + ByteStringRep rep_; +}; + +inline bool ByteString::Equals(const ByteString& rhs) const { + return rhs.Visit(absl::Overload( + [this](absl::string_view rhs) -> bool { return Equals(rhs); }, + [this](const absl::Cord& rhs) -> bool { return Equals(rhs); })); +} + +inline int ByteString::Compare(const ByteString& rhs) const { + return rhs.Visit(absl::Overload( + [this](absl::string_view rhs) -> int { return Compare(rhs); }, + [this](const absl::Cord& rhs) -> int { return Compare(rhs); })); +} + +inline bool ByteString::StartsWith(const ByteString& rhs) const { + return rhs.Visit(absl::Overload( + [this](absl::string_view rhs) -> bool { return StartsWith(rhs); }, + [this](const absl::Cord& rhs) -> bool { return StartsWith(rhs); })); +} + +inline bool ByteString::EndsWith(const ByteString& rhs) const { + return rhs.Visit(absl::Overload( + [this](absl::string_view rhs) -> bool { return EndsWith(rhs); }, + [this](const absl::Cord& rhs) -> bool { return EndsWith(rhs); })); +} + +inline absl::optional ByteString::Find(const ByteString& needle, + size_t pos) const { + return needle.Visit(absl::Overload( + [this, pos](absl::string_view rhs) -> absl::optional { + return Find(rhs, pos); + }, + [this, pos](const absl::Cord& rhs) -> absl::optional { + return Find(rhs, pos); + })); +} + +inline bool operator==(const ByteString& lhs, const ByteString& rhs) { + return lhs.Equals(rhs); +} + +inline bool operator==(const ByteString& lhs, absl::string_view rhs) { + return lhs.Equals(rhs); +} + +inline bool operator==(absl::string_view lhs, const ByteString& rhs) { + return rhs.Equals(lhs); +} + +inline bool operator==(const ByteString& lhs, const absl::Cord& rhs) { + return lhs.Equals(rhs); +} + +inline bool operator==(const absl::Cord& lhs, const ByteString& rhs) { + return rhs.Equals(lhs); +} + +inline bool operator!=(const ByteString& lhs, const ByteString& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(const ByteString& lhs, absl::string_view rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(absl::string_view lhs, const ByteString& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(const ByteString& lhs, const absl::Cord& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(const absl::Cord& lhs, const ByteString& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator<(const ByteString& lhs, const ByteString& rhs) { + return lhs.Compare(rhs) < 0; +} + +inline bool operator<(const ByteString& lhs, absl::string_view rhs) { + return lhs.Compare(rhs) < 0; +} + +inline bool operator<(absl::string_view lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) < 0; +} + +inline bool operator<(const ByteString& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs) < 0; +} + +inline bool operator<(const absl::Cord& lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) < 0; +} + +inline bool operator<=(const ByteString& lhs, const ByteString& rhs) { + return lhs.Compare(rhs) <= 0; +} + +inline bool operator<=(const ByteString& lhs, absl::string_view rhs) { + return lhs.Compare(rhs) <= 0; +} + +inline bool operator<=(absl::string_view lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) <= 0; +} + +inline bool operator<=(const ByteString& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs) <= 0; +} + +inline bool operator<=(const absl::Cord& lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) <= 0; +} + +inline bool operator>(const ByteString& lhs, const ByteString& rhs) { + return lhs.Compare(rhs) > 0; +} + +inline bool operator>(const ByteString& lhs, absl::string_view rhs) { + return lhs.Compare(rhs) > 0; +} + +inline bool operator>(absl::string_view lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) > 0; +} + +inline bool operator>(const ByteString& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs) > 0; +} + +inline bool operator>(const absl::Cord& lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) > 0; +} + +inline bool operator>=(const ByteString& lhs, const ByteString& rhs) { + return lhs.Compare(rhs) >= 0; +} + +inline bool operator>=(const ByteString& lhs, absl::string_view rhs) { + return lhs.Compare(rhs) >= 0; +} + +inline bool operator>=(absl::string_view lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) >= 0; +} + +inline bool operator>=(const ByteString& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs) >= 0; +} + +inline bool operator>=(const absl::Cord& lhs, const ByteString& rhs) { + return -rhs.Compare(lhs) >= 0; +} + +#undef CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI + +} // namespace common_internal + +template <> +struct ArenaTraits { + using constructible = std::true_type; + + static bool trivially_destructible( + const common_internal::ByteString& byte_string) { + switch (byte_string.GetKind()) { + case common_internal::ByteStringKind::kSmall: + return true; + case common_internal::ByteStringKind::kMedium: + return byte_string.GetMediumReferenceCount() == nullptr; + case common_internal::ByteStringKind::kLarge: + return false; + } + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ diff --git a/common/internal/byte_string_test.cc b/common/internal/byte_string_test.cc new file mode 100644 index 000000000..553c9c13a --- /dev/null +++ b/common/internal/byte_string_test.cc @@ -0,0 +1,1204 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/byte_string.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/hash/hash.h" +#include "absl/strings/cord.h" +#include "absl/strings/cord_test_helpers.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/internal/reference_count.h" +#include "common/memory.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +struct ByteStringTestFriend { + static ByteStringKind GetKind(const ByteString& byte_string) { + return byte_string.GetKind(); + } +}; + +namespace { + +using ::testing::_; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::Optional; +using ::testing::SizeIs; +using ::testing::TestWithParam; + +TEST(ByteStringKind, Ostream) { + { + std::ostringstream out; + out << ByteStringKind::kSmall; + EXPECT_EQ(out.str(), "SMALL"); + } + { + std::ostringstream out; + out << ByteStringKind::kMedium; + EXPECT_EQ(out.str(), "MEDIUM"); + } + { + std::ostringstream out; + out << ByteStringKind::kLarge; + EXPECT_EQ(out.str(), "LARGE"); + } +} + +class ByteStringTest : public TestWithParam, + public ByteStringTestFriend { + public: + Allocator<> GetAllocator() { + switch (GetParam()) { + case AllocatorKind::kNewDelete: + return NewDeleteAllocator<>{}; + case AllocatorKind::kArena: + return ArenaAllocator<>(&arena_); + } + } + + private: + google::protobuf::Arena arena_; +}; + +absl::string_view GetSmallStringView() { + static constexpr absl::string_view small = "A small string!"; + return small.substr(0, std::min(kSmallByteStringCapacity, small.size())); +} + +std::string GetSmallString() { return std::string(GetSmallStringView()); } + +absl::Cord GetSmallCord() { + static const absl::NoDestructor small(GetSmallStringView()); + return *small; +} + +absl::string_view GetMediumStringView() { + static constexpr absl::string_view medium = + "A string that is too large for the small string optimization!"; + return medium; +} + +std::string GetMediumString() { return std::string(GetMediumStringView()); } + +const absl::Cord& GetMediumOrLargeCord() { + static const absl::NoDestructor medium_or_large( + GetMediumStringView()); + return *medium_or_large; +} + +const absl::Cord& GetMediumOrLargeFragmentedCord() { + static const absl::NoDestructor medium_or_large( + absl::MakeFragmentedCord( + {GetMediumStringView().substr(0, kSmallByteStringCapacity), + GetMediumStringView().substr(kSmallByteStringCapacity)})); + return *medium_or_large; +} + +TEST_P(ByteStringTest, Default) { + ByteString byte_string = ByteString(GetAllocator(), ""); + EXPECT_THAT(byte_string, SizeIs(0)); + EXPECT_THAT(byte_string, IsEmpty()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); +} + +TEST_P(ByteStringTest, ConstructSmallCString) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallString().c_str()); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumCString) { + ByteString byte_string = + ByteString(GetAllocator(), GetMediumString().c_str()); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructSmallRValueString) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallString()); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructSmallLValueString) { + ByteString byte_string = ByteString( + GetAllocator(), static_cast(GetSmallString())); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumRValueString) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumString()); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumLValueString) { + ByteString byte_string = ByteString( + GetAllocator(), static_cast(GetMediumString())); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructSmallCord) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallCord()); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumOrLargeCord) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + if (GetAllocator().arena() == nullptr) { + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); + } else { + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + } + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST(ByteStringTest, BorrowedUnownedString) { +#ifdef NDEBUG + ByteString byte_string = ByteString(Owner::None(), GetMediumStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumStringView()); +#else + EXPECT_DEBUG_DEATH( + static_cast(ByteString(Owner::None(), GetMediumStringView())), + ::testing::_); +#endif +} + +TEST(ByteStringTest, BorrowedUnownedCord) { +#ifdef NDEBUG + ByteString byte_string = ByteString(Owner::None(), GetMediumOrLargeCord()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumOrLargeCord()); +#else + EXPECT_DEBUG_DEATH( + static_cast(ByteString(Owner::None(), GetMediumOrLargeCord())), + ::testing::_); +#endif +} + +TEST(ByteStringTest, BorrowedReferenceCountSmallString) { + auto* refcount = new ReferenceCounted(); + Owner owner = Owner::ReferenceCount(refcount); + StrongUnref(refcount); + ByteString byte_string = ByteString(owner, GetSmallStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetSmallStringView()); +} + +TEST(ByteStringTest, BorrowedReferenceCountMediumString) { + auto* refcount = new ReferenceCounted(); + Owner owner = Owner::ReferenceCount(refcount); + StrongUnref(refcount); + ByteString byte_string = ByteString(owner, GetMediumStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumStringView()); +} + +TEST(ByteStringTest, BorrowedArenaSmallString) { + google::protobuf::Arena arena; + ByteString byte_string = + ByteString(Owner::Arena(&arena), GetSmallStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), &arena); + EXPECT_EQ(byte_string, GetSmallStringView()); +} + +TEST(ByteStringTest, BorrowedArenaMediumString) { + google::protobuf::Arena arena; + ByteString byte_string = + ByteString(Owner::Arena(&arena), GetMediumStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), &arena); + EXPECT_EQ(byte_string, GetMediumStringView()); +} + +TEST(ByteStringTest, BorrowedReferenceCountCord) { + auto* refcount = new ReferenceCounted(); + Owner owner = Owner::ReferenceCount(refcount); + StrongUnref(refcount); + ByteString byte_string = ByteString(owner, GetMediumOrLargeCord()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumOrLargeCord()); +} + +TEST(ByteStringTest, BorrowedArenaCord) { + google::protobuf::Arena arena; + Owner owner = Owner::Arena(&arena); + ByteString byte_string = ByteString(owner, GetMediumOrLargeCord()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), &arena); + EXPECT_EQ(byte_string, GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, CopyConstruct) { + ByteString small_byte_string = + ByteString(GetAllocator(), GetSmallStringView()); + ByteString medium_byte_string = + ByteString(GetAllocator(), GetMediumStringView()); + ByteString large_byte_string = + ByteString(GetAllocator(), GetMediumOrLargeCord()); + + EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string), + small_byte_string); + EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string), + medium_byte_string); + EXPECT_EQ(ByteString(NewDeleteAllocator(), large_byte_string), + large_byte_string); + + google::protobuf::Arena arena; + EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string), + small_byte_string); + EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string), + medium_byte_string); + EXPECT_EQ(ByteString(ArenaAllocator(&arena), large_byte_string), + large_byte_string); + + EXPECT_EQ(ByteString(GetAllocator(), small_byte_string), small_byte_string); + EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string), medium_byte_string); + EXPECT_EQ(ByteString(GetAllocator(), large_byte_string), large_byte_string); + + EXPECT_EQ(ByteString(small_byte_string), small_byte_string); + EXPECT_EQ(ByteString(medium_byte_string), medium_byte_string); + EXPECT_EQ(ByteString(large_byte_string), large_byte_string); +} + +TEST_P(ByteStringTest, CopyConstructFromExternal) { + ByteString small_byte_string = ByteString::FromExternal(GetSmallStringView()); + ByteString medium_byte_string = + ByteString::FromExternal(GetMediumStringView()); + + EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string), + small_byte_string); + EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string), + medium_byte_string); + + google::protobuf::Arena arena; + EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string), + small_byte_string); + EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string), + medium_byte_string); + + EXPECT_EQ(ByteString(GetAllocator(), small_byte_string), small_byte_string); + EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string), medium_byte_string); + + EXPECT_EQ(ByteString(small_byte_string), small_byte_string); + EXPECT_EQ(ByteString(medium_byte_string), medium_byte_string); +} + +TEST_P(ByteStringTest, MoveConstruct) { + const auto& small_byte_string = [this]() { + return ByteString(GetAllocator(), GetSmallStringView()); + }; + const auto& medium_byte_string = [this]() { + return ByteString(GetAllocator(), GetMediumStringView()); + }; + const auto& large_byte_string = [this]() { + return ByteString(GetAllocator(), GetMediumOrLargeCord()); + }; + + EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string()), + small_byte_string()); + EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string()), + medium_byte_string()); + EXPECT_EQ(ByteString(NewDeleteAllocator(), large_byte_string()), + large_byte_string()); + + google::protobuf::Arena arena; + EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string()), + small_byte_string()); + EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string()), + medium_byte_string()); + EXPECT_EQ(ByteString(ArenaAllocator(&arena), large_byte_string()), + large_byte_string()); + + EXPECT_EQ(ByteString(GetAllocator(), small_byte_string()), + small_byte_string()); + EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string()), + medium_byte_string()); + EXPECT_EQ(ByteString(GetAllocator(), large_byte_string()), + large_byte_string()); + + EXPECT_EQ(ByteString(small_byte_string()), small_byte_string()); + EXPECT_EQ(ByteString(medium_byte_string()), medium_byte_string()); + EXPECT_EQ(ByteString(large_byte_string()), large_byte_string()); +} + +TEST_P(ByteStringTest, MoveConstructFromExternal) { + const auto& small_byte_string = []() { + return ByteString::FromExternal(GetSmallStringView()); + }; + const auto& medium_byte_string = []() { + return ByteString::FromExternal(GetMediumStringView()); + }; + + EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string()), + small_byte_string()); + EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string()), + medium_byte_string()); + + google::protobuf::Arena arena; + EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string()), + small_byte_string()); + EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string()), + medium_byte_string()); + + EXPECT_EQ(ByteString(GetAllocator(), small_byte_string()), + small_byte_string()); + EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string()), + medium_byte_string()); + + EXPECT_EQ(ByteString(small_byte_string()), small_byte_string()); + EXPECT_EQ(ByteString(medium_byte_string()), medium_byte_string()); +} + +TEST_P(ByteStringTest, CopyFromByteString) { + ByteString small_byte_string = + ByteString(GetAllocator(), GetSmallStringView()); + ByteString medium_byte_string = + ByteString(GetAllocator(), GetMediumStringView()); + ByteString large_byte_string = + ByteString(GetAllocator(), GetMediumOrLargeCord()); + + ByteString new_delete_byte_string(NewDeleteAllocator<>{}); + // Small <= Small + new_delete_byte_string = small_byte_string; + EXPECT_EQ(new_delete_byte_string, small_byte_string); + // Small <= Medium + new_delete_byte_string = medium_byte_string; + EXPECT_EQ(new_delete_byte_string, medium_byte_string); + // Medium <= Medium + new_delete_byte_string = medium_byte_string; + EXPECT_EQ(new_delete_byte_string, medium_byte_string); + // Medium <= Large + new_delete_byte_string = large_byte_string; + EXPECT_EQ(new_delete_byte_string, large_byte_string); + // Large <= Large + new_delete_byte_string = large_byte_string; + EXPECT_EQ(new_delete_byte_string, large_byte_string); + // Large <= Small + new_delete_byte_string = small_byte_string; + EXPECT_EQ(new_delete_byte_string, small_byte_string); + // Small <= Large + new_delete_byte_string = large_byte_string; + EXPECT_EQ(new_delete_byte_string, large_byte_string); + // Large <= Medium + new_delete_byte_string = medium_byte_string; + EXPECT_EQ(new_delete_byte_string, medium_byte_string); + // Medium <= Small + new_delete_byte_string = small_byte_string; + EXPECT_EQ(new_delete_byte_string, small_byte_string); + + google::protobuf::Arena arena; + ByteString arena_byte_string(ArenaAllocator<>{&arena}); + // Small <= Small + arena_byte_string = small_byte_string; + EXPECT_EQ(arena_byte_string, small_byte_string); + // Small <= Medium + arena_byte_string = medium_byte_string; + EXPECT_EQ(arena_byte_string, medium_byte_string); + // Medium <= Medium + arena_byte_string = medium_byte_string; + EXPECT_EQ(arena_byte_string, medium_byte_string); + // Medium <= Large + arena_byte_string = large_byte_string; + EXPECT_EQ(arena_byte_string, large_byte_string); + // Large <= Large + arena_byte_string = large_byte_string; + EXPECT_EQ(arena_byte_string, large_byte_string); + // Large <= Small + arena_byte_string = small_byte_string; + EXPECT_EQ(arena_byte_string, small_byte_string); + // Small <= Large + arena_byte_string = large_byte_string; + EXPECT_EQ(arena_byte_string, large_byte_string); + // Large <= Medium + arena_byte_string = medium_byte_string; + EXPECT_EQ(arena_byte_string, medium_byte_string); + // Medium <= Small + arena_byte_string = small_byte_string; + EXPECT_EQ(arena_byte_string, small_byte_string); + + ByteString allocator_byte_string(GetAllocator()); + // Small <= Small + allocator_byte_string = small_byte_string; + EXPECT_EQ(allocator_byte_string, small_byte_string); + // Small <= Medium + allocator_byte_string = medium_byte_string; + EXPECT_EQ(allocator_byte_string, medium_byte_string); + // Medium <= Medium + allocator_byte_string = medium_byte_string; + EXPECT_EQ(allocator_byte_string, medium_byte_string); + // Medium <= Large + allocator_byte_string = large_byte_string; + EXPECT_EQ(allocator_byte_string, large_byte_string); + // Large <= Large + allocator_byte_string = large_byte_string; + EXPECT_EQ(allocator_byte_string, large_byte_string); + // Large <= Small + allocator_byte_string = small_byte_string; + EXPECT_EQ(allocator_byte_string, small_byte_string); + // Small <= Large + allocator_byte_string = large_byte_string; + EXPECT_EQ(allocator_byte_string, large_byte_string); + // Large <= Medium + allocator_byte_string = medium_byte_string; + EXPECT_EQ(allocator_byte_string, medium_byte_string); + // Medium <= Small + allocator_byte_string = small_byte_string; + EXPECT_EQ(allocator_byte_string, small_byte_string); + + // Miscellaneous cases not covered above. + // Large <= Medium Arena String + ByteString large_new_delete_byte_string(NewDeleteAllocator<>{}, + GetMediumOrLargeCord()); + ByteString medium_arena_byte_string(ArenaAllocator<>{&arena}, + GetMediumStringView()); + large_new_delete_byte_string = medium_arena_byte_string; + EXPECT_EQ(large_new_delete_byte_string, medium_arena_byte_string); +} + +TEST_P(ByteStringTest, MoveFrom) { + const auto& small_byte_string = [this]() { + return ByteString(GetAllocator(), GetSmallStringView()); + }; + const auto& medium_byte_string = [this]() { + return ByteString(GetAllocator(), GetMediumStringView()); + }; + const auto& large_byte_string = [this]() { + return ByteString(GetAllocator(), GetMediumOrLargeCord()); + }; + + ByteString new_delete_byte_string(NewDeleteAllocator<>{}); + // Small <= Small + new_delete_byte_string = small_byte_string(); + EXPECT_EQ(new_delete_byte_string, small_byte_string()); + // Small <= Medium + new_delete_byte_string = medium_byte_string(); + EXPECT_EQ(new_delete_byte_string, medium_byte_string()); + // Medium <= Medium + new_delete_byte_string = medium_byte_string(); + EXPECT_EQ(new_delete_byte_string, medium_byte_string()); + // Medium <= Large + new_delete_byte_string = large_byte_string(); + EXPECT_EQ(new_delete_byte_string, large_byte_string()); + // Large <= Large + new_delete_byte_string = large_byte_string(); + EXPECT_EQ(new_delete_byte_string, large_byte_string()); + // Large <= Small + new_delete_byte_string = small_byte_string(); + EXPECT_EQ(new_delete_byte_string, small_byte_string()); + // Small <= Large + new_delete_byte_string = large_byte_string(); + EXPECT_EQ(new_delete_byte_string, large_byte_string()); + // Large <= Medium + new_delete_byte_string = medium_byte_string(); + EXPECT_EQ(new_delete_byte_string, medium_byte_string()); + // Medium <= Small + new_delete_byte_string = small_byte_string(); + EXPECT_EQ(new_delete_byte_string, small_byte_string()); + + google::protobuf::Arena arena; + ByteString arena_byte_string(ArenaAllocator<>{&arena}); + // Small <= Small + arena_byte_string = small_byte_string(); + EXPECT_EQ(arena_byte_string, small_byte_string()); + // Small <= Medium + arena_byte_string = medium_byte_string(); + EXPECT_EQ(arena_byte_string, medium_byte_string()); + // Medium <= Medium + arena_byte_string = medium_byte_string(); + EXPECT_EQ(arena_byte_string, medium_byte_string()); + // Medium <= Large + arena_byte_string = large_byte_string(); + EXPECT_EQ(arena_byte_string, large_byte_string()); + // Large <= Large + arena_byte_string = large_byte_string(); + EXPECT_EQ(arena_byte_string, large_byte_string()); + // Large <= Small + arena_byte_string = small_byte_string(); + EXPECT_EQ(arena_byte_string, small_byte_string()); + // Small <= Large + arena_byte_string = large_byte_string(); + EXPECT_EQ(arena_byte_string, large_byte_string()); + // Large <= Medium + arena_byte_string = medium_byte_string(); + EXPECT_EQ(arena_byte_string, medium_byte_string()); + // Medium <= Small + arena_byte_string = small_byte_string(); + EXPECT_EQ(arena_byte_string, small_byte_string()); + + ByteString allocator_byte_string(GetAllocator()); + // Small <= Small + allocator_byte_string = small_byte_string(); + EXPECT_EQ(allocator_byte_string, small_byte_string()); + // Small <= Medium + allocator_byte_string = medium_byte_string(); + EXPECT_EQ(allocator_byte_string, medium_byte_string()); + // Medium <= Medium + allocator_byte_string = medium_byte_string(); + EXPECT_EQ(allocator_byte_string, medium_byte_string()); + // Medium <= Large + allocator_byte_string = large_byte_string(); + EXPECT_EQ(allocator_byte_string, large_byte_string()); + // Large <= Large + allocator_byte_string = large_byte_string(); + EXPECT_EQ(allocator_byte_string, large_byte_string()); + // Large <= Small + allocator_byte_string = small_byte_string(); + EXPECT_EQ(allocator_byte_string, small_byte_string()); + // Small <= Large + allocator_byte_string = large_byte_string(); + EXPECT_EQ(allocator_byte_string, large_byte_string()); + // Large <= Medium + allocator_byte_string = medium_byte_string(); + EXPECT_EQ(allocator_byte_string, medium_byte_string()); + // Medium <= Small + allocator_byte_string = small_byte_string(); + EXPECT_EQ(allocator_byte_string, small_byte_string()); + + // Miscellaneous cases not covered above. + // Large <= Medium Arena String + ByteString large_new_delete_byte_string(NewDeleteAllocator<>{}, + GetMediumOrLargeCord()); + ByteString medium_arena_byte_string(ArenaAllocator<>{&arena}, + GetMediumStringView()); + large_new_delete_byte_string = std::move(medium_arena_byte_string); + EXPECT_EQ(large_new_delete_byte_string, GetMediumStringView()); +} + +TEST_P(ByteStringTest, Swap) { + using std::swap; + ByteString empty_byte_string(GetAllocator()); + ByteString small_byte_string = + ByteString(GetAllocator(), GetSmallStringView()); + ByteString medium_byte_string = + ByteString(GetAllocator(), GetMediumStringView()); + ByteString large_byte_string = + ByteString(GetAllocator(), GetMediumOrLargeCord()); + + // Small <=> Small + swap(empty_byte_string, small_byte_string); + EXPECT_EQ(empty_byte_string, GetSmallStringView()); + EXPECT_EQ(small_byte_string, ""); + swap(empty_byte_string, small_byte_string); + EXPECT_EQ(empty_byte_string, ""); + EXPECT_EQ(small_byte_string, GetSmallStringView()); + + // Small <=> Medium + swap(small_byte_string, medium_byte_string); + EXPECT_EQ(small_byte_string, GetMediumStringView()); + EXPECT_EQ(medium_byte_string, GetSmallStringView()); + swap(small_byte_string, medium_byte_string); + EXPECT_EQ(small_byte_string, GetSmallStringView()); + EXPECT_EQ(medium_byte_string, GetMediumStringView()); + + // Small <=> Large + swap(small_byte_string, large_byte_string); + EXPECT_EQ(small_byte_string, GetMediumOrLargeCord()); + EXPECT_EQ(large_byte_string, GetSmallStringView()); + swap(small_byte_string, large_byte_string); + EXPECT_EQ(small_byte_string, GetSmallStringView()); + EXPECT_EQ(large_byte_string, GetMediumOrLargeCord()); + + // Medium <=> Medium + static constexpr absl::string_view kDifferentMediumStringView = + "A different string that is too large for the small string optimization!"; + ByteString other_medium_byte_string = + ByteString(GetAllocator(), kDifferentMediumStringView); + swap(medium_byte_string, other_medium_byte_string); + EXPECT_EQ(medium_byte_string, kDifferentMediumStringView); + EXPECT_EQ(other_medium_byte_string, GetMediumStringView()); + swap(medium_byte_string, other_medium_byte_string); + EXPECT_EQ(medium_byte_string, GetMediumStringView()); + EXPECT_EQ(other_medium_byte_string, kDifferentMediumStringView); + + // Medium <=> Large + swap(medium_byte_string, large_byte_string); + EXPECT_EQ(medium_byte_string, GetMediumOrLargeCord()); + EXPECT_EQ(large_byte_string, GetMediumStringView()); + swap(medium_byte_string, large_byte_string); + EXPECT_EQ(medium_byte_string, GetMediumStringView()); + EXPECT_EQ(large_byte_string, GetMediumOrLargeCord()); + + // Large <=> Large + const absl::Cord different_medium_or_large_cord = + absl::Cord(kDifferentMediumStringView); + ByteString other_large_byte_string = + ByteString(GetAllocator(), different_medium_or_large_cord); + swap(large_byte_string, other_large_byte_string); + EXPECT_EQ(large_byte_string, different_medium_or_large_cord); + EXPECT_EQ(other_large_byte_string, GetMediumStringView()); + swap(large_byte_string, other_large_byte_string); + EXPECT_EQ(large_byte_string, GetMediumStringView()); + EXPECT_EQ(other_large_byte_string, different_medium_or_large_cord); + + // Miscellaneous cases not covered above. These do not swap a second time to + // restore state, so they are destructive. + // Small <=> Different Allocator Medium + ByteString medium_new_delete_byte_string = + ByteString(NewDeleteAllocator<>{}, kDifferentMediumStringView); + swap(empty_byte_string, medium_new_delete_byte_string); + EXPECT_EQ(empty_byte_string, kDifferentMediumStringView); + EXPECT_EQ(medium_new_delete_byte_string, ""); + // Small <=> Different Allocator Large + ByteString large_new_delete_byte_string = + ByteString(NewDeleteAllocator<>{}, GetMediumOrLargeCord()); + swap(small_byte_string, large_new_delete_byte_string); + EXPECT_EQ(small_byte_string, GetMediumOrLargeCord()); + EXPECT_EQ(large_new_delete_byte_string, GetSmallStringView()); + // Medium <=> Different Allocator Large + large_new_delete_byte_string = + ByteString(NewDeleteAllocator<>{}, different_medium_or_large_cord); + swap(medium_byte_string, large_new_delete_byte_string); + EXPECT_EQ(medium_byte_string, different_medium_or_large_cord); + EXPECT_EQ(large_new_delete_byte_string, GetMediumStringView()); + // Medium <=> Different Allocator Medium + medium_byte_string = ByteString(GetAllocator(), GetMediumStringView()); + medium_new_delete_byte_string = + ByteString(NewDeleteAllocator<>{}, kDifferentMediumStringView); + swap(medium_byte_string, medium_new_delete_byte_string); + EXPECT_EQ(medium_byte_string, kDifferentMediumStringView); + EXPECT_EQ(medium_new_delete_byte_string, GetMediumStringView()); +} + +TEST_P(ByteStringTest, FlattenSmall) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.Flatten(), GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); +} + +TEST_P(ByteStringTest, FlattenMedium) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.Flatten(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); +} + +TEST_P(ByteStringTest, FlattenLarge) { + if (GetAllocator().arena() != nullptr) { + GTEST_SKIP(); + } + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_EQ(byte_string.Flatten(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); +} + +TEST_P(ByteStringTest, TryFlatSmall) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_THAT(byte_string.TryFlat(), Optional(GetSmallStringView())); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); +} + +TEST_P(ByteStringTest, TryFlatMedium) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_THAT(byte_string.TryFlat(), Optional(GetMediumStringView())); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); +} + +TEST_P(ByteStringTest, TryFlatLarge) { + if (GetAllocator().arena() != nullptr) { + GTEST_SKIP(); + } + ByteString byte_string = + ByteString(GetAllocator(), GetMediumOrLargeFragmentedCord()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_THAT(byte_string.TryFlat(), Eq(absl::nullopt)); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); +} + +TEST_P(ByteStringTest, Equals) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_TRUE(byte_string.Equals(GetMediumStringView())); +} + +TEST_P(ByteStringTest, Compare) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.Compare(GetMediumStringView()), 0); + EXPECT_EQ(byte_string.Compare(GetMediumOrLargeCord()), 0); +} + +TEST_P(ByteStringTest, StartsWith) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_TRUE(byte_string.StartsWith( + GetMediumStringView().substr(0, kSmallByteStringCapacity))); + EXPECT_TRUE(byte_string.StartsWith( + GetMediumOrLargeCord().Subcord(0, kSmallByteStringCapacity))); +} + +TEST_P(ByteStringTest, EndsWith) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_TRUE(byte_string.EndsWith( + GetMediumStringView().substr(kSmallByteStringCapacity))); + EXPECT_TRUE(byte_string.EndsWith(GetMediumOrLargeCord().Subcord( + kSmallByteStringCapacity, + GetMediumOrLargeCord().size() - kSmallByteStringCapacity))); +} + +TEST_P(ByteStringTest, Find) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + + // Find string_view + EXPECT_THAT(byte_string.Find("A string"), Optional(0)); + EXPECT_THAT( + byte_string.Find("small string optimization!"), + Optional(GetMediumStringView().find("small string optimization!"))); + EXPECT_THAT(byte_string.Find("not found"), Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find(""), Optional(0)); + EXPECT_THAT(byte_string.Find("", 3), Optional(3)); + EXPECT_THAT(byte_string.Find("A string", 1), Eq(absl::nullopt)); + + // Find cord + EXPECT_THAT(byte_string.Find(absl::Cord("A string")), Optional(0)); + EXPECT_THAT( + byte_string.Find(absl::Cord("small string optimization!")), + Optional(GetMediumStringView().find("small string optimization!"))); + EXPECT_THAT( + byte_string.Find(absl::MakeFragmentedCord( + {"A string", " that is too large for the small string optimization!", + " extra"})), + Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find(GetMediumOrLargeFragmentedCord()), Optional(0)); + EXPECT_THAT(byte_string.Find(absl::Cord("not found")), Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find(absl::Cord("")), Optional(0)); + EXPECT_THAT(byte_string.Find(absl::Cord(""), 3), Optional(3)); +} + +TEST_P(ByteStringTest, FindEdgeCases) { + ByteString empty_byte_string(GetAllocator(), ""); + EXPECT_THAT(empty_byte_string.Find("a"), Eq(absl::nullopt)); + EXPECT_THAT(empty_byte_string.Find(""), Optional(0)); + ByteString cord_byte_string = + ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_THAT(cord_byte_string.Find("not found"), Eq(absl::nullopt)); + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + + // Needle longer than haystack. + EXPECT_THAT(byte_string.Find(std::string(byte_string.size() + 1, 'a')), + Eq(absl::nullopt)); + + // Needle at the end. + absl::string_view suffix = "optimization!"; + EXPECT_THAT(byte_string.Find(suffix), + Optional(byte_string.size() - suffix.size())); + + // pos at the end. + EXPECT_THAT(byte_string.Find("a", byte_string.size()), Eq(absl::nullopt)); + EXPECT_THAT(byte_string.Find("", byte_string.size()), + Optional(byte_string.size())); + + // Search in a cord-backed ByteString with pos > 0. + EXPECT_THAT(cord_byte_string.Find("string", 1), + Optional(GetMediumStringView().find("string", 1))); + + // Needle at the end of a cord-backed ByteString. + absl::string_view suffix_sv = "optimization!"; + EXPECT_THAT(cord_byte_string.Find(suffix_sv), + Optional(cord_byte_string.size() - suffix_sv.size())); + EXPECT_THAT(cord_byte_string.Find(absl::Cord(suffix_sv)), + Optional(cord_byte_string.size() - suffix_sv.size())); + + // Fragmented needle with empty first chunk. + absl::Cord fragmented_with_empty_chunk; + fragmented_with_empty_chunk.Append(""); + fragmented_with_empty_chunk.Append("A string"); + EXPECT_THAT(byte_string.Find(fragmented_with_empty_chunk), Optional(0)); + + // Search with fragmented cord needle on string_view backed ByteString with + // partial match. + ByteString partial_match_haystack(GetAllocator(), "abababac"); + absl::Cord partial_match_needle = absl::MakeFragmentedCord({"aba", "c"}); + EXPECT_THAT(partial_match_haystack.Find(partial_match_needle), Optional(4)); + + // Search with fragmented cord needle where first chunk is found but not + // enough space for the rest. + ByteString short_haystack(GetAllocator(), "abcdefg"); + absl::Cord needle_too_long = absl::MakeFragmentedCord({"ef", "gh"}); + EXPECT_THAT(short_haystack.Find(needle_too_long), Eq(absl::nullopt)); + + // Search with a fragmented empty cord. + absl::Cord fragmented_empty_cord = absl::MakeFragmentedCord({"", ""}); + EXPECT_THAT(byte_string.Find(fragmented_empty_cord), Optional(0)); + EXPECT_THAT(byte_string.Find(fragmented_empty_cord, 3), Optional(3)); + + // Search for suffix in a fragmented cord. + ByteString fragmented_cord_byte_string(GetAllocator(), + GetMediumOrLargeFragmentedCord()); + EXPECT_THAT(fragmented_cord_byte_string.Find(suffix_sv), + Optional(fragmented_cord_byte_string.size() - suffix_sv.size())); + EXPECT_THAT(fragmented_cord_byte_string.Find(absl::Cord(suffix_sv)), + Optional(fragmented_cord_byte_string.size() - suffix_sv.size())); +} + +#ifndef NDEBUG +TEST_P(ByteStringTest, FindOutOfBounds) { + ByteString byte_string = ByteString(GetAllocator(), "test"); + EXPECT_DEATH(byte_string.Find("t", 5), _); +} +#endif + +TEST_P(ByteStringTest, Substring) { + // small byte_string substring + ByteString small_byte_string = + ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(small_byte_string.Substring(1, 5), + GetSmallStringView().substr(1, 4)); + EXPECT_EQ(small_byte_string.Substring(0, small_byte_string.size()), + GetSmallStringView()); + EXPECT_EQ(small_byte_string.Substring(1, 1), ""); + // medium byte_string substring + ByteString medium_byte_string = + ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(medium_byte_string.Substring(2, 12), + GetMediumStringView().substr(2, 10)); + EXPECT_EQ(medium_byte_string.Substring(0, medium_byte_string.size()), + GetMediumStringView()); + // large byte_string substring + ByteString large_byte_string = + ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(large_byte_string.Substring(3, 15), + GetMediumOrLargeCord().Subcord(3, 12)); + EXPECT_EQ(large_byte_string.Substring(0, large_byte_string.size()), + GetMediumOrLargeCord()); + // substring with one parameter + ByteString tacocat_byte_string = ByteString(GetAllocator(), "tacocat"); + EXPECT_EQ(tacocat_byte_string.Substring(4), "cat"); +} + +TEST_P(ByteStringTest, SubstringEdgeCases) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.Substring(byte_string.size(), byte_string.size()), ""); + EXPECT_EQ(byte_string.Substring(0, 0), ""); +} + +#ifndef NDEBUG +TEST_P(ByteStringTest, SubstringOutOfBounds) { + ByteString byte_string = ByteString(GetAllocator(), "test"); + EXPECT_DEATH(static_cast(byte_string.Substring(5, 5)), _); + EXPECT_DEATH(static_cast(byte_string.Substring(0, 5)), _); + EXPECT_DEATH(static_cast(byte_string.Substring(3, 2)), _); +} +#endif + +TEST_P(ByteStringTest, RemovePrefixSmall) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + byte_string.RemovePrefix(1); + EXPECT_EQ(byte_string, GetSmallStringView().substr(1)); +} + +TEST_P(ByteStringTest, RemovePrefixMedium) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + byte_string.RemovePrefix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(GetMediumStringView().size() - + kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, RemovePrefixMediumOrLarge) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + byte_string.RemovePrefix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(GetMediumStringView().size() - + kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, RemoveSuffixSmall) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + byte_string.RemoveSuffix(1); + EXPECT_EQ(byte_string, + GetSmallStringView().substr(0, GetSmallStringView().size() - 1)); +} + +TEST_P(ByteStringTest, RemoveSuffixMedium) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + byte_string.RemoveSuffix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(0, kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, RemoveSuffixMediumOrLarge) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + byte_string.RemoveSuffix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(0, kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, ToStringSmall) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.ToString(), byte_string); +} + +TEST_P(ByteStringTest, ToStringMedium) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(byte_string.ToString(), byte_string); +} + +TEST_P(ByteStringTest, ToStringLarge) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.ToString(), byte_string); +} + +TEST_P(ByteStringTest, ToStringViewSmall) { + std::string scratch; + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.ToStringView(&scratch), GetSmallStringView()); +} + +TEST_P(ByteStringTest, ToStringViewMedium) { + std::string scratch; + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(byte_string.ToStringView(&scratch), GetMediumStringView()); +} + +TEST_P(ByteStringTest, ToStringViewLarge) { + std::string scratch; + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.ToStringView(&scratch), GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, AsStringViewSmall) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.AsStringView(), GetSmallStringView()); +} + +TEST_P(ByteStringTest, AsStringViewMedium) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(byte_string.AsStringView(), GetMediumStringView()); +} + +TEST_P(ByteStringTest, AsStringViewLarge) { + ByteString byte_string = ByteString(GetMediumOrLargeCord()); + EXPECT_DEATH(byte_string.AsStringView(), _); +} + +TEST_P(ByteStringTest, CopyToStringSmall) { + std::string out; + + ByteString(GetAllocator(), GetSmallStringView()).CopyToString(&out); + EXPECT_EQ(out, GetSmallStringView()); +} + +TEST_P(ByteStringTest, CopyToStringMedium) { + std::string out; + + ByteString(GetAllocator(), GetMediumStringView()).CopyToString(&out); + EXPECT_EQ(out, GetMediumStringView()); +} + +TEST_P(ByteStringTest, CopyToStringLarge) { + std::string out; + + ByteString(GetAllocator(), GetMediumOrLargeCord()).CopyToString(&out); + EXPECT_EQ(out, GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, AppendToStringSmall) { + std::string out; + + ByteString(GetAllocator(), GetSmallStringView()).AppendToString(&out); + EXPECT_EQ(out, GetSmallStringView()); +} + +TEST_P(ByteStringTest, AppendToStringMedium) { + std::string out; + + ByteString(GetAllocator(), GetMediumStringView()).AppendToString(&out); + EXPECT_EQ(out, GetMediumStringView()); +} + +TEST_P(ByteStringTest, AppendToStringLarge) { + std::string out; + + ByteString(GetAllocator(), GetMediumOrLargeCord()).AppendToString(&out); + EXPECT_EQ(out, GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, ToCordSmall) { + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.ToCord(), byte_string); + EXPECT_EQ(std::move(byte_string).ToCord(), GetSmallStringView()); +} + +TEST_P(ByteStringTest, ToCordMedium) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(byte_string.ToCord(), byte_string); + EXPECT_EQ(std::move(byte_string).ToCord(), GetMediumStringView()); +} + +TEST_P(ByteStringTest, ToCordLarge) { + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.ToCord(), byte_string); + EXPECT_EQ(std::move(byte_string).ToCord(), GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, CopyToCordSmall) { + absl::Cord out; + + ByteString(GetAllocator(), GetSmallStringView()).CopyToCord(&out); + EXPECT_EQ(out, GetSmallStringView()); +} + +TEST_P(ByteStringTest, CopyToCordMedium) { + absl::Cord out; + + ByteString(GetAllocator(), GetMediumStringView()).CopyToCord(&out); + EXPECT_EQ(out, GetMediumStringView()); +} + +TEST_P(ByteStringTest, CopyToCordLarge) { + absl::Cord out; + + ByteString(GetAllocator(), GetMediumOrLargeCord()).CopyToCord(&out); + EXPECT_EQ(out, GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, AppendToCordSmall) { + absl::Cord out; + + ByteString(GetAllocator(), GetSmallStringView()).AppendToCord(&out); + EXPECT_EQ(out, GetSmallStringView()); +} + +TEST_P(ByteStringTest, AppendToCordMedium) { + absl::Cord out; + + ByteString(GetAllocator(), GetMediumStringView()).AppendToCord(&out); + EXPECT_EQ(out, GetMediumStringView()); +} + +TEST_P(ByteStringTest, AppendToCordLarge) { + absl::Cord out; + + ByteString(GetAllocator(), GetMediumOrLargeCord()).AppendToCord(&out); + EXPECT_EQ(out, GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, CloneSmall) { + google::protobuf::Arena arena; + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.Clone(&arena), byte_string); +} + +TEST_P(ByteStringTest, CloneMedium) { + google::protobuf::Arena arena; + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(byte_string.Clone(&arena), byte_string); +} + +TEST_P(ByteStringTest, CloneLarge) { + google::protobuf::Arena arena; + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.Clone(&arena), byte_string); +} + +TEST_P(ByteStringTest, LegacyByteStringSmall) { + google::protobuf::Arena arena; + ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/false, &arena), + GetSmallStringView()); + EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/true, &arena), + GetSmallStringView()); +} + +TEST_P(ByteStringTest, LegacyByteStringMedium) { + google::protobuf::Arena arena; + ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/false, &arena), + GetMediumStringView()); + EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/true, &arena), + GetMediumStringView()); +} + +TEST_P(ByteStringTest, LegacyByteStringLarge) { + google::protobuf::Arena arena; + ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/false, &arena), + GetMediumOrLargeCord()); + EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/true, &arena), + GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, HashValue) { + EXPECT_EQ(absl::HashOf(ByteString(GetAllocator(), GetSmallStringView())), + absl::HashOf(GetSmallStringView())); + EXPECT_EQ(absl::HashOf(ByteString(GetAllocator(), GetMediumStringView())), + absl::HashOf(GetMediumStringView())); + EXPECT_EQ(absl::HashOf(ByteString(GetAllocator(), GetMediumOrLargeCord())), + absl::HashOf(GetMediumOrLargeCord())); +} + +INSTANTIATE_TEST_SUITE_P(ByteStringTest, ByteStringTest, + ::testing::Values(AllocatorKind::kNewDelete, + AllocatorKind::kArena)); + +} // namespace +} // namespace cel::common_internal diff --git a/common/internal/casting.h b/common/internal/casting.h new file mode 100644 index 000000000..fe7d03279 --- /dev/null +++ b/common/internal/casting.h @@ -0,0 +1,237 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/casting.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_CASTING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_CASTING_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/meta/type_traits.h" +#include "absl/types/optional.h" +#include "internal/casts.h" + +namespace cel { + +namespace common_internal { + +template +using propagate_const_t = + std::conditional_t>, + std::add_const_t, To>; + +template +using propagate_volatile_t = + std::conditional_t>, + std::add_volatile_t, To>; + +template +using propagate_reference_t = + std::conditional_t, + std::add_lvalue_reference_t, + std::conditional_t, + std::add_rvalue_reference_t, To>>; + +template +using propagate_cvref_t = propagate_reference_t< + propagate_volatile_t, From>, From>; + +} // namespace common_internal + +namespace common_internal { + +// Implementation of `cel::InstanceOf`. +template +struct ABSL_DEPRECATED("Use Is member functions instead.") + InstanceOfImpl final { + static_assert(!std::is_pointer_v, "To must not be a pointer"); + static_assert(!std::is_array_v, "To must not be an array"); + static_assert(!std::is_lvalue_reference_v, + "To must not be a lvalue reference"); + static_assert(!std::is_rvalue_reference_v, + "To must not be a lvalue reference"); + static_assert(!std::is_const_v, "To must not be const qualified"); + static_assert(!std::is_volatile_v, "To must not be volatile qualified"); + static_assert(std::is_class_v, "To must be a non-union class"); + + explicit InstanceOfImpl() = default; + + template + ABSL_DEPRECATED("Use Is member functions instead.") + ABSL_MUST_USE_RESULT bool operator()(const From& from) const { + static_assert(!std::is_volatile_v, + "From must not be volatile qualified"); + static_assert(std::is_class_v, "From must be a non-union class"); + if constexpr (std::is_same_v, To>) { + // Same type. Separate from the next `else if` to work on in-complete + // types. + return true; + } else if constexpr (std::is_polymorphic_v && + std::is_polymorphic_v> && + std::is_base_of_v>) { + // Polymorphic upcast. + return true; + } else if constexpr (!std::is_polymorphic_v && + !std::is_polymorphic_v> && + (std::is_convertible_v || + std::is_convertible_v || + std::is_convertible_v || + std::is_convertible_v)) { + // Implicitly convertible. + return true; + } else { + // Something else. + return from.template Is(); + } + } + + template + ABSL_DEPRECATED("Use Is member functions instead.") + ABSL_MUST_USE_RESULT bool operator()(const From* from) const { + static_assert(!std::is_volatile_v, + "From must not be volatile qualified"); + static_assert(std::is_class_v, "From must be a non-union class"); + return from != nullptr && (*this)(*from); + } +}; + +// Implementation of `cel::Cast`. +template +struct ABSL_DEPRECATED( + "Use explicit conversion functions instead through static_cast.") + CastImpl final { + static_assert(!std::is_pointer_v, "To must not be a pointer"); + static_assert(!std::is_array_v, "To must not be an array"); + static_assert(!std::is_lvalue_reference_v, + "To must not be a lvalue reference"); + static_assert(!std::is_rvalue_reference_v, + "To must not be a lvalue reference"); + static_assert(!std::is_const_v, "To must not be const qualified"); + static_assert(!std::is_volatile_v, "To must not be volatile qualified"); + static_assert(std::is_class_v, "To must be a non-union class"); + + explicit CastImpl() = default; + + template + ABSL_DEPRECATED( + "Use explicit conversion functions instead through static_cast.") + ABSL_MUST_USE_RESULT decltype(auto) + operator()(From&& from) const { + static_assert(!std::is_volatile_v, + "From must not be volatile qualified"); + static_assert(std::is_class_v>, + "From must be a non-union class"); + if constexpr (std::is_polymorphic_v) { + static_assert(std::is_lvalue_reference_v, + "polymorphic casts are only possible on lvalue references"); + } + if constexpr (std::is_same_v, To>) { + // Same type. Separate from the next `else if` to work on in-complete + // types. + return static_cast>(from); + } else if constexpr (std::is_polymorphic_v && + std::is_polymorphic_v> && + std::is_base_of_v>) { + // Polymorphic upcast. + return static_cast>(from); + } else if constexpr (std::is_polymorphic_v && + std::is_polymorphic_v> && + std::is_base_of_v, To>) { + // Polymorphic downcast. + return cel::internal::down_cast>( + std::forward(from)); + } else if constexpr (std::is_convertible_v && + !std::is_polymorphic_v && + !std::is_polymorphic_v>) { + return static_cast(std::forward(from)); + } else { + // Something else. + return std::forward(from).template Get(); + } + } + + template + ABSL_DEPRECATED( + "Use explicit conversion functions instead through static_cast.") + ABSL_MUST_USE_RESULT decltype(auto) + operator()(From* from) const { + static_assert(!std::is_volatile_v, + "From must not be volatile qualified"); + static_assert(std::is_class_v, "From must be a non-union class"); + using R = decltype((*this)(*from)); + static_assert(std::is_lvalue_reference_v); + if (from == nullptr) { + return static_cast>>( + nullptr); + } + return static_cast>>( + std::addressof((*this)(*from))); + } +}; + +// Implementation of `cel::As`. +template +struct ABSL_DEPRECATED("Use As member functions instead.") AsImpl final { + static_assert(!std::is_pointer_v, "To must not be a pointer"); + static_assert(!std::is_array_v, "To must not be an array"); + static_assert(!std::is_lvalue_reference_v, + "To must not be a lvalue reference"); + static_assert(!std::is_rvalue_reference_v, + "To must not be a lvalue reference"); + static_assert(!std::is_const_v, "To must not be const qualified"); + static_assert(!std::is_volatile_v, "To must not be volatile qualified"); + static_assert(std::is_class_v, "To must be a non-union class"); + + explicit AsImpl() = default; + + template + ABSL_DEPRECATED("Use As member functions instead.") + ABSL_MUST_USE_RESULT decltype(auto) operator()(From&& from) const { + // Returns either `absl::optional` or `cel::optional_ref` + // depending on the return type of `CastTraits::Convert`. The use of these + // two types is an implementation detail. + static_assert(!std::is_volatile_v, + "From must not be volatile qualified"); + static_assert(std::is_class_v>, + "From must be a non-union class"); + return std::forward(from).template As(); + } + + // Returns a pointer. + template + ABSL_DEPRECATED("Use As member functions instead.") + ABSL_MUST_USE_RESULT decltype(auto) operator()(From* from) const { + // Returns either `absl::optional` or `To*` depending on the return type of + // `CastTraits::Convert`. The use of these two types is an implementation + // detail. + static_assert(!std::is_volatile_v, + "From must not be volatile qualified"); + static_assert(std::is_class_v, "From must be a non-union class"); + using R = decltype(from->template As()); + if (from == nullptr) { + return R{absl::nullopt}; + } + return from->template As(); + } +}; + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_CASTING_H_ diff --git a/common/internal/metadata.h b/common/internal/metadata.h new file mode 100644 index 000000000..5d2fa8322 --- /dev/null +++ b/common/internal/metadata.h @@ -0,0 +1,41 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_METADATA_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_METADATA_H_ + +#include + +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +// `google::protobuf::Arena` has a minimum alignment of 8. `ReferenceCount` has a minimum +// alignment that is guaranteed to be greater than or equal to `google::protobuf::Arena`. +inline constexpr uintptr_t kMetadataOwnerNone = 0; +inline constexpr uintptr_t kMetadataOwnerReferenceCountBit = uintptr_t{1} << 0; +inline constexpr uintptr_t kMetadataOwnerArenaBit = uintptr_t{1} << 1; +inline constexpr uintptr_t kMetadataOwnerBits = alignof(google::protobuf::Arena) - 1; +inline constexpr uintptr_t kMetadataOwnerPointerMask = ~kMetadataOwnerBits; + +// Ensure kMetadataOwnerBits encompasses kMetadataOwnerReferenceCountBit and +// kMetadataOwnerArenaBit. +static_assert((kMetadataOwnerBits | kMetadataOwnerReferenceCountBit) == + kMetadataOwnerBits); +static_assert((kMetadataOwnerBits | kMetadataOwnerArenaBit) == + kMetadataOwnerBits); + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_METADATA_H_ diff --git a/common/internal/reference_count.cc b/common/internal/reference_count.cc new file mode 100644 index 000000000..c954c685e --- /dev/null +++ b/common/internal/reference_count.cc @@ -0,0 +1,118 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/reference_count.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/data.h" +#include "internal/new.h" +#include "google/protobuf/message_lite.h" + +namespace cel::common_internal { + +template class DeletingReferenceCount; + +namespace { + +class ReferenceCountedStdString final : public ReferenceCounted { + public: + static std::pair New( + std::string&& string) { + const auto* const refcount = + new ReferenceCountedStdString(std::move(string)); + const auto* const refcount_string = std::launder( + reinterpret_cast(&refcount->string_[0])); + return std::pair{static_cast(refcount), + absl::string_view(*refcount_string)}; + } + + explicit ReferenceCountedStdString(std::string&& string) { + (::new (static_cast(&string_[0])) std::string(std::move(string))) + ->shrink_to_fit(); + } + + private: + void Finalize() noexcept override { + std::destroy_at(std::launder(reinterpret_cast(&string_[0]))); + } + + alignas(std::string) char string_[sizeof(std::string)]; +}; + +class ReferenceCountedString final : public ReferenceCounted { + public: + static std::pair New( + absl::string_view string) { + const auto* const refcount = + ::new (internal::New(Overhead() + string.size())) + ReferenceCountedString(string); + return std::pair{static_cast(refcount), + absl::string_view(refcount->data_, refcount->size_)}; + } + + private: +// ReferenceCountedString is non-standard-layout due to having virtual functions +// from a base class. This causes compilers to warn about the use of offsetof(), +// but it still works here, so silence the warning and proceed. +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Winvalid-offsetof" +#endif + + static size_t Overhead() { return offsetof(ReferenceCountedString, data_); } + +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif + + explicit ReferenceCountedString(absl::string_view string) + : size_(string.size()) { + std::memcpy(data_, string.data(), size_); + } + + void Delete() noexcept override { + void* const that = this; + const auto size = size_; + std::destroy_at(this); + internal::SizedDelete(that, Overhead() + size); + } + + const size_t size_; + char data_[]; +}; + +} // namespace + +std::pair +MakeReferenceCountedString(absl::string_view value) { + ABSL_DCHECK(!value.empty()); + return ReferenceCountedString::New(value); +} + +std::pair +MakeReferenceCountedString(std::string&& value) { + ABSL_DCHECK(!value.empty()); + return ReferenceCountedStdString::New(std::move(value)); +} + +} // namespace cel::common_internal diff --git a/common/internal/reference_count.h b/common/internal/reference_count.h new file mode 100644 index 000000000..9c7fb5371 --- /dev/null +++ b/common/internal/reference_count.h @@ -0,0 +1,406 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This header contains primitives for reference counting, roughly equivalent to +// the primitives used to implement `std::shared_ptr`. These primitives should +// not be used directly in most cases, instead `cel::Shared` should be +// used instead. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNT_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNT_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/data.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message_lite.h" + +namespace cel::common_internal { + +struct AdoptRef final { + explicit AdoptRef() = default; +}; + +inline constexpr AdoptRef kAdoptRef{}; + +class ReferenceCount; +struct ReferenceCountFromThis; + +void SetReferenceCountForThat(ReferenceCountFromThis& that, + ReferenceCount* absl_nullable refcount); + +ReferenceCount* absl_nullable GetReferenceCountForThat( + const ReferenceCountFromThis& that); + +// `ReferenceCountFromThis` is similar to `std::enable_shared_from_this`. It +// allows the derived object to inspect its own reference count. It should not +// be used directly, but should be used through +// `cel::EnableManagedMemoryFromThis`. +struct ReferenceCountFromThis { + private: + friend void SetReferenceCountForThat(ReferenceCountFromThis& that, + ReferenceCount* absl_nullable refcount); + friend ReferenceCount* absl_nullable GetReferenceCountForThat( + const ReferenceCountFromThis& that); + + static constexpr uintptr_t kNullPtr = uintptr_t{0}; + static constexpr uintptr_t kSentinelPtr = ~kNullPtr; + + void* absl_nullable refcount = reinterpret_cast(kSentinelPtr); +}; + +inline void SetReferenceCountForThat(ReferenceCountFromThis& that, + ReferenceCount* absl_nullable refcount) { + ABSL_DCHECK_EQ(that.refcount, + reinterpret_cast(ReferenceCountFromThis::kSentinelPtr)); + that.refcount = static_cast(refcount); +} + +inline ReferenceCount* absl_nullable GetReferenceCountForThat( + const ReferenceCountFromThis& that) { + ABSL_DCHECK_NE(that.refcount, + reinterpret_cast(ReferenceCountFromThis::kSentinelPtr)); + return static_cast(that.refcount); +} + +void StrongRef(const ReferenceCount& refcount) noexcept; + +void StrongRef(const ReferenceCount* absl_nullable refcount) noexcept; + +void StrongUnref(const ReferenceCount& refcount) noexcept; + +void StrongUnref(const ReferenceCount* absl_nullable refcount) noexcept; + +ABSL_MUST_USE_RESULT +bool StrengthenRef(const ReferenceCount& refcount) noexcept; + +ABSL_MUST_USE_RESULT +bool StrengthenRef(const ReferenceCount* absl_nullable refcount) noexcept; + +void WeakRef(const ReferenceCount& refcount) noexcept; + +void WeakRef(const ReferenceCount* absl_nullable refcount) noexcept; + +void WeakUnref(const ReferenceCount& refcount) noexcept; + +void WeakUnref(const ReferenceCount* absl_nullable refcount) noexcept; + +ABSL_MUST_USE_RESULT +bool IsUniqueRef(const ReferenceCount& refcount) noexcept; + +ABSL_MUST_USE_RESULT +bool IsUniqueRef(const ReferenceCount* absl_nullable refcount) noexcept; + +ABSL_MUST_USE_RESULT +bool IsExpiredRef(const ReferenceCount& refcount) noexcept; + +ABSL_MUST_USE_RESULT +bool IsExpiredRef(const ReferenceCount* absl_nullable refcount) noexcept; + +// `ReferenceCount` is similar to the control block used by `std::shared_ptr`. +// It is not meant to be interacted with directly in most cases, instead +// `cel::Shared` should be used. +class alignas(8) ReferenceCount { + public: + ReferenceCount() = default; + + ReferenceCount(const ReferenceCount&) = delete; + ReferenceCount(ReferenceCount&&) = delete; + ReferenceCount& operator=(const ReferenceCount&) = delete; + ReferenceCount& operator=(ReferenceCount&&) = delete; + + virtual ~ReferenceCount() = default; + + private: + friend void StrongRef(const ReferenceCount& refcount) noexcept; + friend void StrongUnref(const ReferenceCount& refcount) noexcept; + friend bool StrengthenRef(const ReferenceCount& refcount) noexcept; + friend void WeakRef(const ReferenceCount& refcount) noexcept; + friend void WeakUnref(const ReferenceCount& refcount) noexcept; + friend bool IsUniqueRef(const ReferenceCount& refcount) noexcept; + friend bool IsExpiredRef(const ReferenceCount& refcount) noexcept; + + virtual void Finalize() noexcept = 0; + + virtual void Delete() noexcept = 0; + + mutable std::atomic strong_refcount_ = 1; + mutable std::atomic weak_refcount_ = 1; +}; + +// ReferenceCount and its derivations must be at least as aligned as +// google::protobuf::Arena. This is a requirement for the pointer tagging defined in +// common/internal/metadata.h. +static_assert(alignof(ReferenceCount) >= alignof(google::protobuf::Arena)); + +// `ReferenceCounted` is a base class for classes which should be reference +// counted. It provides default implementations for `Finalize()` and `Delete()`. +class ReferenceCounted : public ReferenceCount { + private: + void Finalize() noexcept override {} + + void Delete() noexcept override { delete this; } +}; + +// `EmplacedReferenceCount` adapts `T` to make it reference countable, by +// storing `T` inside the reference count. This only works when `T` has not yet +// been allocated. +template +class EmplacedReferenceCount final : public ReferenceCounted { + public: + static_assert(std::is_destructible_v, "T must be destructible"); + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_array_v, "T must not be an array"); + + template + explicit EmplacedReferenceCount(T*& value, Args&&... args) noexcept( + std::is_nothrow_constructible_v) { + value = + ::new (static_cast(&value_[0])) T(std::forward(args)...); + } + + private: + void Finalize() noexcept override { + std::destroy_at(std::launder(reinterpret_cast(&value_[0]))); + } + + // We store the instance of `T` in a char buffer and use placement new and + // direct calls to the destructor. The reason for this is `Finalize()` is + // called when the strong reference count hits 0. This allows us to destroy + // our instance of `T` once we are no longer strongly reachable and deallocate + // the memory once we are no longer weakly reachable. + alignas(T) char value_[sizeof(T)]; +}; + +// `DeletingReferenceCount` adapts `T` to make it reference countable, by taking +// ownership of `T` and deleting it. This only works when `T` has already been +// allocated and is to expensive to move or copy. +template +class DeletingReferenceCount final : public ReferenceCounted { + public: + explicit DeletingReferenceCount(const T* absl_nonnull to_delete) noexcept + : to_delete_(to_delete) {} + + private: + void Finalize() noexcept override { delete to_delete_; } + + const T* absl_nonnull const to_delete_; +}; + +extern template class DeletingReferenceCount; + +template +const ReferenceCount* absl_nonnull MakeDeletingReferenceCount( + const T* absl_nonnull to_delete) { + if constexpr (google::protobuf::Arena::is_arena_constructable::value) { + ABSL_DCHECK_EQ(to_delete->GetArena(), nullptr); + } + if constexpr (std::is_base_of_v) { + return new DeletingReferenceCount(to_delete); + } else { + auto* refcount = new DeletingReferenceCount(to_delete); + if constexpr (std::is_base_of_v) { + common_internal::SetDataReferenceCount(to_delete, refcount); + } + return refcount; + } +} + +template +std::pair +MakeEmplacedReferenceCount(Args&&... args) { + using U = std::remove_const_t; + U* pointer; + auto* const refcount = + new EmplacedReferenceCount(pointer, std::forward(args)...); + if constexpr (google::protobuf::Arena::is_arena_constructable::value) { + ABSL_DCHECK_EQ(pointer->GetArena(), nullptr); + } + if constexpr (std::is_base_of_v) { + common_internal::SetDataReferenceCount(pointer, refcount); + } + return std::pair{static_cast(pointer), + static_cast(refcount)}; +} + +template +class InlinedReferenceCount final : public ReferenceCounted { + public: + template + explicit InlinedReferenceCount(std::in_place_t, Args&&... args) + : ReferenceCounted() { + ::new (static_cast(value())) T(std::forward(args)...); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull value() { + return reinterpret_cast(&value_[0]); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull value() const { + return reinterpret_cast(&value_[0]); + } + + private: + void Finalize() noexcept override { value()->~T(); } + + // We store the instance of `T` in a char buffer and use placement new and + // direct calls to the destructor. The reason for this is `Finalize()` is + // called when the strong reference count hits 0. This allows us to destroy + // our instance of `T` once we are no longer strongly reachable and deallocate + // the memory once we are no longer weakly reachable. + alignas(T) char value_[sizeof(T)]; +}; + +template +std::pair MakeReferenceCount( + Args&&... args) { + using U = std::remove_const_t; + auto* const refcount = + new InlinedReferenceCount(std::in_place, std::forward(args)...); + auto* const pointer = refcount->value(); + if constexpr (std::is_base_of_v) { + SetReferenceCountForThat(*pointer, refcount); + } + return std::make_pair(static_cast(pointer), + static_cast(refcount)); +} + +inline void StrongRef(const ReferenceCount& refcount) noexcept { + const auto count = + refcount.strong_refcount_.fetch_add(1, std::memory_order_relaxed); + ABSL_DCHECK_GT(count, 0); +} + +inline void StrongRef(const ReferenceCount* absl_nullable refcount) noexcept { + if (refcount != nullptr) { + StrongRef(*refcount); + } +} + +inline void StrongUnref(const ReferenceCount& refcount) noexcept { + const auto count = + refcount.strong_refcount_.fetch_sub(1, std::memory_order_acq_rel); + ABSL_DCHECK_GT(count, 0); + ABSL_ASSUME(count > 0); + if (ABSL_PREDICT_FALSE(count == 1)) { + const_cast(refcount).Finalize(); + WeakUnref(refcount); + } +} + +inline void StrongUnref(const ReferenceCount* absl_nullable refcount) noexcept { + if (refcount != nullptr) { + StrongUnref(*refcount); + } +} + +ABSL_MUST_USE_RESULT +inline bool StrengthenRef(const ReferenceCount& refcount) noexcept { + auto count = refcount.strong_refcount_.load(std::memory_order_relaxed); + while (true) { + ABSL_DCHECK_GE(count, 0); + ABSL_ASSUME(count >= 0); + if (count == 0) { + return false; + } + if (refcount.strong_refcount_.compare_exchange_weak( + count, count + 1, std::memory_order_release, + std::memory_order_relaxed)) { + return true; + } + } +} + +ABSL_MUST_USE_RESULT +inline bool StrengthenRef( + const ReferenceCount* absl_nullable refcount) noexcept { + return refcount != nullptr ? StrengthenRef(*refcount) : false; +} + +inline void WeakRef(const ReferenceCount& refcount) noexcept { + const auto count = + refcount.weak_refcount_.fetch_add(1, std::memory_order_relaxed); + ABSL_DCHECK_GT(count, 0); +} + +inline void WeakRef(const ReferenceCount* absl_nullable refcount) noexcept { + if (refcount != nullptr) { + WeakRef(*refcount); + } +} + +inline void WeakUnref(const ReferenceCount& refcount) noexcept { + const auto count = + refcount.weak_refcount_.fetch_sub(1, std::memory_order_acq_rel); + ABSL_DCHECK_GT(count, 0); + ABSL_ASSUME(count > 0); + if (ABSL_PREDICT_FALSE(count == 1)) { + const_cast(refcount).Delete(); + } +} + +inline void WeakUnref(const ReferenceCount* absl_nullable refcount) noexcept { + if (refcount != nullptr) { + WeakUnref(*refcount); + } +} + +ABSL_MUST_USE_RESULT +inline bool IsUniqueRef(const ReferenceCount& refcount) noexcept { + const auto count = refcount.strong_refcount_.load(std::memory_order_acquire); + ABSL_DCHECK_GT(count, 0); + ABSL_ASSUME(count > 0); + return count == 1; +} + +ABSL_MUST_USE_RESULT +inline bool IsUniqueRef(const ReferenceCount* absl_nullable refcount) noexcept { + return refcount != nullptr ? IsUniqueRef(*refcount) : false; +} + +ABSL_MUST_USE_RESULT +inline bool IsExpiredRef(const ReferenceCount& refcount) noexcept { + const auto count = refcount.strong_refcount_.load(std::memory_order_acquire); + ABSL_DCHECK_GE(count, 0); + ABSL_ASSUME(count >= 0); + return count == 0; +} + +ABSL_MUST_USE_RESULT +inline bool IsExpiredRef( + const ReferenceCount* absl_nullable refcount) noexcept { + return refcount != nullptr ? IsExpiredRef(*refcount) : false; +} + +std::pair +MakeReferenceCountedString(absl::string_view value); + +std::pair +MakeReferenceCountedString(std::string&& value); + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNT_H_ diff --git a/common/internal/reference_count_test.cc b/common/internal/reference_count_test.cc new file mode 100644 index 000000000..af36fa9a5 --- /dev/null +++ b/common/internal/reference_count_test.cc @@ -0,0 +1,162 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/reference_count.h" + +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "common/data.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message_lite.h" + +namespace cel::common_internal { +namespace { + +using ::testing::NotNull; +using ::testing::WhenDynamicCastTo; + +class Object : public virtual ReferenceCountFromThis { + public: + explicit Object(bool& destructed) : destructed_(destructed) {} + + ~Object() { destructed_ = true; } + + private: + bool& destructed_; +}; + +class Subobject : public Object, public virtual ReferenceCountFromThis { + public: + using Object::Object; +}; + +TEST(ReferenceCount, Strong) { + bool destructed = false; + Object* object; + ReferenceCount* refcount; + std::tie(object, refcount) = MakeReferenceCount(destructed); + EXPECT_EQ(GetReferenceCountForThat(*object), refcount); + EXPECT_EQ(GetReferenceCountForThat(*static_cast(object)), + refcount); + StrongRef(refcount); + StrongUnref(refcount); + EXPECT_TRUE(IsUniqueRef(refcount)); + EXPECT_FALSE(IsExpiredRef(refcount)); + EXPECT_FALSE(destructed); + StrongUnref(refcount); + EXPECT_TRUE(destructed); +} + +TEST(ReferenceCount, Weak) { + bool destructed = false; + Object* object; + ReferenceCount* refcount; + std::tie(object, refcount) = MakeReferenceCount(destructed); + EXPECT_EQ(GetReferenceCountForThat(*object), refcount); + EXPECT_EQ(GetReferenceCountForThat(*static_cast(object)), + refcount); + WeakRef(refcount); + ASSERT_TRUE(StrengthenRef(refcount)); + StrongUnref(refcount); + EXPECT_TRUE(IsUniqueRef(refcount)); + EXPECT_FALSE(IsExpiredRef(refcount)); + EXPECT_FALSE(destructed); + StrongUnref(refcount); + EXPECT_TRUE(destructed); + EXPECT_TRUE(IsExpiredRef(refcount)); + ASSERT_FALSE(StrengthenRef(refcount)); + WeakUnref(refcount); +} + +class DataObject final : public Data { + public: + DataObject() noexcept : Data() {} + + explicit DataObject(google::protobuf::Arena* absl_nullable arena) noexcept + : Data(arena) {} + + char member_[17]; +}; + +struct OtherObject final { + char data[17]; +}; + +TEST(DeletingReferenceCount, Data) { + auto* data = new DataObject(); + const auto* refcount = MakeDeletingReferenceCount(data); + EXPECT_THAT( + refcount, + WhenDynamicCastTo*>(NotNull())); + EXPECT_EQ(common_internal::GetDataReferenceCount(data), refcount); + StrongUnref(refcount); +} + +TEST(DeletingReferenceCount, MessageLite) { + auto* message_lite = new google::protobuf::Value(); + const auto* refcount = MakeDeletingReferenceCount(message_lite); + EXPECT_THAT( + refcount, + WhenDynamicCastTo*>( + NotNull())); + StrongUnref(refcount); +} + +TEST(DeletingReferenceCount, Other) { + auto* other = new OtherObject(); + const auto* refcount = MakeDeletingReferenceCount(other); + EXPECT_THAT( + refcount, + WhenDynamicCastTo*>(NotNull())); + StrongUnref(refcount); +} + +TEST(EmplacedReferenceCount, Data) { + Data* data; + const ReferenceCount* refcount; + std::tie(data, refcount) = MakeEmplacedReferenceCount(); + EXPECT_THAT( + refcount, + WhenDynamicCastTo*>(NotNull())); + EXPECT_EQ(common_internal::GetDataReferenceCount(data), refcount); + StrongUnref(refcount); +} + +TEST(EmplacedReferenceCount, MessageLite) { + google::protobuf::Value* message_lite; + const ReferenceCount* refcount; + std::tie(message_lite, refcount) = + MakeEmplacedReferenceCount(); + EXPECT_THAT( + refcount, + WhenDynamicCastTo*>( + NotNull())); + StrongUnref(refcount); +} + +TEST(EmplacedReferenceCount, Other) { + OtherObject* other; + const ReferenceCount* refcount; + std::tie(other, refcount) = MakeEmplacedReferenceCount(); + EXPECT_THAT( + refcount, + WhenDynamicCastTo*>(NotNull())); + StrongUnref(refcount); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/internal/value_conversion.cc b/common/internal/value_conversion.cc new file mode 100644 index 000000000..57cf2224b --- /dev/null +++ b/common/internal/value_conversion.cc @@ -0,0 +1,321 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "common/internal/value_conversion.h" + +#include +#include + +#include "cel/expr/value.pb.h" +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/any.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "extensions/protobuf/value.h" +#include "internal/proto_time_encoding.h" +#include "internal/status_macros.h" +#include "internal/time.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/message.h" + +namespace cel::test { +namespace { + +using ExprValueKind = cel::expr::Value::KindCase; +using ExprMapValue = cel::expr::MapValue; +using ExprListValue = cel::expr::ListValue; + +std::string ToString(ExprValueKind kind_case) { + switch (kind_case) { + case ExprValueKind::kBoolValue: + return "bool_value"; + case ExprValueKind::kInt64Value: + return "int64_value"; + case ExprValueKind::kUint64Value: + return "uint64_value"; + case ExprValueKind::kDoubleValue: + return "double_value"; + case ExprValueKind::kStringValue: + return "string_value"; + case ExprValueKind::kBytesValue: + return "bytes_value"; + case ExprValueKind::kTypeValue: + return "type_value"; + case ExprValueKind::kEnumValue: + return "enum_value"; + case ExprValueKind::kMapValue: + return "map_value"; + case ExprValueKind::kListValue: + return "list_value"; + case ExprValueKind::kNullValue: + return "null_value"; + case ExprValueKind::kObjectValue: + return "object_value"; + default: + return "unknown kind case"; + } +} + +absl::StatusOr FromObject( + const google::protobuf::Any& any, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (any.type_url() == "type.googleapis.com/google.protobuf.Duration") { + google::protobuf::Duration duration; + if (!any.UnpackTo(&duration)) { + return absl::InvalidArgumentError("invalid duration"); + } + absl::Duration d = internal::DecodeDuration(duration); + CEL_RETURN_IF_ERROR(cel::internal::ValidateDuration(d)); + return cel::DurationValue(d); + } else if (any.type_url() == + "type.googleapis.com/google.protobuf.Timestamp") { + google::protobuf::Timestamp timestamp; + if (!any.UnpackTo(×tamp)) { + return absl::InvalidArgumentError("invalid timestamp"); + } + absl::Time time = internal::DecodeTime(timestamp); + CEL_RETURN_IF_ERROR(cel::internal::ValidateTimestamp(time)); + return cel::TimestampValue(time); + } + + return extensions::ProtoMessageToValue(any, descriptor_pool, message_factory, + arena); +} + +absl::StatusOr MapValueFromExpr( + const ExprMapValue& map_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + auto builder = cel::NewMapValueBuilder(arena); + for (const auto& entry : map_value.entries()) { + CEL_ASSIGN_OR_RETURN(auto key, + FromExprValue(entry.key(), descriptor_pool, + message_factory, arena)); + CEL_ASSIGN_OR_RETURN(auto value, + FromExprValue(entry.value(), descriptor_pool, + message_factory, arena)); + CEL_RETURN_IF_ERROR(builder->Put(std::move(key), std::move(value))); + } + + return std::move(*builder).Build(); +} + +absl::StatusOr ListValueFromExpr( + const ExprListValue& list_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + auto builder = cel::NewListValueBuilder(arena); + for (const auto& elem : list_value.values()) { + CEL_ASSIGN_OR_RETURN( + auto value, + FromExprValue(elem, descriptor_pool, message_factory, arena)); + CEL_RETURN_IF_ERROR(builder->Add(std::move(value))); + } + + return std::move(*builder).Build(); +} + +absl::StatusOr MapValueToExpr( + const MapValue& map_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ExprMapValue result; + + CEL_ASSIGN_OR_RETURN(auto iter, map_value.NewIterator()); + + while (iter->HasNext()) { + CEL_ASSIGN_OR_RETURN(auto key_value, + iter->Next(descriptor_pool, message_factory, arena)); + CEL_ASSIGN_OR_RETURN( + auto value_value, + map_value.Get(key_value, descriptor_pool, message_factory, arena)); + + CEL_ASSIGN_OR_RETURN( + auto key, + ToExprValue(key_value, descriptor_pool, message_factory, arena)); + CEL_ASSIGN_OR_RETURN(auto value, + ToExprValue(value_value, descriptor_pool, + message_factory, arena)); + + auto* entry = result.add_entries(); + + *entry->mutable_key() = std::move(key); + *entry->mutable_value() = std::move(value); + } + + return result; +} + +absl::StatusOr ListValueToExpr( + const ListValue& list_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ExprListValue result; + + CEL_ASSIGN_OR_RETURN(auto iter, list_value.NewIterator()); + + while (iter->HasNext()) { + CEL_ASSIGN_OR_RETURN(auto elem, + iter->Next(descriptor_pool, message_factory, arena)); + CEL_ASSIGN_OR_RETURN( + *result.add_values(), + ToExprValue(elem, descriptor_pool, message_factory, arena)); + } + + return result; +} + +absl::StatusOr ToProtobufAny( + const StructValue& struct_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + google::protobuf::io::CordOutputStream serialized; + CEL_RETURN_IF_ERROR( + struct_value.SerializeTo(descriptor_pool, message_factory, &serialized)); + google::protobuf::Any result; + result.set_type_url(MakeTypeUrl(struct_value.GetTypeName())); + result.set_value(std::string(std::move(serialized).Consume())); + + return result; +} + +} // namespace + +absl::StatusOr FromExprValue( + const cel::expr::Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + google::protobuf::LinkMessageReflection(); + switch (value.kind_case()) { + case ExprValueKind::kBoolValue: + return cel::BoolValue(value.bool_value()); + case ExprValueKind::kInt64Value: + return cel::IntValue(value.int64_value()); + case ExprValueKind::kUint64Value: + return cel::UintValue(value.uint64_value()); + case ExprValueKind::kDoubleValue: + return cel::DoubleValue(value.double_value()); + case ExprValueKind::kStringValue: + return cel::StringValue(value.string_value()); + case ExprValueKind::kBytesValue: + return cel::BytesValue(value.bytes_value()); + case ExprValueKind::kNullValue: + return cel::NullValue(); + case ExprValueKind::kObjectValue: + return FromObject(value.object_value(), descriptor_pool, message_factory, + arena); + case ExprValueKind::kMapValue: + return MapValueFromExpr(value.map_value(), descriptor_pool, + message_factory, arena); + case ExprValueKind::kListValue: + return ListValueFromExpr(value.list_value(), descriptor_pool, + message_factory, arena); + + default: + return absl::UnimplementedError(absl::StrCat( + "FromExprValue not supported ", ToString(value.kind_case()))); + } +} + +absl::StatusOr ToExprValue( + const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + cel::expr::Value result; + switch (value->kind()) { + case ValueKind::kBool: + result.set_bool_value(value.GetBool().NativeValue()); + break; + case ValueKind::kInt: + result.set_int64_value(value.GetInt().NativeValue()); + break; + case ValueKind::kUint: + result.set_uint64_value(value.GetUint().NativeValue()); + break; + case ValueKind::kDouble: + result.set_double_value(value.GetDouble().NativeValue()); + break; + case ValueKind::kString: + result.set_string_value(value.GetString().ToString()); + break; + case ValueKind::kBytes: + result.set_bytes_value(value.GetBytes().ToString()); + break; + case ValueKind::kType: + result.set_type_value(value.GetType().name()); + break; + case ValueKind::kNull: + result.set_null_value(google::protobuf::NullValue::NULL_VALUE); + break; + case ValueKind::kDuration: { + google::protobuf::Duration duration; + CEL_RETURN_IF_ERROR(internal::EncodeDuration( + value.GetDuration().NativeValue(), &duration)); + result.mutable_object_value()->PackFrom(duration); + break; + } + case ValueKind::kTimestamp: { + google::protobuf::Timestamp timestamp; + CEL_RETURN_IF_ERROR( + internal::EncodeTime(value.GetTimestamp().NativeValue(), ×tamp)); + result.mutable_object_value()->PackFrom(timestamp); + break; + } + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN( + *result.mutable_map_value(), + MapValueToExpr(value.GetMap(), descriptor_pool, + message_factory, arena)); + break; + } + case ValueKind::kList: { + CEL_ASSIGN_OR_RETURN( + *result.mutable_list_value(), + ListValueToExpr(value.GetList(), descriptor_pool, + message_factory, arena)); + break; + } + case ValueKind::kStruct: { + CEL_ASSIGN_OR_RETURN(*result.mutable_object_value(), + ToProtobufAny(value.GetStruct(), descriptor_pool, + message_factory, arena)); + break; + } + default: + return absl::UnimplementedError( + absl::StrCat("ToExprValue not supported ", + ValueKindToString(value->kind()))); + } + return result; +} + +} // namespace cel::test diff --git a/common/internal/value_conversion.h b/common/internal/value_conversion.h new file mode 100644 index 000000000..a25b30a39 --- /dev/null +++ b/common/internal/value_conversion.h @@ -0,0 +1,115 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Converters to/from serialized Value to/from runtime values. +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_VALUE_CONVERSION_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_VALUE_CONVERSION_H_ + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +// TODO(uncreated-issue/84): Clean up and expose cel::expr::Value converters +// in the common folder. +namespace cel::test { + +ABSL_MUST_USE_RESULT +inline bool UnsafeConvertWireCompatProto( + const google::protobuf::MessageLite& src, google::protobuf::MessageLite* absl_nonnull dest) { + absl::Cord serialized; + return src.SerializePartialToCord(&serialized) && + dest->ParsePartialFromCord(serialized); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::CheckedExpr& src, + google::api::expr::v1alpha1::CheckedExpr* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const google::api::expr::v1alpha1::CheckedExpr& src, + cel::expr::CheckedExpr* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::ParsedExpr& src, + google::api::expr::v1alpha1::ParsedExpr* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const google::api::expr::v1alpha1::ParsedExpr& src, + cel::expr::ParsedExpr* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::Expr& src, + google::api::expr::v1alpha1::Expr* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto(const google::api::expr::v1alpha1::Expr& src, + cel::expr::Expr* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::Value& src, + google::api::expr::v1alpha1::Value* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const google::api::expr::v1alpha1::Value& src, + cel::expr::Value* absl_nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +absl::StatusOr FromExprValue( + const cel::expr::Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + +absl::StatusOr ToExprValue( + const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + +} // namespace cel::test +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_VALUE_CONVERSION_H_ diff --git a/common/json.h b/common/json.h new file mode 100644 index 000000000..c51f434d5 --- /dev/null +++ b/common/json.h @@ -0,0 +1,35 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_JSON_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_JSON_H_ + +#include + +namespace cel { + +// Maximum `int64_t` value that can be represented as `double` without losing +// data. +inline constexpr int64_t kJsonMaxInt = (int64_t{1} << 53) - 1; +// Minimum `int64_t` value that can be represented as `double` without losing +// data. +inline constexpr int64_t kJsonMinInt = -kJsonMaxInt; + +// Maximum `uint64_t` value that can be represented as `double` without losing +// data. +inline constexpr uint64_t kJsonMaxUint = (uint64_t{1} << 53) - 1; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_JSON_H_ diff --git a/common/kind.cc b/common/kind.cc new file mode 100644 index 000000000..21fb9e9f3 --- /dev/null +++ b/common/kind.cc @@ -0,0 +1,80 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/kind.h" + +#include "absl/strings/string_view.h" + +namespace cel { + +absl::string_view KindToString(Kind kind) { + switch (kind) { + case Kind::kNullType: + return "null_type"; + case Kind::kDyn: + return "dyn"; + case Kind::kAny: + return "any"; + case Kind::kType: + return "type"; + case Kind::kTypeParam: + return "type_param"; + case Kind::kFunction: + return "function"; + case Kind::kBool: + return "bool"; + case Kind::kInt: + return "int"; + case Kind::kUint: + return "uint"; + case Kind::kDouble: + return "double"; + case Kind::kString: + return "string"; + case Kind::kBytes: + return "bytes"; + case Kind::kDuration: + return "duration"; + case Kind::kTimestamp: + return "timestamp"; + case Kind::kList: + return "list"; + case Kind::kMap: + return "map"; + case Kind::kStruct: + return "struct"; + case Kind::kUnknown: + return "*unknown*"; + case Kind::kOpaque: + return "*opaque*"; + case Kind::kBoolWrapper: + return "google.protobuf.BoolValue"; + case Kind::kIntWrapper: + return "google.protobuf.Int64Value"; + case Kind::kUintWrapper: + return "google.protobuf.UInt64Value"; + case Kind::kDoubleWrapper: + return "google.protobuf.DoubleValue"; + case Kind::kStringWrapper: + return "google.protobuf.StringValue"; + case Kind::kBytesWrapper: + return "google.protobuf.BytesValue"; + case Kind::kEnum: + return "enum"; + default: + return "*error*"; + } +} + +} // namespace cel diff --git a/common/kind.h b/common/kind.h new file mode 100644 index 000000000..c46fbdbaf --- /dev/null +++ b/common/kind.h @@ -0,0 +1,76 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_KIND_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_KIND_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" + +namespace cel { + +enum class Kind : uint8_t { + // Must match legacy CelValue::Type. + kNull = 0, + kBool, + kInt, + kUint, + kDouble, + kString, + kBytes, + kStruct, + kDuration, + kTimestamp, + kList, + kMap, + kUnknown, + kType, + kError, + kAny, + + // New kinds not present in legacy CelValue. + kDyn, + kOpaque, + + kBoolWrapper, + kIntWrapper, + kUintWrapper, + kDoubleWrapper, + kStringWrapper, + kBytesWrapper, + + kTypeParam, + kFunction, + kEnum, + + // Legacy aliases, deprecated do not use. + kNullType = kNull, + kInt64 = kInt, + kUint64 = kUint, + kMessage = kStruct, + kUnknownSet = kUnknown, + kCelType = kType, + + // INTERNAL: Do not exceed 63. Implementation details rely on the fact that + // we can store `Kind` using 6 bits. + kNotForUseWithExhaustiveSwitchStatements = 63, +}; + +ABSL_ATTRIBUTE_PURE_FUNCTION absl::string_view KindToString(Kind kind); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_KIND_H_ diff --git a/common/kind_test.cc b/common/kind_test.cc new file mode 100644 index 000000000..3bd6db40e --- /dev/null +++ b/common/kind_test.cc @@ -0,0 +1,103 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/kind.h" + +#include +#include + +#include "common/type_kind.h" +#include "common/value_kind.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +static_assert(std::is_same_v, + std::underlying_type_t>, + "TypeKind and ValueKind must have the same underlying type"); + +TEST(Kind, ToString) { + EXPECT_EQ(KindToString(Kind::kError), "*error*"); + EXPECT_EQ(KindToString(Kind::kNullType), "null_type"); + EXPECT_EQ(KindToString(Kind::kDyn), "dyn"); + EXPECT_EQ(KindToString(Kind::kAny), "any"); + EXPECT_EQ(KindToString(Kind::kType), "type"); + EXPECT_EQ(KindToString(Kind::kBool), "bool"); + EXPECT_EQ(KindToString(Kind::kInt), "int"); + EXPECT_EQ(KindToString(Kind::kUint), "uint"); + EXPECT_EQ(KindToString(Kind::kDouble), "double"); + EXPECT_EQ(KindToString(Kind::kString), "string"); + EXPECT_EQ(KindToString(Kind::kBytes), "bytes"); + EXPECT_EQ(KindToString(Kind::kDuration), "duration"); + EXPECT_EQ(KindToString(Kind::kTimestamp), "timestamp"); + EXPECT_EQ(KindToString(Kind::kList), "list"); + EXPECT_EQ(KindToString(Kind::kMap), "map"); + EXPECT_EQ(KindToString(Kind::kStruct), "struct"); + EXPECT_EQ(KindToString(Kind::kUnknown), "*unknown*"); + EXPECT_EQ(KindToString(Kind::kOpaque), "*opaque*"); + EXPECT_EQ(KindToString(Kind::kBoolWrapper), "google.protobuf.BoolValue"); + EXPECT_EQ(KindToString(Kind::kIntWrapper), "google.protobuf.Int64Value"); + EXPECT_EQ(KindToString(Kind::kUintWrapper), "google.protobuf.UInt64Value"); + EXPECT_EQ(KindToString(Kind::kDoubleWrapper), "google.protobuf.DoubleValue"); + EXPECT_EQ(KindToString(Kind::kStringWrapper), "google.protobuf.StringValue"); + EXPECT_EQ(KindToString(Kind::kBytesWrapper), "google.protobuf.BytesValue"); + EXPECT_EQ(KindToString(static_cast(std::numeric_limits::max())), + "*error*"); +} + +TEST(Kind, TypeKindRoundtrip) { + EXPECT_EQ(TypeKindToKind(KindToTypeKind(Kind::kBool)), Kind::kBool); +} + +TEST(Kind, ValueKindRoundtrip) { + EXPECT_EQ(ValueKindToKind(KindToValueKind(Kind::kBool)), Kind::kBool); +} + +TEST(Kind, IsTypeKind) { + EXPECT_TRUE(KindIsTypeKind(Kind::kBool)); + EXPECT_TRUE(KindIsTypeKind(Kind::kAny)); + EXPECT_TRUE(KindIsTypeKind(Kind::kDyn)); +} + +TEST(Kind, IsValueKind) { + EXPECT_TRUE(KindIsValueKind(Kind::kBool)); + EXPECT_FALSE(KindIsValueKind(Kind::kAny)); + EXPECT_FALSE(KindIsValueKind(Kind::kDyn)); +} + +TEST(Kind, Equality) { + EXPECT_EQ(Kind::kBool, TypeKind::kBool); + EXPECT_EQ(TypeKind::kBool, Kind::kBool); + + EXPECT_EQ(Kind::kBool, ValueKind::kBool); + EXPECT_EQ(ValueKind::kBool, Kind::kBool); + + EXPECT_NE(Kind::kBool, TypeKind::kInt); + EXPECT_NE(TypeKind::kInt, Kind::kBool); + + EXPECT_NE(Kind::kBool, ValueKind::kInt); + EXPECT_NE(ValueKind::kInt, Kind::kBool); +} + +TEST(TypeKind, ToString) { + EXPECT_EQ(TypeKindToString(TypeKind::kBool), KindToString(Kind::kBool)); +} + +TEST(ValueKind, ToString) { + EXPECT_EQ(ValueKindToString(ValueKind::kBool), KindToString(Kind::kBool)); +} + +} // namespace +} // namespace cel diff --git a/common/legacy_value.cc b/common/legacy_value.cc new file mode 100644 index 000000000..7fbf16732 --- /dev/null +++ b/common/legacy_value.cc @@ -0,0 +1,1293 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/legacy_value.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "base/attribute.h" +#include "common/casting.h" +#include "common/kind.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/unknown.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/values/list_value_builder.h" +#include "common/values/map_value_builder.h" +#include "common/values/values.h" +#include "eval/internal/cel_value_equal.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/field_backed_list_impl.h" +#include "eval/public/containers/field_backed_map_impl.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/cel_proto_wrap_util.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "eval/public/structs/proto_message_type_adapter.h" +#include "internal/json.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +// TODO(uncreated-issue/76): improve coverage for JSON/Any handling + +namespace cel { + +namespace { + +using google::api::expr::runtime::CelList; +using google::api::expr::runtime::CelMap; +using google::api::expr::runtime::CelValue; +using google::api::expr::runtime::FieldBackedListImpl; +using google::api::expr::runtime::FieldBackedMapImpl; +using google::api::expr::runtime::GetGenericProtoTypeInfoInstance; +using google::api::expr::runtime::LegacyTypeInfoApis; +using google::api::expr::runtime::MessageWrapper; +using ::google::api::expr::runtime::internal::MaybeWrapValueToMessage; + +absl::Status InvalidMapKeyTypeError(ValueKind kind) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid map key type: '", ValueKindToString(kind), "'")); +} + +MessageWrapper AsMessageWrapper( + const google::protobuf::Message* absl_nullability_unknown message_ptr, + const LegacyTypeInfoApis* absl_nullability_unknown type_info) { + return MessageWrapper(message_ptr, type_info); +} + +class CelListIterator final : public ValueIterator { + public: + explicit CelListIterator(const CelList* cel_list) + : cel_list_(cel_list), size_(cel_list_->size()) {} + + bool HasNext() override { return index_ < size_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (!HasNext()) { + return absl::FailedPreconditionError( + "ValueIterator::Next() called when ValueIterator::HasNext() returns " + "false"); + } + auto cel_value = cel_list_->Get(arena, index_); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *result)); + ++index_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (index_ >= size_) { + return false; + } + auto cel_value = cel_list_->Get(arena, index_); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *key_or_value)); + ++index_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (index_ >= size_) { + return false; + } + if (value != nullptr) { + auto cel_value = cel_list_->Get(arena, index_); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *value)); + } + *key = IntValue(index_); + ++index_; + return true; + } + + private: + const CelList* const cel_list_; + const int size_; + int index_ = 0; +}; + +class CelMapIterator final : public ValueIterator { + public: + explicit CelMapIterator(const CelMap* cel_map) + : cel_map_(cel_map), size_(cel_map->size()) {} + + bool HasNext() override { return index_ < size_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (!HasNext()) { + return absl::FailedPreconditionError( + "ValueIterator::Next() called when ValueIterator::HasNext() returns " + "false"); + } + CEL_RETURN_IF_ERROR(ProjectKeys(arena)); + auto cel_value = (*cel_list_)->Get(arena, index_); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *result)); + ++index_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (index_ >= size_) { + return false; + } + CEL_RETURN_IF_ERROR(ProjectKeys(arena)); + auto cel_value = (*cel_list_)->Get(arena, index_); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *key_or_value)); + ++index_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (index_ >= size_) { + return false; + } + CEL_RETURN_IF_ERROR(ProjectKeys(arena)); + auto cel_key = (*cel_list_)->Get(arena, index_); + if (value != nullptr) { + auto cel_value = cel_map_->Get(arena, cel_key); + if (!cel_value) { + return absl::DataLossError( + "map iterator returned key that was not present in the map"); + } + CEL_RETURN_IF_ERROR(ModernValue(arena, *cel_value, *value)); + } + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_key, *key)); + ++index_; + return true; + } + + private: + absl::Status ProjectKeys(google::protobuf::Arena* arena) { + if (cel_list_.ok() && *cel_list_ == nullptr) { + cel_list_ = cel_map_->ListKeys(arena); + } + return cel_list_.status(); + } + + const CelMap* const cel_map_; + const int size_ = 0; + absl::StatusOr cel_list_ = nullptr; + int index_ = 0; +}; + +} // namespace + +namespace common_internal { + +namespace { + +CelValue LegacyTrivialStructValue(google::protobuf::Arena* absl_nonnull arena, + const Value& value) { + if (auto legacy_struct_value = common_internal::AsLegacyStructValue(value); + legacy_struct_value) { + return CelValue::CreateMessageWrapper( + AsMessageWrapper(legacy_struct_value->message_ptr(), + legacy_struct_value->legacy_type_info())); + } + if (auto parsed_message_value = value.AsParsedMessage(); + parsed_message_value) { + auto maybe_cloned = parsed_message_value->Clone(arena); + return CelValue::CreateMessageWrapper(MessageWrapper( + cel::to_address(maybe_cloned), &GetGenericProtoTypeInfoInstance())); + } + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, absl::InvalidArgumentError(absl::StrCat( + "unsupported conversion from cel::StructValue to CelValue: ", + value.GetRuntimeType().DebugString())))); +} + +CelValue LegacyTrivialListValue(google::protobuf::Arena* absl_nonnull arena, + const Value& value) { + if (auto legacy_list_value = common_internal::AsLegacyListValue(value); + legacy_list_value) { + return CelValue::CreateList(legacy_list_value->cel_list()); + } + if (auto parsed_repeated_field_value = value.AsParsedRepeatedField(); + parsed_repeated_field_value) { + auto maybe_cloned = parsed_repeated_field_value->Clone(arena); + return CelValue::CreateList(google::protobuf::Arena::Create( + arena, &maybe_cloned.message(), maybe_cloned.field(), arena)); + } + if (auto parsed_json_list_value = value.AsParsedJsonList(); + parsed_json_list_value) { + auto maybe_cloned = parsed_json_list_value->Clone(arena); + return CelValue::CreateList(google::protobuf::Arena::Create( + arena, cel::to_address(maybe_cloned), + well_known_types::GetListValueReflectionOrDie( + maybe_cloned->GetDescriptor()) + .GetValuesDescriptor(), + arena)); + } + if (auto custom_list_value = value.AsCustomList(); custom_list_value) { + auto status_or_compat_list = common_internal::MakeCompatListValue( + *custom_list_value, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), arena); + if (!status_or_compat_list.ok()) { + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, std::move(status_or_compat_list).status())); + } + return CelValue::CreateList(*status_or_compat_list); + } + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, absl::InvalidArgumentError(absl::StrCat( + "unsupported conversion from cel::ListValue to CelValue: ", + value.GetRuntimeType().DebugString())))); +} + +CelValue LegacyTrivialMapValue(google::protobuf::Arena* absl_nonnull arena, + const Value& value) { + if (auto legacy_map_value = common_internal::AsLegacyMapValue(value); + legacy_map_value) { + return CelValue::CreateMap(legacy_map_value->cel_map()); + } + if (auto parsed_map_field_value = value.AsParsedMapField(); + parsed_map_field_value) { + auto maybe_cloned = parsed_map_field_value->Clone(arena); + return CelValue::CreateMap(google::protobuf::Arena::Create( + arena, &maybe_cloned.message(), maybe_cloned.field(), arena)); + } + if (auto parsed_json_map_value = value.AsParsedJsonMap(); + parsed_json_map_value) { + auto maybe_cloned = parsed_json_map_value->Clone(arena); + return CelValue::CreateMap(google::protobuf::Arena::Create( + arena, cel::to_address(maybe_cloned), + well_known_types::GetStructReflectionOrDie( + maybe_cloned->GetDescriptor()) + .GetFieldsDescriptor(), + arena)); + } + if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { + auto status_or_compat_map = common_internal::MakeCompatMapValue( + *custom_map_value, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), arena); + if (!status_or_compat_map.ok()) { + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, std::move(status_or_compat_map).status())); + } + return CelValue::CreateMap(*status_or_compat_map); + } + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, absl::InvalidArgumentError(absl::StrCat( + "unsupported conversion from cel::MapValue to CelValue: ", + value.GetRuntimeType().DebugString())))); +} + +} // namespace + +google::api::expr::runtime::CelValue UnsafeLegacyValue( + const Value& value, bool stable, google::protobuf::Arena* absl_nonnull arena) { + switch (value.kind()) { + case ValueKind::kNull: + return CelValue::CreateNull(); + case ValueKind::kBool: + return CelValue::CreateBool(value.GetBool()); + case ValueKind::kInt: + return CelValue::CreateInt64(value.GetInt()); + case ValueKind::kUint: + return CelValue::CreateUint64(value.GetUint()); + case ValueKind::kDouble: + return CelValue::CreateDouble(value.GetDouble()); + case ValueKind::kString: + return CelValue::CreateStringView( + LegacyStringValue(value.GetString(), stable, arena)); + case ValueKind::kBytes: + return CelValue::CreateBytesView( + LegacyBytesValue(value.GetBytes(), stable, arena)); + case ValueKind::kStruct: + return LegacyTrivialStructValue(arena, value); + case ValueKind::kDuration: + return CelValue::CreateDuration(value.GetDuration().ToDuration()); + case ValueKind::kTimestamp: + return CelValue::CreateTimestamp(value.GetTimestamp().ToTime()); + case ValueKind::kList: + return LegacyTrivialListValue(arena, value); + case ValueKind::kMap: + return LegacyTrivialMapValue(arena, value); + case ValueKind::kType: + return CelValue::CreateCelTypeView(value.GetType().name()); + default: + // Everything else is unsupported. + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, absl::InvalidArgumentError(absl::StrCat( + "unsupported conversion from cel::Value to CelValue: ", + value->GetRuntimeType().DebugString())))); + } +} + +} // namespace common_internal + +namespace common_internal { + +std::string LegacyListValue::DebugString() const { + return CelValue::CreateList(impl_).DebugString(); +} + +// See `ValueInterface::SerializeTo`. +absl::Status LegacyListValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + const google::protobuf::Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName("google.protobuf.ListValue"); + if (descriptor == nullptr) { + return absl::InternalError( + "unable to locate descriptor for message type: " + "google.protobuf.ListValue"); + } + + google::protobuf::Arena arena; + const google::protobuf::Message* wrapped = MaybeWrapValueToMessage( + descriptor, message_factory, CelValue::CreateList(impl_), &arena); + if (wrapped == nullptr) { + return absl::UnknownError("failed to convert legacy map to JSON"); + } + if (!wrapped->SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); + } + return absl::OkStatus(); +} + +absl::Status LegacyListValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + google::protobuf::Arena arena; + const google::protobuf::Message* wrapped = + MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, + CelValue::CreateList(impl_), &arena); + if (wrapped == nullptr) { + return absl::UnknownError("failed to convert legacy list to JSON"); + } + + if (wrapped->GetDescriptor() == json->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + json->CopyFrom(*wrapped); + } else { + // Equivalent descriptors but not identical. Must serialize and + // deserialize. + absl::Cord serialized; + if (!wrapped->SerializePartialToString(&serialized)) { + return absl::UnknownError(absl::StrCat("failed to serialize message: ", + wrapped->GetTypeName())); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", json->GetTypeName())); + } + } + return absl::OkStatus(); + } +} + +absl::Status LegacyListValue::ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + + google::protobuf::Arena arena; + const google::protobuf::Message* wrapped = + MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, + CelValue::CreateList(impl_), &arena); + if (wrapped == nullptr) { + return absl::UnknownError("failed to convert legacy list to JSON"); + } + + if (wrapped->GetDescriptor() == json->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + json->CopyFrom(*wrapped); + } else { + // Equivalent descriptors but not identical. Must serialize and + // deserialize. + absl::Cord serialized; + if (!wrapped->SerializePartialToString(&serialized)) { + return absl::UnknownError(absl::StrCat("failed to serialize message: ", + wrapped->GetTypeName())); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", json->GetTypeName())); + } + } + return absl::OkStatus(); + } +} + +bool LegacyListValue::IsEmpty() const { return impl_->empty(); } + +size_t LegacyListValue::Size() const { + return static_cast(impl_->size()); +} + +// See LegacyListValueInterface::Get for documentation. +absl::Status LegacyListValue::Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (ABSL_PREDICT_FALSE(index < 0 || index >= impl_->size())) { + *result = ErrorValue(absl::InvalidArgumentError("index out of bounds")); + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( + ModernValue(arena, impl_->Get(arena, static_cast(index)), *result)); + return absl::OkStatus(); +} + +absl::Status LegacyListValue::ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + const auto size = impl_->size(); + Value element; + for (int index = 0; index < size; ++index) { + CEL_RETURN_IF_ERROR(ModernValue(arena, impl_->Get(arena, index), element)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(index, Value(element))); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +absl::StatusOr LegacyListValue::NewIterator() + const { + return std::make_unique(impl_); +} + +absl::Status LegacyListValue::Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + CEL_ASSIGN_OR_RETURN(auto legacy_other, LegacyValue(arena, other)); + const auto* cel_list = impl_; + for (int i = 0; i < cel_list->size(); ++i) { + auto element = cel_list->Get(arena, i); + absl::optional equal = + interop_internal::CelValueEqualImpl(element, legacy_other); + // Heterogeneous equality behavior is to just return false if equality + // undefined. + if (equal.has_value() && *equal) { + *result = TrueValue(); + return absl::OkStatus(); + } + } + *result = FalseValue(); + return absl::OkStatus(); +} + +std::string LegacyMapValue::DebugString() const { + return CelValue::CreateMap(impl_).DebugString(); +} + +absl::Status LegacyMapValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + const google::protobuf::Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName("google.protobuf.Struct"); + if (descriptor == nullptr) { + return absl::InternalError( + "unable to locate descriptor for message type: google.protobuf.Struct"); + } + + google::protobuf::Arena arena; + const google::protobuf::Message* wrapped = MaybeWrapValueToMessage( + descriptor, message_factory, CelValue::CreateMap(impl_), &arena); + if (wrapped == nullptr) { + return absl::UnknownError("failed to convert legacy map to JSON"); + } + if (!wrapped->SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); + } + return absl::OkStatus(); +} + +absl::Status LegacyMapValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + google::protobuf::Arena arena; + const google::protobuf::Message* wrapped = + MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, + CelValue::CreateMap(impl_), &arena); + if (wrapped == nullptr) { + return absl::UnknownError("failed to convert legacy map to JSON"); + } + + if (wrapped->GetDescriptor() == json->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + json->CopyFrom(*wrapped); + } else { + // Equivalent descriptors but not identical. Must serialize and deserialize. + absl::Cord serialized; + if (!wrapped->SerializePartialToString(&serialized)) { + return absl::UnknownError(absl::StrCat("failed to serialize message: ", + wrapped->GetTypeName())); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", json->GetTypeName())); + } + } + return absl::OkStatus(); +} + +absl::Status LegacyMapValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + google::protobuf::Arena arena; + const google::protobuf::Message* wrapped = + MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, + CelValue::CreateMap(impl_), &arena); + if (wrapped == nullptr) { + return absl::UnknownError("failed to convert legacy map to JSON"); + } + + if (wrapped->GetDescriptor() == json->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + json->CopyFrom(*wrapped); + } else { + // Equivalent descriptors but not identical. Must serialize and deserialize. + absl::Cord serialized; + if (!wrapped->SerializePartialToString(&serialized)) { + return absl::UnknownError(absl::StrCat("failed to serialize message: ", + wrapped->GetTypeName())); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", json->GetTypeName())); + } + } + return absl::OkStatus(); +} + +bool LegacyMapValue::IsEmpty() const { return impl_->empty(); } + +size_t LegacyMapValue::Size() const { + return static_cast(impl_->size()); +} + +absl::Status LegacyMapValue::Get( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + switch (key.kind()) { + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + *result = Value{key}; + return absl::OkStatus(); + case ValueKind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kString: + break; + default: + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); + auto cel_value = impl_->Get(arena, cel_key); + if (!cel_value.has_value()) { + *result = NoSuchKeyError(key.DebugString()); + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(ModernValue(arena, *cel_value, *result)); + return absl::OkStatus(); +} + +absl::StatusOr LegacyMapValue::Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + switch (key.kind()) { + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + *result = Value{key}; + return false; + case ValueKind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kString: + break; + default: + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + } + CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); + auto cel_value = impl_->Get(arena, cel_key); + if (!cel_value.has_value()) { + *result = NullValue{}; + return false; + } + CEL_RETURN_IF_ERROR(ModernValue(arena, *cel_value, *result)); + return true; +} + +absl::Status LegacyMapValue::Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + switch (key.kind()) { + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + *result = Value{key}; + return absl::OkStatus(); + case ValueKind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kString: + break; + default: + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); + absl::StatusOr has = impl_->Has(cel_key); + if (!has.ok()) { + *result = ErrorValue(std::move(has).status()); + return absl::OkStatus(); + } + + *result = BoolValue(*has); + return absl::OkStatus(); +} + +absl::Status LegacyMapValue::ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { + CEL_ASSIGN_OR_RETURN(auto keys, impl_->ListKeys(arena)); + *result = ListValue{common_internal::LegacyListValue(keys)}; + return absl::OkStatus(); +} + +absl::Status LegacyMapValue::ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + CEL_ASSIGN_OR_RETURN(auto keys, impl_->ListKeys(arena)); + const auto size = keys->size(); + Value key; + Value value; + for (int index = 0; index < size; ++index) { + auto cel_key = keys->Get(arena, index); + auto cel_value = *impl_->Get(arena, cel_key); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_key, key)); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, value)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +absl::StatusOr LegacyMapValue::NewIterator() + const { + return std::make_unique(impl_); +} + +absl::string_view LegacyStructValue::GetTypeName() const { + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + return message_wrapper.legacy_type_info()->GetTypename(message_wrapper); +} + +std::string LegacyStructValue::DebugString() const { + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + return message_wrapper.legacy_type_info()->DebugString(message_wrapper); +} + +absl::Status LegacyStructValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + if (ABSL_PREDICT_TRUE( + message_wrapper.message_ptr()->SerializePartialToZeroCopyStream( + output))) { + return absl::OkStatus(); + } + return absl::UnknownError("failed to serialize protocol buffer message"); +} + +absl::Status LegacyStructValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + + return internal::MessageToJson( + *google::protobuf::DownCastMessage(message_wrapper.message_ptr()), + descriptor_pool, message_factory, json); +} + +absl::Status LegacyStructValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + + return internal::MessageToJson( + *google::protobuf::DownCastMessage(message_wrapper.message_ptr()), + descriptor_pool, message_factory, json); +} + +absl::Status LegacyStructValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (auto legacy_struct_value = common_internal::AsLegacyStructValue(other); + legacy_struct_value.has_value()) { + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + const auto* access_apis = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { + return absl::UnimplementedError( + absl::StrCat("legacy access APIs missing for ", GetTypeName())); + } + auto other_message_wrapper = + AsMessageWrapper(legacy_struct_value->message_ptr(), + legacy_struct_value->legacy_type_info()); + *result = BoolValue{ + access_apis->IsEqualTo(message_wrapper, other_message_wrapper)}; + return absl::OkStatus(); + } + if (auto struct_value = other.AsStruct(); struct_value.has_value()) { + return common_internal::StructValueEqual( + common_internal::LegacyStructValue(message_ptr_, legacy_type_info_), + *struct_value, descriptor_pool, message_factory, arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +bool LegacyStructValue::IsZeroValue() const { + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + const auto* access_apis = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { + return false; + } + return access_apis->ListFields(message_wrapper).empty(); +} + +absl::Status LegacyStructValue::GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + const auto* access_apis = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { + *result = NoSuchFieldError(name); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + auto cel_value, + access_apis->GetField(name, message_wrapper, unboxing_options, + MemoryManagerRef::Pooling(arena))); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *result)); + return absl::OkStatus(); +} + +absl::Status LegacyStructValue::GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return absl::UnimplementedError( + "access to fields by numbers is not available for legacy structs"); +} + +absl::StatusOr LegacyStructValue::HasFieldByName( + absl::string_view name) const { + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + const auto* access_apis = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { + return NoSuchFieldError(name).NativeValue(); + } + return access_apis->HasField(name, message_wrapper); +} + +absl::StatusOr LegacyStructValue::HasFieldByNumber(int64_t number) const { + return absl::UnimplementedError( + "access to fields by numbers is not available for legacy structs"); +} + +absl::Status LegacyStructValue::ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + const auto* access_apis = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { + return absl::UnimplementedError( + absl::StrCat("legacy access APIs missing for ", GetTypeName())); + } + auto field_names = access_apis->ListFields(message_wrapper); + Value value; + for (const auto& field_name : field_names) { + CEL_ASSIGN_OR_RETURN( + auto cel_value, + access_apis->GetField(field_name, message_wrapper, + ProtoWrapperTypeOptions::kUnsetNull, + MemoryManagerRef::Pooling(arena))); + CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, value)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(field_name, value)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +absl::Status LegacyStructValue::Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const { + if (ABSL_PREDICT_FALSE(qualifiers.empty())) { + return absl::InvalidArgumentError("invalid select qualifier path."); + } + auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + const auto* access_apis = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { + absl::string_view field_name = absl::visit( + absl::Overload( + [](const FieldSpecifier& field) -> absl::string_view { + return field.name; + }, + [](const AttributeQualifier& field) -> absl::string_view { + return field.GetStringKey().value_or(""); + }), + qualifiers.front()); + *result = NoSuchFieldError(field_name); + *count = -1; + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + auto legacy_result, + access_apis->Qualify(qualifiers, message_wrapper, presence_test, + MemoryManager::Pooling(arena))); + CEL_RETURN_IF_ERROR(ModernValue(arena, legacy_result.value, *result)); + *count = legacy_result.qualifier_count; + return absl::OkStatus(); +} + +} // namespace common_internal + +absl::Status ModernValue(google::protobuf::Arena* arena, + google::api::expr::runtime::CelValue legacy_value, + Value& result) { + switch (legacy_value.type()) { + case CelValue::Type::kNullType: + result = NullValue{}; + return absl::OkStatus(); + case CelValue::Type::kBool: + result = BoolValue{legacy_value.BoolOrDie()}; + return absl::OkStatus(); + case CelValue::Type::kInt64: + result = IntValue{legacy_value.Int64OrDie()}; + return absl::OkStatus(); + case CelValue::Type::kUint64: + result = UintValue{legacy_value.Uint64OrDie()}; + return absl::OkStatus(); + case CelValue::Type::kDouble: + result = DoubleValue{legacy_value.DoubleOrDie()}; + return absl::OkStatus(); + case CelValue::Type::kString: + result = StringValue(Borrower::Arena(arena), + legacy_value.StringOrDie().value()); + return absl::OkStatus(); + case CelValue::Type::kBytes: + result = + BytesValue(Borrower::Arena(arena), legacy_value.BytesOrDie().value()); + return absl::OkStatus(); + case CelValue::Type::kMessage: { + auto message_wrapper = legacy_value.MessageWrapperOrDie(); + result = common_internal::LegacyStructValue( + google::protobuf::DownCastMessage( + message_wrapper.message_ptr()), + message_wrapper.legacy_type_info()); + return absl::OkStatus(); + } + case CelValue::Type::kDuration: + result = UnsafeDurationValue(legacy_value.DurationOrDie()); + return absl::OkStatus(); + case CelValue::Type::kTimestamp: + result = UnsafeTimestampValue(legacy_value.TimestampOrDie()); + return absl::OkStatus(); + case CelValue::Type::kList: + result = + ListValue(common_internal::LegacyListValue(legacy_value.ListOrDie())); + return absl::OkStatus(); + case CelValue::Type::kMap: + result = + MapValue(common_internal::LegacyMapValue(legacy_value.MapOrDie())); + return absl::OkStatus(); + case CelValue::Type::kUnknownSet: + result = UnknownValue{*legacy_value.UnknownSetOrDie()}; + return absl::OkStatus(); + case CelValue::Type::kCelType: { + auto type_name = legacy_value.CelTypeOrDie().value(); + if (type_name.empty()) { + return absl::InvalidArgumentError("empty type name in CelValue"); + } + result = TypeValue(common_internal::LegacyRuntimeType(type_name)); + return absl::OkStatus(); + } + case CelValue::Type::kError: + result = ErrorValue{*legacy_value.ErrorOrDie()}; + return absl::OkStatus(); + case CelValue::Type::kAny: + return absl::InternalError(absl::StrCat( + "illegal attempt to convert special CelValue type ", + CelValue::TypeName(legacy_value.type()), " to cel::Value")); + default: + break; + } + return absl::InvalidArgumentError(absl::StrCat( + "cel::Value does not support ", KindToString(legacy_value.type()))); +} + +absl::StatusOr LegacyValue( + google::protobuf::Arena* arena, const Value& modern_value) { + switch (modern_value.kind()) { + case ValueKind::kNull: + return CelValue::CreateNull(); + case ValueKind::kBool: + return CelValue::CreateBool(Cast(modern_value).NativeValue()); + case ValueKind::kInt: + return CelValue::CreateInt64(Cast(modern_value).NativeValue()); + case ValueKind::kUint: + return CelValue::CreateUint64( + Cast(modern_value).NativeValue()); + case ValueKind::kDouble: + return CelValue::CreateDouble( + Cast(modern_value).NativeValue()); + case ValueKind::kString: + return CelValue::CreateStringView(common_internal::LegacyStringValue( + modern_value.GetString(), /*stable=*/false, arena)); + case ValueKind::kBytes: + return CelValue::CreateBytesView(common_internal::LegacyBytesValue( + modern_value.GetBytes(), /*stable=*/false, arena)); + case ValueKind::kStruct: + return common_internal::LegacyTrivialStructValue(arena, modern_value); + case ValueKind::kDuration: + return CelValue::CreateUncheckedDuration( + modern_value.GetDuration().NativeValue()); + case ValueKind::kTimestamp: + return CelValue::CreateTimestamp( + modern_value.GetTimestamp().NativeValue()); + case ValueKind::kList: + return common_internal::LegacyTrivialListValue(arena, modern_value); + case ValueKind::kMap: + return common_internal::LegacyTrivialMapValue(arena, modern_value); + case ValueKind::kUnknown: + return CelValue::CreateUnknownSet(google::protobuf::Arena::Create( + arena, Cast(modern_value).NativeValue())); + case ValueKind::kType: + return CelValue::CreateCelType( + CelValue::CelTypeHolder(google::protobuf::Arena::Create( + arena, Cast(modern_value).NativeValue().name()))); + case ValueKind::kError: + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, Cast(modern_value).NativeValue())); + default: + return absl::InvalidArgumentError( + absl::StrCat("google::api::expr::runtime::CelValue does not support ", + ValueKindToString(modern_value.kind()))); + } +} + +namespace interop_internal { + +absl::StatusOr FromLegacyValue(google::protobuf::Arena* arena, + const CelValue& legacy_value, bool) { + switch (legacy_value.type()) { + case CelValue::Type::kNullType: + return NullValue{}; + case CelValue::Type::kBool: + return BoolValue(legacy_value.BoolOrDie()); + case CelValue::Type::kInt64: + return IntValue(legacy_value.Int64OrDie()); + case CelValue::Type::kUint64: + return UintValue(legacy_value.Uint64OrDie()); + case CelValue::Type::kDouble: + return DoubleValue(legacy_value.DoubleOrDie()); + case CelValue::Type::kString: + return StringValue(Borrower::Arena(arena), + legacy_value.StringOrDie().value()); + case CelValue::Type::kBytes: + return BytesValue(Borrower::Arena(arena), + legacy_value.BytesOrDie().value()); + case CelValue::Type::kMessage: { + auto message_wrapper = legacy_value.MessageWrapperOrDie(); + return common_internal::LegacyStructValue( + google::protobuf::DownCastMessage( + message_wrapper.message_ptr()), + message_wrapper.legacy_type_info()); + } + case CelValue::Type::kDuration: + return UnsafeDurationValue(legacy_value.DurationOrDie()); + case CelValue::Type::kTimestamp: + return UnsafeTimestampValue(legacy_value.TimestampOrDie()); + case CelValue::Type::kList: + return ListValue( + common_internal::LegacyListValue(legacy_value.ListOrDie())); + case CelValue::Type::kMap: + return MapValue(common_internal::LegacyMapValue(legacy_value.MapOrDie())); + case CelValue::Type::kUnknownSet: + return UnknownValue{*legacy_value.UnknownSetOrDie()}; + case CelValue::Type::kCelType: + return CreateTypeValueFromView(arena, + legacy_value.CelTypeOrDie().value()); + case CelValue::Type::kError: + return ErrorValue(*legacy_value.ErrorOrDie()); + case CelValue::Type::kAny: + return absl::InternalError(absl::StrCat( + "illegal attempt to convert special CelValue type ", + CelValue::TypeName(legacy_value.type()), " to cel::Value")); + default: + break; + } + return absl::UnimplementedError(absl::StrCat( + "conversion from CelValue to cel::Value for type ", + CelValue::TypeName(legacy_value.type()), " is not yet implemented")); +} + +absl::StatusOr ToLegacyValue( + google::protobuf::Arena* arena, const Value& value, bool) { + switch (value.kind()) { + case ValueKind::kNull: + return CelValue::CreateNull(); + case ValueKind::kBool: + return CelValue::CreateBool(Cast(value).NativeValue()); + case ValueKind::kInt: + return CelValue::CreateInt64(Cast(value).NativeValue()); + case ValueKind::kUint: + return CelValue::CreateUint64(Cast(value).NativeValue()); + case ValueKind::kDouble: + return CelValue::CreateDouble(Cast(value).NativeValue()); + case ValueKind::kString: + return CelValue::CreateStringView(common_internal::LegacyStringValue( + value.GetString(), /*stable=*/false, arena)); + case ValueKind::kBytes: + return CelValue::CreateBytesView(common_internal::LegacyBytesValue( + value.GetBytes(), /*stable=*/false, arena)); + case ValueKind::kStruct: + return common_internal::LegacyTrivialStructValue(arena, value); + case ValueKind::kDuration: + return CelValue::CreateUncheckedDuration( + Cast(value).NativeValue()); + case ValueKind::kTimestamp: + return CelValue::CreateTimestamp( + Cast(value).NativeValue()); + case ValueKind::kList: + return common_internal::LegacyTrivialListValue(arena, value); + case ValueKind::kMap: + return common_internal::LegacyTrivialMapValue(arena, value); + case ValueKind::kUnknown: + return CelValue::CreateUnknownSet(google::protobuf::Arena::Create( + arena, Cast(value).NativeValue())); + case ValueKind::kType: + return CelValue::CreateCelType( + CelValue::CelTypeHolder(google::protobuf::Arena::Create( + arena, Cast(value).NativeValue().name()))); + case ValueKind::kError: + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, Cast(value).NativeValue())); + default: + return absl::InvalidArgumentError( + absl::StrCat("google::api::expr::runtime::CelValue does not support ", + ValueKindToString(value.kind()))); + } +} + +Value LegacyValueToModernValueOrDie( + google::protobuf::Arena* arena, const google::api::expr::runtime::CelValue& value, + bool unchecked) { + auto status_or_value = FromLegacyValue(arena, value, unchecked); + ABSL_CHECK_OK(status_or_value.status()); // Crash OK + return std::move(*status_or_value); +} + +std::vector LegacyValueToModernValueOrDie( + google::protobuf::Arena* arena, + absl::Span values, + bool unchecked) { + std::vector modern_values; + modern_values.reserve(values.size()); + for (const auto& value : values) { + modern_values.push_back( + LegacyValueToModernValueOrDie(arena, value, unchecked)); + } + return modern_values; +} + +google::api::expr::runtime::CelValue ModernValueToLegacyValueOrDie( + google::protobuf::Arena* arena, const Value& value, bool unchecked) { + auto status_or_value = ToLegacyValue(arena, value, unchecked); + ABSL_CHECK_OK(status_or_value.status()); // Crash OK + return std::move(*status_or_value); +} + +TypeValue CreateTypeValueFromView(google::protobuf::Arena* arena, + absl::string_view input) { + return TypeValue(common_internal::LegacyRuntimeType(input)); +} + +} // namespace interop_internal + +} // namespace cel diff --git a/common/legacy_value.h b/common/legacy_value.h new file mode 100644 index 000000000..7e703cea1 --- /dev/null +++ b/common/legacy_value.h @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_LEGACY_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_LEGACY_VALUE_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "common/value.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" + +namespace cel { + +absl::Status ModernValue(google::protobuf::Arena* arena, + google::api::expr::runtime::CelValue legacy_value, + Value& result); +inline absl::StatusOr ModernValue( + google::protobuf::Arena* arena, google::api::expr::runtime::CelValue legacy_value) { + Value result; + CEL_RETURN_IF_ERROR(ModernValue(arena, legacy_value, result)); + return result; +} + +absl::StatusOr LegacyValue( + google::protobuf::Arena* arena, const Value& modern_value); + +namespace common_internal { + +// Convert a `cel::Value` to `google::api::expr::runtime::CelValue`, using +// `arena` to make memory allocations if necessary. `stable` indicates whether +// `cel::Value` is in a location where it will not be moved, so that inline +// string/bytes storage can be referenced. +google::api::expr::runtime::CelValue UnsafeLegacyValue( + const Value& value, bool stable, google::protobuf::Arena* absl_nonnull arena); + +} // namespace common_internal + +} // namespace cel + +namespace cel::interop_internal { + +absl::StatusOr FromLegacyValue( + google::protobuf::Arena* arena, + const google::api::expr::runtime::CelValue& legacy_value, + bool unchecked = false); + +absl::StatusOr ToLegacyValue( + google::protobuf::Arena* arena, const Value& value, bool unchecked = false); + +inline NullValue CreateNullValue() { return NullValue{}; } + +inline BoolValue CreateBoolValue(bool value) { return BoolValue{value}; } + +inline IntValue CreateIntValue(int64_t value) { return IntValue{value}; } + +inline UintValue CreateUintValue(uint64_t value) { return UintValue{value}; } + +inline DoubleValue CreateDoubleValue(double value) { + return DoubleValue{value}; +} + +inline ListValue CreateLegacyListValue( + const google::api::expr::runtime::CelList* value) { + return common_internal::LegacyListValue(value); +} + +inline MapValue CreateLegacyMapValue( + const google::api::expr::runtime::CelMap* value) { + return common_internal::LegacyMapValue(value); +} + +inline Value CreateDurationValue(absl::Duration value, bool unchecked = false) { + return DurationValue{value}; +} + +inline TimestampValue CreateTimestampValue(absl::Time value) { + return TimestampValue{value}; +} + +Value LegacyValueToModernValueOrDie( + google::protobuf::Arena* arena, const google::api::expr::runtime::CelValue& value, + bool unchecked = false); +std::vector LegacyValueToModernValueOrDie( + google::protobuf::Arena* arena, + absl::Span values, + bool unchecked = false); + +google::api::expr::runtime::CelValue ModernValueToLegacyValueOrDie( + google::protobuf::Arena* arena, const Value& value, bool unchecked = false); + +TypeValue CreateTypeValueFromView(google::protobuf::Arena* arena, + absl::string_view input); + +} // namespace cel::interop_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_LEGACY_VALUE_H_ diff --git a/common/macros.h b/common/macros.h deleted file mode 100644 index 8fa46eda2..000000000 --- a/common/macros.h +++ /dev/null @@ -1,25 +0,0 @@ -// Helper macros for dealing with CelValues. -// -// Never include this file in another header, as macros are declared in -// the global scope. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_CEL_MACROS_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_CEL_MACROS_H_ - -/** Returns the CelValue immediately if it represents an Error or Unknown. */ -#define RETURN_IF_NOT_VALUE(expr) \ - do { \ - auto return_if_not_value_value = (expr); \ - if (!return_if_not_value_value.is_value()) \ - return return_if_not_value_value; \ - } while (false) - -/** Helper macro to return a status eagerly, if it represents an error. */ -#define RETURN_IF_STATUS_ERROR(expr) \ - do { \ - auto return_if_status_error_status = (expr); \ - if (return_if_status_error_status.code() != ::google::rpc::Code::OK) \ - return return_if_status_error_status; \ - } while (false) - -#endif // THIRD_PARTY_CEL_CPP_COMMON_CEL_MACROS_H_ diff --git a/common/memory.cc b/common/memory.cc new file mode 100644 index 000000000..c00c12ed8 --- /dev/null +++ b/common/memory.cc @@ -0,0 +1,83 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/memory.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "google/protobuf/arena.h" + +namespace cel { + +std::ostream& operator<<(std::ostream& out, + MemoryManagement memory_management) { + switch (memory_management) { + case MemoryManagement::kPooling: + return out << "POOLING"; + case MemoryManagement::kReferenceCounting: + return out << "REFERENCE_COUNTING"; + } +} + +void* ReferenceCountingMemoryManager::Allocate(size_t size, size_t alignment) { + ABSL_DCHECK(absl::has_single_bit(alignment)) + << "alignment must be a power of 2: " << alignment; + if (size == 0) { + return nullptr; + } + if (alignment <= __STDCPP_DEFAULT_NEW_ALIGNMENT__) { + return ::operator new(size); + } + return ::operator new(size, static_cast(alignment)); +} + +bool ReferenceCountingMemoryManager::Deallocate(void* ptr, size_t size, + size_t alignment) noexcept { + ABSL_DCHECK(absl::has_single_bit(alignment)) + << "alignment must be a power of 2: " << alignment; + if (ptr == nullptr) { + ABSL_DCHECK_EQ(size, 0); + return false; + } + ABSL_DCHECK_GT(size, 0); + if (alignment <= __STDCPP_DEFAULT_NEW_ALIGNMENT__) { +#if defined(__cpp_sized_deallocation) && __cpp_sized_deallocation >= 201309L + ::operator delete(ptr, size); +#else + ::operator delete(ptr); +#endif + } else { +#if defined(__cpp_sized_deallocation) && __cpp_sized_deallocation >= 201309L + ::operator delete(ptr, size, static_cast(alignment)); +#else + ::operator delete(ptr, static_cast(alignment)); +#endif + } + return true; +} + +MemoryManager MemoryManager::Unmanaged() { + // A static singleton arena, using `absl::NoDestructor` to avoid warnings + // related static variables without trivial destructors. + static absl::NoDestructor arena; + return MemoryManager::Pooling(&*arena); +} + +} // namespace cel diff --git a/common/memory.h b/common/memory.h new file mode 100644 index 000000000..b19f54f94 --- /dev/null +++ b/common/memory.h @@ -0,0 +1,1502 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_MEMORY_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_MEMORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "common/allocator.h" +#include "common/arena.h" +#include "common/data.h" +#include "common/internal/metadata.h" +#include "common/internal/reference_count.h" +#include "common/reference_count.h" +#include "internal/exceptions.h" +#include "internal/to_address.h" // IWYU pragma: keep +#include "google/protobuf/arena.h" + +namespace cel { + +// Obtain the address of the underlying element from a raw pointer or "fancy" +// pointer. +using internal::to_address; + +// MemoryManagement is an enumeration of supported memory management forms +// underlying `cel::MemoryManager`. +enum class MemoryManagement { + // Region-based (a.k.a. arena). Memory is allocated in fixed size blocks and + // deallocated all at once upon destruction of the `cel::MemoryManager`. + kPooling = 1, + // Reference counting. Memory is allocated with an associated reference + // counter. When the reference counter hits 0, it is deallocated. + kReferenceCounting, +}; + +std::ostream& operator<<(std::ostream& out, MemoryManagement memory_management); + +class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owner; +class Borrower; +template +class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Unique; +template +class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owned; +template +class Borrowed; +template +struct Ownable; +template +struct Borrowable; + +class MemoryManager; +class ReferenceCountingMemoryManager; +class PoolingMemoryManager; + +namespace common_internal { +template +inline constexpr bool kNotMessageLiteAndNotData = + std::conjunction_v>, + std::negation>>; +template +inline constexpr bool kIsPointerConvertible = std::is_convertible_v; +template +inline constexpr bool kNotSameAndIsPointerConvertible = + std::conjunction_v>, + std::bool_constant>>; + +// Clears the contents of `owner`, and returns the reference count if in use. +const ReferenceCount* absl_nullable OwnerRelease(Owner owner) noexcept; +const ReferenceCount* absl_nullable BorrowerRelease(Borrower borrower) noexcept; +template +Owned WrapEternal(const T* value); + +// Pointer tag used by `cel::Unique` to indicate that the destructor needs to be +// registered with the arena, but it has not been done yet. Must be done when +// releasing. +inline constexpr uintptr_t kUniqueArenaUnownedBit = uintptr_t{1} << 0; +inline constexpr uintptr_t kUniqueArenaBits = kUniqueArenaUnownedBit; +inline constexpr uintptr_t kUniqueArenaPointerMask = ~kUniqueArenaBits; +} // namespace common_internal + +template +Owned AllocateShared(Allocator<> allocator, Args&&... args); + +template +Owned WrapShared(T* object, Allocator<> allocator); + +// `Owner` represents a reference to some co-owned data, of which this owner is +// one of the co-owners. When using reference counting, `Owner` performs +// increment/decrement where appropriate similar to `std::shared_ptr`. +// `Borrower` is similar to `Owner`, except that it is always trivially +// copyable/destructible. In that sense, `Borrower` is similar to +// `std::reference_wrapper`. +class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owner final { + private: + static constexpr uintptr_t kNone = common_internal::kMetadataOwnerNone; + static constexpr uintptr_t kReferenceCountBit = + common_internal::kMetadataOwnerReferenceCountBit; + static constexpr uintptr_t kArenaBit = + common_internal::kMetadataOwnerArenaBit; + static constexpr uintptr_t kBits = common_internal::kMetadataOwnerBits; + static constexpr uintptr_t kPointerMask = + common_internal::kMetadataOwnerPointerMask; + + public: + static Owner None() noexcept { return Owner(); } + + static Owner Allocator(Allocator<> allocator) noexcept { + auto* arena = allocator.arena(); + return arena != nullptr ? Arena(arena) : None(); + } + + static Owner Arena(google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + ABSL_DCHECK(arena != nullptr); + return Owner(reinterpret_cast(arena) | kArenaBit); + } + + static Owner Arena(std::nullptr_t) = delete; + + static Owner ReferenceCount(const ReferenceCount* absl_nonnull reference_count + ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + ABSL_DCHECK(reference_count != nullptr); + common_internal::StrongRef(*reference_count); + return Owner(reinterpret_cast(reference_count) | + kReferenceCountBit); + } + + static Owner ReferenceCount(std::nullptr_t) = delete; + + Owner() = default; + + Owner(const Owner& other) noexcept : Owner(CopyFrom(other.ptr_)) {} + + Owner(Owner&& other) noexcept : Owner(MoveFrom(other.ptr_)) {} + + template + // NOLINTNEXTLINE(google-explicit-constructor) + Owner(const Owned& owned) noexcept; + + template + // NOLINTNEXTLINE(google-explicit-constructor) + Owner(Owned&& owned) noexcept; + + explicit Owner(Borrower borrower) noexcept; + + template + explicit Owner(Borrowed borrowed) noexcept; + + ~Owner() { Destroy(ptr_); } + + Owner& operator=(const Owner& other) noexcept { + if (ptr_ != other.ptr_) { + Destroy(ptr_); + ptr_ = CopyFrom(other.ptr_); + } + return *this; + } + + Owner& operator=(Owner&& other) noexcept { + if (ABSL_PREDICT_TRUE(this != &other)) { + Destroy(ptr_); + ptr_ = MoveFrom(other.ptr_); + } + return *this; + } + + template + // NOLINTNEXTLINE(google-explicit-constructor) + Owner& operator=(const Owned& owned) noexcept; + + template + // NOLINTNEXTLINE(google-explicit-constructor) + Owner& operator=(Owned&& owned) noexcept; + + explicit operator bool() const noexcept { return !IsNone(ptr_); } + + google::protobuf::Arena* absl_nullable arena() const noexcept { + return (ptr_ & Owner::kBits) == Owner::kArenaBit + ? reinterpret_cast(ptr_ & Owner::kPointerMask) + : nullptr; + } + + void reset() noexcept { + Destroy(ptr_); + ptr_ = 0; + } + + // Tests whether two owners have ownership over the same data, that is they + // are co-owners. + friend bool operator==(const Owner& lhs, const Owner& rhs) noexcept { + // A reference count and arena can never occupy the same memory address, so + // we can compare for equality without masking off the bits. + return lhs.ptr_ == rhs.ptr_; + } + + private: + template + friend class Unique; + friend class Borrower; + template + friend Owned AllocateShared(cel::Allocator<> allocator, Args&&... args); + template + friend Owned WrapShared(T* object, cel::Allocator<> allocator); + template + friend struct Ownable; + friend const common_internal::ReferenceCount* absl_nullable + common_internal::OwnerRelease(Owner owner) noexcept; + friend const common_internal::ReferenceCount* absl_nullable + common_internal::BorrowerRelease(Borrower borrower) noexcept; + friend struct ArenaTraits; + + constexpr explicit Owner(uintptr_t ptr) noexcept : ptr_(ptr) {} + + static constexpr bool IsNone(uintptr_t ptr) noexcept { return ptr == kNone; } + + static constexpr bool IsArena(uintptr_t ptr) noexcept { + return (ptr & kArenaBit) != kNone; + } + + static constexpr bool IsReferenceCount(uintptr_t ptr) noexcept { + return (ptr & kReferenceCountBit) != kNone; + } + + ABSL_ATTRIBUTE_RETURNS_NONNULL + static google::protobuf::Arena* absl_nonnull AsArena(uintptr_t ptr) noexcept { + ABSL_ASSERT(IsArena(ptr)); + return reinterpret_cast(ptr & kPointerMask); + } + + ABSL_ATTRIBUTE_RETURNS_NONNULL + static const common_internal::ReferenceCount* absl_nonnull AsReferenceCount( + uintptr_t ptr) noexcept { + ABSL_ASSERT(IsReferenceCount(ptr)); + return reinterpret_cast( + ptr & kPointerMask); + } + + static uintptr_t CopyFrom(uintptr_t other) noexcept { return Own(other); } + + static uintptr_t MoveFrom(uintptr_t& other) noexcept { + return std::exchange(other, kNone); + } + + static void Destroy(uintptr_t ptr) noexcept { Unown(ptr); } + + static uintptr_t Own(uintptr_t ptr) noexcept { + if (IsReferenceCount(ptr)) { + const auto* refcount = Owner::AsReferenceCount(ptr); + ABSL_ASSUME(refcount != nullptr); + common_internal::StrongRef(refcount); + } + return ptr; + } + + static void Unown(uintptr_t ptr) noexcept { + if (IsReferenceCount(ptr)) { + const auto* reference_count = AsReferenceCount(ptr); + ABSL_ASSUME(reference_count != nullptr); + common_internal::StrongUnref(reference_count); + } + } + + uintptr_t ptr_ = kNone; +}; + +inline bool operator!=(const Owner& lhs, const Owner& rhs) noexcept { + return !operator==(lhs, rhs); +} + +namespace common_internal { + +inline const ReferenceCount* absl_nullable OwnerRelease(Owner owner) noexcept { + uintptr_t ptr = std::exchange(owner.ptr_, kMetadataOwnerNone); + if (Owner::IsReferenceCount(ptr)) { + return Owner::AsReferenceCount(ptr); + } + return nullptr; +} + +} // namespace common_internal + +template <> +struct ArenaTraits { + static bool trivially_destructible(const Owner& owner) { + return !Owner::IsReferenceCount(owner.ptr_); + } +}; + +// `Borrower` represents a reference to some borrowed data, where the data has +// at least one owner. When using reference counting, `Borrower` does not +// participate in incrementing/decrementing the reference count. Thus `Borrower` +// will not keep the underlying data alive. +class Borrower final { + public: + static Borrower None() noexcept { return Borrower(); } + + static Borrower Allocator(Allocator<> allocator) noexcept { + auto* arena = allocator.arena(); + return arena != nullptr ? Arena(arena) : None(); + } + + static Borrower Arena(google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + ABSL_DCHECK(arena != nullptr); + return Borrower(reinterpret_cast(arena) | Owner::kArenaBit); + } + + static Borrower Arena(std::nullptr_t) = delete; + + static Borrower ReferenceCount( + const ReferenceCount* absl_nonnull reference_count + ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + ABSL_DCHECK(reference_count != nullptr); + return Borrower(reinterpret_cast(reference_count) | + Owner::kReferenceCountBit); + } + + static Borrower ReferenceCount(std::nullptr_t) = delete; + + Borrower() = default; + Borrower(const Borrower&) = default; + Borrower(Borrower&&) = default; + Borrower& operator=(const Borrower&) = default; + Borrower& operator=(Borrower&&) = default; + + template + // NOLINTNEXTLINE(google-explicit-constructor) + Borrower(const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept; + + template + // NOLINTNEXTLINE(google-explicit-constructor) + Borrower(Borrowed borrowed) noexcept; + + // NOLINTNEXTLINE(google-explicit-constructor) + Borrower(const Owner& owner ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept + : ptr_(owner.ptr_) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Borrower& operator=( + const Owner& owner ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + ptr_ = owner.ptr_; + return *this; + } + + Borrower& operator=(Owner&&) = delete; + + template + Borrower& operator=( + const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept; + + template + Borrower& operator=(Owned&&) = delete; + + template + // NOLINTNEXTLINE(google-explicit-constructor) + Borrower& operator=(Borrowed borrowed) noexcept; + + explicit operator bool() const noexcept { return !Owner::IsNone(ptr_); } + + google::protobuf::Arena* absl_nullable arena() const noexcept { + return (ptr_ & Owner::kBits) == Owner::kArenaBit + ? reinterpret_cast(ptr_ & Owner::kPointerMask) + : nullptr; + } + + void reset() noexcept { ptr_ = 0; } + + // Tests whether two borrowers are borrowing the same data. + friend bool operator==(Borrower lhs, Borrower rhs) noexcept { + // A reference count and arena can never occupy the same memory address, so + // we can compare for equality without masking off the bits. + return lhs.ptr_ == rhs.ptr_; + } + + private: + friend class Owner; + template + friend struct Borrowable; + friend const common_internal::ReferenceCount* absl_nullable + common_internal::BorrowerRelease(Borrower borrower) noexcept; + + constexpr explicit Borrower(uintptr_t ptr) noexcept : ptr_(ptr) {} + + uintptr_t ptr_ = Owner::kNone; +}; + +inline bool operator!=(Borrower lhs, Borrower rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline bool operator==(Borrower lhs, const Owner& rhs) noexcept { + return operator==(lhs, Borrower(rhs)); +} + +inline bool operator==(const Owner& lhs, Borrower rhs) noexcept { + return operator==(Borrower(lhs), rhs); +} + +inline bool operator!=(Borrower lhs, const Owner& rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline bool operator!=(const Owner& lhs, Borrower rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline Owner::Owner(Borrower borrower) noexcept + : ptr_(Owner::Own(borrower.ptr_)) {} + +namespace common_internal { + +inline const ReferenceCount* absl_nullable BorrowerRelease( + Borrower borrower) noexcept { + uintptr_t ptr = borrower.ptr_; + if (Owner::IsReferenceCount(ptr)) { + return Owner::AsReferenceCount(ptr); + } + return nullptr; +} + +} // namespace common_internal + +template +Unique AllocateUnique(Allocator<> allocator, Args&&... args); + +// Wrap an already created `T` in `Unique`. Requires that `T` is not const, +// otherwise `GetArena()` may return slightly unexpected results depending on if +// it is the default value. +template +std::enable_if_t, Unique> WrapUnique(T* object); + +template +Unique WrapUnique(T* object, Allocator<> allocator); + +// `Unique` points to an object which was allocated using `Allocator<>` or +// `Allocator`. It has ownership over the object, and will perform any +// destruction and deallocation required. `Unique` must not outlive the +// underlying arena, if any. Unlike `Owned` and `Borrowed`, `Unique` supports +// arena incompatible objects. It is very similar to `std::unique_ptr` when +// using a custom deleter. +// +// IMPLEMENTATION NOTES: +// When utilizing arenas, we optionally perform a risky optimization via +// `AllocateUnique`. We do not use `Arena::Create`, instead we directly allocate +// the bytes and construct it in place ourselves. This avoids registering the +// destructor when required. Instead we register the destructor ourselves, if +// required, during `Unique::release`. This allows us to avoid deferring +// destruction of the object until the arena is destroyed, avoiding the cost +// involved in doing so. +template +class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Unique final { + public: + using element_type = T; + + static_assert(!std::is_array_v, "T must not be an array"); + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + + Unique() = default; + Unique(const Unique&) = delete; + Unique& operator=(const Unique&) = delete; + + explicit Unique(T* ptr) noexcept + : Unique(ptr, common_internal::GetArena(ptr)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Unique(std::nullptr_t) noexcept : Unique() {} + + Unique(Unique&& other) noexcept : Unique(other.ptr_, other.arena_) { + other.ptr_ = nullptr; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Unique(Unique&& other) noexcept : Unique(other.ptr_, other.arena_) { + other.ptr_ = nullptr; + } + + ~Unique() { Delete(); } + + Unique& operator=(Unique&& other) noexcept { + if (ABSL_PREDICT_TRUE(this != &other)) { + Delete(); + ptr_ = other.ptr_; + arena_ = other.arena_; + other.ptr_ = nullptr; + } + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Unique& operator=(U* other) noexcept { + reset(other); + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Unique& operator=(Unique&& other) noexcept { + Delete(); + ptr_ = other.ptr_; + arena_ = other.arena_; + other.ptr_ = nullptr; + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + Unique& operator=(std::nullptr_t) noexcept { + reset(); + return *this; + } + + T& operator*() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(static_cast(*this)); + return *get(); + } + + T* absl_nonnull operator->() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(static_cast(*this)); + return get(); + } + + // Relinquishes ownership of `T*`, returning it. If `T` was allocated and + // constructed using an arena, no further action is required. If `T` was + // allocated and constructed without an arena, the caller must eventually call + // `delete`. + ABSL_MUST_USE_RESULT T* release() noexcept { + PreRelease(); + return std::exchange(ptr_, nullptr); + } + + void reset() noexcept { reset(nullptr); } + + void reset(T* ptr) noexcept { + Delete(); + ptr_ = ptr; + arena_ = reinterpret_cast(common_internal::GetArena(ptr)); + } + + void reset(std::nullptr_t) noexcept { + Delete(); + ptr_ = nullptr; + arena_ = 0; + } + + explicit operator bool() const noexcept { return get() != nullptr; } + + google::protobuf::Arena* absl_nullable arena() const noexcept { + return reinterpret_cast( + arena_ & common_internal::kUniqueArenaPointerMask); + } + + friend void swap(Unique& lhs, Unique& rhs) noexcept { + using std::swap; + swap(lhs.ptr_, rhs.ptr_); + swap(lhs.arena_, rhs.arena_); + } + + private: + template + friend class Unique; + template + friend class Owned; + template + friend Unique AllocateUnique(Allocator<> allocator, Args&&... args); + template + friend Unique WrapUnique(U* object, Allocator<> allocator); + friend class ReferenceCountingMemoryManager; + friend class PoolingMemoryManager; + friend struct std::pointer_traits>; + friend struct ArenaTraits>; + + Unique(T* ptr, uintptr_t arena) noexcept : ptr_(ptr), arena_(arena) {} + + Unique(T* ptr, google::protobuf::Arena* arena, bool unowned = false) noexcept + : Unique(ptr, + reinterpret_cast(arena) | + (unowned ? common_internal::kUniqueArenaUnownedBit : 0)) { + ABSL_ASSERT(!unowned || (unowned && arena != nullptr)); + } + + Unique(google::protobuf::Arena* arena, T* ptr, bool unowned = false) noexcept + : Unique(ptr, arena, unowned) {} + + T* get() const noexcept { return ptr_; } + + void Delete() const noexcept { + if (static_cast(*this)) { + if (arena_ != 0) { + if ((arena_ & common_internal::kUniqueArenaBits) == + common_internal::kUniqueArenaUnownedBit) { + // We never registered the destructor, call it if necessary. + if constexpr (!std::is_trivially_destructible_v && + !google::protobuf::Arena::is_destructor_skippable::value) { + std::destroy_at(ptr_); + } + } + } else { + delete ptr_; + } + } + } + + void PreRelease() noexcept { + if constexpr (!std::is_trivially_destructible_v && + !google::protobuf::Arena::is_destructor_skippable::value) { + if (static_cast(*this) && + (arena_ & common_internal::kUniqueArenaBits) == + common_internal::kUniqueArenaUnownedBit) { + // We never registered the destructor, call it if necessary. + arena()->OwnDestructor(const_cast*>(ptr_)); + arena_ &= common_internal::kUniqueArenaPointerMask; + } + } + } + + void Release(T** ptr, Owner* owner) noexcept { + if (ptr_ == nullptr) { + *ptr = nullptr; + return; + } + PreRelease(); + *ptr = std::exchange(ptr_, nullptr); + if (arena_ == 0) { + owner->ptr_ = reinterpret_cast( + common_internal::MakeDeletingReferenceCount(*ptr)) | + common_internal::kMetadataOwnerReferenceCountBit; + } else { + owner->ptr_ = reinterpret_cast(arena()) | + common_internal::kMetadataOwnerArenaBit; + } + } + + T* ptr_ = nullptr; + // Potentially tagged pointer to `google::protobuf::Arena`. The tag is used to determine + // whether we still need to register the destructor with the `google::protobuf::Arena`. + uintptr_t arena_ = 0; +}; + +template +Unique(T*) -> Unique; + +template +Unique AllocateUnique(Allocator<> allocator, Args&&... args) { + using U = std::remove_cv_t; + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_array_v, "T must not be an array"); + + U* object; + google::protobuf::Arena* absl_nullable arena = allocator.arena(); + bool unowned; + if constexpr (google::protobuf::Arena::is_arena_constructable::value) { + object = google::protobuf::Arena::Create(arena, std::forward(args)...); + // For arena-compatible proto types, let the Arena::Create handle + // registering the destructor call. + // Otherwise, Unique retains a pointer to the owning arena so it may + // conditionally register T::~T depending on usage. + unowned = false; + } else { + void* p = allocator.allocate_bytes(sizeof(U), alignof(U)); + CEL_INTERNAL_TRY { + if constexpr (ArenaTraits<>::constructible()) { + object = ::new (p) U(arena, std::forward(args)...); + } else { + object = ::new (p) U(std::forward(args)...); + } + } + CEL_INTERNAL_CATCH_ANY { + allocator.deallocate_bytes(p, sizeof(U), alignof(U)); + CEL_INTERNAL_RETHROW; + } + unowned = + arena != nullptr && !ArenaTraits<>::trivially_destructible(*object); + } + return Unique(object, arena, unowned); +} + +template +std::enable_if_t, Unique> WrapUnique(T* object) { + return Unique(object); +} + +template +Unique WrapUnique(T* object, Allocator<> allocator) { + return Unique(object, allocator.arena()); +} + +template +inline bool operator==(const Unique& lhs, std::nullptr_t) { + return !static_cast(lhs); +} + +template +inline bool operator==(std::nullptr_t, const Unique& rhs) { + return !static_cast(rhs); +} + +template +inline bool operator!=(const Unique& lhs, std::nullptr_t) { + return static_cast(lhs); +} + +template +inline bool operator!=(std::nullptr_t, const Unique& rhs) { + return static_cast(rhs); +} + +} // namespace cel + +namespace std { + +template +struct pointer_traits> { + using pointer = cel::Unique; + using element_type = typename cel::Unique::element_type; + using difference_type = ptrdiff_t; + + template + using rebind = cel::Unique; + + static element_type* to_address(const pointer& p) noexcept { return p.ptr_; } +}; + +} // namespace std + +namespace cel { + +template +struct ArenaTraits> { + static bool trivially_destructible(const Unique& unique) { + return unique.arena_ != 0 && + (unique.arena_ & common_internal::kUniqueArenaBits) == 0; + } +}; + +// `Owned` points to an object which was allocated using `Allocator<>` or +// `Allocator`. It has co-ownership over the object. `T` must meet the named +// requirement `ArenaConstructable`. +template +class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owned final { + public: + using element_type = T; + + static_assert(!std::is_array_v, "T must not be an array"); + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(!std::is_void_v, "T must not be void"); + + Owned() = default; + Owned(const Owned&) = default; + Owned& operator=(const Owned&) = default; + + Owned(Owned&& other) noexcept + : Owned(std::exchange(other.value_, nullptr), std::move(other.owner_)) {} + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Owned(const Owned& other) noexcept : Owned(other.value_, other.owner_) {} + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Owned(Owned&& other) noexcept + : Owned(std::exchange(other.value_, nullptr), std::move(other.owner_)) {} + + template >> + explicit Owned(Borrowed other) noexcept; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Owned(Unique&& other) : Owned() { + other.Release(&value_, &owner_); + } + + Owned(Owner owner, T* value ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept + : Owned(value, std::move(owner)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Owned(std::nullptr_t) noexcept : Owned() {} + + Owned& operator=(Owned&& other) noexcept { + if (ABSL_PREDICT_TRUE(this != &other)) { + value_ = std::exchange(other.value_, nullptr); + owner_ = std::move(other.owner_); + } + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Owned& operator=(const Owned& other) noexcept { + value_ = other.value_; + owner_ = other.owner_; + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Owned& operator=(Owned&& other) noexcept { + value_ = std::exchange(other.value_, nullptr); + owner_ = std::move(other.owner_); + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Owned& operator=(Borrowed other) noexcept; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Owned& operator=(Unique&& other) { + owner_.reset(); + other.Release(&value_, &owner_); + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + Owned& operator=(std::nullptr_t) noexcept { + reset(); + return *this; + } + + T& operator*() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(static_cast(*this)); + return *get(); + } + + T* absl_nonnull operator->() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(static_cast(*this)); + return get(); + } + + void reset() noexcept { + value_ = nullptr; + owner_.reset(); + } + + google::protobuf::Arena* absl_nullable arena() const noexcept { return owner_.arena(); } + + explicit operator bool() const noexcept { return get() != nullptr; } + + friend void swap(Owned& lhs, Owned& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + swap(lhs.owner_, rhs.owner_); + } + + private: + friend class Owner; + friend class Borrower; + template + friend class Owned; + template + friend class Borrowed; + template + friend struct Ownable; + template + friend Owned AllocateShared(Allocator<> allocator, Args&&... args); + template + friend Owned WrapShared(U* object, Allocator<> allocator); + template + friend Owned common_internal::WrapEternal(const U* value); + friend struct std::pointer_traits>; + friend struct ArenaTraits>; + + Owned(T* value, Owner owner) noexcept + : value_(value), owner_(std::move(owner)) {} + + T* get() const noexcept { return value_; } + + T* value_ = nullptr; + Owner owner_; +}; + +template +Owned(T*) -> Owned; +template +Owned(Unique) -> Owned; +template +Owned(Owner, T*) -> Owned; +template +Owned(Borrowed) -> Owned; + +} // namespace cel + +namespace std { + +template +struct pointer_traits> { + using pointer = cel::Owned; + using element_type = typename cel::Owned::element_type; + using difference_type = ptrdiff_t; + + template + using rebind = cel::Owned; + + static element_type* to_address(const pointer& p) noexcept { + return p.value_; + } +}; + +} // namespace std + +namespace cel { + +template +struct ArenaTraits> { + static bool trivially_destructible(const Owned& owned) { + return ArenaTraits<>::trivially_destructible(owned.owner_); + } +}; + +template +Owner::Owner(const Owned& owned) noexcept : Owner(owned.owner_) {} + +template +Owner::Owner(Owned&& owned) noexcept : Owner(std::move(owned.owner_)) { + owned.value_ = nullptr; +} + +template +Owner& Owner::operator=(const Owned& owned) noexcept { + *this = owned.owner_; + return *this; +} + +template +Owner& Owner::operator=(Owned&& owned) noexcept { + *this = std::move(owned.owner_); + owned.value_ = nullptr; + return *this; +} + +template +bool operator==(const Owned& lhs, std::nullptr_t) noexcept { + return !static_cast(lhs); +} + +template +bool operator==(std::nullptr_t, const Owned& rhs) noexcept { + return rhs == nullptr; +} + +template +bool operator!=(const Owned& lhs, std::nullptr_t) noexcept { + return !operator==(lhs, nullptr); +} + +template +bool operator!=(std::nullptr_t, const Owned& rhs) noexcept { + return !operator==(nullptr, rhs); +} + +template +Owned AllocateShared(Allocator<> allocator, Args&&... args) { + using U = std::remove_cv_t; + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_array_v, "T must not be an array"); + + U* object; + Owner owner; + if (google::protobuf::Arena* absl_nullable arena = allocator.arena(); + arena != nullptr) { + object = ArenaAllocator(arena).template new_object( + std::forward(args)...); + owner.ptr_ = reinterpret_cast(arena) | + common_internal::kMetadataOwnerArenaBit; + } else { + const common_internal::ReferenceCount* refcount; + std::tie(object, refcount) = common_internal::MakeEmplacedReferenceCount( + std::forward(args)...); + owner.ptr_ = reinterpret_cast(refcount) | + common_internal::kMetadataOwnerReferenceCountBit; + } + return Owned(object, std::move(owner)); +} + +template +Owned WrapShared(T* object, Allocator<> allocator) { + Owner owner; + if (object == nullptr) { + } else if (allocator.arena() != nullptr) { + owner.ptr_ = reinterpret_cast( + static_cast(allocator.arena())) | + common_internal::kMetadataOwnerArenaBit; + } else { + owner.ptr_ = reinterpret_cast( + common_internal::MakeDeletingReferenceCount(object)) | + common_internal::kMetadataOwnerReferenceCountBit; + } + return Owned(object, std::move(owner)); +} + +template +std::enable_if_t, Owned> WrapShared(T* object) { + return WrapShared(object, object->GetArena()); +} + +namespace common_internal { + +template +Owned WrapEternal(const T* value) { + return Owned(value, Owner::None()); +} + +} // namespace common_internal + +// `Borrowed` points to an object which was allocated using `Allocator<>` or +// `Allocator`. It has no ownership over the object, and is only valid so +// long as one or more owners of the object exist. `T` must meet the named +// requirement `ArenaConstructable`. +template +class Borrowed final { + public: + using element_type = T; + + static_assert(!std::is_array_v, "T must not be an array"); + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(!std::is_void_v, "T must not be void"); + + Borrowed() = default; + Borrowed(const Borrowed&) = default; + Borrowed(Borrowed&&) = default; + Borrowed& operator=(const Borrowed&) = default; + Borrowed& operator=(Borrowed&&) = default; + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed(const Borrowed& other) noexcept + : Borrowed(other.value_, other.borrower_) {} + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed(Borrowed&& other) noexcept + : Borrowed(other.value_, other.borrower_) {} + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed(const Owned& other ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept + : Borrowed(other.value_, other.owner_) {} + + Borrowed(Borrower borrower, T* ptr) noexcept : Borrowed(ptr, borrower) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed(std::nullptr_t) noexcept : Borrowed() {} + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed& operator=(const Borrowed& other) noexcept { + value_ = other.value_; + borrower_ = other.borrower_; + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed& operator=(Borrowed&& other) noexcept { + value_ = other.value_; + borrower_ = other.borrower_; + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed& operator=( + const Owned& other ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + value_ = other.value_; + borrower_ = other.borrower_; + return *this; + } + + template >> + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed& operator=(Owned&&) = delete; + + // NOLINTNEXTLINE(google-explicit-constructor) + Borrowed& operator=(std::nullptr_t) noexcept { + reset(); + return *this; + } + + T& operator*() const noexcept { + ABSL_DCHECK(static_cast(*this)); + return *get(); + } + + T* absl_nonnull operator->() const noexcept { + ABSL_DCHECK(static_cast(*this)); + return get(); + } + + void reset() noexcept { + value_ = nullptr; + borrower_.reset(); + } + + google::protobuf::Arena* absl_nullable arena() const noexcept { + return borrower_.arena(); + } + + explicit operator bool() const noexcept { return get() != nullptr; } + + private: + friend class Owner; + friend class Borrower; + template + friend class Owned; + template + friend class Borrowed; + template + friend struct Borrowable; + friend struct std::pointer_traits>; + + constexpr Borrowed(T* value, Borrower borrower) noexcept + : value_(value), borrower_(borrower) {} + + T* get() const noexcept { return value_; } + + T* value_ = nullptr; + Borrower borrower_; +}; + +template +Borrowed(T*) -> Borrowed; +template +Borrowed(Borrower, T*) -> Borrowed; +template +Borrowed(Owned) -> Borrowed; + +} // namespace cel + +namespace std { + +template +struct pointer_traits> { + using pointer = cel::Borrowed; + using element_type = typename cel::Borrowed::element_type; + using difference_type = ptrdiff_t; + + template + using rebind = cel::Borrowed; + + static element_type* to_address(pointer p) noexcept { return p.value_; } +}; + +} // namespace std + +namespace cel { + +template +Owner::Owner(Borrowed borrowed) noexcept : Owner(borrowed.borrower_) {} + +template +Borrower::Borrower(const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept + : Borrower(owned.owner_) {} + +template +Borrower::Borrower(Borrowed borrowed) noexcept + : Borrower(borrowed.borrower_) {} + +template +Borrower& Borrower::operator=( + const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + *this = owned.owner_; + return *this; +} + +template +Borrower& Borrower::operator=(Borrowed borrowed) noexcept { + *this = borrowed.borrower_; + return *this; +} + +template +bool operator==(Borrowed lhs, std::nullptr_t) noexcept { + return !static_cast(lhs); +} + +template +bool operator==(std::nullptr_t, Borrowed rhs) noexcept { + return rhs == nullptr; +} + +template +bool operator!=(Borrowed lhs, std::nullptr_t) noexcept { + return !operator==(lhs, nullptr); +} + +template +bool operator!=(std::nullptr_t, Borrowed rhs) noexcept { + return !operator==(nullptr, rhs); +} + +template +template +Owned::Owned(Borrowed other) noexcept + : Owned(other.value_, Owner(other.borrower_)) {} + +template +template +Owned& Owned::operator=(Borrowed other) noexcept { + value_ = other.value_; + owner_ = Owner(other.borrower_); + return *this; +} + +// `Ownable` is a mixin for enabling the ability to get `Owned` that refer to +// this. +template +struct Ownable { + protected: + Owned Own() const noexcept { + static_assert(std::is_base_of_v, "T must be derived from Data"); + const T* const that = static_cast(this); + return Owned( + Owner(Owner::Own(static_cast(that)->owner_)), that); + } + + Owned Own() noexcept { + static_assert(std::is_base_of_v, "T must be derived from Data"); + T* const that = static_cast(this); + return Owned(Owner(Owner::Own(static_cast(that)->owner_)), that); + } + + ABSL_DEPRECATED("Use Own") + Owned shared_from_this() const noexcept { return Own(); } + + ABSL_DEPRECATED("Use Own") + Owned shared_from_this() noexcept { return Own(); } +}; + +// `Borrowable` is a mixin for enabling the ability to get `Borrowed` that +// refer to this. +template +struct Borrowable { + protected: + Borrowed Borrow() const noexcept { + static_assert(std::is_base_of_v, "T must be derived from Data"); + const T* const that = static_cast(this); + return Borrowed(Borrower(static_cast(that)->owner_), + that); + } + + Borrowed Borrow() noexcept { + static_assert(std::is_base_of_v, "T must be derived from Data"); + T* const that = static_cast(this); + return Borrowed(Borrower(static_cast(that)->owner_), that); + } +}; + +// `ReferenceCountingMemoryManager` is a `MemoryManager` which employs automatic +// memory management through reference counting. +class ReferenceCountingMemoryManager final { + public: + ReferenceCountingMemoryManager(const ReferenceCountingMemoryManager&) = + delete; + ReferenceCountingMemoryManager(ReferenceCountingMemoryManager&&) = delete; + ReferenceCountingMemoryManager& operator=( + const ReferenceCountingMemoryManager&) = delete; + ReferenceCountingMemoryManager& operator=(ReferenceCountingMemoryManager&&) = + delete; + + private: + static void* Allocate(size_t size, size_t alignment); + + static bool Deallocate(void* ptr, size_t size, size_t alignment) noexcept; + + explicit ReferenceCountingMemoryManager() = default; + + friend class MemoryManager; +}; + +// `PoolingMemoryManager` is a `MemoryManager` which employs automatic +// memory management through memory pooling. +class PoolingMemoryManager final { + public: + PoolingMemoryManager(const PoolingMemoryManager&) = delete; + PoolingMemoryManager(PoolingMemoryManager&&) = delete; + PoolingMemoryManager& operator=(const PoolingMemoryManager&) = delete; + PoolingMemoryManager& operator=(PoolingMemoryManager&&) = delete; + + private: + // Allocates memory directly from the allocator used by this memory manager. + // If `memory_management()` returns `MemoryManagement::kReferenceCounting`, + // this allocation *must* be explicitly deallocated at some point via + // `Deallocate`. Otherwise deallocation is optional. + ABSL_MUST_USE_RESULT static void* Allocate(google::protobuf::Arena* absl_nonnull arena, + size_t size, size_t alignment) { + ABSL_DCHECK(absl::has_single_bit(alignment)) + << "alignment must be a power of 2"; + if (size == 0) { + return nullptr; + } + return arena->AllocateAligned(size, alignment); + } + + // Attempts to deallocate memory previously allocated via `Allocate`, `size` + // and `alignment` must match the values from the previous call to `Allocate`. + // Returns `true` if the deallocation was successful and additional calls to + // `Allocate` may re-use the memory, `false` otherwise. Returns `false` if + // given `nullptr`. + static bool Deallocate(google::protobuf::Arena* absl_nonnull, void*, size_t, + size_t alignment) noexcept { + ABSL_DCHECK(absl::has_single_bit(alignment)) + << "alignment must be a power of 2"; + return false; + } + + // Registers a custom destructor to be run upon destruction of the memory + // management implementation. Return value is always `true`, indicating that + // the destructor may be called at some point in the future. + static bool OwnCustomDestructor(google::protobuf::Arena* absl_nonnull arena, + void* object, + void (*absl_nonnull destruct)(void*)) { + ABSL_DCHECK(destruct != nullptr); + arena->OwnCustomDestructor(object, destruct); + return true; + } + + template + static void DefaultDestructor(void* ptr) { + static_assert(!std::is_trivially_destructible_v); + static_cast(ptr)->~T(); + } + + explicit PoolingMemoryManager() = default; + + friend class MemoryManager; +}; + +// `MemoryManager` is an abstraction for supporting automatic memory management. +// All objects created by the `MemoryManager` have a lifetime governed by the +// underlying memory management strategy. Currently `MemoryManager` is a +// composed type that holds either a reference to +// `ReferenceCountingMemoryManager` or owns a `PoolingMemoryManager`. +// +// ============================ Reference Counting ============================ +// `Unique`: The object is valid until destruction of the `Unique`. +// +// `Shared`: The object is valid so long as one or more `Shared` managing the +// object exist. +// +// ================================= Pooling ================================== +// `Unique`: The object is valid until destruction of the underlying memory +// resources or of the `Unique`. +// +// `Shared`: The object is valid until destruction of the underlying memory +// resources. +class MemoryManager final { + public: + // Returns a `MemoryManager` which utilizes an arena but never frees its + // memory. It is effectively a memory leak and should only be used for limited + // use cases, such as initializing singletons which live for the life of the + // program. + static MemoryManager Unmanaged(); + + // Returns a `MemoryManager` which utilizes reference counting. + ABSL_MUST_USE_RESULT static MemoryManager ReferenceCounting() { + return MemoryManager(nullptr); + } + + // Returns a `MemoryManager` which utilizes an arena. + ABSL_MUST_USE_RESULT static MemoryManager Pooling( + google::protobuf::Arena* absl_nonnull arena) { + return MemoryManager(arena); + } + + explicit MemoryManager(Allocator<> allocator) : arena_(allocator.arena()) {} + + MemoryManager() = delete; + MemoryManager(const MemoryManager&) = default; + MemoryManager& operator=(const MemoryManager&) = default; + + MemoryManagement memory_management() const noexcept { + return arena_ == nullptr ? MemoryManagement::kReferenceCounting + : MemoryManagement::kPooling; + } + + // Allocates memory directly from the allocator used by this memory manager. + // If `memory_management()` returns `MemoryManagement::kReferenceCounting`, + // this allocation *must* be explicitly deallocated at some point via + // `Deallocate`. Otherwise deallocation is optional. + ABSL_MUST_USE_RESULT void* Allocate(size_t size, size_t alignment) { + if (arena_ == nullptr) { + return ReferenceCountingMemoryManager::Allocate(size, alignment); + } else { + return PoolingMemoryManager::Allocate(arena_, size, alignment); + } + } + + // Attempts to deallocate memory previously allocated via `Allocate`, `size` + // and `alignment` must match the values from the previous call to `Allocate`. + // Returns `true` if the deallocation was successful and additional calls to + // `Allocate` may re-use the memory, `false` otherwise. Returns `false` if + // given `nullptr`. + bool Deallocate(void* ptr, size_t size, size_t alignment) noexcept { + if (arena_ == nullptr) { + return ReferenceCountingMemoryManager::Deallocate(ptr, size, alignment); + } else { + return PoolingMemoryManager::Deallocate(arena_, ptr, size, alignment); + } + } + + // Registers a custom destructor to be run upon destruction of the memory + // management implementation. A return of `true` indicates the destructor may + // be called at some point in the future, `false` if will definitely not be + // called. All pooling memory managers return `true` while the reference + // counting memory manager returns `false`. + bool OwnCustomDestructor(void* object, void (*absl_nonnull destruct)(void*)) { + ABSL_DCHECK(destruct != nullptr); + if (arena_ == nullptr) { + return false; + } else { + return PoolingMemoryManager::OwnCustomDestructor(arena_, object, + destruct); + } + } + + google::protobuf::Arena* absl_nullable arena() const noexcept { return arena_; } + + template + // NOLINTNEXTLINE(google-explicit-constructor) + operator Allocator() const { + return arena(); + } + + friend void swap(MemoryManager& lhs, MemoryManager& rhs) noexcept { + using std::swap; + swap(lhs.arena_, rhs.arena_); + } + + private: + friend class PoolingMemoryManager; + + explicit MemoryManager(std::nullptr_t) : arena_(nullptr) {} + + explicit MemoryManager(google::protobuf::Arena* absl_nonnull arena) : arena_(arena) {} + + // If `nullptr`, we are using reference counting. Otherwise we are using + // Pooling. We use `UnreachablePooling()` as a sentinel to detect use after + // move otherwise the moved-from `MemoryManager` would be in a valid state and + // utilize reference counting. + google::protobuf::Arena* absl_nullable arena_; +}; + +using MemoryManagerRef = MemoryManager; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_MEMORY_H_ diff --git a/common/memory_test.cc b/common/memory_test.cc new file mode 100644 index 000000000..7f3e7a82a --- /dev/null +++ b/common/memory_test.cc @@ -0,0 +1,466 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This header contains primitives for reference counting, roughly equivalent to +// the primitives used to implement `std::shared_ptr`. These primitives should +// not be used directly in most cases, instead `cel::ManagedMemory` should be +// used instead. + +#include "common/memory.h" + +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "common/allocator.h" +#include "common/data.h" +#include "common/internal/reference_count.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +#ifdef ABSL_HAVE_EXCEPTIONS +#include +#endif + +namespace cel { +namespace { + +using ::testing::IsFalse; +using ::testing::IsNull; +using ::testing::IsTrue; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; + +TEST(Owner, None) { + EXPECT_THAT(Owner::None(), IsFalse()); + EXPECT_THAT(Owner::None().arena(), IsNull()); +} + +TEST(Owner, Allocator) { + google::protobuf::Arena arena; + EXPECT_THAT(Owner::Allocator(NewDeleteAllocator<>{}), IsFalse()); + EXPECT_THAT(Owner::Allocator(ArenaAllocator<>{&arena}), IsTrue()); +} + +TEST(Owner, Arena) { + google::protobuf::Arena arena; + EXPECT_THAT(Owner::Arena(&arena), IsTrue()); + EXPECT_EQ(Owner::Arena(&arena).arena(), &arena); +} + +TEST(Owner, ReferenceCount) { + auto* refcount = new common_internal::ReferenceCounted(); + EXPECT_THAT(Owner::ReferenceCount(refcount), IsTrue()); + EXPECT_THAT(Owner::ReferenceCount(refcount).arena(), IsNull()); + common_internal::StrongUnref(refcount); +} + +TEST(Owner, Equality) { + google::protobuf::Arena arena1; + google::protobuf::Arena arena2; + EXPECT_EQ(Owner::None(), Owner::None()); + EXPECT_EQ(Owner::Allocator(NewDeleteAllocator<>{}), Owner::None()); + EXPECT_EQ(Owner::Arena(&arena1), Owner::Arena(&arena1)); + EXPECT_NE(Owner::Arena(&arena1), Owner::None()); + EXPECT_NE(Owner::None(), Owner::Arena(&arena1)); + EXPECT_NE(Owner::Arena(&arena1), Owner::Arena(&arena2)); + EXPECT_EQ(Owner::Allocator(ArenaAllocator<>{&arena1}), Owner::Arena(&arena1)); +} + +TEST(Borrower, None) { + EXPECT_THAT(Borrower::None(), IsFalse()); + EXPECT_THAT(Borrower::None().arena(), IsNull()); +} + +TEST(Borrower, Allocator) { + google::protobuf::Arena arena; + EXPECT_THAT(Borrower::Allocator(NewDeleteAllocator<>{}), IsFalse()); + EXPECT_THAT(Borrower::Allocator(ArenaAllocator<>{&arena}), IsTrue()); +} + +TEST(Borrower, Arena) { + google::protobuf::Arena arena; + EXPECT_THAT(Borrower::Arena(&arena), IsTrue()); + EXPECT_EQ(Borrower::Arena(&arena).arena(), &arena); +} + +TEST(Borrower, ReferenceCount) { + auto* refcount = new common_internal::ReferenceCounted(); + EXPECT_THAT(Borrower::ReferenceCount(refcount), IsTrue()); + EXPECT_THAT(Borrower::ReferenceCount(refcount).arena(), IsNull()); + common_internal::StrongUnref(refcount); +} + +TEST(Borrower, Equality) { + google::protobuf::Arena arena1; + google::protobuf::Arena arena2; + EXPECT_EQ(Borrower::None(), Borrower::None()); + EXPECT_EQ(Borrower::Allocator(NewDeleteAllocator<>{}), Borrower::None()); + EXPECT_EQ(Borrower::Arena(&arena1), Borrower::Arena(&arena1)); + EXPECT_NE(Borrower::Arena(&arena1), Borrower::None()); + EXPECT_NE(Borrower::None(), Borrower::Arena(&arena1)); + EXPECT_NE(Borrower::Arena(&arena1), Borrower::Arena(&arena2)); + EXPECT_EQ(Borrower::Allocator(ArenaAllocator<>{&arena1}), + Borrower::Arena(&arena1)); +} + +TEST(OwnerBorrower, CopyConstruct) { + auto* refcount = new common_internal::ReferenceCounted(); + Owner owner1 = Owner::ReferenceCount(refcount); + common_internal::StrongUnref(refcount); + Owner owner2(owner1); + Borrower borrower(owner1); + EXPECT_EQ(owner1, owner2); + EXPECT_EQ(owner1, borrower); + EXPECT_EQ(borrower, owner1); +} + +TEST(OwnerBorrower, MoveConstruct) { + auto* refcount = new common_internal::ReferenceCounted(); + Owner owner1 = Owner::ReferenceCount(refcount); + common_internal::StrongUnref(refcount); + Owner owner2(std::move(owner1)); + Borrower borrower(owner2); + EXPECT_EQ(owner2, borrower); + EXPECT_EQ(borrower, owner2); +} + +TEST(OwnerBorrower, CopyAssign) { + auto* refcount = new common_internal::ReferenceCounted(); + Owner owner1 = Owner::ReferenceCount(refcount); + common_internal::StrongUnref(refcount); + Owner owner2; + owner2 = owner1; + Borrower borrower(owner1); + EXPECT_EQ(owner1, owner2); + EXPECT_EQ(owner1, borrower); + EXPECT_EQ(borrower, owner1); +} + +TEST(OwnerBorrower, MoveAssign) { + auto* refcount = new common_internal::ReferenceCounted(); + Owner owner1 = Owner::ReferenceCount(refcount); + common_internal::StrongUnref(refcount); + Owner owner2; + owner2 = std::move(owner1); + Borrower borrower(owner2); + EXPECT_EQ(owner2, borrower); + EXPECT_EQ(borrower, owner2); +} + +TEST(Unique, ToAddress) { + Unique unique; + EXPECT_EQ(cel::to_address(unique), nullptr); + unique = AllocateUnique(NewDeleteAllocator<>{}); + EXPECT_EQ(cel::to_address(unique), unique.operator->()); +} + +class OwnedTest : public TestWithParam { + public: + Allocator<> GetAllocator() { + switch (GetParam()) { + case AllocatorKind::kArena: + return ArenaAllocator<>{&arena_}; + case AllocatorKind::kNewDelete: + return NewDeleteAllocator<>{}; + } + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_P(OwnedTest, Default) { + Owned owned; + EXPECT_FALSE(owned); + EXPECT_EQ(cel::to_address(owned), nullptr); + EXPECT_FALSE(owned != nullptr); + EXPECT_FALSE(nullptr != owned); +} + +class TestData final : public Data { + public: + using InternalArenaConstructable_ = void; + using DestructorSkippable_ = void; + + TestData() noexcept : Data() {} + + explicit TestData(google::protobuf::Arena* absl_nullable arena) noexcept + : Data(arena) {} +}; + +TEST_P(OwnedTest, AllocateSharedData) { + auto owned = AllocateShared(GetAllocator()); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); + EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, AllocateSharedMessageLite) { + auto owned = AllocateShared(GetAllocator()); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); + EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, WrapSharedData) { + auto owned = + WrapShared(google::protobuf::Arena::Create(GetAllocator().arena())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); + EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, WrapSharedMessageLite) { + auto owned = WrapShared( + google::protobuf::Arena::Create(GetAllocator().arena())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); + EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, SharedFromUniqueData) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); + EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, SharedFromUniqueMessageLite) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); + EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, CopyConstruct) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned copied_owned(owned); + EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, MoveConstruct) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned moved_owned(std::move(owned)); + EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, CopyConstructOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned copied_owned(owned); + EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, MoveConstructOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned moved_owned(std::move(owned)); + EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, ConstructBorrowed) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned borrowed_owned(Borrowed{owned}); + EXPECT_EQ(borrowed_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, ConstructOwner) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned owner_owned(Owner(owned), cel::to_address(owned)); + EXPECT_EQ(owner_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, ConstructNullPtr) { + Owned owned(nullptr); + EXPECT_EQ(owned, nullptr); +} + +TEST_P(OwnedTest, CopyAssign) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned copied_owned; + copied_owned = owned; + EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, MoveAssign) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned moved_owned; + moved_owned = std::move(owned); + EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, CopyAssignOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned copied_owned; + copied_owned = owned; + EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, MoveAssignOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned moved_owned; + moved_owned = std::move(owned); + EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, AssignBorrowed) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Owned borrowed_owned; + borrowed_owned = Borrowed{owned}; + EXPECT_EQ(borrowed_owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, AssignUnique) { + Owned owned; + owned = AllocateUnique(GetAllocator()); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); +} + +TEST_P(OwnedTest, AssignNullPtr) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + EXPECT_TRUE(owned); + owned = nullptr; + EXPECT_FALSE(owned); +} + +INSTANTIATE_TEST_SUITE_P(OwnedTest, OwnedTest, + ::testing::Values(AllocatorKind::kArena, + AllocatorKind::kNewDelete)); + +class BorrowedTest : public TestWithParam { + public: + Allocator<> GetAllocator() { + switch (GetParam()) { + case AllocatorKind::kArena: + return ArenaAllocator<>{&arena_}; + case AllocatorKind::kNewDelete: + return NewDeleteAllocator<>{}; + } + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_P(BorrowedTest, Default) { + Borrowed borrowed; + EXPECT_FALSE(borrowed); + EXPECT_EQ(cel::to_address(borrowed), nullptr); + EXPECT_FALSE(borrowed != nullptr); + EXPECT_FALSE(nullptr != borrowed); +} + +TEST_P(BorrowedTest, CopyConstruct) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed copied_borrowed(borrowed); + EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, MoveConstruct) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed moved_borrowed(std::move(borrowed)); + EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, CopyConstructOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed copied_borrowed(borrowed); + EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, MoveConstructOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed moved_borrowed(std::move(borrowed)); + EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, ConstructNullPtr) { + Borrowed borrowed(nullptr); + EXPECT_FALSE(borrowed); +} + +TEST_P(BorrowedTest, CopyAssign) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed copied_borrowed; + copied_borrowed = borrowed; + EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, MoveAssign) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed moved_borrowed; + moved_borrowed = std::move(borrowed); + EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, CopyAssignOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed copied_borrowed; + copied_borrowed = borrowed; + EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, MoveAssignOther) { + auto owned = Owned(AllocateUnique(GetAllocator())); + auto borrowed = Borrowed(owned); + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); + Borrowed moved_borrowed; + moved_borrowed = std::move(borrowed); + EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, AssignOwned) { + auto owned = Owned(AllocateUnique(GetAllocator())); + EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); + Borrowed borrowed = owned; + EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); +} + +TEST_P(BorrowedTest, AssignNullPtr) { + Borrowed borrowed; + borrowed = nullptr; + EXPECT_FALSE(borrowed); +} + +INSTANTIATE_TEST_SUITE_P(BorrowedTest, BorrowedTest, + ::testing::Values(AllocatorKind::kArena, + AllocatorKind::kNewDelete)); + +} // namespace +} // namespace cel diff --git a/common/memory_testing.h b/common/memory_testing.h new file mode 100644 index 000000000..37244dd8f --- /dev/null +++ b/common/memory_testing.h @@ -0,0 +1,71 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_MEMORY_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_MEMORY_TESTING_H_ + +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +template +class ThreadCompatibleMemoryTest + : public ::testing::TestWithParam> { + public: + void SetUp() override {} + + void TearDown() override { Finish(); } + + MemoryManagement memory_management() { return std::get<0>(this->GetParam()); } + + MemoryManagerRef memory_manager() { + switch (memory_management()) { + case MemoryManagement::kReferenceCounting: + return MemoryManager::ReferenceCounting(); + break; + case MemoryManagement::kPooling: + if (!arena_) { + arena_.emplace(); + } + return MemoryManager::Pooling(&*arena_); + break; + } + } + + void Finish() { arena_.reset(); } + + static std::string ToString( + ::testing::TestParamInfo> param) { + return absl::StrJoin(param.param, "_", absl::StreamFormatter()); + } + + protected: + virtual MemoryManager NewThreadCompatiblePoolingMemoryManager() { + return MemoryManager::Pooling(&*arena_); + } + + private: + absl::optional arena_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_MEMORY_TESTING_H_ diff --git a/common/minimal_descriptor_database.cc b/common/minimal_descriptor_database.cc new file mode 100644 index 000000000..20c9bf6b1 --- /dev/null +++ b/common/minimal_descriptor_database.cc @@ -0,0 +1,27 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/minimal_descriptor_database.h" + +#include "absl/base/nullability.h" +#include "internal/minimal_descriptor_database.h" +#include "google/protobuf/descriptor_database.h" + +namespace cel { + +google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase() { + return internal::GetMinimalDescriptorDatabase(); +} + +} // namespace cel diff --git a/common/minimal_descriptor_database.h b/common/minimal_descriptor_database.h new file mode 100644 index 000000000..ba0dbc3b7 --- /dev/null +++ b/common/minimal_descriptor_database.h @@ -0,0 +1,32 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_DATABASE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_DATABASE_H_ + +#include "absl/base/nullability.h" +#include "google/protobuf/descriptor_database.h" + +namespace cel { + +// GetMinimalDescriptorDatabase returns a pointer to a +// `google::protobuf::DescriptorDatabase` which includes has the minimally necessary +// descriptors required by the Common Expression Language. The returned +// `google::protobuf::DescriptorDatabase` is valid for the lifetime of the process and +// should not be deleted. +google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_DATABASE_H_ diff --git a/common/minimal_descriptor_database_test.cc b/common/minimal_descriptor_database_test.cc new file mode 100644 index 000000000..e91d73cf6 --- /dev/null +++ b/common/minimal_descriptor_database_test.cc @@ -0,0 +1,139 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/minimal_descriptor_database.h" + +#include "google/protobuf/descriptor.pb.h" +#include "internal/testing.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::testing::IsTrue; + +TEST(GetMinimalDescriptorDatabase, NullValue) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.NullValue", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, BoolValue) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.BoolValue", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, Int32Value) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.Int32Value", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, Int64Value) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.Int64Value", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, UInt32Value) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.UInt32Value", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, UInt64Value) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.UInt64Value", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, FloatValue) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.FloatValue", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, DoubleValue) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.DoubleValue", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, BytesValue) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.BytesValue", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, StringValue) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.StringValue", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, Any) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.Any", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, Duration) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.Duration", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, Timestamp) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.Timestamp", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, Value) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.Value", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, ListValue) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.ListValue", &fd), + IsTrue()); +} + +TEST(GetMinimalDescriptorDatabase, Struct) { + google::protobuf::FileDescriptorProto fd; + EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( + "google.protobuf.Struct", &fd), + IsTrue()); +} + +} // namespace +} // namespace cel diff --git a/common/minimal_descriptor_pool.cc b/common/minimal_descriptor_pool.cc new file mode 100644 index 000000000..e52614acb --- /dev/null +++ b/common/minimal_descriptor_pool.cc @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/minimal_descriptor_pool.h" + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "internal/minimal_descriptor_pool.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool() { + return internal::GetMinimalDescriptorPool(); +} + +// If required, adds the minimally required descriptors to the pool. +absl::Status AddMinimumRequiredDescriptorsToPool( + google::protobuf::DescriptorPool* absl_nonnull pool) { + return internal::AddMinimumRequiredDescriptorsToPool(pool); +} + +} // namespace cel diff --git a/common/minimal_descriptor_pool.h b/common/minimal_descriptor_pool.h new file mode 100644 index 000000000..e1582f36a --- /dev/null +++ b/common/minimal_descriptor_pool.h @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// GetMinimalDescriptorPool returns a pointer to a `google::protobuf::DescriptorPool` +// which includes has the minimally necessary descriptors required by the Common +// Expression Language. The returned `google::protobuf::DescriptorPool` is valid for the +// lifetime of the process and should not be deleted. +const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool(); + +// If required, adds the minimally required descriptors to the pool. +absl::Status AddMinimumRequiredDescriptorsToPool( + google::protobuf::DescriptorPool* absl_nonnull pool); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ diff --git a/common/minimal_descriptor_pool_test.cc b/common/minimal_descriptor_pool_test.cc new file mode 100644 index 000000000..c8932505e --- /dev/null +++ b/common/minimal_descriptor_pool_test.cc @@ -0,0 +1,184 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/minimal_descriptor_pool.h" + +#include "absl/status/status_matchers.h" +#include "internal/testing.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::NotNull; + +TEST(GetMinimalDescriptorPool, NullValue) { + ASSERT_THAT(GetMinimalDescriptorPool()->FindEnumTypeByName( + "google.protobuf.NullValue"), + NotNull()); +} + +TEST(GetMinimalDescriptorPool, BoolValue) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.BoolValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE); +} + +TEST(GetMinimalDescriptorPool, Int32Value) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Int32Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE); +} + +TEST(GetMinimalDescriptorPool, Int64Value) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Int64Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE); +} + +TEST(GetMinimalDescriptorPool, UInt32Value) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.UInt32Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE); +} + +TEST(GetMinimalDescriptorPool, UInt64Value) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.UInt64Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE); +} + +TEST(GetMinimalDescriptorPool, FloatValue) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.FloatValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE); +} + +TEST(GetMinimalDescriptorPool, DoubleValue) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.DoubleValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE); +} + +TEST(GetMinimalDescriptorPool, BytesValue) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.BytesValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE); +} + +TEST(GetMinimalDescriptorPool, StringValue) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.StringValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE); +} + +TEST(GetMinimalDescriptorPool, Any) { + const auto* desc = + GetMinimalDescriptorPool()->FindMessageTypeByName("google.protobuf.Any"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_ANY); +} + +TEST(GetMinimalDescriptorPool, Duration) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Duration"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION); +} + +TEST(GetMinimalDescriptorPool, Timestamp) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Timestamp"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP); +} + +TEST(GetMinimalDescriptorPool, Value) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); +} + +TEST(GetMinimalDescriptorPool, ListValue) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.ListValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); +} + +TEST(GetMinimalDescriptorPool, Struct) { + const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Struct"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); +} + +TEST(AddMinimumRequiredDescriptorsToPool, Adds) { + google::protobuf::DescriptorPool pool; + ASSERT_THAT(AddMinimumRequiredDescriptorsToPool(&pool), IsOk()); + EXPECT_THAT(pool.FindEnumTypeByName("google.protobuf.NullValue"), NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.BoolValue"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Int32Value"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Int64Value"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.UInt32Value"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.UInt64Value"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.FloatValue"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.DoubleValue"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.BytesValue"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.StringValue"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Any"), NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Duration"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Timestamp"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Value"), NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.ListValue"), + NotNull()); + EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Struct"), NotNull()); +} + +} // namespace +} // namespace cel diff --git a/common/native_type.h b/common/native_type.h new file mode 100644 index 000000000..96c53c1da --- /dev/null +++ b/common/native_type.h @@ -0,0 +1,26 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_NATIVE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_NATIVE_TYPE_H_ + +#include "common/typeinfo.h" + +namespace cel { + +using NativeTypeId = TypeInfo; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_NATIVE_TYPE_H_ diff --git a/common/navigable_ast.cc b/common/navigable_ast.cc new file mode 100644 index 000000000..941c37921 --- /dev/null +++ b/common/navigable_ast.cc @@ -0,0 +1,202 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/navigable_ast.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "common/ast/navigable_ast_internal.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor.h" +#include "common/ast_visitor_base.h" +#include "common/expr.h" + +namespace cel { + +namespace { + +using NavigableAstNodeData = + common_internal::NavigableAstNodeData; +using NavigableAstMetadata = + common_internal::NavigableAstMetadata; + +NodeKind GetNodeKind(const Expr& expr) { + switch (expr.kind_case()) { + case ExprKindCase::kConstant: + return NodeKind::kConstant; + case ExprKindCase::kIdentExpr: + return NodeKind::kIdent; + case ExprKindCase::kSelectExpr: + return NodeKind::kSelect; + case ExprKindCase::kCallExpr: + return NodeKind::kCall; + case ExprKindCase::kListExpr: + return NodeKind::kList; + case ExprKindCase::kStructExpr: + return NodeKind::kStruct; + case ExprKindCase::kMapExpr: + return NodeKind::kMap; + case ExprKindCase::kComprehensionExpr: + return NodeKind::kComprehension; + case ExprKindCase::kUnspecifiedExpr: + default: + return NodeKind::kUnspecified; + } +} + +// Get the traversal relationship from parent to the given node. +// Note: these depend on the ast_visitor utility's traversal ordering. +ChildKind GetChildKind(const NavigableAstNodeData& parent_node, + size_t child_index, + absl::optional comprehension_arg) { + switch (parent_node.node_kind) { + case NodeKind::kStruct: + return ChildKind::kStructValue; + case NodeKind::kMap: + if (child_index % 2 == 0) { + return ChildKind::kMapKey; + } + return ChildKind::kMapValue; + case NodeKind::kList: + return ChildKind::kListElem; + case NodeKind::kSelect: + return ChildKind::kSelectOperand; + case NodeKind::kCall: + if (child_index == 0 && parent_node.expr->call_expr().has_target()) { + return ChildKind::kCallReceiver; + } + return ChildKind::kCallArg; + case NodeKind::kComprehension: + if (!comprehension_arg.has_value()) { + return ChildKind::kUnspecified; + } + switch (*comprehension_arg) { + case ComprehensionArg::ITER_RANGE: + return ChildKind::kComprehensionRange; + case ComprehensionArg::ACCU_INIT: + return ChildKind::kComprehensionInit; + case ComprehensionArg::LOOP_CONDITION: + return ChildKind::kComprehensionCondition; + case ComprehensionArg::LOOP_STEP: + return ChildKind::kComprehensionLoopStep; + case ComprehensionArg::RESULT: + return ChildKind::kComprensionResult; + default: + return ChildKind::kUnspecified; + } + default: + return ChildKind::kUnspecified; + } +} + +class NavigableExprBuilderVisitor : public cel::AstVisitorBase { + public: + NavigableExprBuilderVisitor( + absl::AnyInvocable()> node_factory, + absl::AnyInvocable + node_data_accessor) + : node_factory_(std::move(node_factory)), + node_data_accessor_(std::move(node_data_accessor)), + metadata_(std::make_unique()) {} + + NavigableAstNodeData& NodeDataAt(size_t index) { + return node_data_accessor_(*metadata_->nodes[index]); + } + + void PreVisitExpr(const Expr& expr) override { + NavigableAstNode* parent = + parent_stack_.empty() ? nullptr + : metadata_->nodes[parent_stack_.back()].get(); + size_t index = metadata_->nodes.size(); + metadata_->nodes.push_back(node_factory_()); + NavigableAstNode* node = metadata_->nodes[index].get(); + auto& node_data = NodeDataAt(index); + node_data.parent = parent; + node_data.expr = &expr; + node_data.parent_relation = ChildKind::kUnspecified; + node_data.node_kind = GetNodeKind(expr); + node_data.tree_size = 1; + node_data.height = 1; + node_data.index = index; + node_data.child_index = -1; + node_data.metadata = metadata_.get(); + + metadata_->id_to_node.insert({expr.id(), node}); + metadata_->expr_to_node.insert({&expr, node}); + if (!parent_stack_.empty()) { + auto& parent_node_data = NodeDataAt(parent_stack_.back()); + size_t child_index = parent_node_data.children.size(); + parent_node_data.children.push_back(node); + node_data.parent_relation = + GetChildKind(parent_node_data, child_index, comprehension_arg_); + node_data.child_index = child_index; + } + parent_stack_.push_back(index); + } + + void PreVisitComprehensionSubexpression( + const Expr& expr, const ComprehensionExpr& comprehension, + ComprehensionArg comprehension_arg) override { + comprehension_arg_ = comprehension_arg; + } + + void PostVisitExpr(const Expr& expr) override { + size_t idx = parent_stack_.back(); + parent_stack_.pop_back(); + metadata_->postorder.push_back(metadata_->nodes[idx].get()); + NavigableAstNodeData& node = NodeDataAt(idx); + if (!parent_stack_.empty()) { + auto& parent_node_data = NodeDataAt(parent_stack_.back()); + parent_node_data.tree_size += node.tree_size; + parent_node_data.height = + std::max(parent_node_data.height, node.height + 1); + } + } + + std::unique_ptr Consume() && { + return std::move(metadata_); + } + + private: + absl::AnyInvocable()> node_factory_; + absl::AnyInvocable + node_data_accessor_; + std::unique_ptr metadata_; + std::vector parent_stack_; + absl::optional comprehension_arg_; +}; + +} // namespace + +NavigableAst NavigableAst::Build(const Expr& expr) { + cel::TraversalOptions opts; + opts.use_comprehension_callbacks = true; + NavigableExprBuilderVisitor visitor( + []() { return absl::WrapUnique(new NavigableAstNode()); }, + [](NavigableAstNode& node) -> NavigableAstNodeData& { + return node.data_; + }); + AstTraverse(expr, visitor, opts); + return NavigableAst(std::move(visitor).Consume()); +} + +} // namespace cel diff --git a/common/navigable_ast.h b/common/navigable_ast.h new file mode 100644 index 000000000..a8c608e24 --- /dev/null +++ b/common/navigable_ast.h @@ -0,0 +1,168 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_NAVIGABLE_AST_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_NAVIGABLE_AST_H_ + +#include "common/ast/navigable_ast_internal.h" +#include "common/ast/navigable_ast_kinds.h" // IWYU pragma: export +#include "common/expr.h" + +namespace cel { + +class NavigableAst; +class NavigableAstNode; + +namespace common_internal { + +struct NativeAstTraits { + using ExprType = Expr; + using AstType = NavigableAst; + using NodeType = NavigableAstNode; +}; + +} // namespace common_internal + +// Wrapper around a CEL AST node that exposes traversal information. +class NavigableAstNode : public common_internal::NavigableAstNodeBase< + common_internal::NativeAstTraits> { + private: + using Base = + common_internal::NavigableAstNodeBase; + + public: + // A const Span like type that provides pre-order traversal for a sub tree. + // provides .begin() and .end() returning bidirectional iterators to + // const AstNode&. + using PreorderRange = Base::PreorderRange; + + // A const Span like type that provides post-order traversal for a sub tree. + // provides .begin() and .end() returning bidirectional iterators to + // const AstNode&. + using PostorderRange = Base::PostorderRange; + + // The parent of this node or nullptr if it is a root. + using Base::parent; + + // The ptr to the backing Expr in the source AST. + // + // This may dangle if the source AST is mutated or destroyed. + using Base::expr; + + // The index of this node in the parent's children. -1 if this is a root. + using Base::child_index; + + // The type of traversal from parent to this node. + using Base::parent_relation; + + // The type of this node, analogous to Expr::ExprKindCase. + using Base::node_kind; + + // The number of nodes in the tree rooted at this node (including self). + using Base::tree_size; + + // The height of this node in the tree (the number of descendants including + // self on the longest path). + using Base::height; + + // The children of this node in their natural order. + using Base::children; + + // Range over the descendants of this node (including self) using preorder + // semantics. Each node is visited immediately before all of its descendants. + // + // example: + // for (const cel::NavigableAstNode& node : + // ast.Root().DescendantsPreorder()) { + // ... + // } + // + // Children are traversed in their natural order: + // - call arguments are traversed in order (receiver if present is first) + // - list elements are traversed in order + // - maps are traversed in order (alternating key, value per entry) + // - comprehensions are traversed in the order: range, accu_init, condition, + // step, result + using Base::DescendantsPreorder; + + // Range over the descendants of this node (including self) using postorder + // semantics. Each node is visited immediately after all of its descendants. + using Base::DescendantsPostorder; + + private: + friend class NavigableAst; + + NavigableAstNode() = default; +}; + +// NavigableExpr provides a view over a CEL AST that allows for generalized +// traversal. The traversal structures are eagerly built on construction, +// requiring a full traversal of the AST. This is intended for use in tools that +// might require random access or multiple passes over the AST, amortizing the +// cost of building the traversal structures. +// +// Pointers to AstNodes are owned by this instance and must not outlive it. +// +// `NavigableAst` and Navigable nodes are independent of the input Expr and may +// outlive it, but may contain dangling pointers if the input Expr is modified +// or destroyed. +class NavigableAst : public common_internal::NavigableAstBase< + common_internal::NativeAstTraits> { + private: + using Base = + common_internal::NavigableAstBase; + + public: + static NavigableAst Build(const Expr& expr); + + // Default constructor creates an empty instance. + // + // Operations other than equality are undefined on an empty instance. + // + // This is intended for composed object construction, a new NavigableAst + // should be obtained from the Build factory function. + NavigableAst() = default; + + // Move only. + NavigableAst(const NavigableAst&) = delete; + NavigableAst& operator=(const NavigableAst&) = delete; + NavigableAst(NavigableAst&&) = default; + NavigableAst& operator=(NavigableAst&&) = default; + + // Return ptr to the AST node with id if present. Otherwise returns nullptr. + // + // If ids are non-unique, the first pre-order node encountered with id is + // returned. + using Base::FindId; + + // Return ptr to the AST node representing the given Expr node. + using Base::FindExpr; + + // Returns the root of the AST. + using Base::Root; + + // Return whether the source AST used unique IDs for each node. + // + // This is typically the case, but older versions of the parsers didn't + // guarantee uniqueness for nodes generated by some macros and ASTs modified + // outside of CEL's parse/type check may not have unique IDs. + using Base::IdsAreUnique; + + private: + using Base::Base; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_NAVIGABLE_AST_H_ diff --git a/common/navigable_ast_test.cc b/common/navigable_ast_test.cc new file mode 100644 index 000000000..2891a105d --- /dev/null +++ b/common/navigable_ast_test.cc @@ -0,0 +1,410 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/navigable_ast.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/source.h" +#include "common/standard_definitions.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::SizeIs; + +absl::StatusOr> Parse(absl::string_view expr) { + static const auto* parser = cel::NewParserBuilder()->Build()->release(); + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expr)); + return parser->Parse(*source); +} + +TEST(NavigableAst, Basic) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr().set_int_value(42); + + NavigableAst ast = NavigableAst::Build(const_node); + EXPECT_TRUE(ast.IdsAreUnique()); + + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.expr(), &const_node); + EXPECT_THAT(root.children(), IsEmpty()); + EXPECT_TRUE(root.parent() == nullptr); + EXPECT_EQ(root.child_index(), -1); + EXPECT_EQ(root.node_kind(), NodeKind::kConstant); + EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); +} + +TEST(NavigableAst, DefaultCtorEmpty) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr().set_int_value(42); + + NavigableAst ast = NavigableAst::Build(const_node); + EXPECT_EQ(ast, ast); + + NavigableAst empty; + + EXPECT_NE(ast, empty); + EXPECT_EQ(empty, empty); + + EXPECT_TRUE(static_cast(ast)); + EXPECT_FALSE(static_cast(empty)); + + NavigableAst moved = std::move(ast); + EXPECT_EQ(ast, empty); + EXPECT_FALSE(static_cast(ast)); + EXPECT_TRUE(static_cast(moved)); +} + +TEST(NavigableAst, FindById) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr().set_int_value(42); + + NavigableAst ast = NavigableAst::Build(const_node); + + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(ast.FindId(const_node.id()), &root); + EXPECT_EQ(ast.FindId(-1), nullptr); +} + +MATCHER_P(AstNodeWrapping, expr, "") { + const NavigableAstNode* ptr = arg; + return ptr != nullptr && ptr->expr() == expr; +} + +TEST(NavigableAst, ToleratesNonUnique) { + Expr call_node; + call_node.set_id(1); + call_node.mutable_call_expr().set_function(cel::StandardFunctions::kNot); + Expr* const_node = + &call_node.mutable_call_expr().mutable_args().emplace_back(); + const_node->mutable_const_expr().set_bool_value(false); + const_node->set_id(1); + + NavigableAst ast = NavigableAst::Build(call_node); + + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(ast.FindId(1), &root); + EXPECT_EQ(ast.FindExpr(&call_node), &root); + EXPECT_FALSE(ast.IdsAreUnique()); + EXPECT_THAT(ast.FindExpr(const_node), AstNodeWrapping(const_node)); +} + +TEST(NavigableAst, FindByExprPtr) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr().set_int_value(42); + + NavigableAst ast = NavigableAst::Build(const_node); + + const NavigableAstNode& root = ast.Root(); + + Expr other_expr; + + EXPECT_EQ(ast.FindExpr(&const_node), &root); + EXPECT_EQ(ast.FindExpr(&other_expr), nullptr); +} + +TEST(NavigableAst, Children) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + 2")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.expr(), &parsed_expr->root_expr()); + EXPECT_THAT(root.children(), SizeIs(2)); + EXPECT_TRUE(root.parent() == nullptr); + EXPECT_EQ(root.child_index(), -1); + EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); + EXPECT_EQ(root.node_kind(), NodeKind::kCall); + + EXPECT_THAT( + root.children(), + ElementsAre( + AstNodeWrapping(&parsed_expr->root_expr().call_expr().args().at(0)), + AstNodeWrapping(&parsed_expr->root_expr().call_expr().args().at(1)))); + + ASSERT_THAT(root.children(), SizeIs(2)); + const auto* child1 = root.children()[0]; + EXPECT_EQ(child1->child_index(), 0); + EXPECT_EQ(child1->parent(), &root); + EXPECT_EQ(child1->parent_relation(), ChildKind::kCallArg); + EXPECT_EQ(child1->node_kind(), NodeKind::kConstant); + EXPECT_THAT(child1->children(), IsEmpty()); + + const auto* child2 = root.children()[1]; + EXPECT_EQ(child2->child_index(), 1); +} + +TEST(NavigableAst, UnspecifiedExpr) { + Expr expr; + expr.set_id(1); + NavigableAst ast = NavigableAst::Build(expr); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.expr(), &expr); + EXPECT_THAT(root.children(), SizeIs(0)); + EXPECT_TRUE(root.parent() == nullptr); + EXPECT_EQ(root.child_index(), -1); + EXPECT_EQ(root.node_kind(), NodeKind::kUnspecified); +} + +TEST(NavigableAst, ParentRelationSelect) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kSelectOperand); + EXPECT_EQ(child->node_kind(), NodeKind::kIdent); +} + +TEST(NavigableAst, ParentRelationCallReceiver) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b()")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kCallReceiver); + EXPECT_EQ(child->node_kind(), NodeKind::kIdent); +} + +TEST(NavigableAst, ParentRelationCreateStruct) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + Parse("com.example.Type{field: '123'}")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kStruct); + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kStructValue); + EXPECT_EQ(child->node_kind(), NodeKind::kConstant); +} + +TEST(NavigableAst, ParentRelationCreateMap) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'a': 123}")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kMap); + ASSERT_THAT(root.children(), SizeIs(2)); + const auto* key = root.children()[0]; + const auto* value = root.children()[1]; + + EXPECT_EQ(key->parent_relation(), ChildKind::kMapKey); + EXPECT_EQ(key->node_kind(), NodeKind::kConstant); + + EXPECT_EQ(value->parent_relation(), ChildKind::kMapValue); + EXPECT_EQ(value->node_kind(), NodeKind::kConstant); +} + +TEST(NavigableAst, ParentRelationCreateList) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[123]")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kList); + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kListElem); + EXPECT_EQ(child->node_kind(), NodeKind::kConstant); +} + +TEST(NavigableAst, ParentRelationComprehension) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1].all(x, x < 2)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + ASSERT_THAT(root.children(), SizeIs(5)); + const auto* range = root.children()[0]; + const auto* init = root.children()[1]; + const auto* condition = root.children()[2]; + const auto* step = root.children()[3]; + const auto* finish = root.children()[4]; + + EXPECT_EQ(range->parent_relation(), ChildKind::kComprehensionRange); + EXPECT_EQ(init->parent_relation(), ChildKind::kComprehensionInit); + EXPECT_EQ(condition->parent_relation(), ChildKind::kComprehensionCondition); + EXPECT_EQ(step->parent_relation(), ChildKind::kComprehensionLoopStep); + EXPECT_EQ(finish->parent_relation(), ChildKind::kComprensionResult); +} + +TEST(NavigableAst, DescendantsPostorder) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kCall); + + std::vector constants; + std::vector node_kinds; + + for (const NavigableAstNode& node : root.DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kConstant) { + constants.push_back(node.expr()->const_expr().int64_value()); + } + node_kinds.push_back(node.node_kind()); + } + + EXPECT_THAT(node_kinds, ElementsAre(NodeKind::kConstant, NodeKind::kIdent, + NodeKind::kConstant, NodeKind::kCall, + NodeKind::kCall)); + EXPECT_THAT(constants, ElementsAre(1, 3)); +} + +TEST(NavigableAst, DescendantsPreorder) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kCall); + + std::vector constants; + std::vector node_kinds; + + for (const NavigableAstNode& node : root.DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kConstant) { + constants.push_back(node.expr()->const_expr().int64_value()); + } + node_kinds.push_back(node.node_kind()); + } + + EXPECT_THAT(node_kinds, + ElementsAre(NodeKind::kCall, NodeKind::kConstant, NodeKind::kCall, + NodeKind::kIdent, NodeKind::kConstant)); + EXPECT_THAT(constants, ElementsAre(1, 3)); +} + +TEST(NavigableAst, DescendantsPreorderComprehension) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + for (const NavigableAstNode& node : root.DescendantsPreorder()) { + node_kinds.push_back( + std::make_pair(node.node_kind(), node.parent_relation())); + } + + EXPECT_THAT( + node_kinds, + ElementsAre(Pair(NodeKind::kComprehension, ChildKind::kUnspecified), + Pair(NodeKind::kList, ChildKind::kComprehensionRange), + Pair(NodeKind::kConstant, ChildKind::kListElem), + Pair(NodeKind::kConstant, ChildKind::kListElem), + Pair(NodeKind::kConstant, ChildKind::kListElem), + Pair(NodeKind::kList, ChildKind::kComprehensionInit), + Pair(NodeKind::kConstant, ChildKind::kComprehensionCondition), + Pair(NodeKind::kCall, ChildKind::kComprehensionLoopStep), + Pair(NodeKind::kIdent, ChildKind::kCallArg), + Pair(NodeKind::kList, ChildKind::kCallArg), + Pair(NodeKind::kCall, ChildKind::kListElem), + Pair(NodeKind::kIdent, ChildKind::kCallArg), + Pair(NodeKind::kConstant, ChildKind::kCallArg), + Pair(NodeKind::kIdent, ChildKind::kComprensionResult))); +} + +TEST(NavigableAst, TreeSize) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + EXPECT_EQ(root.tree_size(), 14); + auto it = root.DescendantsPostorder().begin(); + EXPECT_EQ(it->tree_size(), 1); +} + +TEST(NavigableAst, Height) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + EXPECT_EQ(root.height(), 5); + auto it = root.DescendantsPostorder().begin(); + EXPECT_EQ(it->height(), 1); +} + +TEST(NavigableAst, DescendantsPreorderCreateMap) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'key1': 1, 'key2': 2}")); + + NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); + const NavigableAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kMap); + + std::vector> node_kinds; + + for (const NavigableAstNode& node : root.DescendantsPreorder()) { + node_kinds.push_back( + std::make_pair(node.node_kind(), node.parent_relation())); + } + + EXPECT_THAT(node_kinds, + ElementsAre(Pair(NodeKind::kMap, ChildKind::kUnspecified), + Pair(NodeKind::kConstant, ChildKind::kMapKey), + Pair(NodeKind::kConstant, ChildKind::kMapValue), + Pair(NodeKind::kConstant, ChildKind::kMapKey), + Pair(NodeKind::kConstant, ChildKind::kMapValue))); +} + +} // namespace +} // namespace cel diff --git a/common/operators.cc b/common/operators.cc index 669469c9a..2e2ab47d3 100644 --- a/common/operators.cc +++ b/common/operators.cc @@ -1,11 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "common/operators.h" -#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" -namespace google { -namespace api { -namespace expr { -namespace common { +#undef IN + +namespace google::api::expr::common { namespace { // These functions provide access to reverse mappings for operators. @@ -13,127 +30,106 @@ namespace { // e.g., from "&&" to "_&&_". Reverse operators provides a mapping from // Expr to textual mapping, e.g., from "_&&_" to "&&". -const std::map& UnaryOperators() { - static std::shared_ptr> unaries_map = - [&]() { - auto u = std::make_shared>( - std::map{ - {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}); - return u; - }(); +const absl::flat_hash_map& UnaryOperators() { + static auto* unaries_map = new absl::flat_hash_map{ + {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}; return *unaries_map; } -const std::map& BinaryOperators() { - static std::shared_ptr> binops_map = - [&]() { - auto c = std::make_shared>( - std::map{ - {CelOperator::LOGICAL_OR, "||"}, - {CelOperator::LOGICAL_AND, "&&"}, - {CelOperator::LESS_EQUALS, "<="}, - {CelOperator::LESS, "<"}, - {CelOperator::GREATER_EQUALS, ">="}, - {CelOperator::GREATER, ">"}, - {CelOperator::EQUALS, "=="}, - {CelOperator::NOT_EQUALS, "!="}, - {CelOperator::IN_DEPRECATED, "in"}, - {CelOperator::IN, "in"}, - {CelOperator::ADD, "+"}, - {CelOperator::SUBTRACT, "-"}, - {CelOperator::MULTIPLY, "*"}, - {CelOperator::DIVIDE, "/"}, - {CelOperator::MODULO, "%"}}); - return c; - }(); +const absl::flat_hash_map& BinaryOperators() { + static auto* binops_map = new absl::flat_hash_map{ + {CelOperator::LOGICAL_OR, "||"}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LESS, "<"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::GREATER, ">"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::IN, "in"}, + {CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}}; return *binops_map; } -const std::map& ReverseOperators() { - static std::shared_ptr> operators_map = - [&]() { - auto c = std::make_shared>( - std::map{ - {"+", CelOperator::ADD}, - {"-", CelOperator::SUBTRACT}, - {"*", CelOperator::MULTIPLY}, - {"/", CelOperator::DIVIDE}, - {"%", CelOperator::MODULO}, - {"==", CelOperator::EQUALS}, - {"!=", CelOperator::NOT_EQUALS}, - {">", CelOperator::GREATER}, - {">=", CelOperator::GREATER_EQUALS}, - {"<", CelOperator::LESS}, - {"<=", CelOperator::LESS_EQUALS}, - {"&&", CelOperator::LOGICAL_AND}, - {"!", CelOperator::LOGICAL_NOT}, - {"||", CelOperator::LOGICAL_OR}, - {"in", CelOperator::IN}, - }); - return c; - }(); +const absl::flat_hash_map& ReverseOperators() { + static auto* operators_map = + new absl::flat_hash_map{ + {"+", CelOperator::ADD}, + {"-", CelOperator::SUBTRACT}, + {"*", CelOperator::MULTIPLY}, + {"/", CelOperator::DIVIDE}, + {"%", CelOperator::MODULO}, + {"==", CelOperator::EQUALS}, + {"!=", CelOperator::NOT_EQUALS}, + {">", CelOperator::GREATER}, + {">=", CelOperator::GREATER_EQUALS}, + {"<", CelOperator::LESS}, + {"<=", CelOperator::LESS_EQUALS}, + {"&&", CelOperator::LOGICAL_AND}, + {"!", CelOperator::LOGICAL_NOT}, + {"||", CelOperator::LOGICAL_OR}, + {"in", CelOperator::IN}, + }; return *operators_map; } -const std::map& Operators() { - static std::shared_ptr> operators_map = - [&]() { - auto c = std::make_shared>( - std::map{ - {CelOperator::ADD, "+"}, - {CelOperator::SUBTRACT, "-"}, - {CelOperator::MULTIPLY, "*"}, - {CelOperator::DIVIDE, "/"}, - {CelOperator::MODULO, "%"}, - {CelOperator::EQUALS, "=="}, - {CelOperator::NOT_EQUALS, "!="}, - {CelOperator::GREATER, ">"}, - {CelOperator::GREATER_EQUALS, ">="}, - {CelOperator::LESS, "<"}, - {CelOperator::LESS_EQUALS, "<="}, - {CelOperator::LOGICAL_AND, "&&"}, - {CelOperator::LOGICAL_NOT, "!"}, - {CelOperator::LOGICAL_OR, "||"}, - {CelOperator::IN, "in"}, - {CelOperator::IN_DEPRECATED, "in"}, - {CelOperator::NEGATE, "-"}}); - return c; - }(); +const absl::flat_hash_map& Operators() { + static auto* operators_map = + new absl::flat_hash_map{ + {CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::GREATER, ">"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::LESS, "<"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LOGICAL_NOT, "!"}, + {CelOperator::LOGICAL_OR, "||"}, + {CelOperator::IN, "in"}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::NEGATE, "-"}}; return *operators_map; } // precedence of the operator, where the higher value means higher. -const std::map& Precedences() { - static std::shared_ptr> precedence_map = [&]() { - auto c = std::make_shared>( - std::map{{CelOperator::CONDITIONAL, 8}, +const absl::flat_hash_map& Precedences() { + static auto* precedence_map = new absl::flat_hash_map{ + {CelOperator::CONDITIONAL, 8}, - {CelOperator::LOGICAL_OR, 7}, + {CelOperator::LOGICAL_OR, 7}, - {CelOperator::LOGICAL_AND, 6}, + {CelOperator::LOGICAL_AND, 6}, - {CelOperator::EQUALS, 5}, - {CelOperator::GREATER, 5}, - {CelOperator::GREATER_EQUALS, 5}, - {CelOperator::IN, 5}, - {CelOperator::LESS, 5}, - {CelOperator::LESS_EQUALS, 5}, - {CelOperator::NOT_EQUALS, 5}, - {CelOperator::IN_DEPRECATED, 5}, + {CelOperator::EQUALS, 5}, + {CelOperator::GREATER, 5}, + {CelOperator::GREATER_EQUALS, 5}, + {CelOperator::IN, 5}, + {CelOperator::LESS, 5}, + {CelOperator::LESS_EQUALS, 5}, + {CelOperator::NOT_EQUALS, 5}, + {CelOperator::IN_DEPRECATED, 5}, - {CelOperator::ADD, 4}, - {CelOperator::SUBTRACT, 4}, + {CelOperator::ADD, 4}, + {CelOperator::SUBTRACT, 4}, - {CelOperator::DIVIDE, 3}, - {CelOperator::MODULO, 3}, - {CelOperator::MULTIPLY, 3}, + {CelOperator::DIVIDE, 3}, + {CelOperator::MODULO, 3}, + {CelOperator::MULTIPLY, 3}, - {CelOperator::LOGICAL_NOT, 2}, - {CelOperator::NEGATE, 2}, + {CelOperator::LOGICAL_NOT, 2}, + {CelOperator::NEGATE, 2}, - {CelOperator::INDEX, 1}}); - return c; - }(); + {CelOperator::INDEX, 1}}; return *precedence_map; } @@ -166,8 +162,11 @@ const char* CelOperator::FILTER = "filter"; const char* CelOperator::NOT_STRICTLY_FALSE = "@not_strictly_false"; const char* CelOperator::IN = "@in"; -int LookupPrecedence(const std::string& op) { - auto precs = Precedences(); +const absl::string_view CelOperator::OPT_INDEX = "_[?_]"; +const absl::string_view CelOperator::OPT_SELECT = "_?._"; + +int LookupPrecedence(absl::string_view op) { + const auto& precs = Precedences(); auto p = precs.find(op); if (p != precs.end()) { return p->second; @@ -175,8 +174,8 @@ int LookupPrecedence(const std::string& op) { return 0; } -absl::optional LookupUnaryOperator(const std::string& op) { - auto unary_ops = UnaryOperators(); +absl::optional LookupUnaryOperator(absl::string_view op) { + const auto& unary_ops = UnaryOperators(); auto o = unary_ops.find(op); if (o == unary_ops.end()) { return absl::optional(); @@ -184,8 +183,8 @@ absl::optional LookupUnaryOperator(const std::string& op) { return o->second; } -absl::optional LookupBinaryOperator(const std::string& op) { - auto bin_ops = BinaryOperators(); +absl::optional LookupBinaryOperator(absl::string_view op) { + const auto& bin_ops = BinaryOperators(); auto o = bin_ops.find(op); if (o == bin_ops.end()) { return absl::optional(); @@ -193,8 +192,8 @@ absl::optional LookupBinaryOperator(const std::string& op) { return o->second; } -absl::optional LookupOperator(const std::string& op) { - auto ops = Operators(); +absl::optional LookupOperator(absl::string_view op) { + const auto& ops = Operators(); auto o = ops.find(op); if (o == ops.end()) { return absl::optional(); @@ -202,8 +201,8 @@ absl::optional LookupOperator(const std::string& op) { return o->second; } -absl::optional ReverseLookupOperator(const std::string& op) { - auto rev_ops = ReverseOperators(); +absl::optional ReverseLookupOperator(absl::string_view op) { + const auto& rev_ops = ReverseOperators(); auto o = rev_ops.find(op); if (o == rev_ops.end()) { return absl::optional(); @@ -211,27 +210,24 @@ absl::optional ReverseLookupOperator(const std::string& op) { return o->second; } -bool IsOperatorSamePrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr) { +bool IsOperatorSamePrecedence(absl::string_view op, + const cel::expr::Expr& expr) { if (!expr.has_call_expr()) { return false; } return LookupPrecedence(op) == LookupPrecedence(expr.call_expr().function()); } -bool IsOperatorLowerPrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr) { +bool IsOperatorLowerPrecedence(absl::string_view op, + const cel::expr::Expr& expr) { if (!expr.has_call_expr()) { return false; } return LookupPrecedence(op) < LookupPrecedence(expr.call_expr().function()); } -bool IsOperatorLeftRecursive(const std::string& op) { +bool IsOperatorLeftRecursive(absl::string_view op) { return op != CelOperator::LOGICAL_AND && op != CelOperator::LOGICAL_OR; } -} // namespace common -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::common diff --git a/common/operators.h b/common/operators.h index d005a1582..5d7a775b0 100644 --- a/common/operators.h +++ b/common/operators.h @@ -1,17 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ #define THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -namespace google { -namespace api { -namespace expr { -namespace common { +namespace google::api::expr::common { // Operator function names. struct CelOperator { @@ -43,31 +54,34 @@ struct CelOperator { // Named operators, must not have be valid identifiers. static const char* NOT_STRICTLY_FALSE; +#pragma push_macro("IN") +#undef IN static const char* IN; +#pragma pop_macro("IN") + + static const absl::string_view OPT_INDEX; + static const absl::string_view OPT_SELECT; }; // These give access to all or some specific precedence value. // Higher value means higher precedence, 0 means no precedence, i.e., // custom function and not builtin operator. -int LookupPrecedence(const std::string& op); +int LookupPrecedence(absl::string_view op); -absl::optional LookupUnaryOperator(const std::string& op); -absl::optional LookupBinaryOperator(const std::string& op); -absl::optional LookupOperator(const std::string& op); -absl::optional ReverseLookupOperator(const std::string& op); +absl::optional LookupUnaryOperator(absl::string_view op); +absl::optional LookupBinaryOperator(absl::string_view op); +absl::optional LookupOperator(absl::string_view op); +absl::optional ReverseLookupOperator(absl::string_view op); // returns true if op has a lower precedence than the one expressed in expr -bool IsOperatorLowerPrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr); +bool IsOperatorLowerPrecedence(absl::string_view op, + const cel::expr::Expr& expr); // returns true if op has the same precedence as the one expressed in expr -bool IsOperatorSamePrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr); +bool IsOperatorSamePrecedence(absl::string_view op, + const cel::expr::Expr& expr); // return true if operator is left recursive, i.e., neither && nor ||. -bool IsOperatorLeftRecursive(const std::string& op); +bool IsOperatorLeftRecursive(absl::string_view op); -} // namespace common -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::common #endif // THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ diff --git a/common/optional_ref.h b/common/optional_ref.h new file mode 100644 index 000000000..c7ba580fc --- /dev/null +++ b/common/optional_ref.h @@ -0,0 +1,163 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_OPTIONAL_REF_H_ +#define THIRD_PARTY_CEL_CPP_OPTIONAL_REF_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/types/optional.h" +#include "absl/utility/utility.h" + +namespace cel { + +// `optional_ref` looks and feels like `absl::optional`, but instead of +// owning the underlying value, it retains a reference to the value it accepts +// in its constructor. +template +class optional_ref final { + public: + static_assert(!std::is_reference_v, "T must not be a reference."); + static_assert(!std::is_same_v>, + "optional_ref is not allowed."); + static_assert(!std::is_same_v>, + "optional_ref is not allowed."); + + using value_type = T; + + optional_ref() = default; + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr optional_ref(absl::nullopt_t) : optional_ref() {} + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr optional_ref(T& value ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(std::addressof(value)) {} + + template < + typename U, + typename = std::enable_if_t, std::is_same, std::decay_t>>>> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr optional_ref( + const absl::optional& value ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(value.has_value() ? std::addressof(*value) : nullptr) {} + + template , std::decay_t>>> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr optional_ref(absl::optional& value ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(value.has_value() ? std::addressof(*value) : nullptr) {} + + template < + typename U, + typename = std::enable_if_t>, + std::is_convertible, std::add_pointer_t>>>> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr optional_ref(const optional_ref& other) : value_(other.value_) {} + + optional_ref(const optional_ref&) = default; + + optional_ref& operator=(const optional_ref&) = delete; + + constexpr bool has_value() const { return value_ != nullptr; } + + constexpr explicit operator bool() const { return has_value(); } + + constexpr T& value() const { + return ABSL_PREDICT_TRUE(has_value()) + ? *value_ + // Replicate the same error logic as in `absl::optional`'s + // `value()`. It either throws an exception or aborts the + // program. We intentionally ignore the return value of + // the constructed optional's value as we only need to run + // the code for error checking. + : ((void)absl::optional().value(), *value_); + } + + constexpr T& operator*() const { + ABSL_ASSERT(has_value()); + return *value_; + } + + constexpr T* absl_nonnull operator->() const { + ABSL_ASSERT(has_value()); + return value_; + } + + private: + template + friend class optional_ref; + + T* const value_ = nullptr; +}; + +template +optional_ref(const T&) -> optional_ref; + +template +optional_ref(T&) -> optional_ref; + +template +optional_ref(const absl::optional&) -> optional_ref; + +template +optional_ref(absl::optional&) -> optional_ref; + +template +constexpr bool operator==(const optional_ref& lhs, absl::nullopt_t) { + return !lhs.has_value(); +} + +template +constexpr bool operator==(absl::nullopt_t, const optional_ref& rhs) { + return !rhs.has_value(); +} + +template +constexpr bool operator!=(const optional_ref& lhs, absl::nullopt_t) { + return !operator==(lhs, absl::nullopt); +} + +template +constexpr bool operator!=(absl::nullopt_t, const optional_ref& rhs) { + return !operator==(absl::nullopt, rhs); +} + +namespace common_internal { + +template +absl::optional> AsOptional(optional_ref ref) { + if (ref) { + return *ref; + } + return absl::nullopt; +} + +template +absl::optional AsOptional(absl::optional opt) { + return opt; +} + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_OPTIONAL_REF_H_ diff --git a/common/parent_ref.cc b/common/parent_ref.cc deleted file mode 100644 index d2b9f86c2..000000000 --- a/common/parent_ref.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "common/parent_ref.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -absl::optional SharedValue::SelfRefProvider() const { - if (!owns_value()) { - // No reference needed. - return RefProvider(nullptr); - } - if (unowned()) { - // Not shareable. - return absl::nullopt; - } - return RefProvider(this); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/parent_ref.h b/common/parent_ref.h deleted file mode 100644 index 6ca86aa3a..000000000 --- a/common/parent_ref.h +++ /dev/null @@ -1,140 +0,0 @@ -// Helper classes for creating 'view' values using parent references. -// -// Can class that inherits from 'SharedValue' can be used as the parent of -// of another value. Note that `SharedValue` should typically not be inherited -// from directly, instead inherit from `List`, `Map`, or `Object`. -// -// Shared values support: -// - Down-casting to a specific implementation through `Shared::cast_if`. -// - Creating self-references that can be used to create 'views' of the data -// owned by the shared value. For example, a 'list of strings' can -// create element `Value` instances that reference (vs copy) the underlying -// string. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_PARENT_REF_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_PARENT_REF_H_ - -#include "absl/types/optional.h" -#include "internal/ref_countable.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -class SharedValue; - -/** - * An opaque reference to a value that prevents the value from being deleted. - * - * Used to support 'views' by allowing the view prevent a parent value from - * being deleted. - * - * Only constructable via `RefProvider::GetRef`. - */ -class ValueRef { - public: - // Value semantics. - ValueRef() = default; - ValueRef(const ValueRef& other) = default; - ValueRef(ValueRef&& other) = default; - ValueRef& operator=(const ValueRef& other) = default; - ValueRef& operator=(ValueRef&& other) = default; - - operator bool() const { return ptr_ != nullptr; } - - private: - friend class RefProvider; - - internal::ReffedPtr ptr_; - - explicit ValueRef(const SharedValue* ptr) : ptr_(ptr) {} -}; - -/** - * A class that can create references for a shared value. - * - * Only constructable via `SharedValue::SelfRefProvider()` - */ -class RefProvider { - public: - RefProvider() = default; - - RefProvider(const RefProvider&) = default; - RefProvider(RefProvider&&) = default; - RefProvider& operator=(const RefProvider&) = default; - RefProvider& operator=(RefProvider&&) = default; - - /** - * Returns true if a reference is required for the given parent. - * - * False when the parent does not own its value, thus ownership need not be - * tracked. - */ - bool RequiresReference() const { return ptr_ != nullptr; } - - ValueRef GetRef() const { return ValueRef(ptr_); } - - private: - friend class SharedValue; - friend constexpr RefProvider NoParent(); - - constexpr explicit RefProvider(const SharedValue* ptr) : ptr_(ptr) {} - - const SharedValue* ptr_; -}; - -constexpr inline RefProvider NoParent() { return RefProvider(nullptr); } - -// The type by which a parent reference provider should be passed around as. -// -// A value of absl::nullopt indicates that the parent cannot be referenced. -using ParentRef = absl::optional; - -/** The base class for custom value implementations. */ -class SharedValue : public internal::RefCountable { - public: - virtual ~SharedValue() {} - - virtual bool owns_value() const = 0; - - /** - * Returns a canonical cel expression for the value. - * - * Computation may be expensive. - */ - virtual std::string ToString() const = 0; - - /** - * Attempts to cast the given Container to the given type. - * - * Returns nullptr if value is null or is not the requested type. - */ - template - static const T* cast_if(const SharedValue* value) { - return dynamic_cast(value); - } - - protected: - SharedValue() = default; - - /** - * Construct a self reference provider, to be passed to a 'view' Value. - * - * `this` must live longer than the returned value. - * - * Returns absl::nullopt if the container cannot be reffed, in which case - * all needed data must be copied. - * - * No synchronization is performed until ParentRef->GetRef() is called, so - * calls to this function do not incur a synchronization performance penalty. - */ - ParentRef SelfRefProvider() const; -}; - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_PARENT_REF_H_ diff --git a/common/reference.cc b/common/reference.cc new file mode 100644 index 000000000..75cc36e80 --- /dev/null +++ b/common/reference.cc @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/reference.h" + +#include "absl/base/no_destructor.h" + +namespace cel { + +const VariableReference& VariableReference::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +const FunctionReference& FunctionReference::default_instance() { + static const absl::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/common/reference.h b/common/reference.h new file mode 100644 index 000000000..5a8ac9706 --- /dev/null +++ b/common/reference.h @@ -0,0 +1,269 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "common/constant.h" + +namespace cel { + +class Reference; +class VariableReference; +class FunctionReference; + +using ReferenceKind = absl::variant; + +// `VariableReference` is a resolved reference to a `VariableDecl`. +class VariableReference final { + public: + bool has_value() const { return value_.has_value(); } + + void set_value(Constant value) { value_ = std::move(value); } + + const Constant& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } + + Constant& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } + + ABSL_MUST_USE_RESULT Constant release_value() { + using std::swap; + Constant value; + swap(mutable_value(), value); + return value; + } + + friend void swap(VariableReference& lhs, VariableReference& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + private: + friend class Reference; + + static const VariableReference& default_instance(); + + Constant value_; +}; + +inline bool operator==(const VariableReference& lhs, + const VariableReference& rhs) { + return lhs.value() == rhs.value(); +} + +inline bool operator!=(const VariableReference& lhs, + const VariableReference& rhs) { + return !operator==(lhs, rhs); +} + +// `FunctionReference` is a resolved reference to a `FunctionDecl`. +class FunctionReference final { + public: + const std::vector& overloads() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return overloads_; + } + + void set_overloads(std::vector overloads) { + mutable_overloads() = std::move(overloads); + } + + std::vector& mutable_overloads() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return overloads_; + } + + ABSL_MUST_USE_RESULT std::vector release_overloads() { + std::vector overloads; + overloads.swap(mutable_overloads()); + return overloads; + } + + friend void swap(FunctionReference& lhs, FunctionReference& rhs) noexcept { + using std::swap; + swap(lhs.overloads_, rhs.overloads_); + } + + private: + friend class Reference; + + static const FunctionReference& default_instance(); + + std::vector overloads_; +}; + +inline bool operator==(const FunctionReference& lhs, + const FunctionReference& rhs) { + return absl::c_equal(lhs.overloads(), rhs.overloads()); +} + +inline bool operator!=(const FunctionReference& lhs, + const FunctionReference& rhs) { + return !operator==(lhs, rhs); +} + +// `Reference` is a resolved reference to a `VariableDecl` or `FunctionDecl`. By +// default `Reference` is a `VariableReference`. +class Reference final { + public: + const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return name_; + } + + void set_name(std::string name) { name_ = std::move(name); } + + void set_name(absl::string_view name) { + name_.assign(name.data(), name.size()); + } + + void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } + + ABSL_MUST_USE_RESULT std::string release_name() { + std::string name; + name.swap(name_); + return name; + } + + void set_kind(ReferenceKind kind) { kind_ = std::move(kind); } + + const ReferenceKind& kind() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return kind_; + } + + ReferenceKind& mutable_kind() ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind_; } + + ABSL_MUST_USE_RESULT ReferenceKind release_kind() { + using std::swap; + ReferenceKind kind; + swap(kind, kind_); + return kind; + } + + ABSL_MUST_USE_RESULT bool has_variable() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const VariableReference& variable() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (const auto* alt = absl::get_if(&kind()); alt) { + return *alt; + } + return VariableReference::default_instance(); + } + + void set_variable(VariableReference variable) { + mutable_variable() = std::move(variable); + } + + VariableReference& mutable_variable() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_variable()) { + mutable_kind().emplace(); + } + return absl::get(mutable_kind()); + } + + ABSL_MUST_USE_RESULT VariableReference release_variable() { + VariableReference variable_reference; + if (auto* alt = absl::get_if(&mutable_kind()); alt) { + variable_reference = std::move(*alt); + } + mutable_kind().emplace(); + return variable_reference; + } + + ABSL_MUST_USE_RESULT bool has_function() const { + return absl::holds_alternative(kind()); + } + + ABSL_MUST_USE_RESULT const FunctionReference& function() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (const auto* alt = absl::get_if(&kind()); alt) { + return *alt; + } + return FunctionReference::default_instance(); + } + + void set_function(FunctionReference function) { + mutable_function() = std::move(function); + } + + FunctionReference& mutable_function() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (!has_function()) { + mutable_kind().emplace(); + } + return absl::get(mutable_kind()); + } + + ABSL_MUST_USE_RESULT FunctionReference release_function() { + FunctionReference function_reference; + if (auto* alt = absl::get_if(&mutable_kind()); alt) { + function_reference = std::move(*alt); + } + mutable_kind().emplace(); + return function_reference; + } + + friend void swap(Reference& lhs, Reference& rhs) noexcept { + using std::swap; + swap(lhs.name_, rhs.name_); + swap(lhs.kind_, rhs.kind_); + } + + private: + std::string name_; + ReferenceKind kind_; +}; + +inline bool operator==(const Reference& lhs, const Reference& rhs) { + return lhs.name() == rhs.name() && lhs.kind() == rhs.kind(); +} + +inline bool operator!=(const Reference& lhs, const Reference& rhs) { + return !operator==(lhs, rhs); +} + +inline Reference MakeVariableReference(std::string name) { + Reference reference; + reference.set_name(std::move(name)); + reference.mutable_kind().emplace(); + return reference; +} + +inline Reference MakeConstantVariableReference(std::string name, + Constant constant) { + Reference reference; + reference.set_name(std::move(name)); + reference.mutable_kind().emplace().set_value( + std::move(constant)); + return reference; +} + +inline Reference MakeFunctionReference(std::string name, + std::vector overloads) { + Reference reference; + reference.set_name(std::move(name)); + reference.mutable_kind().emplace().set_overloads( + std::move(overloads)); + return reference; +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_H_ diff --git a/common/reference_count.h b/common/reference_count.h new file mode 100644 index 000000000..0a07670bd --- /dev/null +++ b/common/reference_count.h @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_COUNT_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_COUNT_H_ + +#include "common/internal/reference_count.h" + +namespace cel { + +using ReferenceCount = common_internal::ReferenceCount; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_COUNT_H_ diff --git a/common/reference_test.cc b/common/reference_test.cc new file mode 100644 index 000000000..54a1f383d --- /dev/null +++ b/common/reference_test.cc @@ -0,0 +1,113 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/reference.h" + +#include +#include +#include + +#include "common/constant.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::VariantWith; + +TEST(VariableReference, Value) { + VariableReference variable_reference; + EXPECT_FALSE(variable_reference.has_value()); + EXPECT_EQ(variable_reference.value(), Constant{}); + Constant value; + value.set_bool_value(true); + variable_reference.set_value(value); + EXPECT_TRUE(variable_reference.has_value()); + EXPECT_EQ(variable_reference.value(), value); + EXPECT_EQ(variable_reference.release_value(), value); + EXPECT_EQ(variable_reference.value(), Constant{}); +} + +TEST(VariableReference, Equality) { + VariableReference variable_reference; + EXPECT_EQ(variable_reference, VariableReference{}); + variable_reference.mutable_value().set_bool_value(true); + EXPECT_NE(variable_reference, VariableReference{}); +} + +TEST(FunctionReference, Overloads) { + FunctionReference function_reference; + EXPECT_THAT(function_reference.overloads(), IsEmpty()); + function_reference.mutable_overloads().reserve(2); + function_reference.mutable_overloads().push_back("foo"); + function_reference.mutable_overloads().push_back("bar"); + EXPECT_THAT(function_reference.release_overloads(), + ElementsAre("foo", "bar")); + EXPECT_THAT(function_reference.overloads(), IsEmpty()); +} + +TEST(FunctionReference, Equality) { + FunctionReference function_reference; + EXPECT_EQ(function_reference, FunctionReference{}); + function_reference.mutable_overloads().push_back("foo"); + EXPECT_NE(function_reference, FunctionReference{}); +} + +TEST(Reference, Name) { + Reference reference; + EXPECT_THAT(reference.name(), IsEmpty()); + reference.set_name("foo"); + EXPECT_EQ(reference.name(), "foo"); + EXPECT_EQ(reference.release_name(), "foo"); + EXPECT_THAT(reference.name(), IsEmpty()); +} + +TEST(Reference, Variable) { + Reference reference; + EXPECT_THAT(reference.kind(), VariantWith(_)); + EXPECT_TRUE(reference.has_variable()); + EXPECT_THAT(reference.release_variable(), Eq(VariableReference{})); + EXPECT_TRUE(reference.has_variable()); +} + +TEST(Reference, Function) { + Reference reference; + EXPECT_FALSE(reference.has_function()); + EXPECT_THAT(reference.function(), Eq(FunctionReference{})); + reference.mutable_function(); + EXPECT_TRUE(reference.has_function()); + EXPECT_THAT(reference.variable(), Eq(VariableReference{})); + EXPECT_THAT(reference.kind(), VariantWith(_)); + EXPECT_THAT(reference.release_function(), Eq(FunctionReference{})); + EXPECT_FALSE(reference.has_function()); +} + +TEST(Reference, Equality) { + EXPECT_EQ(MakeVariableReference("foo"), MakeVariableReference("foo")); + EXPECT_NE(MakeVariableReference("foo"), + MakeConstantVariableReference("foo", Constant(int64_t{1}))); + EXPECT_EQ( + MakeFunctionReference("foo", std::vector{"bar", "baz"}), + MakeFunctionReference("foo", std::vector{"bar", "baz"})); + EXPECT_NE( + MakeFunctionReference("foo", std::vector{"bar", "baz"}), + MakeFunctionReference("foo", std::vector{"bar"})); +} + +} // namespace +} // namespace cel diff --git a/common/signature.cc b/common/signature.cc new file mode 100644 index 000000000..e497e780d --- /dev/null +++ b/common/signature.cc @@ -0,0 +1,640 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/signature.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_spec_resolver.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Signature generator helper functions. +namespace { + +void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { + for (char c : str) { + switch (c) { + case '\\': + case '(': + case ')': + case '<': + case '>': + case '"': + case ',': + case '~': + result->push_back('\\'); + break; + case '.': + if (escape_dot) { + result->push_back('\\'); + } + break; + } + result->push_back(c); + } +} + +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec); + +absl::Status AppendTypeSpecList(std::string* result, + const std::vector& params) { + if (!params.empty()) { + result->push_back('<'); + for (size_t i = 0; i < params.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, params[i])); + if (i < params.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); + } + return absl::OkStatus(); +} + +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec) { + if (type_spec.has_null()) { + absl::StrAppend(result, "null"); + } else if (type_spec.has_dyn()) { + absl::StrAppend(result, "dyn"); + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes"); + break; + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + absl::StrAppend(result, "any"); + break; + case WellKnownTypeSpec::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case WellKnownTypeSpec::kDuration: + absl::StrAppend(result, "duration"); + break; + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool_wrapper"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int_wrapper"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint_wrapper"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double_wrapper"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string_wrapper"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes_wrapper"); + break; + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } else if (type_spec.has_list_type()) { + absl::StrAppend(result, "list<"); + if (type_spec.list_type().elem_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.list_type().elem_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_map_type()) { + absl::StrAppend(result, "map<"); + if (type_spec.map_type().key_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().key_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back(','); + if (type_spec.map_type().value_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().value_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_function()) { + absl::StrAppend(result, "function<"); + if (type_spec.function().result_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.function().result_type())); + } else { + absl::StrAppend(result, "dyn"); + } + for (const auto& arg : type_spec.function().arg_types()) { + result->push_back(','); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, arg)); + } + result->push_back('>'); + } else if (type_spec.has_type()) { + absl::StrAppend(result, "type"); + result->push_back('<'); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, type_spec.type())); + result->push_back('>'); + } else if (type_spec.has_type_param()) { + absl::StrAppend(result, "~"); + AppendEscaped(result, type_spec.type_param().type(), /*escape_dot=*/true); + } else if (type_spec.has_abstract_type()) { + AppendEscaped(result, type_spec.abstract_type().name(), + /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeSpecList( + result, type_spec.abstract_type().parameter_types())); + } else if (type_spec.has_message_type()) { + AppendEscaped(result, type_spec.message_type().type(), + /*escape_dot=*/false); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported type in signature: ", FormatTypeSpec(type_spec))); + } + return absl::OkStatus(); +} +} // namespace + +absl::StatusOr MakeTypeSignature(const Type& type) { + std::string result; + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(type)); + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { + std::vector arg_type_specs; + arg_type_specs.reserve(args.size()); + for (const auto& arg : args) { + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(arg)); + arg_type_specs.push_back(type_spec); + } + return MakeOverloadSignature(function_name, arg_type_specs, is_member); +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { + std::string result; + if (is_member) { + if (!args.empty()) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[0])); + } else { + return absl::InvalidArgumentError("Member function with no receiver"); + } + result.push_back('.'); + } + AppendEscaped(&result, function_name, /*escape_dot=*/true); + result.push_back('('); + for (size_t i = is_member ? 1 : 0; i < args.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[i])); + if (i < args.size() - 1) { + result.push_back(','); + } + } + result.push_back(')'); + + return result; +} + +// Signature parser helper functions. +namespace { + +std::string StripUnescapedWhitespace(std::string_view str) { + std::string result; + result.reserve(str.size()); + bool escaped = false; + for (char c : str) { + if (escaped) { + result.push_back(c); + escaped = false; + continue; + } + if (c == '\\') { + result.push_back(c); + escaped = true; + continue; + } + if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { + continue; + } + result.push_back(c); + } + return result; +} + +absl::optional ParseBuiltinOrWrapper(std::string_view name_str) { + if (name_str == "null") return TypeSpec(NullTypeSpec()); + if (name_str == "bool") return TypeSpec(PrimitiveType::kBool); + if (name_str == "int") return TypeSpec(PrimitiveType::kInt64); + if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64); + if (name_str == "double") return TypeSpec(PrimitiveType::kDouble); + if (name_str == "string") return TypeSpec(PrimitiveType::kString); + if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes); + if (name_str == "any" || name_str == "google.protobuf.Any") + return TypeSpec(WellKnownTypeSpec::kAny); + if (name_str == "timestamp" || name_str == "google.protobuf.Timestamp") + return TypeSpec(WellKnownTypeSpec::kTimestamp); + if (name_str == "duration" || name_str == "google.protobuf.Duration") + return TypeSpec(WellKnownTypeSpec::kDuration); + if (name_str == "dyn" || name_str == "google.protobuf.Value") + return TypeSpec(DynTypeSpec()); + + // Handle standard Protobuf well-known wrapper types to preserve + // backward compatibility for users migrating YAML configuration files. + if (name_str == "bool_wrapper" || name_str == "google.protobuf.BoolValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + if (name_str == "int_wrapper" || name_str == "google.protobuf.Int64Value" || + name_str == "google.protobuf.Int32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + if (name_str == "uint_wrapper" || name_str == "google.protobuf.UInt64Value" || + name_str == "google.protobuf.UInt32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + if (name_str == "double_wrapper" || + name_str == "google.protobuf.DoubleValue" || + name_str == "google.protobuf.FloatValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + if (name_str == "string_wrapper" || name_str == "google.protobuf.StringValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + if (name_str == "bytes_wrapper" || name_str == "google.protobuf.BytesValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + + if (name_str == "google.protobuf.ListValue") { + return TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec()))); + } + if (name_str == "google.protobuf.Struct") { + return TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))); + } + + return absl::nullopt; +} + +std::string Unescape(std::string_view str) { + size_t first_escape = str.find('\\'); + if (first_escape == std::string_view::npos) { + return std::string(str); + } + std::string result; + result.reserve(str.size()); + result.append(str.substr(0, first_escape)); + bool escaped = false; + for (size_t i = first_escape; i < str.size(); ++i) { + char c = str[i]; + if (escaped) { + result.push_back(c); + escaped = false; + } else if (c == '\\') { + escaped = true; + } else { + result.push_back(c); + } + } + if (escaped) { + result.push_back('\\'); + } + return result; +} + +class SignatureScanner { + public: + explicit SignatureScanner(std::string_view input, + std::string_view error_prefix = "Invalid signature") + : input_(input), error_prefix_(error_prefix) {} + + absl::StatusOr FindTopLevelChar(char target, bool find_last = false) { + size_t found_idx = std::string_view::npos; + int nesting = 0; + bool escaped = false; + // Scanning str for delimiter boundaries while ensuring + // brackets are balanced and escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == target && nesting == 0) { + if (find_last || found_idx == std::string_view::npos) { + found_idx = i; + } + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + return found_idx; + } + + absl::StatusOr> SplitTopLevel(char delimiter) { + std::vector result; + int nesting = 0; + bool escaped = false; + size_t start = 0; + // Scanning str for delimiter while ensuring brackets are balanced and + // escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == delimiter && nesting == 0) { + result.push_back(input_.substr(start, i - start)); + start = i + 1; + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + result.push_back(input_.substr(start)); + return result; + } + + private: + std::string_view input_; + std::string_view error_prefix_; +}; + +absl::StatusOr> SplitTypeList( + std::string_view params) { + return SignatureScanner(params, "Invalid type signature").SplitTopLevel(','); +} + +absl::StatusOr ParseTypeSignature(std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty type signature"); + } + + if (signature[0] == '~') { + std::string_view param_name = signature.substr(1); + if (param_name.empty()) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(param_name) + .FindTopLevelChar('<', /*find_last=*/false)); + CEL_ASSIGN_OR_RETURN(size_t comma_idx, + SignatureScanner(param_name) + .FindTopLevelChar(',', /*find_last=*/false)); + if (less_idx != std::string_view::npos || + comma_idx != std::string_view::npos) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + return TypeSpec(ParamTypeSpec(Unescape(param_name))); + } + + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(signature, "Invalid type signature") + .FindTopLevelChar('<', /*find_last=*/false)); + + std::string name_str; + std::vector params; + + if (less_idx != std::string_view::npos) { + // If the signature contains a '<', it must also contain a matching '>'. + if (signature.back() != '>') { + return absl::InvalidArgumentError( + "Invalid type signature: missing closing >"); + } + name_str = Unescape(signature.substr(0, less_idx)); + std::string_view params_str = + signature.substr(less_idx + 1, signature.size() - less_idx - 2); + CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str)); + for (std::string_view param_str : param_list) { + CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str)); + params.push_back(std::move(param)); + } + } else { + name_str = Unescape(signature); + } + + auto read_param_or_dyn = [¶ms](size_t index) { + auto spec = std::make_unique(DynTypeSpec()); + if (params.size() > index) { + *spec = std::move(params[index]); + } + return spec; + }; + + if (!params.empty()) { + if (ParseBuiltinOrWrapper(name_str).has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid type signature: ", name_str, + " cannot have type parameters")); + } + } else { + if (auto builtin = ParseBuiltinOrWrapper(name_str); builtin.has_value()) { + return *builtin; + } + } + + if (name_str == "type") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: type expects at most 1 parameter"); + } + return TypeSpec(read_param_or_dyn(0)); + } + + if (name_str == "list") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: list expects at most 1 parameter"); + } + return TypeSpec(ListTypeSpec(read_param_or_dyn(0))); + } + + if (name_str == "map") { + if (!params.empty() && params.size() != 2) { + return absl::InvalidArgumentError( + "Invalid type signature: map expects 0 or 2 parameters"); + } + auto key = read_param_or_dyn(0); + auto value = read_param_or_dyn(1); + return TypeSpec(MapTypeSpec(std::move(key), std::move(value))); + } + + if (name_str == "function") { + auto result_type = read_param_or_dyn(0); + std::vector arg_types; + for (size_t i = 1; i < params.size(); ++i) { + arg_types.push_back(std::move(params[i])); + } + return TypeSpec( + FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + } + + if (name_str.empty() || absl::StrContains(name_str, "..")) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid identifier"); + } + + return TypeSpec(AbstractType(name_str, std::move(params))); +} + +} // namespace + +absl::StatusOr ParseFunctionSignature( + std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + if (stripped_sig.empty()) { + return absl::InvalidArgumentError("Empty function signature"); + } + + CEL_ASSIGN_OR_RETURN( + size_t paren_idx, + SignatureScanner(stripped_sig, "Invalid function signature") + .FindTopLevelChar('(', /*find_last=*/false)); + + if (paren_idx == std::string_view::npos || stripped_sig.back() != ')') { + return absl::InvalidArgumentError("Invalid function signature"); + } + + std::string_view prefix = std::string_view(stripped_sig).substr(0, paren_idx); + std::string_view args_str = + std::string_view(stripped_sig) + .substr(paren_idx + 1, stripped_sig.size() - paren_idx - 2); + + std::vector arg_types; + ParsedFunctionOverload out; + + CEL_ASSIGN_OR_RETURN(size_t dot_idx, + SignatureScanner(prefix, "Invalid function signature") + .FindTopLevelChar('.', /*find_last=*/true)); + + if (dot_idx != std::string_view::npos) { + out.is_member = true; + std::string_view receiver_str = prefix.substr(0, dot_idx); + std::string_view func_str = prefix.substr(dot_idx + 1); + + CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str)); + arg_types.push_back(std::move(receiver_param)); + out.function_name = Unescape(func_str); + } else { + out.is_member = false; + out.function_name = Unescape(prefix); + } + + if (out.function_name.empty()) { + return absl::InvalidArgumentError( + "Invalid function signature: empty function name"); + } + + if (!args_str.empty()) { + CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str)); + for (std::string_view arg_str : arg_list) { + CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str)); + arg_types.push_back(std::move(arg_param)); + } + } + + auto result_type = std::make_unique(DynTypeSpec()); + out.signature_type = + TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + + return out; +} + +absl::StatusOr ParseTypeSpec(std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + return ParseTypeSignature(stripped_sig); +} + +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(signature)); + return cel::ConvertTypeSpecToType(type_spec, arena, pool); +} + +} // namespace cel diff --git a/common/signature.h b/common/signature.h new file mode 100644 index 000000000..777f03439 --- /dev/null +++ b/common/signature.h @@ -0,0 +1,101 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Generates a signature for a `cel::Type`, which is a string representation of +// the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSignature(const Type& type); + +// Generates a signature for a `cel::TypeSpec`, which is a string +// representation of the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec); + +// Generates a signature for a function overload based on the function name +// and the types of the arguments. If `is_member` is true, the first argument +// type is used as the receiver and is prepended to the function name, followed +// by a dollar sign. +// +// Examples: +// +// - `foo()` +// - `foo(int)` +// - `bar.foo(int)` +// - `foo(int,string)` +// - `foo(list,list)` +// - `bar.foo(list,list>)` +// +// If the function name contains a period, it is escaped with a backslash, e.g. +// `foo.bar` becomes `foo\.bar`. This allows to disambiguate between a member +// function and qualified target type name. +// +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Generates a signature for a function overload based on the function name +// and the type specs of the arguments. See above for more details. +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Parses a string type signature directly into a `cel::TypeSpec`. +absl::StatusOr ParseTypeSpec(std::string_view signature); + +// Parses a string type signature directly into a `cel::Type`. +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// A parsed function overload signature with the function name, flag for member +// function, and the function signature type. +struct ParsedFunctionOverload { + std::string function_name; + bool is_member = false; + // The function signature type, configured as a `FunctionTypeSpec`. + TypeSpec signature_type; +}; + +// Parses a string function overload signature directly into a +// `cel::TypeSpec` configured as a `FunctionTypeSpec`. +absl::StatusOr ParseFunctionSignature( + std::string_view signature); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ diff --git a/common/signature_test.cc b/common/signature_test.cc new file mode 100644 index 000000000..ea51eb566 --- /dev/null +++ b/common/signature_test.cc @@ -0,0 +1,784 @@ +#include "common/signature.h" +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "common/type_spec_resolver.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +void VerifyParsedMatchesType(const TypeSpec& parsed, const TypeSpec& expected) { + EXPECT_EQ(parsed, expected); +} +void VerifyTypesEqual(const Type& lhs, const Type& rhs) { + EXPECT_EQ(lhs.kind(), rhs.kind()); + if (lhs.kind() != rhs.kind()) return; + + if (lhs.kind() == TypeKind::kOpaque || lhs.kind() == TypeKind::kStruct || + lhs.kind() == TypeKind::kTypeParam) { + EXPECT_EQ(lhs.name(), rhs.name()); + } + + const auto& lhs_params = lhs.GetParameters(); + const auto& rhs_params = rhs.GetParameters(); + EXPECT_EQ(lhs_params.size(), rhs_params.size()); + if (lhs_params.size() == rhs_params.size()) { + for (size_t i = 0; i < lhs_params.size(); ++i) { + VerifyTypesEqual(lhs_params[i], rhs_params[i]); + } + } +} + +struct TypeSignatureTestCase { + TypeSpec type; + std::string expected_signature; + std::string expected_error; +}; + +using TypeSignatureTest = testing::TestWithParam; + +TEST_P(TypeSignatureTest, TypeSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = MakeTypeSpecSignature(param.type); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + + absl::StatusOr type = ConvertTypeSpecToType( + param.type, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(type, ::absl_testing::IsOk()); + EXPECT_THAT(MakeTypeSignature(*type), + IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetTypeSignatureTestCases() { + return { + { + .type = TypeSpec(NullTypeSpec{}), + .expected_signature = "null", + }, + { + .type = TypeSpec(PrimitiveType::kBool), + .expected_signature = "bool", + }, + { + .type = TypeSpec(PrimitiveType::kInt64), + .expected_signature = "int", + }, + { + .type = TypeSpec(PrimitiveType::kUint64), + .expected_signature = "uint", + }, + { + .type = TypeSpec(PrimitiveType::kDouble), + .expected_signature = "double", + }, + { + .type = TypeSpec(PrimitiveType::kString), + .expected_signature = "string", + }, + { + .type = TypeSpec(PrimitiveType::kBytes), + .expected_signature = "bytes", + }, + { + .type = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {})), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kDuration), + .expected_signature = "duration", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_signature = "timestamp", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kString))), + .expected_signature = "list", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A")))), + .expected_signature = "list<~A>", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A(ParamTypeSpec(R"(a,b..(d)\e)")))), + .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + }, + { + .type = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec()))), + .expected_signature = "map", + }, + { + .type = TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C")))), + .expected_signature = "map<~B,~C>", + }, + { + .type = TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(MapTypeSpec(nullptr, nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(std::make_unique(PrimitiveType::kInt64)), + .expected_signature = "type", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kAny), + .expected_signature = "any", + }, + { + .type = TypeSpec(DynTypeSpec{}), + .expected_signature = "dyn", + }, + { + .type = TypeSpec(AbstractType( + "bar", {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), + {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool)}))})), + .expected_signature = "bar>", + }, + { + .type = + TypeSpec(AbstractType("bar", {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kString)})), + .expected_signature = "bar", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + .expected_signature = "bool_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + .expected_signature = "int_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + .expected_signature = "uint_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_signature = "double_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + .expected_signature = "string_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + .expected_signature = "bytes_wrapper", + }, + { + .type = TypeSpec( + FunctionTypeSpec(nullptr, {TypeSpec(PrimitiveType::kInt64)})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec( + std::make_unique(PrimitiveType::kInt64), {})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec(nullptr, {})), + .expected_signature = "function", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeSignatureTest, TypeSignatureTest, + ValuesIn(GetTypeSignatureTestCases())); + +TEST(TypeSignatureTest, UnsupportedTypes) { + EXPECT_THAT(MakeTypeSignature(UnknownType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported Type kind: *unknown*"))); + + EXPECT_THAT(MakeTypeSignature(ErrorType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported type in signature: *error*"))); + + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported primitive type"))); + + EXPECT_THAT( + MakeTypeSpecSignature(TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported well-known type"))); + + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec( + PrimitiveTypeWrapper(static_cast(999)))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported wrapper type"))); +} + +TEST_P(TypeSignatureTest, ParseTypeCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty() && param.expected_error.empty()) { + auto parsed = ParseType(param.expected_signature, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(auto expected_type, + ConvertTypeSpecToType(param.type, GetTestArena(), + *GetTestingDescriptorPool())); + VerifyTypesEqual(*parsed, expected_type); + } +} + +struct OverloadSignatureTestCase { + std::string function_name = "hello"; + std::vector args; + bool is_member = false; + std::string expected_signature; + std::string expected_error; +}; + +using OverloadSignatureTest = testing::TestWithParam; + +TEST_P(OverloadSignatureTest, OverloadSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = + MakeOverloadSignature(param.function_name, param.args, param.is_member); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetOverloadSignatureTestCases() { + return { + { + .args = {TypeSpec(PrimitiveType::kString)}, + .expected_signature = "hello(string)", + }, + { + .args = {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kUint64)}, + .expected_signature = "hello(int,uint)", + }, + { + .args = {TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kString)))}, + .expected_signature = "hello(list)", + }, + { + .args = {TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A"))))}, + .expected_signature = "hello(list<~A>)", + }, + { + .args = {TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec())))}, + .expected_signature = "hello(map)", + }, + { + .args = {TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C"))))}, + .expected_signature = "hello(map<~B,~C>)", + }, + + { + .args = {TypeSpec(AbstractType( + "bar", + {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), {}))}))}, + .expected_signature = "hello(bar>)", + }, + { + .args = {TypeSpec(WellKnownTypeSpec::kAny)}, + .expected_signature = "hello(any)", + }, + { + .args = {TypeSpec(WellKnownTypeSpec::kDuration)}, + .expected_signature = "hello(duration)", + }, + { + .args = {TypeSpec(WellKnownTypeSpec::kTimestamp)}, + .expected_signature = "hello(timestamp)", + }, + { + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + .expected_signature = "hello(bool_wrapper)", + }, + { + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64))}, + .expected_signature = "hello(int_wrapper)", + }, + { + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64))}, + .expected_signature = "hello(uint_wrapper)", + }, + { + .args = {TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {}))}, + .expected_signature = + "hello(cel.expr.conformance.proto3.TestAllTypes)", + }, + { + .args = {TypeSpec(PrimitiveType::kString)}, + .is_member = true, + .expected_signature = "string.hello()", + }, + { + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kBool)))}, + .is_member = true, + .expected_signature = "string.hello(list)", + }, + { + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool), TypeSpec(DynTypeSpec())}, + .is_member = true, + .expected_signature = "string.hello(bool,dyn)", + }, + { + .function_name = "hello", + .args = {TypeSpec( + AbstractType("bar", {TypeSpec(ParamTypeSpec("dummy.type"))}))}, + .is_member = true, + .expected_signature = R"(bar<~dummy\.type>.hello())", + }, + { + .function_name = "inspect", + .args = {TypeSpec( + std::make_unique(PrimitiveType::kString))}, + .expected_signature = "inspect(type)", + }, + { + .function_name = R"(h.(e),l\o)", + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec(std::make_unique( + ParamTypeSpec(R"(a,b..(d)\e)"))))}, + .is_member = true, + .expected_signature = + R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", + }, + }; +} + +TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { + auto signature = + MakeOverloadSignature("hello", std::vector{}, true); + EXPECT_THAT(signature, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Member function with no receiver"))); +} + +INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, + ValuesIn(GetOverloadSignatureTestCases())); + +TEST_P(OverloadSignatureTest, ExhaustiveFunctionParseCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty()) { + auto parsed = ParseFunctionSignature(param.expected_signature); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + EXPECT_EQ(parsed->function_name, param.function_name); + EXPECT_EQ(parsed->is_member, param.is_member); + EXPECT_TRUE(parsed->signature_type.has_function()); + const auto& func = parsed->signature_type.function(); + for (size_t i = 0; i < param.args.size(); ++i) { + VerifyParsedMatchesType(func.arg_types()[i], param.args[i]); + } + } +} + +TEST(ParseSignatureTest, ProtoParsing) { + ASSERT_OK_AND_ASSIGN( + auto t1, ParseType("int", GetTestArena(), *GetTestingDescriptorPool())); + EXPECT_TRUE(t1.IsInt()); + + ASSERT_OK_AND_ASSIGN(auto t2, ParseType("list<~A>", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t2.IsList()); + + ASSERT_OK_AND_ASSIGN(auto t3, ParseType(R"(~abc\)", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t3.IsTypeParam()); + EXPECT_EQ(t3.GetTypeParam().name(), R"(abc\)"); + + ASSERT_OK_AND_ASSIGN(auto w1, + ParseType("google.protobuf.BoolValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w1.IsBoolWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w2, + ParseType("google.protobuf.Int64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w2.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w3, + ParseType("google.protobuf.Int32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w3.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w4, + ParseType("google.protobuf.UInt64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w4.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w5, + ParseType("google.protobuf.UInt32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w5.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w6, + ParseType("google.protobuf.DoubleValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w6.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w7, + ParseType("google.protobuf.FloatValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w7.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w8, + ParseType("google.protobuf.StringValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w8.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w9, + ParseType("google.protobuf.BytesValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w9.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w10, ParseType("string_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w10.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w11, ParseType("bytes_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w11.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto gp_any, + ParseType("google.protobuf.Any", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_any.IsAny()); + + ASSERT_OK_AND_ASSIGN(auto gp_timestamp, + ParseType("google.protobuf.Timestamp", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_timestamp.IsTimestamp()); + + ASSERT_OK_AND_ASSIGN(auto gp_duration, + ParseType("google.protobuf.Duration", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_duration.IsDuration()); + + ASSERT_OK_AND_ASSIGN(auto gp_value, + ParseType("google.protobuf.Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_value.IsDyn()); + + ASSERT_OK_AND_ASSIGN(auto gp_list_value, + ParseType("google.protobuf.ListValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_list_value.IsList()); + + ASSERT_OK_AND_ASSIGN(auto gp_struct, + ParseType("google.protobuf.Struct", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_struct.IsMap()); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_type1, + ParseType("map < int , string > ", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type1.IsMap()); + + ASSERT_OK_AND_ASSIGN(auto ws_type2, + ParseType("map\t<\nint\r,\tstring\n>\r", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type2.IsMap()); +} + +TEST(ParseSignatureTest, FunctionParsing) { + ASSERT_OK_AND_ASSIGN(auto f1, ParseFunctionSignature("hello(string)")); + EXPECT_TRUE(f1.signature_type.has_function()); + EXPECT_EQ(f1.signature_type.function().arg_types().size(), 1); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_func1, + ParseFunctionSignature(" hello ( string ) ")); + EXPECT_TRUE(ws_func1.signature_type.has_function()); + EXPECT_EQ(ws_func1.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto ws_func2, + ParseFunctionSignature("\thello\n(\rstring\t)\n\r")); + EXPECT_TRUE(ws_func2.signature_type.has_function()); + EXPECT_EQ(ws_func2.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto f2, ParseFunctionSignature("a.b.c()")); + EXPECT_TRUE(f2.is_member); + EXPECT_EQ(f2.function_name, "c"); +} + +TEST(ParseSignatureTest, ParsingErrors) { + // Mismatched template brackets and parentheses. + EXPECT_THAT( + ParseType("list>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseType("list><", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list>)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("foo"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT( + ParseType("list b < c>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + // Parameter count validations for list, map and type types. + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("list expects at most 1 parameter"))); + EXPECT_THAT( + ParseType("map", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("type", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type expects at most 1 parameter"))); + + // Invalid parameter name validations. + EXPECT_THAT(ParseType("~", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); + EXPECT_THAT(ParseType("~A", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); + + // Enforcing valid function and identifier names. + EXPECT_THAT(ParseFunctionSignature("()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + EXPECT_THAT(ParseFunctionSignature("string.()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + + // Missing closing operators and boundary checks. + EXPECT_THAT( + ParseType("listfoo", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("missing closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello>(string)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list<", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map int, string>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + EXPECT_THAT(ParseFunctionSignature("a..b.c()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + + EXPECT_THAT( + ParseType("~list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + // Checks that builtin types cannot have type parameters. + EXPECT_THAT( + ParseType("int", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MessageTypeWithParamsError) { + EXPECT_THAT(ParseType("cel.expr.conformance.proto3.TestAllTypes", + GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MissingClosingParenthesisError) { + EXPECT_THAT(ParseFunctionSignature("hello(string"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT(ParseFunctionSignature("hello)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); +} + +TEST(ParseSignatureTest, NestedDotsNonMember) { + auto f1 = ParseFunctionSignature( + "my_opaque()"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_FALSE(f1->is_member); + EXPECT_EQ(f1->function_name, + "my_opaque"); +} + +TEST(ParseSignatureTest, OverlyComplexSignatures) { + auto t1 = ParseType("map>,map>>", + GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsMap()); + + auto t2 = ParseType(R"(~abc\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsTypeParam()); + EXPECT_EQ(t2->GetTypeParam().name(), R"(abc\)"); + + auto t3 = + ParseType(R"(~abc\\\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\\)"); + + auto f1 = ParseFunctionSignature( + "bar>,map>.func(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->is_member); + EXPECT_EQ(f1->function_name, "func"); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 2); +} + +TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { + EXPECT_THAT(ParseType("", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + EXPECT_THAT(ParseFunctionSignature(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty function signature"))); + EXPECT_THAT(ParseType("list>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); +} + +TEST(OverloadSignatureTest, ArgumentTypeVector) { + std::vector args; + args.push_back(Type(IntType())); + args.push_back(Type(StringType())); + args.push_back(Type(ListType(GetTestArena(), IntType()))); + args.push_back( + Type(MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))); + args.push_back(Type(OpaqueType(GetTestArena(), "Foo", {TypeParamType("T")}))); + ASSERT_OK_AND_ASSIGN(auto sig, MakeOverloadSignature("foo", args, false)); + EXPECT_EQ(sig, + "foo(int,string,list,cel.expr.conformance.proto3.TestAllTypes," + "Foo<~T>)"); +} + +} // namespace +} // namespace cel diff --git a/common/source.cc b/common/source.cc new file mode 100644 index 000000000..5fa4cca0e --- /dev/null +++ b/common/source.cc @@ -0,0 +1,600 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/source.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "internal/unicode.h" +#include "internal/utf8.h" + +namespace cel { + +SourcePosition SourceContentView::size() const { + return static_cast(absl::visit( + absl::Overload( + [](absl::Span view) { return view.size(); }, + [](absl::Span view) { return view.size(); }, + [](absl::Span view) { return view.size(); }, + [](absl::Span view) { return view.size(); }), + view_)); +} + +bool SourceContentView::empty() const { + return absl::visit( + absl::Overload( + [](absl::Span view) { return view.empty(); }, + [](absl::Span view) { return view.empty(); }, + [](absl::Span view) { return view.empty(); }, + [](absl::Span view) { return view.empty(); }), + view_); +} + +char32_t SourceContentView::at(SourcePosition position) const { + ABSL_DCHECK_GE(position, 0); + ABSL_DCHECK_LT(position, size()); + return absl::visit( + absl::Overload( + [position = + static_cast(position)](absl::Span view) { + return static_cast(static_cast(view[position])); + }, + [position = + static_cast(position)](absl::Span view) { + return static_cast(view[position]); + }, + [position = + static_cast(position)](absl::Span view) { + return static_cast(view[position]); + }, + [position = + static_cast(position)](absl::Span view) { + return static_cast(view[position]); + }), + view_); +} + +std::string SourceContentView::ToString(SourcePosition begin, + SourcePosition end) const { + ABSL_DCHECK_GE(begin, 0); + ABSL_DCHECK_LE(end, size()); + ABSL_DCHECK_LE(begin, end); + return absl::visit( + absl::Overload( + [begin = static_cast(begin), + end = static_cast(end)](absl::Span view) { + view = view.subspan(begin, end - begin); + return std::string(view.data(), view.size()); + }, + [begin = static_cast(begin), + end = static_cast(end)](absl::Span view) { + view = view.subspan(begin, end - begin); + std::string result; + result.reserve(view.size() * 2); + for (const auto& code_point : view) { + internal::Utf8Encode(result, code_point); + } + result.shrink_to_fit(); + return result; + }, + [begin = static_cast(begin), + end = static_cast(end)](absl::Span view) { + view = view.subspan(begin, end - begin); + std::string result; + result.reserve(view.size() * 3); + for (const auto& code_point : view) { + internal::Utf8Encode(result, code_point); + } + result.shrink_to_fit(); + return result; + }, + [begin = static_cast(begin), + end = static_cast(end)](absl::Span view) { + view = view.subspan(begin, end - begin); + std::string result; + result.reserve(view.size() * 4); + for (const auto& code_point : view) { + internal::Utf8Encode(result, code_point); + } + result.shrink_to_fit(); + return result; + }), + view_); +} + +void SourceContentView::AppendToString(std::string& dest) const { + absl::visit(absl::Overload( + [&dest](absl::Span view) { + dest.append(view.data(), view.size()); + }, + [&dest](absl::Span view) { + for (const auto& code_point : view) { + internal::Utf8Encode(dest, code_point); + } + }, + [&dest](absl::Span view) { + for (const auto& code_point : view) { + internal::Utf8Encode(dest, code_point); + } + }, + [&dest](absl::Span view) { + for (const auto& code_point : view) { + internal::Utf8Encode(dest, code_point); + } + }), + view_); +} + +namespace common_internal { + +class SourceImpl : public Source { + public: + SourceImpl(std::string description, + absl::InlinedVector line_offsets) + : description_(std::move(description)), + line_offsets_(std::move(line_offsets)) {} + + absl::string_view description() const final { return description_; } + + absl::Span line_offsets() const final { + return absl::MakeConstSpan(line_offsets_); + } + + private: + const std::string description_; + const absl::InlinedVector line_offsets_; +}; + +namespace { + +class AsciiSource final : public SourceImpl { + public: + AsciiSource(std::string description, + absl::InlinedVector line_offsets, + std::vector text) + : SourceImpl(std::move(description), std::move(line_offsets)), + text_(std::move(text)) {} + + ContentView content() const override { + return MakeContentView(absl::MakeConstSpan(text_)); + } + + private: + const std::vector text_; +}; + +class Latin1Source final : public SourceImpl { + public: + Latin1Source(std::string description, + absl::InlinedVector line_offsets, + std::vector text) + : SourceImpl(std::move(description), std::move(line_offsets)), + text_(std::move(text)) {} + + ContentView content() const override { + return MakeContentView(absl::MakeConstSpan(text_)); + } + + private: + const std::vector text_; +}; + +class BasicPlaneSource final : public SourceImpl { + public: + BasicPlaneSource(std::string description, + absl::InlinedVector line_offsets, + std::vector text) + : SourceImpl(std::move(description), std::move(line_offsets)), + text_(std::move(text)) {} + + ContentView content() const override { + return MakeContentView(absl::MakeConstSpan(text_)); + } + + private: + const std::vector text_; +}; + +class SupplementalPlaneSource final : public SourceImpl { + public: + SupplementalPlaneSource(std::string description, + absl::InlinedVector line_offsets, + std::vector text) + : SourceImpl(std::move(description), std::move(line_offsets)), + text_(std::move(text)) {} + + ContentView content() const override { + return MakeContentView(absl::MakeConstSpan(text_)); + } + + private: + const std::vector text_; +}; + +template +struct SourceTextTraits; + +template <> +struct SourceTextTraits { + using iterator_type = absl::string_view; + + static iterator_type Begin(absl::string_view text) { return text; } + + static void Advance(iterator_type& it, size_t n) { it.remove_prefix(n); } + + static void AppendTo(std::vector& out, absl::string_view text, + size_t n) { + const auto* in = reinterpret_cast(text.data()); + out.insert(out.end(), in, in + n); + } + + static std::vector ToVector(absl::string_view in) { + std::vector out; + out.reserve(in.size()); + out.insert(out.end(), in.begin(), in.end()); + return out; + } +}; + +template <> +struct SourceTextTraits { + using iterator_type = absl::Cord::CharIterator; + + static iterator_type Begin(const absl::Cord& text) { + return text.char_begin(); + } + + static void Advance(iterator_type& it, size_t n) { + absl::Cord::Advance(&it, n); + } + + static void AppendTo(std::vector& out, const absl::Cord& text, + size_t n) { + auto it = text.char_begin(); + while (n > 0) { + auto str = absl::Cord::ChunkRemaining(it); + size_t to_append = std::min(n, str.size()); + const auto* in = reinterpret_cast(str.data()); + out.insert(out.end(), in, in + to_append); + n -= to_append; + absl::Cord::Advance(&it, to_append); + } + } + + static std::vector ToVector(const absl::Cord& in) { + std::vector out; + out.reserve(in.size()); + for (const auto& chunk : in.Chunks()) { + out.insert(out.end(), chunk.begin(), chunk.end()); + } + return out; + } +}; + +template +absl::StatusOr NewSourceImpl(std::string description, const T& text, + const size_t text_size) { + if (ABSL_PREDICT_FALSE( + text_size > + static_cast(std::numeric_limits::max()))) { + return absl::InvalidArgumentError("expression larger than 2GiB limit"); + } + using Traits = SourceTextTraits; + size_t index = 0; + typename Traits::iterator_type it = Traits::Begin(text); + SourcePosition offset = 0; + char32_t code_point; + size_t code_units; + std::vector data8; + std::vector data16; + std::vector data32; + absl::InlinedVector line_offsets; + while (index < text_size) { + std::tie(code_point, code_units) = cel::internal::Utf8Decode(it); + if (ABSL_PREDICT_FALSE(code_point == + cel::internal::kUnicodeReplacementCharacter && + code_units == 1)) { + // Thats an invalid UTF-8 encoding. + return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); + } + if (code_point == '\n') { + line_offsets.push_back(offset + 1); + } + if (code_point <= 0x7f) { + Traits::Advance(it, code_units); + index += code_units; + ++offset; + continue; + } + if (code_point <= 0xff) { + data8.reserve(text_size); + Traits::AppendTo(data8, text, index); + data8.push_back(static_cast(code_point)); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + goto latin1; + } + if (code_point <= 0xffff) { + data16.reserve(text_size); + for (size_t offset = 0; offset < index; offset++) { + data16.push_back(static_cast(text[offset])); + } + data16.push_back(static_cast(code_point)); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + goto basic; + } + data32.reserve(text_size); + for (size_t offset = 0; offset < index; offset++) { + data32.push_back(static_cast(text[offset])); + } + data32.push_back(code_point); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + goto supplemental; + } + line_offsets.push_back(offset + 1); + return std::make_unique( + std::move(description), std::move(line_offsets), Traits::ToVector(text)); +latin1: + while (index < text_size) { + std::tie(code_point, code_units) = internal::Utf8Decode(it); + if (ABSL_PREDICT_FALSE(code_point == + internal::kUnicodeReplacementCharacter && + code_units == 1)) { + // Thats an invalid UTF-8 encoding. + return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); + } + if (code_point == '\n') { + line_offsets.push_back(offset + 1); + } + if (code_point <= 0xff) { + data8.push_back(static_cast(code_point)); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + continue; + } + if (code_point <= 0xffff) { + data16.reserve(text_size); + for (const auto& value : data8) { + data16.push_back(value); + } + std::vector().swap(data8); + data16.push_back(static_cast(code_point)); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + goto basic; + } + data32.reserve(text_size); + for (const auto& value : data8) { + data32.push_back(value); + } + std::vector().swap(data8); + data32.push_back(code_point); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + goto supplemental; + } + line_offsets.push_back(offset + 1); + return std::make_unique( + std::move(description), std::move(line_offsets), std::move(data8)); +basic: + while (index < text_size) { + std::tie(code_point, code_units) = internal::Utf8Decode(it); + if (ABSL_PREDICT_FALSE(code_point == + internal::kUnicodeReplacementCharacter && + code_units == 1)) { + // Thats an invalid UTF-8 encoding. + return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); + } + if (code_point == '\n') { + line_offsets.push_back(offset + 1); + } + if (code_point <= 0xffff) { + data16.push_back(static_cast(code_point)); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + continue; + } + data32.reserve(text_size); + for (const auto& value : data16) { + data32.push_back(static_cast(value)); + } + std::vector().swap(data16); + data32.push_back(code_point); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + goto supplemental; + } + line_offsets.push_back(offset + 1); + return std::make_unique( + std::move(description), std::move(line_offsets), std::move(data16)); +supplemental: + while (index < text_size) { + std::tie(code_point, code_units) = internal::Utf8Decode(it); + if (ABSL_PREDICT_FALSE(code_point == + internal::kUnicodeReplacementCharacter && + code_units == 1)) { + // Thats an invalid UTF-8 encoding. + return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); + } + if (code_point == '\n') { + line_offsets.push_back(offset + 1); + } + data32.push_back(code_point); + Traits::Advance(it, code_units); + index += code_units; + ++offset; + } + line_offsets.push_back(offset + 1); + return std::make_unique( + std::move(description), std::move(line_offsets), std::move(data32)); +} + +} // namespace + +} // namespace common_internal + +absl::optional Source::GetLocation( + SourcePosition position) const { + if (auto line_and_offset = FindLine(position); + ABSL_PREDICT_TRUE(line_and_offset.has_value())) { + return SourceLocation{line_and_offset->first, + position - line_and_offset->second}; + } + return std::nullopt; +} + +absl::optional Source::GetPosition( + const SourceLocation& location) const { + if (ABSL_PREDICT_FALSE(location.line < 1 || location.column < 0)) { + return std::nullopt; + } + if (auto position = FindLinePosition(location.line); + ABSL_PREDICT_TRUE(position.has_value())) { + return *position + location.column; + } + return std::nullopt; +} + +absl::optional Source::Snippet(int32_t line) const { + auto content = this->content(); + auto start = FindLinePosition(line); + if (ABSL_PREDICT_FALSE(!start.has_value() || content.empty())) { + return std::nullopt; + } + auto end = FindLinePosition(line + 1); + if (end.has_value()) { + return content.ToString(*start, *end - 1); + } + return content.ToString(*start); +} + +std::string Source::DisplayErrorLocation(SourceLocation location) const { + constexpr char32_t kDot = '.'; + constexpr char32_t kHat = '^'; + + constexpr char32_t kWideDot = 0xff0e; + constexpr char32_t kWideHat = 0xff3e; + absl::optional snippet = Snippet(location.line); + if (!snippet || snippet->empty()) { + return ""; + } + + *snippet = absl::StrReplaceAll(*snippet, {{"\t", " "}}); + absl::string_view snippet_view(*snippet); + std::string result; + absl::StrAppend(&result, "\n | ", *snippet); + absl::StrAppend(&result, "\n | "); + + std::string index_line; + for (int32_t i = 0; i < location.column && !snippet_view.empty(); ++i) { + size_t count; + std::tie(std::ignore, count) = internal::Utf8Decode(snippet_view); + snippet_view.remove_prefix(count); + if (count > 1) { + internal::Utf8Encode(index_line, kWideDot); + } else { + internal::Utf8Encode(index_line, kDot); + } + } + size_t count = 0; + if (!snippet_view.empty()) { + std::tie(std::ignore, count) = internal::Utf8Decode(snippet_view); + } + if (count > 1) { + internal::Utf8Encode(index_line, kWideHat); + } else { + internal::Utf8Encode(index_line, kHat); + } + absl::StrAppend(&result, index_line); + return result; +} + +absl::optional Source::FindLinePosition(int32_t line) const { + if (ABSL_PREDICT_FALSE(line < 1)) { + return std::nullopt; + } + if (line == 1) { + return SourcePosition{0}; + } + const auto line_offsets = this->line_offsets(); + if (ABSL_PREDICT_TRUE(line <= static_cast(line_offsets.size()))) { + return line_offsets[static_cast(line - 2)]; + } + return std::nullopt; +} + +absl::optional> Source::FindLine( + SourcePosition position) const { + if (ABSL_PREDICT_FALSE(position < 0)) { + return std::nullopt; + } + int32_t line = 1; + const auto line_offsets = this->line_offsets(); + for (const auto& line_offset : line_offsets) { + if (line_offset > position) { + break; + } + ++line; + } + if (line == 1) { + return std::make_pair(line, SourcePosition{0}); + } + return std::make_pair(line, line_offsets[static_cast(line) - 2]); +} + +absl::StatusOr NewSource(absl::string_view content, + std::string description) { + return common_internal::NewSourceImpl(std::move(description), content, + content.size()); +} + +absl::StatusOr NewSource(const absl::Cord& content, + std::string description) { + return common_internal::NewSourceImpl(std::move(description), content, + content.size()); +} + +} // namespace cel diff --git a/common/source.h b/common/source.h new file mode 100644 index 000000000..6453363a8 --- /dev/null +++ b/common/source.h @@ -0,0 +1,200 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_SOURCE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_SOURCE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" + +namespace cel { + +namespace common_internal { +class SourceImpl; +} // namespace common_internal + +class Source; + +// SourcePosition represents an offset in source text. +using SourcePosition = int32_t; + +// SourceRange represents a range of positions, where `begin` is inclusive and +// `end` is exclusive. +struct SourceRange final { + SourcePosition begin = -1; + SourcePosition end = -1; +}; + +inline bool operator==(const SourceRange& lhs, const SourceRange& rhs) { + return lhs.begin == rhs.begin && lhs.end == rhs.end; +} + +inline bool operator!=(const SourceRange& lhs, const SourceRange& rhs) { + return !operator==(lhs, rhs); +} + +// `SourceLocation` is a representation of a line and column in source text. +struct SourceLocation final { + int32_t line = -1; // 1-based line number. + int32_t column = -1; // 0-based column number. +}; + +inline bool operator==(const SourceLocation& lhs, const SourceLocation& rhs) { + return lhs.line == rhs.line && lhs.column == rhs.column; +} + +inline bool operator!=(const SourceLocation& lhs, const SourceLocation& rhs) { + return !operator==(lhs, rhs); +} + +// `SourceContentView` is a view of the content owned by `Source`, which is a +// sequence of Unicode code points. +class SourceContentView final { + public: + SourceContentView(const SourceContentView&) = default; + SourceContentView(SourceContentView&&) = default; + SourceContentView& operator=(const SourceContentView&) = default; + SourceContentView& operator=(SourceContentView&&) = default; + + SourcePosition size() const; + + bool empty() const; + + char32_t at(SourcePosition position) const; + + std::string ToString(SourcePosition begin, SourcePosition end) const; + std::string ToString(SourcePosition begin) const { + return ToString(begin, size()); + } + std::string ToString() const { return ToString(0); } + + void AppendToString(std::string& dest) const; + + private: + friend class Source; + + constexpr SourceContentView() = default; + + constexpr explicit SourceContentView(absl::Span view) + : view_(view) {} + + constexpr explicit SourceContentView(absl::Span view) + : view_(view) {} + + constexpr explicit SourceContentView(absl::Span view) + : view_(view) {} + + constexpr explicit SourceContentView(absl::Span view) + : view_(view) {} + + absl::variant, absl::Span, + absl::Span, absl::Span> + view_; +}; + +// `Source` represents the source expression. +class Source { + public: + using ContentView = SourceContentView; + + Source(const Source&) = delete; + Source(Source&&) = delete; + + virtual ~Source() = default; + + Source& operator=(const Source&) = delete; + Source& operator=(Source&&) = delete; + + virtual absl::string_view description() const + ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + + // Maps a `SourcePosition` to a `SourceLocation`. Returns an empty + // `absl::optional` when `SourcePosition` is invalid or the information + // required to perform the mapping is not present. + absl::optional GetLocation(SourcePosition position) const; + + // Maps a `SourceLocation` to a `SourcePosition`. Returns an empty + // `absl::optional` when `SourceLocation` is invalid or the information + // required to perform the mapping is not present. + absl::optional GetPosition( + const SourceLocation& location) const; + + absl::optional Snippet(int32_t line) const; + + // Formats an annotated snippet highlighting an error at location, e.g. + // + // "\n | $SOURCE_SNIPPET" + + // "\n | .......^" + // + // Returns an empty string if location is not a valid location in this source. + std::string DisplayErrorLocation(SourceLocation location) const; + + // Returns a view of the underlying expression text, if present. + virtual ContentView content() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + + // Returns a `absl::Span` of `SourcePosition` which represent the positions + // where new lines occur. + virtual absl::Span line_offsets() const + ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + + protected: + static constexpr ContentView EmptyContentView() { return ContentView(); } + static constexpr ContentView MakeContentView(absl::Span view) { + return ContentView(view); + } + static constexpr ContentView MakeContentView(absl::Span view) { + return ContentView(view); + } + static constexpr ContentView MakeContentView( + absl::Span view) { + return ContentView(view); + } + static constexpr ContentView MakeContentView( + absl::Span view) { + return ContentView(view); + } + + private: + friend class common_internal::SourceImpl; + + Source() = default; + + absl::optional FindLinePosition(int32_t line) const; + + absl::optional> FindLine( + SourcePosition position) const; +}; + +using SourcePtr = std::unique_ptr; + +absl::StatusOr NewSource( + absl::string_view content, std::string description = ""); + +absl::StatusOr NewSource( + const absl::Cord& content, std::string description = ""); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_SOURCE_H_ diff --git a/common/source_test.cc b/common/source_test.cc new file mode 100644 index 000000000..30a2ce9b0 --- /dev/null +++ b/common/source_test.cc @@ -0,0 +1,227 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/source.h" + +#include "absl/strings/cord.h" +#include "absl/types/optional.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Ne; +using ::testing::Optional; + +TEST(SourceRange, Default) { + SourceRange range; + EXPECT_EQ(range.begin, -1); + EXPECT_EQ(range.end, -1); +} + +TEST(SourceRange, Equality) { + EXPECT_THAT((SourceRange{}), (Eq(SourceRange{}))); + EXPECT_THAT((SourceRange{0, 1}), (Ne(SourceRange{0, 0}))); +} + +TEST(SourceLocation, Default) { + SourceLocation location; + EXPECT_EQ(location.line, -1); + EXPECT_EQ(location.column, -1); +} + +TEST(SourceLocation, Equality) { + EXPECT_THAT((SourceLocation{}), (Eq(SourceLocation{}))); + EXPECT_THAT((SourceLocation{1, 1}), (Ne(SourceLocation{1, 0}))); +} + +TEST(StringSource, Description) { + ASSERT_OK_AND_ASSIGN( + auto source, + NewSource("c.d &&\n\t b.c.arg(10) &&\n\t test(10)", "offset-test")); + + EXPECT_THAT(source->description(), Eq("offset-test")); +} + +TEST(StringSource, Content) { + ASSERT_OK_AND_ASSIGN( + auto source, + NewSource("c.d &&\n\t b.c.arg(10) &&\n\t test(10)", "offset-test")); + + EXPECT_THAT(source->content().ToString(), + Eq("c.d &&\n\t b.c.arg(10) &&\n\t test(10)")); +} + +TEST(StringSource, PositionAndLocation) { + ASSERT_OK_AND_ASSIGN( + auto source, + NewSource("c.d &&\n\t b.c.arg(10) &&\n\t test(10)", "offset-test")); + + EXPECT_THAT(source->line_offsets(), ElementsAre(7, 24, 35)); + + auto start = source->GetPosition(SourceLocation{int32_t{1}, int32_t{2}}); + auto end = source->GetPosition(SourceLocation{int32_t{3}, int32_t{2}}); + ASSERT_TRUE(start.has_value()); + ASSERT_TRUE(end.has_value()); + + EXPECT_THAT(source->GetLocation(*start), + Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); + EXPECT_THAT(source->GetLocation(*end), + Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); + EXPECT_THAT(source->GetLocation(-1), Eq(std::nullopt)); + + EXPECT_THAT(source->content().ToString(*start, *end), + Eq("d &&\n\t b.c.arg(10) &&\n\t ")); + + EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), + Eq(std::nullopt)); + EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), + Eq(std::nullopt)); + EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), + Eq(std::nullopt)); +} + +TEST(StringSource, SnippetSingle) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("hello, world", "one-line-test")); + + EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); + EXPECT_THAT(source->Snippet(2), Eq(std::nullopt)); +} + +TEST(StringSource, SnippetMulti) { + ASSERT_OK_AND_ASSIGN(auto source, + NewSource("hello\nworld\nmy\nbub\n", "four-line-test")); + + EXPECT_THAT(source->Snippet(0), Eq(std::nullopt)); + EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); + EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); + EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); + EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); + EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); + EXPECT_THAT(source->Snippet(6), Eq(std::nullopt)); +} + +TEST(CordSource, Description) { + ASSERT_OK_AND_ASSIGN( + auto source, + NewSource(absl::Cord("c.d &&\n\t b.c.arg(10) &&\n\t test(10)"), + "offset-test")); + + EXPECT_THAT(source->description(), Eq("offset-test")); +} + +TEST(CordSource, Content) { + ASSERT_OK_AND_ASSIGN( + auto source, + NewSource(absl::Cord("c.d &&\n\t b.c.arg(10) &&\n\t test(10)"), + "offset-test")); + + EXPECT_THAT(source->content().ToString(), + Eq("c.d &&\n\t b.c.arg(10) &&\n\t test(10)")); +} + +TEST(CordSource, PositionAndLocation) { + ASSERT_OK_AND_ASSIGN( + auto source, + NewSource(absl::Cord("c.d &&\n\t b.c.arg(10) &&\n\t test(10)"), + "offset-test")); + + EXPECT_THAT(source->line_offsets(), ElementsAre(7, 24, 35)); + + auto start = source->GetPosition(SourceLocation{int32_t{1}, int32_t{2}}); + auto end = source->GetPosition(SourceLocation{int32_t{3}, int32_t{2}}); + ASSERT_TRUE(start.has_value()); + ASSERT_TRUE(end.has_value()); + + EXPECT_THAT(source->GetLocation(*start), + Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); + EXPECT_THAT(source->GetLocation(*end), + Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); + EXPECT_THAT(source->GetLocation(-1), Eq(std::nullopt)); + + EXPECT_THAT(source->content().ToString(*start, *end), + Eq("d &&\n\t b.c.arg(10) &&\n\t ")); + + EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), + Eq(std::nullopt)); + EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), + Eq(std::nullopt)); + EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), + Eq(std::nullopt)); +} + +TEST(CordSource, SnippetSingle) { + ASSERT_OK_AND_ASSIGN(auto source, + NewSource(absl::Cord("hello, world"), "one-line-test")); + + EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); + EXPECT_THAT(source->Snippet(2), Eq(std::nullopt)); +} + +TEST(CordSource, SnippetMulti) { + ASSERT_OK_AND_ASSIGN( + auto source, + NewSource(absl::Cord("hello\nworld\nmy\nbub\n"), "four-line-test")); + + EXPECT_THAT(source->Snippet(0), Eq(std::nullopt)); + EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); + EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); + EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); + EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); + EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); + EXPECT_THAT(source->Snippet(6), Eq(std::nullopt)); +} + +TEST(Source, DisplayErrorLocationBasic) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello' +\n 'world'")); + + SourceLocation location{/*line=*/2, /*column=*/3}; + + EXPECT_EQ(source->DisplayErrorLocation(location), + "\n | 'world'" + "\n | ...^"); +} + +TEST(Source, DisplayErrorLocationOutOfRange) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello world!'")); + + SourceLocation location{/*line=*/3, /*column=*/3}; + + EXPECT_EQ(source->DisplayErrorLocation(location), ""); +} + +TEST(Source, DisplayErrorLocationTabsShortened) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello' +\n\t\t'world!'")); + + SourceLocation location{/*line=*/2, /*column=*/4}; + + EXPECT_EQ(source->DisplayErrorLocation(location), + "\n | 'world!'" + "\n | ....^"); +} + +TEST(Source, DisplayErrorLocationFullWidth) { + ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello'")); + + SourceLocation location{/*line=*/1, /*column=*/2}; + + EXPECT_EQ(source->DisplayErrorLocation(location), + "\n | 'Hello'" + "\n | ..^"); +} + +} // namespace +} // namespace cel diff --git a/common/standard_definitions.h b/common/standard_definitions.h new file mode 100644 index 000000000..eea185f6b --- /dev/null +++ b/common/standard_definitions.h @@ -0,0 +1,349 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Constants used for standard definitions for CEL. +#ifndef THIRD_PARTY_CEL_CPP_COMMON_STANDARD_DEFINITIONS_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_STANDARD_DEFINITIONS_H_ + +#include "absl/strings/string_view.h" + +namespace cel { + +// Standard function names as represented in an AST. +// TODO(uncreated-issue/71): use a namespace instead of a class. +struct StandardFunctions { + // Comparison + static constexpr absl::string_view kEqual = "_==_"; + static constexpr absl::string_view kInequal = "_!=_"; + static constexpr absl::string_view kLess = "_<_"; + static constexpr absl::string_view kLessOrEqual = "_<=_"; + static constexpr absl::string_view kGreater = "_>_"; + static constexpr absl::string_view kGreaterOrEqual = "_>=_"; + + // Logical + static constexpr absl::string_view kAnd = "_&&_"; + static constexpr absl::string_view kOr = "_||_"; + static constexpr absl::string_view kNot = "!_"; + + // Strictness + static constexpr absl::string_view kNotStrictlyFalse = "@not_strictly_false"; + // Deprecated '__not_strictly_false__' function. Preserved for backwards + // compatibility with stored expressions. + static constexpr absl::string_view kNotStrictlyFalseDeprecated = + "__not_strictly_false__"; + + // Arithmetical + static constexpr absl::string_view kAdd = "_+_"; + static constexpr absl::string_view kSubtract = "_-_"; + static constexpr absl::string_view kNeg = "-_"; + static constexpr absl::string_view kMultiply = "_*_"; + static constexpr absl::string_view kDivide = "_/_"; + static constexpr absl::string_view kModulo = "_%_"; + + // String operations + static constexpr absl::string_view kRegexMatch = "matches"; + static constexpr absl::string_view kStringContains = "contains"; + static constexpr absl::string_view kStringEndsWith = "endsWith"; + static constexpr absl::string_view kStringStartsWith = "startsWith"; + + // Container operations + static constexpr absl::string_view kIn = "@in"; + // Deprecated '_in_' operator. Preserved for backwards compatibility with + // stored expressions. + static constexpr absl::string_view kInDeprecated = "_in_"; + // Deprecated 'in()' function. Preserved for backwards compatibility with + // stored expressions. + static constexpr absl::string_view kInFunction = "in"; + static constexpr absl::string_view kIndex = "_[_]"; + static constexpr absl::string_view kSize = "size"; + + static constexpr absl::string_view kTernary = "_?_:_"; + + // Timestamp and Duration + static constexpr absl::string_view kDuration = "duration"; + static constexpr absl::string_view kTimestamp = "timestamp"; + static constexpr absl::string_view kFullYear = "getFullYear"; + static constexpr absl::string_view kMonth = "getMonth"; + static constexpr absl::string_view kDayOfYear = "getDayOfYear"; + static constexpr absl::string_view kDayOfMonth = "getDayOfMonth"; + static constexpr absl::string_view kDate = "getDate"; + static constexpr absl::string_view kDayOfWeek = "getDayOfWeek"; + static constexpr absl::string_view kHours = "getHours"; + static constexpr absl::string_view kMinutes = "getMinutes"; + static constexpr absl::string_view kSeconds = "getSeconds"; + static constexpr absl::string_view kMilliseconds = "getMilliseconds"; + + // Type conversions + static constexpr absl::string_view kBool = "bool"; + static constexpr absl::string_view kBytes = "bytes"; + static constexpr absl::string_view kDouble = "double"; + static constexpr absl::string_view kDyn = "dyn"; + static constexpr absl::string_view kInt = "int"; + static constexpr absl::string_view kString = "string"; + static constexpr absl::string_view kType = "type"; + static constexpr absl::string_view kUint = "uint"; + + // Runtime-only functions. + // The convention for runtime-only functions where only the runtime needs to + // differentiate behavior is to prefix the function with `#`. + // Note, this is a different convention from CEL internal functions where the + // whole stack needs to be aware of the function id. + static constexpr absl::string_view kRuntimeListAppend = "#list_append"; +}; + +// Standard overload IDs used by type checkers. +// TODO(uncreated-issue/71): use a namespace instead of a class. +struct StandardOverloadIds { + // Add operator _+_ + static constexpr absl::string_view kAddInt = "add_int64"; + static constexpr absl::string_view kAddUint = "add_uint64"; + static constexpr absl::string_view kAddDouble = "add_double"; + static constexpr absl::string_view kAddDurationDuration = + "add_duration_duration"; + static constexpr absl::string_view kAddDurationTimestamp = + "add_duration_timestamp"; + static constexpr absl::string_view kAddTimestampDuration = + "add_timestamp_duration"; + static constexpr absl::string_view kAddString = "add_string"; + static constexpr absl::string_view kAddBytes = "add_bytes"; + static constexpr absl::string_view kAddList = "add_list"; + // Subtract operator _-_ + static constexpr absl::string_view kSubtractInt = "subtract_int64"; + static constexpr absl::string_view kSubtractUint = "subtract_uint64"; + static constexpr absl::string_view kSubtractDouble = "subtract_double"; + static constexpr absl::string_view kSubtractDurationDuration = + "subtract_duration_duration"; + static constexpr absl::string_view kSubtractTimestampDuration = + "subtract_timestamp_duration"; + static constexpr absl::string_view kSubtractTimestampTimestamp = + "subtract_timestamp_timestamp"; + // Multiply operator _*_ + static constexpr absl::string_view kMultiplyInt = "multiply_int64"; + static constexpr absl::string_view kMultiplyUint = "multiply_uint64"; + static constexpr absl::string_view kMultiplyDouble = "multiply_double"; + // Division operator _/_ + static constexpr absl::string_view kDivideInt = "divide_int64"; + static constexpr absl::string_view kDivideUint = "divide_uint64"; + static constexpr absl::string_view kDivideDouble = "divide_double"; + // Modulo operator _%_ + static constexpr absl::string_view kModuloInt = "modulo_int64"; + static constexpr absl::string_view kModuloUint = "modulo_uint64"; + // Negation operator -_ + static constexpr absl::string_view kNegateInt = "negate_int64"; + static constexpr absl::string_view kNegateDouble = "negate_double"; + // Logical operators + static constexpr absl::string_view kNot = "logical_not"; + static constexpr absl::string_view kAnd = "logical_and"; + static constexpr absl::string_view kOr = "logical_or"; + static constexpr absl::string_view kConditional = "conditional"; + // Comprehension logic + static constexpr absl::string_view kNotStrictlyFalse = "not_strictly_false"; + static constexpr absl::string_view kNotStrictlyFalseDeprecated = + "__not_strictly_false__"; + // Equality operators + static constexpr absl::string_view kEquals = "equals"; + static constexpr absl::string_view kNotEquals = "not_equals"; + // Relational operators + static constexpr absl::string_view kLessBool = "less_bool"; + static constexpr absl::string_view kLessString = "less_string"; + static constexpr absl::string_view kLessBytes = "less_bytes"; + static constexpr absl::string_view kLessDuration = "less_duration"; + static constexpr absl::string_view kLessTimestamp = "less_timestamp"; + static constexpr absl::string_view kLessInt = "less_int64"; + static constexpr absl::string_view kLessIntUint = "less_int64_uint64"; + static constexpr absl::string_view kLessIntDouble = "less_int64_double"; + static constexpr absl::string_view kLessDouble = "less_double"; + static constexpr absl::string_view kLessDoubleInt = "less_double_int64"; + static constexpr absl::string_view kLessDoubleUint = "less_double_uint64"; + static constexpr absl::string_view kLessUint = "less_uint64"; + static constexpr absl::string_view kLessUintInt = "less_uint64_int64"; + static constexpr absl::string_view kLessUintDouble = "less_uint64_double"; + static constexpr absl::string_view kGreaterBool = "greater_bool"; + static constexpr absl::string_view kGreaterString = "greater_string"; + static constexpr absl::string_view kGreaterBytes = "greater_bytes"; + static constexpr absl::string_view kGreaterDuration = "greater_duration"; + static constexpr absl::string_view kGreaterTimestamp = "greater_timestamp"; + static constexpr absl::string_view kGreaterInt = "greater_int64"; + static constexpr absl::string_view kGreaterIntUint = "greater_int64_uint64"; + static constexpr absl::string_view kGreaterIntDouble = "greater_int64_double"; + static constexpr absl::string_view kGreaterDouble = "greater_double"; + static constexpr absl::string_view kGreaterDoubleInt = "greater_double_int64"; + static constexpr absl::string_view kGreaterDoubleUint = + "greater_double_uint64"; + static constexpr absl::string_view kGreaterUint = "greater_uint64"; + static constexpr absl::string_view kGreaterUintInt = "greater_uint64_int64"; + static constexpr absl::string_view kGreaterUintDouble = + "greater_uint64_double"; + static constexpr absl::string_view kGreaterEqualsBool = "greater_equals_bool"; + static constexpr absl::string_view kGreaterEqualsString = + "greater_equals_string"; + static constexpr absl::string_view kGreaterEqualsBytes = + "greater_equals_bytes"; + static constexpr absl::string_view kGreaterEqualsDuration = + "greater_equals_duration"; + static constexpr absl::string_view kGreaterEqualsTimestamp = + "greater_equals_timestamp"; + static constexpr absl::string_view kGreaterEqualsInt = "greater_equals_int64"; + static constexpr absl::string_view kGreaterEqualsIntUint = + "greater_equals_int64_uint64"; + static constexpr absl::string_view kGreaterEqualsIntDouble = + "greater_equals_int64_double"; + static constexpr absl::string_view kGreaterEqualsDouble = + "greater_equals_double"; + static constexpr absl::string_view kGreaterEqualsDoubleInt = + "greater_equals_double_int64"; + static constexpr absl::string_view kGreaterEqualsDoubleUint = + "greater_equals_double_uint64"; + static constexpr absl::string_view kGreaterEqualsUint = + "greater_equals_uint64"; + static constexpr absl::string_view kGreaterEqualsUintInt = + "greater_equals_uint64_int64"; + static constexpr absl::string_view kGreaterEqualsUintDouble = + "greater_equals_uint_double"; + static constexpr absl::string_view kLessEqualsBool = "less_equals_bool"; + static constexpr absl::string_view kLessEqualsString = "less_equals_string"; + static constexpr absl::string_view kLessEqualsBytes = "less_equals_bytes"; + static constexpr absl::string_view kLessEqualsDuration = + "less_equals_duration"; + static constexpr absl::string_view kLessEqualsTimestamp = + "less_equals_timestamp"; + static constexpr absl::string_view kLessEqualsInt = "less_equals_int64"; + static constexpr absl::string_view kLessEqualsIntUint = + "less_equals_int64_uint64"; + static constexpr absl::string_view kLessEqualsIntDouble = + "less_equals_int64_double"; + static constexpr absl::string_view kLessEqualsDouble = "less_equals_double"; + static constexpr absl::string_view kLessEqualsDoubleInt = + "less_equals_double_int64"; + static constexpr absl::string_view kLessEqualsDoubleUint = + "less_equals_double_uint64"; + static constexpr absl::string_view kLessEqualsUint = "less_equals_uint64"; + static constexpr absl::string_view kLessEqualsUintInt = + "less_equals_uint64_int64"; + static constexpr absl::string_view kLessEqualsUintDouble = + "less_equals_uint64_double"; + // Container operators + static constexpr absl::string_view kIndexList = "index_list"; + static constexpr absl::string_view kIndexMap = "index_map"; + static constexpr absl::string_view kInList = "in_list"; + static constexpr absl::string_view kInMap = "in_map"; + static constexpr absl::string_view kSizeBytes = "size_bytes"; + static constexpr absl::string_view kSizeList = "size_list"; + static constexpr absl::string_view kSizeMap = "size_map"; + static constexpr absl::string_view kSizeString = "size_string"; + static constexpr absl::string_view kSizeBytesMember = "bytes_size"; + static constexpr absl::string_view kSizeListMember = "list_size"; + static constexpr absl::string_view kSizeMapMember = "map_size"; + static constexpr absl::string_view kSizeStringMember = "string_size"; + // String functions + static constexpr absl::string_view kContainsString = "contains_string"; + static constexpr absl::string_view kEndsWithString = "ends_with_string"; + static constexpr absl::string_view kStartsWithString = "starts_with_string"; + // String RE2 functions + static constexpr absl::string_view kMatches = "matches"; + static constexpr absl::string_view kMatchesMember = "matches_string"; + // Timestamp / duration accessors + static constexpr absl::string_view kTimestampToYear = "timestamp_to_year"; + static constexpr absl::string_view kTimestampToYearWithTz = + "timestamp_to_year_with_tz"; + static constexpr absl::string_view kTimestampToMonth = "timestamp_to_month"; + static constexpr absl::string_view kTimestampToMonthWithTz = + "timestamp_to_month_with_tz"; + static constexpr absl::string_view kTimestampToDayOfYear = + "timestamp_to_day_of_year"; + static constexpr absl::string_view kTimestampToDayOfYearWithTz = + "timestamp_to_day_of_year_with_tz"; + static constexpr absl::string_view kTimestampToDayOfMonth = + "timestamp_to_day_of_month"; + static constexpr absl::string_view kTimestampToDayOfMonthWithTz = + "timestamp_to_day_of_month_with_tz"; + static constexpr absl::string_view kTimestampToDayOfWeek = + "timestamp_to_day_of_week"; + static constexpr absl::string_view kTimestampToDayOfWeekWithTz = + "timestamp_to_day_of_week_with_tz"; + static constexpr absl::string_view kTimestampToDate = + "timestamp_to_day_of_month_1_based"; + static constexpr absl::string_view kTimestampToDateWithTz = + "timestamp_to_day_of_month_1_based_with_tz"; + static constexpr absl::string_view kTimestampToHours = "timestamp_to_hours"; + static constexpr absl::string_view kTimestampToHoursWithTz = + "timestamp_to_hours_with_tz"; + static constexpr absl::string_view kDurationToHours = "duration_to_hours"; + static constexpr absl::string_view kTimestampToMinutes = + "timestamp_to_minutes"; + static constexpr absl::string_view kTimestampToMinutesWithTz = + "timestamp_to_minutes_with_tz"; + static constexpr absl::string_view kDurationToMinutes = "duration_to_minutes"; + static constexpr absl::string_view kTimestampToSeconds = + "timestamp_to_seconds"; + static constexpr absl::string_view kTimestampToSecondsWithTz = + "timestamp_to_seconds_tz"; + static constexpr absl::string_view kDurationToSeconds = "duration_to_seconds"; + static constexpr absl::string_view kTimestampToMilliseconds = + "timestamp_to_milliseconds"; + static constexpr absl::string_view kTimestampToMillisecondsWithTz = + "timestamp_to_milliseconds_with_tz"; + static constexpr absl::string_view kDurationToMilliseconds = + "duration_to_milliseconds"; + // Type conversions + static constexpr absl::string_view kToDyn = "to_dyn"; + // to_uint + static constexpr absl::string_view kUintToUint = "uint64_to_uint64"; + static constexpr absl::string_view kDoubleToUint = "double_to_uint64"; + static constexpr absl::string_view kIntToUint = "int64_to_uint64"; + static constexpr absl::string_view kStringToUint = "string_to_uint64"; + // to_int + static constexpr absl::string_view kUintToInt = "uint64_to_int64"; + static constexpr absl::string_view kDoubleToInt = "double_to_int64"; + static constexpr absl::string_view kIntToInt = "int64_to_int64"; + static constexpr absl::string_view kStringToInt = "string_to_int64"; + static constexpr absl::string_view kTimestampToInt = "timestamp_to_int64"; + static constexpr absl::string_view kDurationToInt = "duration_to_int64"; + // to_double + static constexpr absl::string_view kDoubleToDouble = "double_to_double"; + static constexpr absl::string_view kUintToDouble = "uint64_to_double"; + static constexpr absl::string_view kIntToDouble = "int64_to_double"; + static constexpr absl::string_view kStringToDouble = "string_to_double"; + // to_bool + static constexpr absl::string_view kBoolToBool = "bool_to_bool"; + static constexpr absl::string_view kStringToBool = "string_to_bool"; + // to_bytes + static constexpr absl::string_view kBytesToBytes = "bytes_to_bytes"; + static constexpr absl::string_view kStringToBytes = "string_to_bytes"; + // to_string + static constexpr absl::string_view kStringToString = "string_to_string"; + static constexpr absl::string_view kBytesToString = "bytes_to_string"; + static constexpr absl::string_view kBoolToString = "bool_to_string"; + static constexpr absl::string_view kDoubleToString = "double_to_string"; + static constexpr absl::string_view kIntToString = "int64_to_string"; + static constexpr absl::string_view kUintToString = "uint64_to_string"; + static constexpr absl::string_view kDurationToString = "duration_to_string"; + static constexpr absl::string_view kTimestampToString = "timestamp_to_string"; + // to_timestamp + static constexpr absl::string_view kTimestampToTimestamp = + "timestamp_to_timestamp"; + static constexpr absl::string_view kIntToTimestamp = "int64_to_timestamp"; + static constexpr absl::string_view kStringToTimestamp = "string_to_timestamp"; + // to_duration + static constexpr absl::string_view kDurationToDuration = + "duration_to_duration"; + static constexpr absl::string_view kIntToDuration = "int64_to_duration"; + static constexpr absl::string_view kStringToDuration = "string_to_duration"; + // to_type + static constexpr absl::string_view kToType = "type"; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_STANDARD_DEFINITIONS_H_ diff --git a/common/type.cc b/common/type.cc index ad9fa0ddc..9ea85954c 100644 --- a/common/type.cc +++ b/common/type.cc @@ -1,120 +1,732 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "common/type.h" -#include "absl/memory/memory.h" +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "common/type_kind.h" +#include "common/types/types.h" +#include "google/protobuf/descriptor.h" -namespace google { -namespace api { -namespace expr { -namespace common { +namespace cel { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::FieldDescriptor; + +Type Type::Message(const Descriptor* absl_nonnull descriptor) { + switch (descriptor->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return BoolWrapperType(); + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + return IntWrapperType(); + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return UintWrapperType(); + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + return DoubleWrapperType(); + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return BytesWrapperType(); + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return StringWrapperType(); + case Descriptor::WELLKNOWNTYPE_ANY: + return AnyType(); + case Descriptor::WELLKNOWNTYPE_DURATION: + return DurationType(); + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return TimestampType(); + case Descriptor::WELLKNOWNTYPE_VALUE: + return DynType(); + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + return ListType(); + case Descriptor::WELLKNOWNTYPE_STRUCT: + return JsonMapType(); + default: + return MessageType(descriptor); + } +} + +Type Type::Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor) { + if (descriptor->full_name() == "google.protobuf.NullValue") { + // Special case NullValue to prevent the emebedder providing a different + // descriptor for it and it leaking. + return IntType(); + } + return EnumType(descriptor); +} namespace { -constexpr const std::size_t kBasicTypeNamesSize = 10; -const auto* kBasicTypeNames = new std::array({ - "null_type", // kNull - "bool", // kBool - "int", // kInt - "uint", // kUInt - "double", // kDouble - "string", // kString, - "bytes", // kBytes - "type", // kType - "map", // kMap - "list", // kList -}); - -static_assert(kBasicTypeNamesSize == - static_cast(BasicTypeValue::DO_NOT_USE), - "unexpected size"); - -static const std::map* const kBasicTypeMap = - []() { - auto result = new std::map(); - for (std::size_t i = 0; i < kBasicTypeNames->size(); ++i) { - result->emplace(kBasicTypeNames->at(i), - BasicType(static_cast(i))); - } - return result; - }(); - -struct ToStringVisitor { - template - const std::string& operator()(const T& value) { - return value.ToString(); - } -}; - -struct FullNameVisitor { - template - absl::string_view operator()(const T& value) { - return value.full_name(); - } -}; +static constexpr std::array kTypeToKindArray = { + TypeKind::kDyn, TypeKind::kAny, TypeKind::kBool, + TypeKind::kBoolWrapper, TypeKind::kBytes, TypeKind::kBytesWrapper, + TypeKind::kDouble, TypeKind::kDoubleWrapper, TypeKind::kDuration, + TypeKind::kEnum, TypeKind::kError, TypeKind::kFunction, + TypeKind::kInt, TypeKind::kIntWrapper, TypeKind::kList, + TypeKind::kMap, TypeKind::kNull, TypeKind::kOpaque, + TypeKind::kString, TypeKind::kStringWrapper, TypeKind::kStruct, + TypeKind::kStruct, TypeKind::kTimestamp, TypeKind::kTypeParam, + TypeKind::kType, TypeKind::kUint, TypeKind::kUintWrapper, + TypeKind::kUnknown}; + +static_assert(kTypeToKindArray.size() == + std::variant_size(), + "Kind indexer must match variant declaration for cel::Type."); } // namespace -const std::string& BasicType::ToString() const { - return kBasicTypeNames->at(static_cast(value_)); +TypeKind Type::kind() const { return kTypeToKindArray[variant_.index()]; } + +absl::string_view Type::name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return absl::visit( + [](const auto& alternative) -> absl::string_view { + return alternative.name(); + }, + variant_); +} + +std::string Type::DebugString() const { + return absl::visit( + [](const auto& alternative) -> std::string { + return alternative.DebugString(); + }, + variant_); } -std::unique_ptr ObjectType::Unpack( - const google::protobuf::Any& value) { - auto msg = absl::WrapUnique( - google::protobuf::MessageFactory::generated_factory()->GetPrototype(value_)->New()); - if (!value.UnpackTo(msg.get())) { - return nullptr; +TypeParameters Type::GetParameters() const { + return absl::visit( + [](const auto& alternative) -> TypeParameters { + return alternative.GetParameters(); + }, + variant_); +} + +bool operator==(const Type& lhs, const Type& rhs) { + if (lhs.IsStruct() && rhs.IsStruct()) { + return lhs.GetStruct() == rhs.GetStruct(); + } else if (lhs.IsStruct() || rhs.IsStruct()) { + return false; + } else { + return lhs.variant_ == rhs.variant_; + } +} + +common_internal::StructTypeVariant Type::ToStructTypeVariant() const { + if (const auto* other = absl::get_if(&variant_); + other != nullptr) { + return common_internal::StructTypeVariant(*other); + } + if (const auto* other = + absl::get_if(&variant_); + other != nullptr) { + return common_internal::StructTypeVariant(*other); } - return msg; + return common_internal::StructTypeVariant(); +} + +namespace { + +template +absl::optional GetOrNullopt(const common_internal::TypeVariant& variant) { + if (const auto* alt = absl::get_if(&variant); alt != nullptr) { + return *alt; + } + return std::nullopt; +} + +} // namespace + +absl::optional Type::AsAny() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsBool() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsBoolWrapper() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsBytes() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsBytesWrapper() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsDouble() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsDoubleWrapper() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsDuration() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsDyn() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsEnum() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsError() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsFunction() const { + return GetOrNullopt(variant_); } -UnrecognizedType::UnrecognizedType(absl::string_view full_name) - : string_rep_(absl::StrCat("type(\"", full_name, "\")")), - hash_code_(internal::Hash(full_name)) { - assert(google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - std::string(full_name)) == nullptr); +absl::optional Type::AsInt() const { + return GetOrNullopt(variant_); } -absl::string_view UnrecognizedType::full_name() const { - return absl::string_view(string_rep_).substr(6, string_rep_.size() - 8); +absl::optional Type::AsIntWrapper() const { + return GetOrNullopt(variant_); } -Type::Type(const std::string& full_name) - : data_(BasicType(BasicTypeValue::kNull)) { - auto itr = kBasicTypeMap->find(full_name); - if (itr != kBasicTypeMap->end()) { - data_ = itr->second; - return; +absl::optional Type::AsList() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsMap() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsMessage() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsNull() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsOpaque() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsOptional() const { + if (auto maybe_opaque = AsOpaque(); maybe_opaque.has_value()) { + return maybe_opaque->AsOptional(); } + return std::nullopt; +} + +absl::optional Type::AsString() const { + return GetOrNullopt(variant_); +} - auto obj_desc = - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - full_name); - if (obj_desc != nullptr) { - data_ = ObjectType(obj_desc); - return; +absl::optional Type::AsStringWrapper() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsStruct() const { + if (const auto* alt = + absl::get_if(&variant_); + alt != nullptr) { + return *alt; + } + if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { + return *alt; } + return std::nullopt; +} + +absl::optional Type::AsTimestamp() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsTypeParam() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsType() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsUint() const { + return GetOrNullopt(variant_); +} - auto enum_desc = - google::protobuf::DescriptorPool::generated_pool()->FindEnumTypeByName(full_name); - if (enum_desc != nullptr) { - data_ = EnumType(enum_desc); - return; +absl::optional Type::AsUintWrapper() const { + return GetOrNullopt(variant_); +} + +absl::optional Type::AsUnknown() const { + return GetOrNullopt(variant_); +} + +namespace { + +template +T GetOrDie(const common_internal::TypeVariant& variant) { + return absl::get(variant); +} + +} // namespace + +AnyType Type::GetAny() const { + ABSL_DCHECK(IsAny()) << DebugString(); + return GetOrDie(variant_); +} + +BoolType Type::GetBool() const { + ABSL_DCHECK(IsBool()) << DebugString(); + return GetOrDie(variant_); +} + +BoolWrapperType Type::GetBoolWrapper() const { + ABSL_DCHECK(IsBoolWrapper()) << DebugString(); + return GetOrDie(variant_); +} + +BytesType Type::GetBytes() const { + ABSL_DCHECK(IsBytes()) << DebugString(); + return GetOrDie(variant_); +} + +BytesWrapperType Type::GetBytesWrapper() const { + ABSL_DCHECK(IsBytesWrapper()) << DebugString(); + return GetOrDie(variant_); +} + +DoubleType Type::GetDouble() const { + ABSL_DCHECK(IsDouble()) << DebugString(); + return GetOrDie(variant_); +} + +DoubleWrapperType Type::GetDoubleWrapper() const { + ABSL_DCHECK(IsDoubleWrapper()) << DebugString(); + return GetOrDie(variant_); +} + +DurationType Type::GetDuration() const { + ABSL_DCHECK(IsDuration()) << DebugString(); + return GetOrDie(variant_); +} + +DynType Type::GetDyn() const { + ABSL_DCHECK(IsDyn()) << DebugString(); + return GetOrDie(variant_); +} + +EnumType Type::GetEnum() const { + ABSL_DCHECK(IsEnum()) << DebugString(); + return GetOrDie(variant_); +} + +ErrorType Type::GetError() const { + ABSL_DCHECK(IsError()) << DebugString(); + return GetOrDie(variant_); +} + +FunctionType Type::GetFunction() const { + ABSL_DCHECK(IsFunction()) << DebugString(); + return GetOrDie(variant_); +} + +IntType Type::GetInt() const { + ABSL_DCHECK(IsInt()) << DebugString(); + return GetOrDie(variant_); +} + +IntWrapperType Type::GetIntWrapper() const { + ABSL_DCHECK(IsIntWrapper()) << DebugString(); + return GetOrDie(variant_); +} + +ListType Type::GetList() const { + ABSL_DCHECK(IsList()) << DebugString(); + return GetOrDie(variant_); +} + +MapType Type::GetMap() const { + ABSL_DCHECK(IsMap()) << DebugString(); + return GetOrDie(variant_); +} + +MessageType Type::GetMessage() const { + ABSL_DCHECK(IsMessage()) << DebugString(); + return GetOrDie(variant_); +} + +NullType Type::GetNull() const { + ABSL_DCHECK(IsNull()) << DebugString(); + return GetOrDie(variant_); +} + +OpaqueType Type::GetOpaque() const { + ABSL_DCHECK(IsOpaque()) << DebugString(); + return GetOrDie(variant_); +} + +OptionalType Type::GetOptional() const { + ABSL_DCHECK(IsOptional()) << DebugString(); + return GetOrDie(variant_).GetOptional(); +} + +StringType Type::GetString() const { + ABSL_DCHECK(IsString()) << DebugString(); + return GetOrDie(variant_); +} + +StringWrapperType Type::GetStringWrapper() const { + ABSL_DCHECK(IsStringWrapper()) << DebugString(); + return GetOrDie(variant_); +} + +StructType Type::GetStruct() const { + ABSL_DCHECK(IsStruct()) << DebugString(); + if (const auto* alt = + absl::get_if(&variant_); + alt != nullptr) { + return *alt; } + if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { + return *alt; + } + return StructType(); +} - auto value = UnrecognizedType(full_name); - data_ = value; +TimestampType Type::GetTimestamp() const { + ABSL_DCHECK(IsTimestamp()) << DebugString(); + return GetOrDie(variant_); } -absl::string_view Type::full_name() const { - return absl::visit(FullNameVisitor(), data_); +TypeParamType Type::GetTypeParam() const { + ABSL_DCHECK(IsTypeParam()) << DebugString(); + return GetOrDie(variant_); } -const std::string& Type::ToString() const { - return absl::visit(ToStringVisitor(), data_); +TypeType Type::GetType() const { + ABSL_DCHECK(IsType()) << DebugString(); + return GetOrDie(variant_); } -} // namespace common -} // namespace expr -} // namespace api -} // namespace google +UintType Type::GetUint() const { + ABSL_DCHECK(IsUint()) << DebugString(); + return GetOrDie(variant_); +} + +UintWrapperType Type::GetUintWrapper() const { + ABSL_DCHECK(IsUintWrapper()) << DebugString(); + return GetOrDie(variant_); +} + +UnknownType Type::GetUnknown() const { + ABSL_DCHECK(IsUnknown()) << DebugString(); + return GetOrDie(variant_); +} + +Type Type::Unwrap() const { + switch (kind()) { + case TypeKind::kBoolWrapper: + return BoolType(); + case TypeKind::kIntWrapper: + return IntType(); + case TypeKind::kUintWrapper: + return UintType(); + case TypeKind::kDoubleWrapper: + return DoubleType(); + case TypeKind::kBytesWrapper: + return BytesType(); + case TypeKind::kStringWrapper: + return StringType(); + default: + return *this; + } +} + +Type Type::Wrap() const { + switch (kind()) { + case TypeKind::kBool: + return BoolWrapperType(); + case TypeKind::kInt: + return IntWrapperType(); + case TypeKind::kUint: + return UintWrapperType(); + case TypeKind::kDouble: + return DoubleWrapperType(); + case TypeKind::kBytes: + return BytesWrapperType(); + case TypeKind::kString: + return StringWrapperType(); + default: + return *this; + } +} + +namespace common_internal { + +Type SingularMessageFieldType( + const google::protobuf::FieldDescriptor* absl_nonnull descriptor) { + ABSL_DCHECK(!descriptor->is_map()); + switch (descriptor->type()) { + case FieldDescriptor::TYPE_BOOL: + return BoolType(); + case FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT64: + return IntType(); + case FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT64: + return UintType(); + case FieldDescriptor::TYPE_FLOAT: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_DOUBLE: + return DoubleType(); + case FieldDescriptor::TYPE_BYTES: + return BytesType(); + case FieldDescriptor::TYPE_STRING: + return StringType(); + case FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_MESSAGE: + return Type::Message(descriptor->message_type()); + case FieldDescriptor::TYPE_ENUM: + return Type::Enum(descriptor->enum_type()); + default: + return Type(); + } +} + +std::string BasicStructTypeField::DebugString() const { + if (!name().empty() && number() >= 1) { + return absl::StrCat("[", number(), "]", name()); + } + if (!name().empty()) { + return std::string(name()); + } + if (number() >= 1) { + return absl::StrCat(number()); + } + return std::string(); +} + +} // namespace common_internal + +Type Type::Field(const google::protobuf::FieldDescriptor* absl_nonnull descriptor) { + if (descriptor->is_map()) { + return MapType(descriptor->message_type()); + } + if (descriptor->is_repeated()) { + return ListType(descriptor); + } + return common_internal::SingularMessageFieldType(descriptor); +} + +std::string StructTypeField::DebugString() const { + return absl::visit( + [](const auto& alternative) -> std::string { + return alternative.DebugString(); + }, + variant_); +} + +absl::string_view StructTypeField::name() const { + return absl::visit( + [](const auto& alternative) -> absl::string_view { + return alternative.name(); + }, + variant_); +} + +int32_t StructTypeField::number() const { + return absl::visit( + [](const auto& alternative) -> int32_t { return alternative.number(); }, + variant_); +} + +Type StructTypeField::GetType() const { + return absl::visit( + [](const auto& alternative) -> Type { return alternative.GetType(); }, + variant_); +} + +StructTypeField::operator bool() const { + return absl::visit( + [](const auto& alternative) -> bool { + return static_cast(alternative); + }, + variant_); +} + +absl::optional StructTypeField::AsMessage() const { + if (const auto* alternative = absl::get_if(&variant_); + alternative != nullptr) { + return *alternative; + } + return std::nullopt; +} + +StructTypeField::operator MessageTypeField() const { + ABSL_DCHECK(IsMessage()); + return absl::get(variant_); +} + +TypeParameters::TypeParameters(absl::Span types) + : size_(types.size()) { + if (size_ <= 2) { + std::memcpy(&internal_[0], types.data(), size_ * sizeof(Type)); + } else { + external_ = types.data(); + } +} + +TypeParameters::TypeParameters(const Type& element) : size_(1) { + std::memcpy(&internal_[0], &element, sizeof(element)); +} + +TypeParameters::TypeParameters(const Type& key, const Type& value) : size_(2) { + std::memcpy(&internal_[0], &key, sizeof(key)); + std::memcpy(&internal_[0] + sizeof(key), &value, sizeof(value)); +} + +namespace common_internal { + +namespace { + +constexpr absl::string_view kNullTypeName = "null_type"; +constexpr absl::string_view kBoolTypeName = "bool"; +constexpr absl::string_view kInt64TypeName = "int"; +constexpr absl::string_view kUInt64TypeName = "uint"; +constexpr absl::string_view kDoubleTypeName = "double"; +constexpr absl::string_view kStringTypeName = "string"; +constexpr absl::string_view kBytesTypeName = "bytes"; +constexpr absl::string_view kListTypeName = "list"; +constexpr absl::string_view kMapTypeName = "map"; +constexpr absl::string_view kCelTypeTypeName = "type"; + +} // namespace + +Type LegacyRuntimeType(absl::string_view name) { + if (name == kNullTypeName) { + return NullType{}; + } + if (name == kBoolTypeName) { + return BoolType{}; + } + if (name == kInt64TypeName) { + return IntType{}; + } + if (name == kUInt64TypeName) { + return UintType{}; + } + if (name == kDoubleTypeName) { + return DoubleType{}; + } + if (name == kStringTypeName) { + return StringType{}; + } + if (name == kBytesTypeName) { + return BytesType{}; + } + if (name == kListTypeName) { + return ListType{}; + } + if (name == kMapTypeName) { + return MapType{}; + } + if (name == kCelTypeTypeName) { + return TypeType{}; + } + if (cel::IsWellKnownMessageType(name)) { + if (name == "google.protobuf.Any") { + return AnyType(); + } + if (name == "google.protobuf.BoolValue") { + return BoolWrapperType(); + } + if (name == "google.protobuf.BytesValue") { + return BytesWrapperType(); + } + if (name == "google.protobuf.DoubleValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Duration") { + return DurationType(); + } + if (name == "google.protobuf.FloatValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Int32Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.Int64Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.ListValue") { + return ListType(); + } + if (name == "google.protobuf.StringValue") { + return StringWrapperType(); + } + if (name == "google.protobuf.Struct") { + return JsonMapType(); + } + if (name == "google.protobuf.Timestamp") { + return TimestampType(); + } + if (name == "google.protobuf.UInt32Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.UInt64Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.Value") { + return DynType(); + } + } + return common_internal::MakeBasicStructType(name); +} + +} // namespace common_internal + +} // namespace cel diff --git a/common/type.h b/common/type.h index ce7a5ce84..c8851dd4e 100644 --- a/common/type.h +++ b/common/type.h @@ -1,183 +1,1302 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_H_ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "google/protobuf/descriptor.h" +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/meta/type_traits.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "absl/types/variant.h" -#include "internal/handle.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -/** The basic value types. */ -enum class BasicTypeValue { - kNull, - kBool, - kInt, - kUint, - kDouble, - kString, - kBytes, - kType, - kMap, - kList, - - // Special value to require 'default' case in switch statements. - DO_NOT_USE -}; +#include "absl/utility/utility.h" +#include "common/type_kind.h" +#include "common/types/any_type.h" // IWYU pragma: export +#include "common/types/bool_type.h" // IWYU pragma: export +#include "common/types/bool_wrapper_type.h" // IWYU pragma: export +#include "common/types/bytes_type.h" // IWYU pragma: export +#include "common/types/bytes_wrapper_type.h" // IWYU pragma: export +#include "common/types/double_type.h" // IWYU pragma: export +#include "common/types/double_wrapper_type.h" // IWYU pragma: export +#include "common/types/duration_type.h" // IWYU pragma: export +#include "common/types/dyn_type.h" // IWYU pragma: export +#include "common/types/enum_type.h" // IWYU pragma: export +#include "common/types/error_type.h" // IWYU pragma: export +#include "common/types/function_type.h" // IWYU pragma: export +#include "common/types/int_type.h" // IWYU pragma: export +#include "common/types/int_wrapper_type.h" // IWYU pragma: export +#include "common/types/list_type.h" // IWYU pragma: export +#include "common/types/map_type.h" // IWYU pragma: export +#include "common/types/message_type.h" // IWYU pragma: export +#include "common/types/null_type.h" // IWYU pragma: export +#include "common/types/opaque_type.h" // IWYU pragma: export +#include "common/types/optional_type.h" // IWYU pragma: export +#include "common/types/string_type.h" // IWYU pragma: export +#include "common/types/string_wrapper_type.h" // IWYU pragma: export +#include "common/types/struct_type.h" // IWYU pragma: export +#include "common/types/timestamp_type.h" // IWYU pragma: export +#include "common/types/type_param_type.h" // IWYU pragma: export +#include "common/types/type_type.h" // IWYU pragma: export +#include "common/types/types.h" +#include "common/types/uint_type.h" // IWYU pragma: export +#include "common/types/uint_wrapper_type.h" // IWYU pragma: export +#include "common/types/unknown_type.h" // IWYU pragma: export +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class Type; +class TypeParameters; -class BasicType : public internal::Handle { +// `Type` is a composition type which encompasses all types supported by the +// Common Expression Language. When default constructed, `Type` is in a +// known but invalid state. Any attempt to use it from then on, without +// assigning another type, is undefined behavior. In debug builds, we do our +// best to fail. +// +// The data underlying `Type` is either static or owned by `google::protobuf::Arena`. As +// such, care must be taken to ensure types remain valid throughout their use. +class Type final { public: - constexpr explicit BasicType(BasicTypeValue value) : Handle(value) {} + // Returns an appropriate `Type` for the dynamic protobuf message. For well + // known message types, the appropriate `Type` is returned. All others return + // `MessageType`. + static Type Message(const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Type` for the dynamic protobuf message field. + static Type Field(const google::protobuf::FieldDescriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Type` for the dynamic protobuf enum. For well + // known enum types, the appropriate `Type` is returned. All others return + // `EnumType`. + static Type Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + using Parameters = TypeParameters; + + // The default constructor results in Type being DynType. + Type() = default; + Type(const Type&) = default; + Type(Type&&) = default; + Type& operator=(const Type&) = default; + Type& operator=(Type&&) = default; + + template >>> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Type(T&& alternative) noexcept + : variant_(absl::in_place_type>, + std::forward(alternative)) {} + + template >>> + // NOLINTNEXTLINE(google-explicit-constructor) + Type& operator=(T&& type) noexcept { + variant_.emplace>(std::forward(type)); + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + Type(StructType alternative) : variant_(alternative.ToTypeVariant()) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Type& operator=(StructType alternative) { + variant_ = alternative.ToTypeVariant(); + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + Type(OptionalType alternative) : Type(OpaqueType(std::move(alternative))) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Type& operator=(OptionalType alternative) { + return *this = OpaqueType(std::move(alternative)); + } + + TypeKind kind() const; + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + // Returns a debug string for the type. Not suitable for user-facing error + // messages. + std::string DebugString() const; + + Parameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + template + friend H AbslHashValue(H state, const Type& type) { + return absl::visit( + [state = std::move(state)](const auto& alternative) mutable -> H { + return H::combine(std::move(state), alternative, alternative.kind()); + }, + type.variant_); + } - inline const absl::string_view full_name() const { return ToString(); } + friend bool operator==(const Type& lhs, const Type& rhs); + + friend std::ostream& operator<<(std::ostream& out, const Type& type) { + return absl::visit( + [&out](const auto& alternative) -> std::ostream& { + return out << alternative; + }, + type.variant_); + } + + bool IsAny() const { return absl::holds_alternative(variant_); } + + bool IsBool() const { return absl::holds_alternative(variant_); } + + bool IsBoolWrapper() const { + return absl::holds_alternative(variant_); + } + + bool IsBytes() const { return absl::holds_alternative(variant_); } + + bool IsBytesWrapper() const { + return absl::holds_alternative(variant_); + } + + bool IsDouble() const { + return absl::holds_alternative(variant_); + } + + bool IsDoubleWrapper() const { + return absl::holds_alternative(variant_); + } + + bool IsDuration() const { + return absl::holds_alternative(variant_); + } + + bool IsDyn() const { return absl::holds_alternative(variant_); } + + bool IsEnum() const { return absl::holds_alternative(variant_); } + + bool IsError() const { return absl::holds_alternative(variant_); } + + bool IsFunction() const { + return absl::holds_alternative(variant_); + } + + bool IsInt() const { return absl::holds_alternative(variant_); } + + bool IsIntWrapper() const { + return absl::holds_alternative(variant_); + } + + bool IsList() const { return absl::holds_alternative(variant_); } + + bool IsMap() const { return absl::holds_alternative(variant_); } + + bool IsMessage() const { + return absl::holds_alternative(variant_); + } - inline bool operator==(BasicTypeValue value) const { return value_ == value; } - inline bool operator!=(BasicTypeValue value) const { return value_ != value; } + bool IsNull() const { return absl::holds_alternative(variant_); } + + bool IsOpaque() const { + return absl::holds_alternative(variant_); + } + + bool IsOptional() const { return IsOpaque() && GetOpaque().IsOptional(); } + + bool IsString() const { + return absl::holds_alternative(variant_); + } + + bool IsStringWrapper() const { + return absl::holds_alternative(variant_); + } + + bool IsStruct() const { + return absl::holds_alternative( + variant_) || + absl::holds_alternative(variant_); + } + + bool IsTimestamp() const { + return absl::holds_alternative(variant_); + } + + bool IsTypeParam() const { + return absl::holds_alternative(variant_); + } + + bool IsType() const { return absl::holds_alternative(variant_); } + + bool IsUint() const { return absl::holds_alternative(variant_); } + + bool IsUintWrapper() const { + return absl::holds_alternative(variant_); + } + + bool IsUnknown() const { + return absl::holds_alternative(variant_); + } + + bool IsWrapper() const { + return IsBoolWrapper() || IsIntWrapper() || IsUintWrapper() || + IsDoubleWrapper() || IsBytesWrapper() || IsStringWrapper(); + } + + template + std::enable_if_t, bool> Is() const { + return IsAny(); + } + + template + std::enable_if_t, bool> Is() const { + return IsBool(); + } + + template + std::enable_if_t, bool> Is() const { + return IsBoolWrapper(); + } + + template + std::enable_if_t, bool> Is() const { + return IsBytes(); + } + + template + std::enable_if_t, bool> Is() const { + return IsBytesWrapper(); + } + + template + std::enable_if_t, bool> Is() const { + return IsDouble(); + } + + template + std::enable_if_t, bool> Is() const { + return IsDoubleWrapper(); + } + + template + std::enable_if_t, bool> Is() const { + return IsDuration(); + } + + template + std::enable_if_t, bool> Is() const { + return IsDyn(); + } + + template + std::enable_if_t, bool> Is() const { + return IsEnum(); + } + + template + std::enable_if_t, bool> Is() const { + return IsError(); + } + + template + std::enable_if_t, bool> Is() const { + return IsFunction(); + } + + template + std::enable_if_t, bool> Is() const { + return IsInt(); + } + + template + std::enable_if_t, bool> Is() const { + return IsIntWrapper(); + } + + template + std::enable_if_t, bool> Is() const { + return IsList(); + } + + template + std::enable_if_t, bool> Is() const { + return IsMap(); + } - /** - * Returns a canonical cel expression for the value. - */ - const std::string& ToString() const; + template + std::enable_if_t, bool> Is() const { + return IsMessage(); + } + + template + std::enable_if_t, bool> Is() const { + return IsNull(); + } + + template + std::enable_if_t, bool> Is() const { + return IsOpaque(); + } + + template + std::enable_if_t, bool> Is() const { + return IsOptional(); + } + + template + std::enable_if_t, bool> Is() const { + return IsString(); + } + + template + std::enable_if_t, bool> Is() const { + return IsStringWrapper(); + } + + template + std::enable_if_t, bool> Is() const { + return IsStruct(); + } + + template + std::enable_if_t, bool> Is() const { + return IsTimestamp(); + } + + template + std::enable_if_t, bool> Is() const { + return IsTypeParam(); + } + + template + std::enable_if_t, bool> Is() const { + return IsType(); + } + + template + std::enable_if_t, bool> Is() const { + return IsUint(); + } + + template + std::enable_if_t, bool> Is() const { + return IsUintWrapper(); + } + + template + std::enable_if_t, bool> Is() const { + return IsUnknown(); + } + + absl::optional AsAny() const; + + absl::optional AsBool() const; + + absl::optional AsBoolWrapper() const; + + absl::optional AsBytes() const; + + absl::optional AsBytesWrapper() const; + + absl::optional AsDouble() const; + + absl::optional AsDoubleWrapper() const; + + absl::optional AsDuration() const; + + absl::optional AsDyn() const; + + absl::optional AsEnum() const; + + absl::optional AsError() const; + + absl::optional AsFunction() const; + + absl::optional AsInt() const; + + absl::optional AsIntWrapper() const; + + absl::optional AsList() const; + + absl::optional AsMap() const; + + // AsMessage performs a checked cast, returning `MessageType` if this type is + // both a struct and a message or `absl::nullopt` otherwise. If you have + // already called `IsMessage()` it is more performant to perform to do + // `static_cast(type)`. + absl::optional AsMessage() const; + + absl::optional AsNull() const; + + absl::optional AsOpaque() const; + + absl::optional AsOptional() const; + + absl::optional AsString() const; + + absl::optional AsStringWrapper() const; + + // AsStruct performs a checked cast, returning `StructType` if this type is a + // struct or `absl::nullopt` otherwise. If you have already called + // `IsStruct()` it is more performant to perform to do + // `static_cast(type)`. + absl::optional AsStruct() const; + + absl::optional AsTimestamp() const; + + absl::optional AsTypeParam() const; + + absl::optional AsType() const; + + absl::optional AsUint() const; + + absl::optional AsUintWrapper() const; + + absl::optional AsUnknown() const; + + template + std::enable_if_t, absl::optional> As() + const { + return AsAny(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsBool(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsBoolWrapper(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsBytes(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsBytesWrapper(); + } + + template + std::enable_if_t, absl::optional> + As() const { + return AsDouble(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsDoubleWrapper(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsDuration(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsDyn(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsEnum(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsError(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsFunction(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsInt(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsIntWrapper(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsList(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsMap(); + } + + template + std::enable_if_t, absl::optional> + As() const { + return AsMessage(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsNull(); + } + + template + std::enable_if_t, absl::optional> + As() const { + return AsOpaque(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsOptional(); + } + + template + std::enable_if_t, absl::optional> + As() const { + return AsString(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsStringWrapper(); + } + + template + std::enable_if_t, absl::optional> + As() const { + return AsStruct(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsTimestamp(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsTypeParam(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsType(); + } + + template + std::enable_if_t, absl::optional> As() + const { + return AsUint(); + } + + template + std::enable_if_t, + absl::optional> + As() const { + return AsUintWrapper(); + } + + template + std::enable_if_t, absl::optional> + As() const { + return AsUnknown(); + } + + AnyType GetAny() const; + + BoolType GetBool() const; + + BoolWrapperType GetBoolWrapper() const; + + BytesType GetBytes() const; + + BytesWrapperType GetBytesWrapper() const; + + DoubleType GetDouble() const; + + DoubleWrapperType GetDoubleWrapper() const; + + DurationType GetDuration() const; + + DynType GetDyn() const; + + EnumType GetEnum() const; + + ErrorType GetError() const; + + FunctionType GetFunction() const; + + IntType GetInt() const; + + IntWrapperType GetIntWrapper() const; + + ListType GetList() const; + + MapType GetMap() const; + + MessageType GetMessage() const; + + NullType GetNull() const; + + OpaqueType GetOpaque() const; + + OptionalType GetOptional() const; + + StringType GetString() const; + + StringWrapperType GetStringWrapper() const; + + StructType GetStruct() const; + + TimestampType GetTimestamp() const; + + TypeParamType GetTypeParam() const; + + TypeType GetType() const; + + UintType GetUint() const; + + UintWrapperType GetUintWrapper() const; + + UnknownType GetUnknown() const; + + template + std::enable_if_t, AnyType> Get() const { + return GetAny(); + } + + template + std::enable_if_t, BoolType> Get() const { + return GetBool(); + } + + template + std::enable_if_t, BoolWrapperType> Get() + const { + return GetBoolWrapper(); + } + + template + std::enable_if_t, BytesType> Get() const { + return GetBytes(); + } + + template + std::enable_if_t, BytesWrapperType> Get() + const { + return GetBytesWrapper(); + } + + template + std::enable_if_t, DoubleType> Get() const { + return GetDouble(); + } + + template + std::enable_if_t, DoubleWrapperType> + Get() const { + return GetDoubleWrapper(); + } + + template + std::enable_if_t, DurationType> Get() const { + return GetDuration(); + } + + template + std::enable_if_t, DynType> Get() const { + return GetDyn(); + } + + template + std::enable_if_t, EnumType> Get() const { + return GetEnum(); + } + + template + std::enable_if_t, ErrorType> Get() const { + return GetError(); + } + + template + std::enable_if_t, FunctionType> Get() const { + return GetFunction(); + } + + template + std::enable_if_t, IntType> Get() const { + return GetInt(); + } + + template + std::enable_if_t, IntWrapperType> Get() + const { + return GetIntWrapper(); + } + + template + std::enable_if_t, ListType> Get() const { + return GetList(); + } + + template + std::enable_if_t, MapType> Get() const { + return GetMap(); + } + + template + std::enable_if_t, MessageType> Get() const { + return GetMessage(); + } + + template + std::enable_if_t, NullType> Get() const { + return GetNull(); + } + + template + std::enable_if_t, OpaqueType> Get() const { + return GetOpaque(); + } + + template + std::enable_if_t, OptionalType> Get() const { + return GetOptional(); + } + + template + std::enable_if_t, StringType> Get() const { + return GetString(); + } + + template + std::enable_if_t, StringWrapperType> + Get() const { + return GetStringWrapper(); + } + + template + std::enable_if_t, StructType> Get() const { + return GetStruct(); + } + + template + std::enable_if_t, TimestampType> Get() + const { + return GetTimestamp(); + } + + template + std::enable_if_t, TypeParamType> Get() + const { + return GetTypeParam(); + } + + template + std::enable_if_t, TypeType> Get() const { + return GetType(); + } + + template + std::enable_if_t, UintType> Get() const { + return GetUint(); + } + + template + std::enable_if_t, UintWrapperType> Get() + const { + return GetUintWrapper(); + } + + template + std::enable_if_t, UnknownType> Get() const { + return GetUnknown(); + } + + // Returns an unwrapped `Type` for a wrapped type, otherwise just returns + // this. + Type Unwrap() const; + + // Returns an wrapped `Type` for a primitive type, otherwise just returns + // this. + Type Wrap() const; + + private: + friend class StructType; + friend class MessageType; + friend class common_internal::BasicStructType; + + common_internal::StructTypeVariant ToStructTypeVariant() const; + + common_internal::TypeVariant variant_; }; -/** An object type. */ -class ObjectType - : public internal::Handle { +inline bool operator!=(const Type& lhs, const Type& rhs) { + return !operator==(lhs, rhs); +} + +inline Type JsonType() { return DynType(); } + +// Statically assert some expectations. +static_assert(std::is_default_constructible_v); +static_assert(std::is_copy_constructible_v); +static_assert(std::is_copy_assignable_v); +static_assert(std::is_nothrow_move_constructible_v); +static_assert(std::is_nothrow_move_assignable_v); + +// TypeParameters is a specialized view of a contiguous list of `Type`. It is +// very similar to `absl::Span`, except that it has a small amount +// of inline storage. Thus the pointers and references returned by +// TypeParameters are invalidated upon copying or moving. +// +// We store up to 2 types inline. This is done to accommodate list and map types +// which correspond to protocol buffer message fields. We launder around their +// descriptors and would have to allocate to return the type parameters. We want +// to avoid this, as types are supposed to be constant after creation. +class TypeParameters final { public: - constexpr explicit ObjectType(const google::protobuf::Descriptor* desc) - : Handle(desc) {} + using element_type = const Type; + using value_type = Type; + using pointer = element_type*; + using const_pointer = const element_type*; + using reference = element_type&; + using const_reference = const element_type&; + using iterator = pointer; + using const_iterator = const_pointer; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + using size_type = size_t; + using difference_type = ptrdiff_t; + + explicit TypeParameters(absl::Span types); + + TypeParameters() = default; + TypeParameters(const TypeParameters&) = default; + TypeParameters(TypeParameters&&) = default; + TypeParameters& operator=(const TypeParameters&) = default; + TypeParameters& operator=(TypeParameters&&) = default; + + size_type size() const { return size_; } - template - constexpr static ObjectType For() { - return ObjectType(T::descriptor()); + bool empty() const { return size() == 0; } + + const_reference front() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(!empty()); + return data()[0]; + } + + const_reference back() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(!empty()); + return data()[size() - 1]; + } + + const_reference operator[](size_type index) const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK_LT(index, size()); + return data()[index]; + } + + const_pointer data() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return size() <= 2 ? reinterpret_cast(&internal_[0]) + : external_; } - inline absl::string_view full_name() const { return value_->full_name(); } + const_iterator begin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return data(); } - /** - * Returns a canonical cel expression for the value. - */ - inline const std::string& ToString() const { return value_->full_name(); } + const_iterator cbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return begin(); + } + + const_iterator end() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return data() + size(); + } + + const_iterator cend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return end(); } - std::unique_ptr Unpack(const google::protobuf::Any& value); + const_reverse_iterator rbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::make_reverse_iterator(end()); + } + + const_reverse_iterator crbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return rbegin(); + } + + const_reverse_iterator rend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::make_reverse_iterator(begin()); + } + + const_reverse_iterator crend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return rend(); + } + + private: + friend class ListType; + friend class MapType; + + explicit TypeParameters(const Type& element); + + explicit TypeParameters(const Type& key, const Type& value); + + // When size_ <= 2, elements are stored directly in `internal_`. Otherwise we + // store a pointer to the elements in `external_`. + size_t size_ = 0; + union { + const Type* external_ = nullptr; + // Old versions of GCC do not like `Type internal_[2]`, so we cheat. + alignas(Type) char internal_[sizeof(Type) * 2]; + }; }; -/** An enum type. */ -class EnumType - : public internal::Handle { +// Now that TypeParameters is defined, we can define `GetParameters()` for most +// types. + +inline TypeParameters AnyType::GetParameters() { return {}; } + +inline TypeParameters BoolType::GetParameters() { return {}; } + +inline TypeParameters BoolWrapperType::GetParameters() { return {}; } + +inline TypeParameters BytesType::GetParameters() { return {}; } + +inline TypeParameters BytesWrapperType::GetParameters() { return {}; } + +inline TypeParameters DoubleType::GetParameters() { return {}; } + +inline TypeParameters DoubleWrapperType::GetParameters() { return {}; } + +inline TypeParameters DurationType::GetParameters() { return {}; } + +inline TypeParameters DynType::GetParameters() { return {}; } + +inline TypeParameters EnumType::GetParameters() { return {}; } + +inline TypeParameters ErrorType::GetParameters() { return {}; } + +inline TypeParameters IntType::GetParameters() { return {}; } + +inline TypeParameters IntWrapperType::GetParameters() { return {}; } + +inline TypeParameters MessageType::GetParameters() { return {}; } + +inline TypeParameters NullType::GetParameters() { return {}; } + +inline TypeParameters OptionalType::GetParameters() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return opaque_.GetParameters(); +} + +inline TypeParameters StringType::GetParameters() { return {}; } + +inline TypeParameters StringWrapperType::GetParameters() { return {}; } + +inline TypeParameters TimestampType::GetParameters() { return {}; } + +inline TypeParameters TypeParamType::GetParameters() { return {}; } + +inline TypeParameters UintType::GetParameters() { return {}; } + +inline TypeParameters UintWrapperType::GetParameters() { return {}; } + +inline TypeParameters UnknownType::GetParameters() { return {}; } + +namespace common_internal { + +inline TypeParameters BasicStructType::GetParameters() { return {}; } + +Type SingularMessageFieldType( + const google::protobuf::FieldDescriptor* absl_nonnull descriptor); + +class BasicStructTypeField final { public: - constexpr explicit EnumType(const google::protobuf::EnumDescriptor* desc) - : Handle(desc) {} + BasicStructTypeField(absl::string_view name, int32_t number, Type type) + : name_(name), number_(number), type_(type) {} + + BasicStructTypeField(const BasicStructTypeField&) = default; + BasicStructTypeField(BasicStructTypeField&&) = default; + BasicStructTypeField& operator=(const BasicStructTypeField&) = default; + BasicStructTypeField& operator=(BasicStructTypeField&&) = default; + + std::string DebugString() const; - inline absl::string_view full_name() const { return value_->full_name(); } + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } - /** - * Returns a canonical cel expression for the value. - */ - inline const std::string& ToString() const { return value_->full_name(); } + int32_t number() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return number_; } + + Type GetType() const { return type_; } + + explicit operator bool() const { return !name_.empty() || number_ >= 1; } + + private: + absl::string_view name_; + int32_t number_ = 0; + Type type_; }; -/** - * An unrecognized type. - */ -class UnrecognizedType final { +inline bool operator==(const BasicStructTypeField& lhs, + const BasicStructTypeField& rhs) { + return lhs.name() == rhs.name() && lhs.number() == rhs.number() && + lhs.GetType() == rhs.GetType(); +} + +inline bool operator!=(const BasicStructTypeField& lhs, + const BasicStructTypeField& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace common_internal + +class StructTypeField final { public: - explicit UnrecognizedType(absl::string_view full_name); + // NOLINTNEXTLINE(google-explicit-constructor) + StructTypeField(common_internal::BasicStructTypeField field) + : variant_(absl::in_place_type, + field) {} - absl::string_view full_name() const; - inline std::size_t hash_code() const { return hash_code_; } + // NOLINTNEXTLINE(google-explicit-constructor) + StructTypeField(MessageTypeField field) + : variant_(absl::in_place_type, field) {} - inline bool operator==(const UnrecognizedType& rhs) const { - return hash_code_ == rhs.hash_code_ && full_name() == rhs.full_name(); - } + StructTypeField() = delete; + StructTypeField(const StructTypeField&) = default; + StructTypeField(StructTypeField&&) = default; + StructTypeField& operator=(const StructTypeField&) = default; + StructTypeField& operator=(StructTypeField&&) = default; + + std::string DebugString() const; + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + int32_t number() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Type GetType() const; - inline bool operator!=(const UnrecognizedType& rhs) const { - return !(*this == rhs); + explicit operator bool() const; + + bool IsMessage() const { + return absl::holds_alternative(variant_); } - /** - * Returns a canonical cel expression for the value. - */ - inline const std::string& ToString() const { return string_rep_; } + absl::optional AsMessage() const; + + explicit operator MessageTypeField() const; private: - std::string string_rep_; - std::size_t hash_code_; + absl::variant + variant_; }; -/** A type value. */ -class Type final { - public: - // Allow for implicit conversion, so visitors can implement overloads - // for either Type or more specific instances. - constexpr Type(BasicType basic_type) : data_(basic_type) {} - constexpr Type(EnumType enum_type) : data_(enum_type) {} - constexpr Type(ObjectType object_type) : data_(object_type) {} - explicit Type(const std::string& full_name); - explicit Type(absl::string_view full_name) : Type(std::string(full_name)) {} - explicit Type(const char* full_name) : Type(std::string(full_name)) {} +inline bool operator==(const StructTypeField& lhs, const StructTypeField& rhs) { + return lhs.name() == rhs.name() && lhs.number() == rhs.number() && + lhs.GetType() == rhs.GetType(); +} - Type(const UnrecognizedType& unrecognized_type) : data_(unrecognized_type) {} - Type(UnrecognizedType&& unrecognized_type) - : data_(std::move(unrecognized_type)) {} +inline bool operator!=(const StructTypeField& lhs, const StructTypeField& rhs) { + return !operator==(lhs, rhs); +} - absl::string_view full_name() const; +// Now that Type is defined, we can define everything else. - bool is_basic() const { return absl::holds_alternative(data_); } - bool is_enum() const { return absl::holds_alternative(data_); } - bool is_object() const { return absl::holds_alternative(data_); } - bool is_unrecognized() const { - return absl::holds_alternative(data_); - } +namespace common_internal { - inline BasicType basic_type() const { return absl::get(data_); } - inline EnumType enum_type() const { return absl::get(data_); } - inline ObjectType object_type() const { return absl::get(data_); } +struct ListTypeData final { + static ListTypeData* absl_nonnull Create(google::protobuf::Arena* absl_nonnull arena, + const Type& element); - inline bool operator==(const Type& rhs) const { return data_ == rhs.data_; } - inline bool operator!=(const Type& rhs) const { return data_ != rhs.data_; } + ListTypeData() = default; + ListTypeData(const ListTypeData&) = delete; + ListTypeData(ListTypeData&&) = delete; + ListTypeData& operator=(const ListTypeData&) = delete; + ListTypeData& operator=(ListTypeData&&) = delete; - /** The hash code for this value. */ - inline std::size_t hash_code() const { return internal::Hash(data_); } + Type element = DynType(); + + private: + explicit ListTypeData(const Type& element); +}; + +struct MapTypeData final { + static MapTypeData* absl_nonnull Create(google::protobuf::Arena* absl_nonnull arena, + const Type& key, const Type& value); + + Type key_and_value[2]; +}; - /** - * Returns a canonical cel expression for the value. - */ - const std::string& ToString() const; +struct FunctionTypeData final { + static FunctionTypeData* absl_nonnull Create( + google::protobuf::Arena* absl_nonnull arena, const Type& result, + absl::Span args); + + FunctionTypeData() = delete; + FunctionTypeData(const FunctionTypeData&) = delete; + FunctionTypeData(FunctionTypeData&&) = delete; + FunctionTypeData& operator=(const FunctionTypeData&) = delete; + FunctionTypeData& operator=(FunctionTypeData&&) = delete; + + const size_t args_size; + // Flexible array, has `args_size` elements, with the first element being the + // return type. FunctionTypeData has a variable length size, which includes + // this flexible array. + Type args[]; private: - absl::variant data_; + FunctionTypeData(const Type& result, absl::Span args); }; -inline std::ostream& operator<<(std::ostream& os, const Type& value) { - return os << value.ToString(); +struct OpaqueTypeData final { + static OpaqueTypeData* absl_nonnull Create(google::protobuf::Arena* absl_nonnull arena, + absl::string_view name, + absl::Span parameters); + + OpaqueTypeData() = delete; + OpaqueTypeData(const OpaqueTypeData&) = delete; + OpaqueTypeData(OpaqueTypeData&&) = delete; + OpaqueTypeData& operator=(const OpaqueTypeData&) = delete; + OpaqueTypeData& operator=(OpaqueTypeData&&) = delete; + + const absl::string_view name; + const size_t parameters_size; + // Flexible array, has `parameters_size` elements. OpaqueTypeData has a + // variable length size, which includes this flexible array. + Type parameters[]; + + private: + OpaqueTypeData(absl::string_view name, absl::Span parameters); +}; + +} // namespace common_internal + +inline bool operator==(const MessageTypeField& lhs, + const MessageTypeField& rhs) { + return lhs.name() == rhs.name() && lhs.number() == rhs.number() && + lhs.GetType() == rhs.GetType(); +} + +inline bool operator!=(const MessageTypeField& lhs, + const MessageTypeField& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator==(const ListType& lhs, const ListType& rhs) { + return &lhs == &rhs || lhs.GetElement() == rhs.GetElement(); } -} // namespace common -} // namespace expr -} // namespace api -} // namespace google +template +inline H AbslHashValue(H state, const ListType& type) { + return H::combine(std::move(state), type.GetElement(), size_t{1}); +} + +inline bool operator==(const MapType& lhs, const MapType& rhs) { + return &lhs == &rhs || + (lhs.GetKey() == rhs.GetKey() && lhs.GetValue() == rhs.GetValue()); +} -namespace std { +template +inline H AbslHashValue(H state, const MapType& type) { + return H::combine(std::move(state), type.GetKey(), type.GetValue(), + size_t{2}); +} -template <> -struct hash - : google::api::expr::internal::Hasher {}; +inline bool operator==(const OpaqueType& lhs, const OpaqueType& rhs) { + return lhs.name() == rhs.name() && + absl::c_equal(lhs.GetParameters(), rhs.GetParameters()); +} + +template +inline H AbslHashValue(H state, const OpaqueType& type) { + state = H::combine(std::move(state), type.name()); + auto parameters = type.GetParameters(); + for (const auto& parameter : parameters) { + state = H::combine(std::move(state), parameter); + } + return H::combine(std::move(state), parameters.size()); +} + +inline bool operator==(const FunctionType& lhs, const FunctionType& rhs) { + return lhs.result() == rhs.result() && absl::c_equal(lhs.args(), rhs.args()); +} + +template +inline H AbslHashValue(H state, const FunctionType& type) { + state = H::combine(std::move(state), type.result()); + auto args = type.args(); + for (const auto& arg : args) { + state = H::combine(std::move(state), arg); + } + return H::combine(std::move(state), args.size()); +} -template <> -struct hash - : google::api::expr::internal::Hasher {}; +namespace common_internal { -template <> -struct hash - : google::api::expr::common::EnumType::Hasher {}; +// Converts the string returned from `CelValue::CelTypeHolder` to `cel::Type`. +// The underlying content of `name` must outlive the resulting type and any of +// its shallow copies. +Type LegacyRuntimeType(absl::string_view name); -template <> -struct hash - : google::api::expr::common::ObjectType::Hasher {}; +} // namespace common_internal -} // namespace std +} // namespace cel -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_H_ +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_H_ diff --git a/common/type_introspector.cc b/common/type_introspector.cc new file mode 100644 index 000000000..3846ab58b --- /dev/null +++ b/common/type_introspector.cc @@ -0,0 +1,277 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_introspector.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/type.h" + +namespace cel { + +namespace { + +common_internal::BasicStructTypeField MakeBasicStructTypeField( + absl::string_view name, Type type, int32_t number) { + return common_internal::BasicStructTypeField(name, number, type); +} + +struct FieldNameComparer { + using is_transparent = void; + + bool operator()(const common_internal::BasicStructTypeField& lhs, + const common_internal::BasicStructTypeField& rhs) const { + return (*this)(lhs.name(), rhs.name()); + } + + bool operator()(const common_internal::BasicStructTypeField& lhs, + absl::string_view rhs) const { + return (*this)(lhs.name(), rhs); + } + + bool operator()(absl::string_view lhs, + const common_internal::BasicStructTypeField& rhs) const { + return (*this)(lhs, rhs.name()); + } + + bool operator()(absl::string_view lhs, absl::string_view rhs) const { + return lhs < rhs; + } +}; + +struct FieldNumberComparer { + using is_transparent = void; + + bool operator()(const common_internal::BasicStructTypeField& lhs, + const common_internal::BasicStructTypeField& rhs) const { + return (*this)(lhs.number(), rhs.number()); + } + + bool operator()(const common_internal::BasicStructTypeField& lhs, + int64_t rhs) const { + return (*this)(lhs.number(), rhs); + } + + bool operator()(int64_t lhs, + const common_internal::BasicStructTypeField& rhs) const { + return (*this)(lhs, rhs.number()); + } + + bool operator()(int64_t lhs, int64_t rhs) const { return lhs < rhs; } +}; + +struct WellKnownType { + WellKnownType( + const Type& type, + std::initializer_list fields) + : type(type), fields_by_name(fields), fields_by_number(fields) { + std::sort(fields_by_name.begin(), fields_by_name.end(), + FieldNameComparer{}); + std::sort(fields_by_number.begin(), fields_by_number.end(), + FieldNumberComparer{}); + } + + explicit WellKnownType(const Type& type) : WellKnownType(type, {}) {} + + Type type; + // We use `2` as that accommodates most well known types. + absl::InlinedVector fields_by_name; + absl::InlinedVector + fields_by_number; + + absl::optional FieldByName(absl::string_view name) const { + // Basically `std::binary_search`. + auto it = std::lower_bound(fields_by_name.begin(), fields_by_name.end(), + name, FieldNameComparer{}); + if (it == fields_by_name.end() || it->name() != name) { + return std::nullopt; + } + return *it; + } + + absl::optional FieldByNumber(int64_t number) const { + // Basically `std::binary_search`. + auto it = std::lower_bound(fields_by_number.begin(), fields_by_number.end(), + number, FieldNumberComparer{}); + if (it == fields_by_number.end() || it->number() != number) { + return std::nullopt; + } + return *it; + } +}; + +using WellKnownTypesMap = absl::flat_hash_map; + +const WellKnownTypesMap& GetWellKnownTypesMap() { + static const WellKnownTypesMap* types = []() -> WellKnownTypesMap* { + WellKnownTypesMap* types = new WellKnownTypesMap(); + types->insert_or_assign( + "google.protobuf.BoolValue", + WellKnownType{BoolWrapperType{}, + {MakeBasicStructTypeField("value", BoolType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.Int32Value", + WellKnownType{IntWrapperType{}, + {MakeBasicStructTypeField("value", IntType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.Int64Value", + WellKnownType{IntWrapperType{}, + {MakeBasicStructTypeField("value", IntType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.UInt32Value", + WellKnownType{UintWrapperType{}, + {MakeBasicStructTypeField("value", UintType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.UInt64Value", + WellKnownType{UintWrapperType{}, + {MakeBasicStructTypeField("value", UintType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.FloatValue", + WellKnownType{DoubleWrapperType{}, + {MakeBasicStructTypeField("value", DoubleType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.DoubleValue", + WellKnownType{DoubleWrapperType{}, + {MakeBasicStructTypeField("value", DoubleType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.StringValue", + WellKnownType{StringWrapperType{}, + {MakeBasicStructTypeField("value", StringType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.BytesValue", + WellKnownType{BytesWrapperType{}, + {MakeBasicStructTypeField("value", BytesType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.Duration", + WellKnownType{DurationType{}, + {MakeBasicStructTypeField("seconds", IntType{}, 1), + MakeBasicStructTypeField("nanos", IntType{}, 2)}}); + types->insert_or_assign( + "google.protobuf.Timestamp", + WellKnownType{TimestampType{}, + {MakeBasicStructTypeField("seconds", IntType{}, 1), + MakeBasicStructTypeField("nanos", IntType{}, 2)}}); + types->insert_or_assign( + "google.protobuf.Value", + WellKnownType{ + DynType{}, + {// NullValue enum is an int. Not normally referenced directly. + MakeBasicStructTypeField("null_value", IntType{}, 1), + MakeBasicStructTypeField("number_value", DoubleType{}, 2), + MakeBasicStructTypeField("string_value", StringType{}, 3), + MakeBasicStructTypeField("bool_value", BoolType{}, 4), + MakeBasicStructTypeField("struct_value", JsonMapType(), 5), + MakeBasicStructTypeField("list_value", ListType{}, 6)}}); + types->insert_or_assign( + "google.protobuf.ListValue", + WellKnownType{ListType{}, + {MakeBasicStructTypeField("values", ListType{}, 1)}}); + types->insert_or_assign( + "google.protobuf.Struct", + WellKnownType{JsonMapType(), + {MakeBasicStructTypeField("fields", JsonMapType(), 1)}}); + types->insert_or_assign( + "google.protobuf.Any", + WellKnownType{AnyType{}, + {MakeBasicStructTypeField("type_url", StringType{}, 1), + MakeBasicStructTypeField("value", BytesType{}, 2)}}); + types->insert_or_assign("null_type", WellKnownType{NullType{}}); + types->insert_or_assign("google.protobuf.NullValue", + WellKnownType{NullType{}}); + types->insert_or_assign("bool", WellKnownType{BoolType{}}); + types->insert_or_assign("int", WellKnownType{IntType{}}); + types->insert_or_assign("uint", WellKnownType{UintType{}}); + types->insert_or_assign("double", WellKnownType{DoubleType{}}); + types->insert_or_assign("bytes", WellKnownType{BytesType{}}); + types->insert_or_assign("string", WellKnownType{StringType{}}); + types->insert_or_assign("list", WellKnownType{ListType{}}); + types->insert_or_assign("map", WellKnownType{MapType{}}); + types->insert_or_assign("type", WellKnownType{TypeType{}}); + return types; + }(); + return *types; +} + +} // namespace + +absl::StatusOr> TypeIntrospector::FindTypeImpl( + absl::string_view) const { + return std::nullopt; +} + +absl::StatusOr> +TypeIntrospector::FindEnumConstantImpl(absl::string_view, + absl::string_view) const { + return std::nullopt; +} + +absl::StatusOr> +TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, + absl::string_view) const { + return std::nullopt; +} + +absl::StatusOr< + absl::optional>> +TypeIntrospector::ListFieldsForStructTypeImpl(absl::string_view) const { + return std::nullopt; +} + +absl::optional FindWellKnownType(absl::string_view name) { + const auto& well_known_types = GetWellKnownTypesMap(); + if (auto it = well_known_types.find(name); it != well_known_types.end()) { + return it->second.type; + } + return std::nullopt; +} + +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value) { + if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { + return TypeIntrospector::EnumConstant{ + IntType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; + } + return std::nullopt; +} + +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name) { + const auto& well_known_types = GetWellKnownTypesMap(); + if (auto it = well_known_types.find(type); it != well_known_types.end()) { + return it->second.FieldByName(name); + } + return std::nullopt; +} + +absl::optional> +ListFieldsForWellKnownType(absl::string_view type) { + const auto& well_known_types = GetWellKnownTypesMap(); + auto it = well_known_types.find(type); + if (it == well_known_types.end()) { + return std::nullopt; + } + // The fields are not normally gettable. + return {}; +} + +} // namespace cel diff --git a/common/type_introspector.h b/common/type_introspector.h new file mode 100644 index 000000000..932fb108e --- /dev/null +++ b/common/type_introspector.h @@ -0,0 +1,157 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/type.h" + +namespace cel { + +// `TypeIntrospector` is an interface which allows querying type-related +// information. It handles type introspection, but not type reflection. That is, +// it is not capable of instantiating new values or understanding values. Its +// primary usage is for type checking, and a subset of that shared functionality +// is used by the runtime. +class TypeIntrospector { + public: + struct EnumConstant { + // The type of the enum. For JSON null, this may be a specific type rather + // than an enum type. + Type type; + absl::string_view type_full_name; + absl::string_view value_name; + int32_t number; + }; + + struct StructTypeFieldListing { + // The name used to access the field in source CEL. + // This is assumed owned by the TypeIntrospector or a dependency that + // outlives it. + absl::string_view name; + // The field description. + StructTypeField field; + }; + + virtual ~TypeIntrospector() = default; + + // `FindType` find the type corresponding to name `name`. + absl::StatusOr> FindType(absl::string_view name) const { + return FindTypeImpl(name); + } + + // `FindEnumConstant` find a fully qualified enumerator name `name` in enum + // type `type`. + absl::StatusOr> FindEnumConstant( + absl::string_view type, absl::string_view value) const { + return FindEnumConstantImpl(type, value); + } + + // `FindStructTypeFieldByName` find the name, number, and type of the field + // `name` in type `type`. + absl::StatusOr> FindStructTypeFieldByName( + absl::string_view type, absl::string_view name) const { + return FindStructTypeFieldByNameImpl(type, name); + } + + // `ListFieldsForStructType` returns the fields of struct type `type`. + // + // This is used when the struct is declared as a context type. + // + // If the type is not found, returns `absl::nullopt`. + // If the type exists but is not a struct or has no fields, returns an empty + // vector. + absl::StatusOr>> + ListFieldsForStructType(absl::string_view type) const { + return ListFieldsForStructTypeImpl(type); + } + + // `FindStructTypeFieldByName` find the name, number, and type of the field + // `name` in struct type `type`. + absl::StatusOr> FindStructTypeFieldByName( + const StructType& type, absl::string_view name) const { + return FindStructTypeFieldByName(type.name(), name); + } + + protected: + virtual absl::StatusOr> FindTypeImpl( + absl::string_view name) const; + + virtual absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const; + + virtual absl::StatusOr> + FindStructTypeFieldByNameImpl(absl::string_view type, + absl::string_view name) const; + + virtual absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const; +}; + +// Looks up a well-known type by name. +absl::optional FindWellKnownType(absl::string_view name); + +// Looks up a well-known enum constant by type and value. +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value); + +// Looks up a well-known struct type field by type and field name. +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name); + +absl::optional> +ListFieldsForWellKnownType(absl::string_view type); + +// `WellKnownTypeIntrospector` is an implementation of `TypeIntrospector` which +// handles well known types that are treated specially by CEL. +// +// This also serves as a minimal implementation of a TypeInstrospector when no +// custom types are present. +// +// This class has no mutable state, so trivially thread-safe. +class WellKnownTypeIntrospector : public virtual TypeIntrospector { + public: + WellKnownTypeIntrospector() = default; + + private: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final { + return FindWellKnownType(name); + } + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final { + return FindWellKnownTypeEnumConstant(type, value); + } + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final { + return FindWellKnownTypeFieldByName(type, name); + } + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final { + return ListFieldsForWellKnownType(type); + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ diff --git a/common/type_kind.h b/common/type_kind.h new file mode 100644 index 000000000..34df8e385 --- /dev/null +++ b/common/type_kind.h @@ -0,0 +1,113 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_KIND_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_KIND_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "common/kind.h" + +namespace cel { + +// `TypeKind` is a subset of `Kind`, representing all valid `Kind` for `Type`. +// All `TypeKind` are valid `Kind`, but it is not guaranteed that all `Kind` are +// valid `TypeKind`. +enum class TypeKind : std::underlying_type_t { + kNull = static_cast(Kind::kNull), + kBool = static_cast(Kind::kBool), + kInt = static_cast(Kind::kInt), + kUint = static_cast(Kind::kUint), + kDouble = static_cast(Kind::kDouble), + kString = static_cast(Kind::kString), + kBytes = static_cast(Kind::kBytes), + kStruct = static_cast(Kind::kStruct), + kDuration = static_cast(Kind::kDuration), + kTimestamp = static_cast(Kind::kTimestamp), + kList = static_cast(Kind::kList), + kMap = static_cast(Kind::kMap), + kUnknown = static_cast(Kind::kUnknown), + kType = static_cast(Kind::kType), + kError = static_cast(Kind::kError), + kAny = static_cast(Kind::kAny), + kDyn = static_cast(Kind::kDyn), + kOpaque = static_cast(Kind::kOpaque), + + kBoolWrapper = static_cast(Kind::kBoolWrapper), + kIntWrapper = static_cast(Kind::kIntWrapper), + kUintWrapper = static_cast(Kind::kUintWrapper), + kDoubleWrapper = static_cast(Kind::kDoubleWrapper), + kStringWrapper = static_cast(Kind::kStringWrapper), + kBytesWrapper = static_cast(Kind::kBytesWrapper), + + kTypeParam = static_cast(Kind::kTypeParam), + kFunction = static_cast(Kind::kFunction), + kEnum = static_cast(Kind::kEnum), + + // Legacy aliases, deprecated do not use. + kNullType = kNull, + kInt64 = kInt, + kUint64 = kUint, + kMessage = kStruct, + kUnknownSet = kUnknown, + kCelType = kType, + + // INTERNAL: Do not exceed 63. Implementation details rely on the fact that + // we can store `Kind` using 6 bits. + kNotForUseWithExhaustiveSwitchStatements = + static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), +}; + +constexpr Kind TypeKindToKind(TypeKind kind) { + return static_cast(static_cast>(kind)); +} + +constexpr bool KindIsTypeKind(Kind kind ABSL_ATTRIBUTE_UNUSED) { + // Currently all Kind are valid TypeKind. + return true; +} + +constexpr bool operator==(Kind lhs, TypeKind rhs) { + return lhs == TypeKindToKind(rhs); +} + +constexpr bool operator==(TypeKind lhs, Kind rhs) { + return TypeKindToKind(lhs) == rhs; +} + +constexpr bool operator!=(Kind lhs, TypeKind rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(TypeKind lhs, Kind rhs) { + return !operator==(lhs, rhs); +} + +inline absl::string_view TypeKindToString(TypeKind kind) { + // All TypeKind are valid Kind. + return KindToString(TypeKindToKind(kind)); +} + +constexpr TypeKind KindToTypeKind(Kind kind) { + ABSL_ASSERT(KindIsTypeKind(kind)); + return static_cast(static_cast>(kind)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_KIND_H_ diff --git a/common/type_proto.cc b/common/type_proto.cc new file mode 100644 index 000000000..b6b66f73a --- /dev/null +++ b/common/type_proto.cc @@ -0,0 +1,333 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto.h" + +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +namespace { + +using ::google::protobuf::NullValue; + +using TypePb = cel::expr::Type; + +// filter well-known types from message types. +absl::optional MaybeWellKnownType(absl::string_view type_name) { + static const absl::flat_hash_map* kWellKnownTypes = + []() { + auto* instance = new absl::flat_hash_map{ + // keep-sorted start + {"google.protobuf.Any", AnyType()}, + {"google.protobuf.BoolValue", BoolWrapperType()}, + {"google.protobuf.BytesValue", BytesWrapperType()}, + {"google.protobuf.DoubleValue", DoubleWrapperType()}, + {"google.protobuf.Duration", DurationType()}, + {"google.protobuf.FloatValue", DoubleWrapperType()}, + {"google.protobuf.Int32Value", IntWrapperType()}, + {"google.protobuf.Int64Value", IntWrapperType()}, + {"google.protobuf.ListValue", ListType()}, + {"google.protobuf.StringValue", StringWrapperType()}, + {"google.protobuf.Struct", JsonMapType()}, + {"google.protobuf.Timestamp", TimestampType()}, + {"google.protobuf.UInt32Value", UintWrapperType()}, + {"google.protobuf.UInt64Value", UintWrapperType()}, + {"google.protobuf.Value", DynType()}, + // keep-sorted end + }; + return instance; + }(); + + if (auto it = kWellKnownTypes->find(type_name); + it != kWellKnownTypes->end()) { + return it->second; + } + + return std::nullopt; +} + +absl::Status TypeToProtoInternal(const cel::Type& type, + TypePb* absl_nonnull type_pb); + +absl::Status ToProtoAbstractType(const cel::OpaqueType& type, + TypePb* absl_nonnull type_pb) { + auto* abstract_type = type_pb->mutable_abstract_type(); + abstract_type->set_name(type.name()); + abstract_type->mutable_parameter_types()->Reserve( + type.GetParameters().size()); + + for (const auto& param : type.GetParameters()) { + CEL_RETURN_IF_ERROR( + TypeToProtoInternal(param, abstract_type->add_parameter_types())); + } + + return absl::OkStatus(); +} + +absl::Status ToProtoMapType(const cel::MapType& type, + TypePb* absl_nonnull type_pb) { + auto* map_type = type_pb->mutable_map_type(); + CEL_RETURN_IF_ERROR( + TypeToProtoInternal(type.key(), map_type->mutable_key_type())); + CEL_RETURN_IF_ERROR( + TypeToProtoInternal(type.value(), map_type->mutable_value_type())); + + return absl::OkStatus(); +} + +absl::Status ToProtoListType(const cel::ListType& type, + TypePb* absl_nonnull type_pb) { + auto* list_type = type_pb->mutable_list_type(); + CEL_RETURN_IF_ERROR( + TypeToProtoInternal(type.element(), list_type->mutable_elem_type())); + + return absl::OkStatus(); +} + +absl::Status ToProtoTypeType(const cel::TypeType& type, + TypePb* absl_nonnull type_pb) { + if (type.GetParameters().size() > 1) { + return absl::InternalError( + absl::StrCat("unsupported type: ", type.DebugString())); + } + auto* type_type = type_pb->mutable_type(); + if (type.GetParameters().empty()) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(TypeToProtoInternal(type.GetParameters()[0], type_type)); + return absl::OkStatus(); +} + +absl::Status TypeToProtoInternal(const cel::Type& type, + TypePb* absl_nonnull type_pb) { + switch (type.kind()) { + case TypeKind::kDyn: + type_pb->mutable_dyn(); + return absl::OkStatus(); + case TypeKind::kError: + type_pb->mutable_error(); + return absl::OkStatus(); + case TypeKind::kNull: + type_pb->set_null(NullValue::NULL_VALUE); + return absl::OkStatus(); + case TypeKind::kBool: + type_pb->set_primitive(TypePb::BOOL); + return absl::OkStatus(); + case TypeKind::kInt: + type_pb->set_primitive(TypePb::INT64); + return absl::OkStatus(); + case TypeKind::kUint: + type_pb->set_primitive(TypePb::UINT64); + return absl::OkStatus(); + case TypeKind::kDouble: + type_pb->set_primitive(TypePb::DOUBLE); + return absl::OkStatus(); + case TypeKind::kString: + type_pb->set_primitive(TypePb::STRING); + return absl::OkStatus(); + case TypeKind::kBytes: + type_pb->set_primitive(TypePb::BYTES); + return absl::OkStatus(); + case TypeKind::kEnum: + type_pb->set_primitive(TypePb::INT64); + return absl::OkStatus(); + case TypeKind::kDuration: + type_pb->set_well_known(TypePb::DURATION); + return absl::OkStatus(); + case TypeKind::kTimestamp: + type_pb->set_well_known(TypePb::TIMESTAMP); + return absl::OkStatus(); + case TypeKind::kStruct: + type_pb->set_message_type(type.GetStruct().name()); + return absl::OkStatus(); + case TypeKind::kList: + return ToProtoListType(type.GetList(), type_pb); + case TypeKind::kMap: + return ToProtoMapType(type.GetMap(), type_pb); + case TypeKind::kOpaque: + return ToProtoAbstractType(type.GetOpaque(), type_pb); + case TypeKind::kBoolWrapper: + type_pb->set_wrapper(TypePb::BOOL); + return absl::OkStatus(); + case TypeKind::kIntWrapper: + type_pb->set_wrapper(TypePb::INT64); + return absl::OkStatus(); + case TypeKind::kUintWrapper: + type_pb->set_wrapper(TypePb::UINT64); + return absl::OkStatus(); + case TypeKind::kDoubleWrapper: + type_pb->set_wrapper(TypePb::DOUBLE); + return absl::OkStatus(); + case TypeKind::kStringWrapper: + type_pb->set_wrapper(TypePb::STRING); + return absl::OkStatus(); + case TypeKind::kBytesWrapper: + type_pb->set_wrapper(TypePb::BYTES); + return absl::OkStatus(); + case TypeKind::kTypeParam: + type_pb->set_type_param(type.GetTypeParam().name()); + return absl::OkStatus(); + case TypeKind::kType: + return ToProtoTypeType(type.GetType(), type_pb); + case TypeKind::kAny: + type_pb->set_well_known(TypePb::ANY); + return absl::OkStatus(); + default: + return absl::InternalError( + absl::StrCat("unsupported type: ", type.DebugString())); + } +} + +} // namespace + +absl::StatusOr TypeFromProto( + const cel::expr::Type& type_pb, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena) { + switch (type_pb.type_kind_case()) { + case TypePb::kAbstractType: { + auto* name = google::protobuf::Arena::Create( + arena, type_pb.abstract_type().name()); + std::vector params; + params.resize(type_pb.abstract_type().parameter_types_size()); + size_t i = 0; + for (const auto& p : type_pb.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(params[i], + TypeFromProto(p, descriptor_pool, arena)); + i++; + } + return OpaqueType(arena, *name, params); + } + case TypePb::kDyn: + return DynType(); + case TypePb::kError: + return ErrorType(); + case TypePb::kListType: { + CEL_ASSIGN_OR_RETURN(Type element, + TypeFromProto(type_pb.list_type().elem_type(), + descriptor_pool, arena)); + return ListType(arena, element); + } + case TypePb::kMapType: { + CEL_ASSIGN_OR_RETURN( + Type key, + TypeFromProto(type_pb.map_type().key_type(), descriptor_pool, arena)); + CEL_ASSIGN_OR_RETURN(Type value, + TypeFromProto(type_pb.map_type().value_type(), + descriptor_pool, arena)); + return MapType(arena, key, value); + } + case TypePb::kMessageType: { + if (auto well_known = MaybeWellKnownType(type_pb.message_type()); + well_known.has_value()) { + return *well_known; + } + + const auto* descriptor = + descriptor_pool->FindMessageTypeByName(type_pb.message_type()); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("unknown message type: ", type_pb.message_type())); + } + return MessageType(descriptor); + } + case TypePb::kNull: + return NullType(); + case TypePb::kPrimitive: + switch (type_pb.primitive()) { + case TypePb::BOOL: + return BoolType(); + case TypePb::BYTES: + return BytesType(); + case TypePb::DOUBLE: + return DoubleType(); + case TypePb::INT64: + return IntType(); + case TypePb::STRING: + return StringType(); + case TypePb::UINT64: + return UintType(); + default: + return absl::InvalidArgumentError("unknown primitive kind"); + } + case TypePb::kType: { + CEL_ASSIGN_OR_RETURN( + Type nested, TypeFromProto(type_pb.type(), descriptor_pool, arena)); + return TypeType(arena, nested); + } + case TypePb::kTypeParam: { + auto* name = + google::protobuf::Arena::Create(arena, type_pb.type_param()); + return TypeParamType(*name); + } + case TypePb::kWellKnown: + switch (type_pb.well_known()) { + case TypePb::ANY: + return AnyType(); + case TypePb::DURATION: + return DurationType(); + case TypePb::TIMESTAMP: + return TimestampType(); + default: + break; + } + return absl::InvalidArgumentError("unknown well known type."); + case TypePb::kWrapper: { + switch (type_pb.wrapper()) { + case TypePb::BOOL: + return BoolWrapperType(); + case TypePb::BYTES: + return BytesWrapperType(); + case TypePb::DOUBLE: + return DoubleWrapperType(); + case TypePb::INT64: + return IntWrapperType(); + case TypePb::STRING: + return StringWrapperType(); + case TypePb::UINT64: + return UintWrapperType(); + default: + return absl::InvalidArgumentError("unknown primitive wrapper kind"); + } + } + // Function types are not supported in the C++ type checker. + case TypePb::kFunction: + default: + return absl::InvalidArgumentError( + absl::StrCat("unsupported type kind: ", type_pb.type_kind_case())); + } +} + +absl::Status TypeToProto(const Type& type, TypePb* absl_nonnull type_pb) { + return TypeToProtoInternal(type, type_pb); +} + +} // namespace cel diff --git a/common/type_proto.h b/common/type_proto.h new file mode 100644 index 000000000..4336c1da2 --- /dev/null +++ b/common/type_proto.h @@ -0,0 +1,39 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ + +#include "cel/expr/checked.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Creates a Type from a google.api.expr.Type proto. +absl::StatusOr TypeFromProto( + const cel::expr::Type& type_pb, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::Arena* absl_nonnull arena); + +absl::Status TypeToProto(const Type& type, + cel::expr::Type* absl_nonnull type_pb); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ diff --git a/common/type_proto_test.cc b/common/type_proto_test.cc new file mode 100644 index 000000000..5cb81824e --- /dev/null +++ b/common/type_proto_test.cc @@ -0,0 +1,267 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_proto.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::test::EqualsProto; + +enum class RoundTrip { + kYes, + kNo, +}; + +struct TestCase { + std::string type_pb; + absl::StatusOr type_kind; + RoundTrip round_trip = RoundTrip::kYes; +}; + +class TypeFromProtoTest : public ::testing::TestWithParam {}; + +TEST_P(TypeFromProtoTest, FromProtoWorks) { + const google::protobuf::DescriptorPool* descriptor_pool = + internal::GetTestingDescriptorPool(); + google::protobuf::Arena arena; + + const TestCase& test_case = GetParam(); + cel::expr::Type type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(test_case.type_pb, &type_pb)); + absl::StatusOr result = TypeFromProto(type_pb, descriptor_pool, &arena); + + if (test_case.type_kind.ok()) { + ASSERT_OK_AND_ASSIGN(Type type, result); + + EXPECT_EQ(type.kind(), *test_case.type_kind) + << absl::StrCat("got: ", type.DebugString(), + " want: ", TypeKindToString(*test_case.type_kind)); + } else { + EXPECT_THAT(result, StatusIs(test_case.type_kind.status().code())); + } +} + +TEST_P(TypeFromProtoTest, RoundTripProtoWorks) { + const google::protobuf::DescriptorPool* descriptor_pool = + internal::GetTestingDescriptorPool(); + google::protobuf::Arena arena; + + const TestCase& test_case = GetParam(); + if (!test_case.type_kind.ok() || test_case.round_trip == RoundTrip::kNo) { + return GTEST_SUCCEED(); + } + cel::expr::Type type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(test_case.type_pb, &type_pb)); + absl::StatusOr result = TypeFromProto(type_pb, descriptor_pool, &arena); + + ASSERT_THAT(test_case.type_kind, IsOk()); + ASSERT_OK_AND_ASSIGN(Type type, result); + + EXPECT_EQ(type.kind(), *test_case.type_kind) + << absl::StrCat("got: ", type.DebugString(), + " want: ", TypeKindToString(*test_case.type_kind)); + cel::expr::Type round_trip_pb; + ASSERT_THAT(TypeToProto(type, &round_trip_pb), IsOk()); + EXPECT_THAT(round_trip_pb, EqualsProto(type_pb)); +} + +INSTANTIATE_TEST_SUITE_P( + TypeFromProtoTest, TypeFromProtoTest, + testing::Values( + TestCase{ + R"pb( + abstract_type { + name: "foo" + parameter_types { primitive: INT64 } + parameter_types { primitive: STRING } + } + )pb", + TypeKind::kOpaque}, + TestCase{R"pb( + dyn {} + )pb", + TypeKind::kDyn}, + TestCase{R"pb( + error {} + )pb", + TypeKind::kError}, + TestCase{R"pb( + list_type { elem_type { primitive: INT64 } } + )pb", + TypeKind::kList}, + TestCase{R"pb( + map_type { + key_type { primitive: INT64 } + value_type { primitive: STRING } + } + )pb", + TypeKind::kMap}, + TestCase{R"pb( + message_type: "google.api.expr.runtime.TestExtensions" + )pb", + TypeKind::kMessage}, + TestCase{R"pb( + message_type: "com.example.UnknownMessage" + )pb", + absl::InvalidArgumentError("")}, + // Special-case well known types referenced by + // equivalent proto message types. + TestCase{R"pb( + message_type: "google.protobuf.Any" + )pb", + TypeKind::kAny, RoundTrip::kNo}, + TestCase{R"pb( + message_type: "google.protobuf.Timestamp" + )pb", + TypeKind::kTimestamp, RoundTrip::kNo}, + TestCase{R"pb( + message_type: "google.protobuf.Duration" + )pb", + TypeKind::kDuration, RoundTrip::kNo}, + TestCase{R"pb( + message_type: "google.protobuf.Struct" + )pb", + TypeKind::kMap, RoundTrip::kNo}, + TestCase{R"pb( + message_type: "google.protobuf.ListValue" + )pb", + TypeKind::kList, RoundTrip::kNo}, + TestCase{R"pb( + message_type: "google.protobuf.Value" + )pb", + TypeKind::kDyn, RoundTrip::kNo}, + TestCase{R"pb( + message_type: "google.protobuf.Int64Value" + )pb", + TypeKind::kIntWrapper, RoundTrip::kNo}, + TestCase{R"pb( + null: 0 + )pb", + TypeKind::kNull}, + TestCase{ + R"pb( + primitive: BOOL)pb", + TypeKind::kBool}, + TestCase{ + R"pb( + primitive: BYTES)pb", + TypeKind::kBytes}, + TestCase{ + R"pb( + primitive: DOUBLE)pb", + TypeKind::kDouble}, + TestCase{ + R"pb( + primitive: INT64)pb", + TypeKind::kInt}, + TestCase{ + R"pb( + primitive: STRING)pb", + TypeKind::kString}, + TestCase{ + R"pb( + primitive: UINT64)pb", + TypeKind::kUint}, + TestCase{ + R"pb( + primitive: PRIMITIVE_TYPE_UNSPECIFIED)pb", + absl::InvalidArgumentError("")}, + TestCase{ + R"pb( + type { type { primitive: UINT64 } })pb", + TypeKind::kType}, + TestCase{ + R"pb( + type_param: "T")pb", + TypeKind::kTypeParam}, + TestCase{ + R"pb( + well_known: ANY)pb", + TypeKind::kAny}, + TestCase{ + R"pb( + well_known: TIMESTAMP)pb", + TypeKind::kTimestamp}, + TestCase{ + R"pb( + well_known: DURATION)pb", + TypeKind::kDuration}, + TestCase{ + R"pb( + well_known: WELL_KNOWN_TYPE_UNSPECIFIED)pb", + absl::InvalidArgumentError("")}, + TestCase{ + R"pb( + wrapper: BOOL + )pb", + TypeKind::kBoolWrapper}, + TestCase{ + R"pb( + wrapper: BYTES + )pb", + TypeKind::kBytesWrapper}, + TestCase{ + R"pb( + wrapper: DOUBLE + )pb", + TypeKind::kDoubleWrapper}, + TestCase{ + R"pb( + wrapper: INT64 + )pb", + TypeKind::kIntWrapper}, + TestCase{ + R"pb( + wrapper: STRING + )pb", + TypeKind::kStringWrapper}, + TestCase{ + R"pb( + wrapper: UINT64 + )pb", + TypeKind::kUintWrapper}, + TestCase{ + R"pb( + wrapper: PRIMITIVE_TYPE_UNSPECIFIED + )pb", + absl::InvalidArgumentError("")}, + TestCase{ + R"pb( + function { + result_type { primitive: BOOL } + arg_types { primitive: INT64 } + arg_types { primitive: STRING } + })pb", + absl::InvalidArgumentError("")})); + +} // namespace +} // namespace cel diff --git a/common/type_reflector.h b/common/type_reflector.h new file mode 100644 index 000000000..8378ed36c --- /dev/null +++ b/common/type_reflector.h @@ -0,0 +1,43 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/type_introspector.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace cel { + +// `TypeReflector` is an interface for constructing new instances of types are +// runtime. It handles type reflection. +class TypeReflector : public virtual TypeIntrospector { + public: + // `NewValueBuilder` returns a new `ValueBuilder` for the corresponding type + // `name`. It is primarily used to handle wrapper types which sometimes show + // up literally in expressions. + virtual absl::StatusOr NewValueBuilder( + absl::string_view name, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ diff --git a/common/type_reflector_test.cc b/common/type_reflector_test.cc new file mode 100644 index 000000000..d9c855e4b --- /dev/null +++ b/common/type_reflector_test.cc @@ -0,0 +1,588 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "common/values/list_value.h" +#include "common/values/value_builder.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::ErrorValueIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::NotNull; +using ::testing::Optional; + +using TypeReflectorTest = common_internal::ValueTest<>; + +#define TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(element_type) \ + TEST_F(TypeReflectorTest, NewListValueBuilder_##element_type) { \ + auto list_value_builder = NewListValueBuilder(arena()); \ + EXPECT_TRUE(list_value_builder->IsEmpty()); \ + EXPECT_EQ(list_value_builder->Size(), 0); \ + auto list_value = std::move(*list_value_builder).Build(); \ + EXPECT_THAT(list_value.IsEmpty(), IsOkAndHolds(true)); \ + EXPECT_THAT(list_value.Size(), IsOkAndHolds(0)); \ + EXPECT_EQ(list_value.DebugString(), "[]"); \ + } + +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(BoolType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(BytesType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(DoubleType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(DurationType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(IntType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(ListType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(MapType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(NullType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(OptionalType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(StringType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(TimestampType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(TypeType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(UintType) +TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(DynType) + +#undef TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST + +#define TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(key_type, value_type) \ + TEST_F(TypeReflectorTest, NewMapValueBuilder_##key_type##_##value_type) { \ + auto map_value_builder = NewMapValueBuilder(arena()); \ + EXPECT_TRUE(map_value_builder->IsEmpty()); \ + EXPECT_EQ(map_value_builder->Size(), 0); \ + auto map_value = std::move(*map_value_builder).Build(); \ + EXPECT_THAT(map_value.IsEmpty(), IsOkAndHolds(true)); \ + EXPECT_THAT(map_value.Size(), IsOkAndHolds(0)); \ + EXPECT_EQ(map_value.DebugString(), "{}"); \ + } + +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, BoolType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, BytesType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, DoubleType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, DurationType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, IntType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, ListType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, MapType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, NullType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, OptionalType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, StringType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, TimestampType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, TypeType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, UintType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, DynType) + +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, BoolType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, BytesType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, DoubleType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, DurationType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, IntType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, ListType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, MapType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, NullType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, OptionalType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, StringType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, TimestampType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, TypeType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, UintType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, DynType) + +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, BoolType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, BytesType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, DoubleType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, DurationType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, IntType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, ListType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, MapType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, NullType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, OptionalType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, StringType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, TimestampType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, TypeType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, UintType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, DynType) + +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, BoolType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, BytesType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, DoubleType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, DurationType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, IntType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, ListType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, MapType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, NullType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, OptionalType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, StringType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, TimestampType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, TypeType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, UintType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, DynType) + +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, BoolType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, BytesType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, DoubleType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, DurationType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, IntType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, ListType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, MapType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, NullType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, OptionalType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, StringType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, TimestampType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, TypeType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, UintType) +TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, DynType) + +#undef TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST + +TEST_F(TypeReflectorTest, NewListValueBuilderCoverage_Dynamic) { + auto builder = NewListValueBuilder(arena()); + EXPECT_OK(builder->Add(IntValue(0))); + EXPECT_OK(builder->Add(IntValue(1))); + EXPECT_OK(builder->Add(IntValue(2))); + EXPECT_EQ(builder->Size(), 3); + EXPECT_FALSE(builder->IsEmpty()); + auto value = std::move(*builder).Build(); + EXPECT_EQ(value.DebugString(), "[0, 1, 2]"); +} + +TEST_F(TypeReflectorTest, NewMapValueBuilderCoverage_DynamicDynamic) { + auto builder = NewMapValueBuilder(arena()); + EXPECT_OK(builder->Put(BoolValue(false), IntValue(1))); + EXPECT_OK(builder->Put(BoolValue(true), IntValue(2))); + EXPECT_OK(builder->Put(IntValue(0), IntValue(3))); + EXPECT_OK(builder->Put(IntValue(1), IntValue(4))); + EXPECT_OK(builder->Put(UintValue(0), IntValue(5))); + EXPECT_OK(builder->Put(UintValue(1), IntValue(6))); + EXPECT_OK(builder->Put(StringValue("a"), IntValue(7))); + EXPECT_OK(builder->Put(StringValue("b"), IntValue(8))); + EXPECT_EQ(builder->Size(), 8); + EXPECT_FALSE(builder->IsEmpty()); + auto value = std::move(*builder).Build(); + EXPECT_THAT(value.DebugString(), Not(IsEmpty())); +} + +TEST_F(TypeReflectorTest, NewMapValueBuilderCoverage_StaticDynamic) { + auto builder = NewMapValueBuilder(arena()); + EXPECT_OK(builder->Put(BoolValue(true), IntValue(0))); + EXPECT_EQ(builder->Size(), 1); + EXPECT_FALSE(builder->IsEmpty()); + auto value = std::move(*builder).Build(); + EXPECT_EQ(value.DebugString(), "{true: 0}"); +} + +TEST_F(TypeReflectorTest, NewMapValueBuilderCoverage_DynamicStatic) { + auto builder = NewMapValueBuilder(arena()); + EXPECT_OK(builder->Put(BoolValue(true), IntValue(0))); + EXPECT_EQ(builder->Size(), 1); + EXPECT_FALSE(builder->IsEmpty()); + auto value = std::move(*builder).Build(); + EXPECT_EQ(value.DebugString(), "{true: 0}"); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_BoolValue) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.BoolValue"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), true); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_Int32Value) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Int32Value"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByName( + "value", IntValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber( + 1, IntValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), 1); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_Int64Value) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Int64Value"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), 1); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_UInt32Value) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.UInt32Value"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByName( + "value", UintValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber( + 1, UintValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), 1); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_UInt64Value) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.UInt64Value"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), 1); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_FloatValue) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.FloatValue"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), 1); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_DoubleValue) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.DoubleValue"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), 1); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_StringValue) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.StringValue"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", StringValue("foo")), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", StringValue("foo")), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, StringValue("foo")), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, StringValue("foo")), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeString(), "foo"); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_BytesValue) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.BytesValue"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("value", BytesValue("foo")), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", BytesValue("foo")), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BytesValue("foo")), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue("foo")), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeString(), "foo"); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Duration"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("seconds", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName( + "nanos", IntValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + EXPECT_THAT(builder->SetFieldByName("nanos", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber( + 2, IntValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), + absl::Seconds(1) + absl::Nanoseconds(1)); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Timestamp"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("seconds", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName( + "nanos", IntValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + EXPECT_THAT(builder->SetFieldByName("nanos", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber( + 2, IntValue(std::numeric_limits::max())), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); + EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), + absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)); +} + +TEST_F(TypeReflectorTest, NewValueBuilder_Any) { + auto builder = common_internal::NewValueBuilder( + arena(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Any"); + ASSERT_THAT(builder, NotNull()); + EXPECT_THAT(builder->SetFieldByName( + "type_url", + StringValue("type.googleapis.com/google.protobuf.BoolValue")), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByName("type_url", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByName("value", BytesValue()), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT( + builder->SetFieldByNumber( + 1, StringValue("type.googleapis.com/google.protobuf.BoolValue")), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); + EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue()), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), + IsOkAndHolds(Optional( + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); + EXPECT_TRUE(InstanceOf(value)); + EXPECT_EQ(Cast(value).NativeValue(), false); +} + +} // namespace +} // namespace cel diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc new file mode 100644 index 000000000..90c9930a8 --- /dev/null +++ b/common/type_spec_resolver.cc @@ -0,0 +1,301 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + if (type_spec.has_null()) return Type(NullType{}); + if (type_spec.has_dyn()) return Type(DynType{}); + + if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + return Type(BoolType{}); + case PrimitiveType::kInt64: + return Type(IntType{}); + case PrimitiveType::kUint64: + return Type(UintType{}); + case PrimitiveType::kDouble: + return Type(DoubleType{}); + case PrimitiveType::kString: + return Type(StringType{}); + case PrimitiveType::kBytes: + return Type(BytesType{}); + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } + + if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + return Type(AnyType{}); + case WellKnownTypeSpec::kTimestamp: + return Type(TimestampType{}); + case WellKnownTypeSpec::kDuration: + return Type(DurationType{}); + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } + + if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + return Type(BoolWrapperType{}); + case PrimitiveType::kInt64: + return Type(IntWrapperType{}); + case PrimitiveType::kUint64: + return Type(UintWrapperType{}); + case PrimitiveType::kDouble: + return Type(DoubleWrapperType{}); + case PrimitiveType::kString: + return Type(StringWrapperType{}); + case PrimitiveType::kBytes: + return Type(BytesWrapperType{}); + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } + + if (type_spec.has_list_type()) { + Type elem_type; + if (type_spec.list_type().elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + elem_type, ConvertTypeSpecToType(type_spec.list_type().elem_type(), + arena, pool)); + } + return Type(ListType(arena, elem_type)); + } + + if (type_spec.has_map_type()) { + Type key_type; + if (type_spec.map_type().key_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + } + + Type value_type; + if (type_spec.map_type().value_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + value_type, ConvertTypeSpecToType(type_spec.map_type().value_type(), + arena, pool)); + } + return Type(MapType(arena, key_type, value_type)); + } + + if (type_spec.has_function()) { + const auto& func_spec = type_spec.function(); + Type result_type; + if (func_spec.result_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + } + std::vector arg_types; + arg_types.reserve(func_spec.arg_types().size()); + for (const auto& arg_spec : func_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, + ConvertTypeSpecToType(arg_spec, arena, pool)); + arg_types.push_back(std::move(arg_type)); + } + return Type(FunctionType(arena, result_type, arg_types)); + } + + if (type_spec.has_type_param()) { + const std::string& name = type_spec.type_param().type(); + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(TypeParamType(absl::string_view(*allocated_name))); + } + + if (type_spec.has_message_type()) { + const std::string& name = type_spec.message_type().type(); + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' not found in descriptor pool")); + } + return Type::Message(descriptor); + } + + if (type_spec.has_abstract_type()) { + const std::string& name = type_spec.abstract_type().name(); + + // Check if it's a message type in the pool + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' cannot have type parameters")); + } + return Type::Message(descriptor); + } + + // Check if it's an enum type in the pool + const google::protobuf::EnumDescriptor* enum_descriptor = + pool.FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Enum type '", name, "' cannot have type parameters")); + } + return Type::Enum(enum_descriptor); + } + + // Otherwise fallback to OpaqueType + std::vector params; + for (const auto& param_spec : type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(auto param, + ConvertTypeSpecToType(param_spec, arena, pool)); + params.push_back(std::move(param)); + } + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(OpaqueType(arena, absl::string_view(*allocated_name), params)); + } + + if (type_spec.has_type()) { + CEL_ASSIGN_OR_RETURN(auto contained_type, + ConvertTypeSpecToType(type_spec.type(), arena, pool)); + return Type(TypeType(arena, contained_type)); + } + + if (type_spec.has_error()) { + return Type(ErrorType{}); + } + + return absl::InvalidArgumentError("Unknown TypeSpec kind"); +} + +absl::StatusOr ConvertTypeToTypeSpec(const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec{}); + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec{}); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kList: { + CEL_ASSIGN_OR_RETURN(auto elem_type, + ConvertTypeToTypeSpec(type.GetList().element())); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } + case TypeKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto key_type, + ConvertTypeToTypeSpec(type.GetMap().key())); + CEL_ASSIGN_OR_RETURN(auto value_type, + ConvertTypeToTypeSpec(type.GetMap().value())); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + case TypeKind::kFunction: { + auto func_type = type.GetFunction(); + CEL_ASSIGN_OR_RETURN(auto result_type, + ConvertTypeToTypeSpec(func_type.result())); + std::vector arg_types; + arg_types.reserve(func_type.args().size()); + for (const auto& arg : func_type.args()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeToTypeSpec(arg)); + arg_types.push_back(std::move(arg_type)); + } + return TypeSpec( + FunctionTypeSpec(std::make_unique(std::move(result_type)), + std::move(arg_types))); + } + case TypeKind::kTypeParam: + return TypeSpec(ParamTypeSpec(std::string(type.GetTypeParam().name()))); + case TypeKind::kStruct: { + if (type.IsMessage()) { + return TypeSpec(MessageTypeSpec(std::string(type.GetMessage().name()))); + } + return absl::InvalidArgumentError("Unsupported struct type"); + } + case TypeKind::kOpaque: { + auto opaque_type = type.GetOpaque(); + std::vector params; + params.reserve(opaque_type.GetParameters().size()); + for (const auto& param : opaque_type.GetParameters()) { + CEL_ASSIGN_OR_RETURN(auto param_type, ConvertTypeToTypeSpec(param)); + params.push_back(std::move(param_type)); + } + return TypeSpec( + AbstractType(std::string(opaque_type.name()), std::move(params))); + } + case TypeKind::kType: { + CEL_ASSIGN_OR_RETURN(auto nested_type, + ConvertTypeToTypeSpec(type.GetType().GetType())); + return TypeSpec(std::make_unique(std::move(nested_type))); + } + case TypeKind::kError: + return TypeSpec(ErrorTypeSpec::kValue); + case TypeKind::kEnum: + return TypeSpec( + AbstractType(std::string(type.GetEnum().name()), /*params=*/{})); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported Type kind: ", TypeKindToString(type.kind()))); + } +} + +} // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h new file mode 100644 index 000000000..edbfa3bde --- /dev/null +++ b/common/type_spec_resolver.h @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Resolves a `cel::TypeSpec` to a `cel::Type`. +// +// TypeSpec only specifies a type while Type provides support for inspecting +// properties of the type when used in CEL. Returns a status with code +// `InvalidArgument` if the input cannot be resolved to a type. +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// Resolves a `cel::Type` to a `cel::TypeSpec`. +absl::StatusOr ConvertTypeToTypeSpec(const Type& type); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc new file mode 100644 index 000000000..1cda7280f --- /dev/null +++ b/common/type_spec_resolver_test.cc @@ -0,0 +1,284 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::Values; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +TEST(TypeSpecResolverTest, NullTypeSpec) { + TypeSpec spec(NullTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsNull()); +} + +TEST(TypeSpecResolverTest, DynTypeSpec) { + TypeSpec spec(DynTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsDyn()); +} + +using ConversionTest = testing::TestWithParam>; + +TEST_P(ConversionTest, TestTypeSpecConversion) { + ASSERT_OK_AND_ASSIGN( + auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_EQ(t.kind(), std::get<1>(GetParam())); + EXPECT_THAT(ConvertTypeToTypeSpec(t), IsOkAndHolds(std::get<0>(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P( + TypeSpecResolverTest, ConversionTest, + testing::Values( + std::make_tuple(TypeSpec(PrimitiveType::kBool), TypeKind::kBool), + std::make_tuple(TypeSpec(PrimitiveType::kInt64), TypeKind::kInt), + std::make_tuple(TypeSpec(PrimitiveType::kUint64), TypeKind::kUint), + std::make_tuple(TypeSpec(PrimitiveType::kDouble), TypeKind::kDouble), + std::make_tuple(TypeSpec(PrimitiveType::kString), TypeKind::kString), + std::make_tuple(TypeSpec(PrimitiveType::kBytes), TypeKind::kBytes), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kAny), TypeKind::kAny), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kTimestamp), + TypeKind::kTimestamp), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kDuration), + TypeKind::kDuration), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + TypeKind::kBoolWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + TypeKind::kIntWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + TypeKind::kUintWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + TypeKind::kDoubleWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + TypeKind::kStringWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + TypeKind::kBytesWrapper))); + +TEST(TypeSpecResolverTest, ListTypeConversion) { + auto elem = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(ListTypeSpec(std::move(elem))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsList()); + EXPECT_TRUE(t->GetList().element().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MapTypeConversion) { + auto key = std::make_unique(PrimitiveType::kString); + auto val = std::make_unique(PrimitiveType::kBytes); + TypeSpec spec(MapTypeSpec(std::move(key), std::move(val))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMap()); + EXPECT_TRUE(t->GetMap().key().IsString()); + EXPECT_TRUE(t->GetMap().value().IsBytes()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, FunctionTypeConversion) { + auto result = std::make_unique(PrimitiveType::kBool); + std::vector args; + args.push_back(TypeSpec(PrimitiveType::kString)); + TypeSpec spec(FunctionTypeSpec(std::move(result), std::move(args))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsFunction()); + EXPECT_EQ(t->GetFunction().args().size(), 1); + EXPECT_TRUE(t->GetFunction().result().IsBool()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, TypeParamConversion) { + TypeSpec spec(ParamTypeSpec("T")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsTypeParam()); + EXPECT_EQ(t->GetTypeParam().name(), "T"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MessageTypeConversion) { + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ( + spec2, + TypeSpec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))); +} + +TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("cel.expr.conformance.proto3.TestAllTypes", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("my.custom.OpaqueType", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "my.custom.OpaqueType"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, OptionalType) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("optional_type", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "optional_type"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); + EXPECT_TRUE(t->IsOptional()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, TypeTypeConversion) { + auto nested = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(std::move(nested)); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsType()); + EXPECT_TRUE(t->GetType().GetType().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, ErrorTypeConversion) { + TypeSpec spec(ErrorTypeSpec::kValue); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsError()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.NonExistentType")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in descriptor pool"))); +} + +TEST(TypeSpecResolverTest, EnumTypeConversion) { + TypeSpec spec(AbstractType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsEnum()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes.NestedEnum", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnknownTypeSpecKindError) { + TypeSpec spec; + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unknown TypeSpec kind"))); +} + +} // namespace +} // namespace cel diff --git a/common/type_test.cc b/common/type_test.cc new file mode 100644 index 000000000..d6a613c3c --- /dev/null +++ b/common/type_test.cc @@ -0,0 +1,676 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type.h" + +#include "absl/hash/hash.h" +#include "absl/hash/hash_testing.h" +#include "absl/log/die_if_null.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::An; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Optional; + +TEST(Type, Default) { + EXPECT_EQ(Type(), DynType()); + EXPECT_TRUE(Type().IsDyn()); +} + +TEST(Type, Enum) { + EXPECT_EQ( + Type::Enum( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"))), + EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))); + EXPECT_EQ(Type::Enum( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.protobuf.NullValue"))), + IntType()); +} + +TEST(Type, Field) { + google::protobuf::Arena arena; + const auto* descriptor = + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")); + EXPECT_EQ( + Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_bool"))), + BoolType()); + EXPECT_EQ( + Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("null_value"))), + IntType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_int32"))), + IntType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_sint32"))), + IntType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_sfixed32"))), + IntType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_int64"))), + IntType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_sint64"))), + IntType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_sfixed64"))), + IntType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_fixed32"))), + UintType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_uint32"))), + UintType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_fixed64"))), + UintType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_uint64"))), + UintType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_float"))), + DoubleType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_double"))), + DoubleType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_bytes"))), + BytesType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_string"))), + StringType()); + EXPECT_EQ( + Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_any"))), + AnyType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_duration"))), + DurationType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_timestamp"))), + TimestampType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_struct"))), + JsonMapType()); + EXPECT_EQ( + Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("list_value"))), + JsonListType()); + EXPECT_EQ(Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_value"))), + JsonType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_bool_wrapper"))), + BoolWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_int32_wrapper"))), + IntWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_int64_wrapper"))), + IntWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_uint32_wrapper"))), + UintWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_uint64_wrapper"))), + UintWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_float_wrapper"))), + DoubleWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_double_wrapper"))), + DoubleWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_bytes_wrapper"))), + BytesWrapperType()); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("single_string_wrapper"))), + StringWrapperType()); + EXPECT_EQ( + Type::Field( + ABSL_DIE_IF_NULL(descriptor->FindFieldByName("standalone_enum"))), + EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("repeated_int32"))), + ListType(&arena, IntType())); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + descriptor->FindFieldByName("map_int32_int32"))), + MapType(&arena, IntType(), IntType())); +} + +TEST(Type, Kind) { + google::protobuf::Arena arena; + + EXPECT_EQ(Type(AnyType()).kind(), AnyType::kKind); + + EXPECT_EQ(Type(BoolType()).kind(), BoolType::kKind); + + EXPECT_EQ(Type(BoolWrapperType()).kind(), BoolWrapperType::kKind); + + EXPECT_EQ(Type(BytesType()).kind(), BytesType::kKind); + + EXPECT_EQ(Type(BytesWrapperType()).kind(), BytesWrapperType::kKind); + + EXPECT_EQ(Type(DoubleType()).kind(), DoubleType::kKind); + + EXPECT_EQ(Type(DoubleWrapperType()).kind(), DoubleWrapperType::kKind); + + EXPECT_EQ(Type(DurationType()).kind(), DurationType::kKind); + + EXPECT_EQ(Type(DynType()).kind(), DynType::kKind); + + EXPECT_EQ( + Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) + .kind(), + EnumType::kKind); + + EXPECT_EQ(Type(ErrorType()).kind(), ErrorType::kKind); + + EXPECT_EQ(Type(FunctionType(&arena, DynType(), {})).kind(), + FunctionType::kKind); + + EXPECT_EQ(Type(IntType()).kind(), IntType::kKind); + + EXPECT_EQ(Type(IntWrapperType()).kind(), IntWrapperType::kKind); + + EXPECT_EQ(Type(ListType()).kind(), ListType::kKind); + + EXPECT_EQ(Type(MapType()).kind(), MapType::kKind); + + EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))) + .kind(), + MessageType::kKind); + EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))) + .kind(), + MessageType::kKind); + + EXPECT_EQ(Type(NullType()).kind(), NullType::kKind); + + EXPECT_EQ(Type(OptionalType()).kind(), OpaqueType::kKind); + + EXPECT_EQ(Type(StringType()).kind(), StringType::kKind); + + EXPECT_EQ(Type(StringWrapperType()).kind(), StringWrapperType::kKind); + + EXPECT_EQ(Type(TimestampType()).kind(), TimestampType::kKind); + + EXPECT_EQ(Type(UintType()).kind(), UintType::kKind); + + EXPECT_EQ(Type(UintWrapperType()).kind(), UintWrapperType::kKind); + + EXPECT_EQ(Type(UnknownType()).kind(), UnknownType::kKind); +} + +TEST(Type, GetParameters) { + google::protobuf::Arena arena; + + EXPECT_THAT(Type(AnyType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(BoolType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(BoolWrapperType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(BytesType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(BytesWrapperType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(DoubleType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(DoubleWrapperType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(DurationType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(DynType()).GetParameters(), IsEmpty()); + + EXPECT_THAT( + Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) + .GetParameters(), + IsEmpty()); + + EXPECT_THAT(Type(ErrorType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(FunctionType(&arena, DynType(), + {IntType(), StringType(), DynType()})) + .GetParameters(), + ElementsAre(DynType(), IntType(), StringType(), DynType())); + + EXPECT_THAT(Type(IntType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(IntWrapperType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(ListType()).GetParameters(), ElementsAre(DynType())); + + EXPECT_THAT(Type(MapType()).GetParameters(), + ElementsAre(DynType(), DynType())); + + EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))) + .GetParameters(), + IsEmpty()); + + EXPECT_THAT(Type(NullType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(OptionalType()).GetParameters(), ElementsAre(DynType())); + + EXPECT_THAT(Type(StringType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(StringWrapperType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(TimestampType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(UintType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(UintWrapperType()).GetParameters(), IsEmpty()); + + EXPECT_THAT(Type(UnknownType()).GetParameters(), IsEmpty()); +} + +TEST(Type, Is) { + google::protobuf::Arena arena; + + EXPECT_TRUE(Type(AnyType()).Is()); + + EXPECT_TRUE(Type(BoolType()).Is()); + + EXPECT_TRUE(Type(BoolWrapperType()).Is()); + EXPECT_TRUE(Type(BoolWrapperType()).IsWrapper()); + + EXPECT_TRUE(Type(BytesType()).Is()); + + EXPECT_TRUE(Type(BytesWrapperType()).Is()); + EXPECT_TRUE(Type(BytesWrapperType()).IsWrapper()); + + EXPECT_TRUE(Type(DoubleType()).Is()); + + EXPECT_TRUE(Type(DoubleWrapperType()).Is()); + EXPECT_TRUE(Type(DoubleWrapperType()).IsWrapper()); + + EXPECT_TRUE(Type(DurationType()).Is()); + + EXPECT_TRUE(Type(DynType()).Is()); + + EXPECT_TRUE( + Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) + .Is()); + + EXPECT_TRUE(Type(ErrorType()).Is()); + + EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is()); + + EXPECT_TRUE(Type(IntType()).Is()); + + EXPECT_TRUE(Type(IntWrapperType()).Is()); + EXPECT_TRUE(Type(IntWrapperType()).IsWrapper()); + + EXPECT_TRUE(Type(ListType()).Is()); + + EXPECT_TRUE(Type(MapType()).Is()); + + EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))) + .IsStruct()); + EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))) + .IsMessage()); + + EXPECT_TRUE(Type(NullType()).Is()); + + EXPECT_TRUE(Type(OptionalType()).Is()); + EXPECT_TRUE(Type(OptionalType()).Is()); + + EXPECT_TRUE(Type(StringType()).Is()); + + EXPECT_TRUE(Type(StringWrapperType()).Is()); + EXPECT_TRUE(Type(StringWrapperType()).IsWrapper()); + + EXPECT_TRUE(Type(TimestampType()).Is()); + + EXPECT_TRUE(Type(TypeType()).Is()); + + EXPECT_TRUE(Type(TypeParamType("T")).Is()); + + EXPECT_TRUE(Type(UintType()).Is()); + + EXPECT_TRUE(Type(UintWrapperType()).Is()); + EXPECT_TRUE(Type(UintWrapperType()).IsWrapper()); + + EXPECT_TRUE(Type(UnknownType()).Is()); +} + +TEST(Type, As) { + google::protobuf::Arena arena; + + EXPECT_THAT(Type(AnyType()).As(), Optional(An())); + + EXPECT_THAT(Type(BoolType()).As(), Optional(An())); + + EXPECT_THAT(Type(BoolWrapperType()).As(), + Optional(An())); + + EXPECT_THAT(Type(BytesType()).As(), Optional(An())); + + EXPECT_THAT(Type(BytesWrapperType()).As(), + Optional(An())); + + EXPECT_THAT(Type(DoubleType()).As(), Optional(An())); + + EXPECT_THAT(Type(DoubleWrapperType()).As(), + Optional(An())); + + EXPECT_THAT(Type(DurationType()).As(), + Optional(An())); + + EXPECT_THAT(Type(DynType()).As(), Optional(An())); + + EXPECT_THAT( + Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) + .As(), + Optional(An())); + + EXPECT_THAT(Type(ErrorType()).As(), Optional(An())); + + EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is()); + + EXPECT_THAT(Type(IntType()).As(), Optional(An())); + + EXPECT_THAT(Type(IntWrapperType()).As(), + Optional(An())); + + EXPECT_THAT(Type(ListType()).As(), Optional(An())); + + EXPECT_THAT(Type(MapType()).As(), Optional(An())); + + EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))) + .As(), + Optional(An())); + EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))) + .As(), + Optional(An())); + + EXPECT_THAT(Type(NullType()).As(), Optional(An())); + + EXPECT_THAT(Type(OptionalType()).As(), + Optional(An())); + EXPECT_THAT(Type(OptionalType()).As(), + Optional(An())); + + EXPECT_THAT(Type(StringType()).As(), Optional(An())); + + EXPECT_THAT(Type(StringWrapperType()).As(), + Optional(An())); + + EXPECT_THAT(Type(TimestampType()).As(), + Optional(An())); + + EXPECT_THAT(Type(TypeType()).As(), Optional(An())); + + EXPECT_THAT(Type(TypeParamType("T")).As(), + Optional(An())); + + EXPECT_THAT(Type(UintType()).As(), Optional(An())); + + EXPECT_THAT(Type(UintWrapperType()).As(), + Optional(An())); + + EXPECT_THAT(Type(UnknownType()).As(), + Optional(An())); +} + +template +T DoGet(const Type& type) { + return type.template Get(); +} + +TEST(Type, Get) { + google::protobuf::Arena arena; + + EXPECT_THAT(DoGet(Type(AnyType())), An()); + + EXPECT_THAT(DoGet(Type(BoolType())), An()); + + EXPECT_THAT(DoGet(Type(BoolWrapperType())), + An()); + EXPECT_THAT(DoGet(Type(BoolWrapperType())), + An()); + + EXPECT_THAT(DoGet(Type(BytesType())), An()); + + EXPECT_THAT(DoGet(Type(BytesWrapperType())), + An()); + EXPECT_THAT(DoGet(Type(BytesWrapperType())), + An()); + + EXPECT_THAT(DoGet(Type(DoubleType())), An()); + + EXPECT_THAT(DoGet(Type(DoubleWrapperType())), + An()); + EXPECT_THAT(DoGet(Type(DoubleWrapperType())), + An()); + + EXPECT_THAT(DoGet(Type(DurationType())), An()); + + EXPECT_THAT(DoGet(Type(DynType())), An()); + + EXPECT_THAT( + DoGet(Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"))))), + An()); + + EXPECT_THAT(DoGet(Type(ErrorType())), An()); + + EXPECT_THAT(DoGet(Type(FunctionType(&arena, DynType(), {}))), + An()); + + EXPECT_THAT(DoGet(Type(IntType())), An()); + + EXPECT_THAT(DoGet(Type(IntWrapperType())), + An()); + EXPECT_THAT(DoGet(Type(IntWrapperType())), + An()); + + EXPECT_THAT(DoGet(Type(ListType())), An()); + + EXPECT_THAT(DoGet(Type(MapType())), An()); + + EXPECT_THAT(DoGet(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"))))), + An()); + EXPECT_THAT(DoGet(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"))))), + An()); + + EXPECT_THAT(DoGet(Type(NullType())), An()); + + EXPECT_THAT(DoGet(Type(OptionalType())), An()); + EXPECT_THAT(DoGet(Type(OptionalType())), An()); + + EXPECT_THAT(DoGet(Type(StringType())), An()); + + EXPECT_THAT(DoGet(Type(StringWrapperType())), + An()); + EXPECT_THAT(DoGet(Type(StringWrapperType())), + An()); + + EXPECT_THAT(DoGet(Type(TimestampType())), An()); + + EXPECT_THAT(DoGet(Type(TypeType())), An()); + + EXPECT_THAT(DoGet(Type(TypeParamType("T"))), + An()); + + EXPECT_THAT(DoGet(Type(UintType())), An()); + + EXPECT_THAT(DoGet(Type(UintWrapperType())), + An()); + EXPECT_THAT(DoGet(Type(UintWrapperType())), + An()); + + EXPECT_THAT(DoGet(Type(UnknownType())), An()); +} + +TEST(Type, VerifyTypeImplementsAbslHashCorrectly) { + google::protobuf::Arena arena; + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {Type(AnyType()), + Type(BoolType()), + Type(BoolWrapperType()), + Type(BytesType()), + Type(BytesWrapperType()), + Type(DoubleType()), + Type(DoubleWrapperType()), + Type(DurationType()), + Type(DynType()), + Type(ErrorType()), + Type(FunctionType(&arena, DynType(), {DynType()})), + Type(IntType()), + Type(IntWrapperType()), + Type(ListType(&arena, DynType())), + Type(MapType(&arena, DynType(), DynType())), + Type(NullType()), + Type(OptionalType(&arena, DynType())), + Type(StringType()), + Type(StringWrapperType()), + Type(StructType(common_internal::MakeBasicStructType("test.Struct"))), + Type(TimestampType()), + Type(TypeParamType("T")), + Type(TypeType()), + Type(UintType()), + Type(UintWrapperType()), + Type(UnknownType())})); + + EXPECT_EQ( + absl::HashOf(Type::Field( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("repeated_int64"))), + absl::HashOf(Type(ListType(&arena, IntType())))); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("repeated_int64")), + Type(ListType(&arena, IntType()))); + + EXPECT_EQ( + absl::HashOf(Type::Field( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("map_int64_int64"))), + absl::HashOf(Type(MapType(&arena, IntType(), IntType())))); + EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("map_int64_int64")), + Type(MapType(&arena, IntType(), IntType()))); + + EXPECT_EQ(absl::HashOf(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"))))), + absl::HashOf(Type(StructType(common_internal::MakeBasicStructType( + "cel.expr.conformance.proto3.TestAllTypes"))))); + EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))), + Type(StructType(common_internal::MakeBasicStructType( + "cel.expr.conformance.proto3.TestAllTypes")))); +} + +TEST(Type, Unwrap) { + EXPECT_EQ(Type(BoolWrapperType()).Unwrap(), BoolType()); + EXPECT_EQ(Type(IntWrapperType()).Unwrap(), IntType()); + EXPECT_EQ(Type(UintWrapperType()).Unwrap(), UintType()); + EXPECT_EQ(Type(DoubleWrapperType()).Unwrap(), DoubleType()); + EXPECT_EQ(Type(BytesWrapperType()).Unwrap(), BytesType()); + EXPECT_EQ(Type(StringWrapperType()).Unwrap(), StringType()); + EXPECT_EQ(Type(AnyType()).Unwrap(), AnyType()); +} + +TEST(Type, Wrap) { + EXPECT_EQ(Type(BoolType()).Wrap(), BoolWrapperType()); + EXPECT_EQ(Type(IntType()).Wrap(), IntWrapperType()); + EXPECT_EQ(Type(UintType()).Wrap(), UintWrapperType()); + EXPECT_EQ(Type(DoubleType()).Wrap(), DoubleWrapperType()); + EXPECT_EQ(Type(BytesType()).Wrap(), BytesWrapperType()); + EXPECT_EQ(Type(StringType()).Wrap(), StringWrapperType()); + EXPECT_EQ(Type(AnyType()).Wrap(), AnyType()); +} + +TEST(Type, LegacyRuntimeType) { + EXPECT_EQ(common_internal::LegacyRuntimeType("bool"), BoolType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Any"), + AnyType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BoolValue"), + BoolWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BytesValue"), + BytesWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.DoubleValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Duration"), + DurationType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.FloatValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int32Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int64Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.ListValue"), + ListType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.StringValue"), + StringWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Struct"), + JsonMapType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Timestamp"), + TimestampType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt32Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt64Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Value"), + DynType()); +} + +} // namespace +} // namespace cel diff --git a/common/type_testing.h b/common/type_testing.h new file mode 100644 index 000000000..284201101 --- /dev/null +++ b/common/type_testing.h @@ -0,0 +1,24 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_TESTING_H_ + +namespace cel::common_internal { + +// Empty for now. + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_TESTING_H_ diff --git a/common/typeinfo.cc b/common/typeinfo.cc new file mode 100644 index 000000000..b07275712 --- /dev/null +++ b/common/typeinfo.cc @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/typeinfo.h" + +#include +#include // IWYU pragma: keep +#include +#include +#include + +#include "absl/base/casts.h" // IWYU pragma: keep +#include "absl/strings/str_cat.h" // IWYU pragma: keep + +#ifdef CEL_INTERNAL_HAVE_RTTI +#ifdef _WIN32 +extern "C" char* __unDName(char*, const char*, int, void* (*)(size_t), + void (*)(void*), int); +#else +#include +#endif +#endif + +namespace cel { + +namespace { + +#ifdef CEL_INTERNAL_HAVE_RTTI +struct FreeDeleter { + void operator()(char* ptr) const { std::free(ptr); } +}; +#endif + +} // namespace + +std::string TypeInfo::DebugString() const { + if (rep_ == nullptr) { + return std::string(); + } +#ifdef CEL_INTERNAL_HAVE_RTTI +#ifdef _WIN32 + std::unique_ptr demangled( + __unDName(nullptr, rep_->raw_name(), 0, std::malloc, std::free, 0x2800)); + if (demangled == nullptr) { + return std::string(rep_->name()); + } + return std::string(demangled.get()); +#else + int status = 0; + std::unique_ptr demangled( + abi::__cxa_demangle(rep_->name(), nullptr, nullptr, &status)); + if (status != 0 || demangled == nullptr) { + return std::string(rep_->name()); + } + return std::string(demangled.get()); +#endif +#else + return absl::StrCat("0x", absl::Hex(absl::bit_cast(rep_))); +#endif +} + +} // namespace cel diff --git a/common/typeinfo.h b/common/typeinfo.h new file mode 100644 index 000000000..dadc42cba --- /dev/null +++ b/common/typeinfo.h @@ -0,0 +1,221 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" // IWYU pragma: keep +#include "absl/base/config.h" +#include "absl/base/nullability.h" +#include "absl/meta/type_traits.h" + +#if ABSL_HAVE_FEATURE(cxx_rtti) +#define CEL_INTERNAL_HAVE_RTTI 1 +#elif defined(__GNUC__) && defined(__GXX_RTTI) +#define CEL_INTERNAL_HAVE_RTTI 1 +#elif defined(_MSC_VER) && defined(_CPPRTTI) +#define CEL_INTERNAL_HAVE_RTTI 1 +#elif !defined(__GNUC__) && !defined(_MSC_VER) +#define CEL_INTERNAL_HAVE_RTTI 1 +#endif + +#ifdef CEL_INTERNAL_HAVE_RTTI +#include +#endif + +namespace cel { + +class TypeInfo; + +template +struct NativeTypeTraits; + +namespace common_internal { + +template +struct HasNativeTypeTraitsId : std::false_type {}; + +template +struct HasNativeTypeTraitsId< + T, std::void_t::Id(std::declval()))>> + : std::true_type {}; + +template +static constexpr bool HasNativeTypeTraitsIdV = HasNativeTypeTraitsId::value; + +template +struct HasCelTypeId : std::false_type {}; + +template +struct HasCelTypeId< + T, std::enable_if_t()))>, + TypeInfo>>> : std::true_type {}; + +} // namespace common_internal + +template +TypeInfo TypeId(); + +template +std::enable_if_t< + std::conjunction_v, + std::negation>>, + TypeInfo> +TypeId(const T& t [[maybe_unused]]) { + return NativeTypeTraits>::Id(t); +} + +template +std::enable_if_t< + std::conjunction_v>, + std::negation>, + std::is_final>, + TypeInfo> +TypeId(const T& t [[maybe_unused]]) { + return cel::TypeId>(); +} + +template +std::enable_if_t< + std::conjunction_v>, + common_internal::HasCelTypeId>, + TypeInfo> +TypeId(const T& t [[maybe_unused]]) { + return CelTypeId(t); +} + +class TypeInfo final { + public: + template + ABSL_DEPRECATED("Use cel::TypeId() instead") + static TypeInfo For() { + return cel::TypeId(); + } + + template + ABSL_DEPRECATED("Use cel::TypeId(...) instead") + static TypeInfo Of(const T& type) { + return cel::TypeId(type); + } + + TypeInfo() = default; + TypeInfo(const TypeInfo&) = default; + TypeInfo& operator=(const TypeInfo&) = default; + + std::string DebugString() const; + + template + friend void AbslStringify(S& sink, TypeInfo type_info) { + sink.Append(type_info.DebugString()); + } + + friend constexpr bool operator==(TypeInfo lhs, TypeInfo rhs) noexcept { +#ifdef CEL_INTERNAL_HAVE_RTTI + return lhs.rep_ == rhs.rep_ || + (lhs.rep_ != nullptr && rhs.rep_ != nullptr && + *lhs.rep_ == *rhs.rep_); +#else + return lhs.rep_ == rhs.rep_; +#endif + } + + template + friend H AbslHashValue(H state, TypeInfo id) { +#ifdef CEL_INTERNAL_HAVE_RTTI + return H::combine(std::move(state), + id.rep_ != nullptr ? id.rep_->hash_code() : size_t{0}); +#else + return H::combine(std::move(state), absl::bit_cast(id.rep_)); +#endif + } + + private: + template + friend TypeInfo TypeId(); + +#ifdef CEL_INTERNAL_HAVE_RTTI + constexpr explicit TypeInfo(const std::type_info* absl_nullable rep) + : rep_(rep) {} + + const std::type_info* absl_nullable rep_ = nullptr; +#else + constexpr explicit TypeInfo(const void* absl_nullable rep) : rep_(rep) {} + + const void* absl_nullable rep_ = nullptr; +#endif +}; + +#ifndef CEL_INTERNAL_HAVE_RTTI +namespace common_internal { +template +struct TypeTag final { + static constexpr char value = 0; +}; +} // namespace common_internal +#endif + +template +TypeInfo TypeId() { + static_assert(std::is_same_v>); + static_assert(!std::is_same_v>); +#ifdef CEL_INTERNAL_HAVE_RTTI + return TypeInfo(&typeid(T)); +#else + return TypeInfo(&common_internal::TypeTag::value); +#endif +} + +inline constexpr bool operator!=(TypeInfo lhs, TypeInfo rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, TypeInfo id) { + return out << id.DebugString(); +} + +// Helper class for adapting a type to an index in a tuple or array. +// Scope is an arbitrary type used as a namespace for the index. +template +class TypeIdInSet { + public: + template + static size_t IndexFor() { + static size_t index = + type_id_set_index_.fetch_add(1, std::memory_order_relaxed); + return index; + } + + static size_t Size() { + return type_id_set_index_.load(std::memory_order_relaxed); + } + + private: + static std::atomic type_id_set_index_; +}; + +template +std::atomic TypeIdInSet::type_id_set_index_ = 0; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_ diff --git a/common/typeinfo_test.cc b/common/typeinfo_test.cc new file mode 100644 index 000000000..cf5b5f877 --- /dev/null +++ b/common/typeinfo_test.cc @@ -0,0 +1,75 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/typeinfo.h" + +#include +#include + +#include "absl/hash/hash_testing.h" +#include "absl/strings/str_cat.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::SizeIs; + +struct Type1 {}; + +struct Type2 {}; + +struct Type3 {}; + +TEST(TypeInfo, ImplementsAbslHashCorrectly) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {TypeInfo(), cel::TypeId(), cel::TypeId(), + cel::TypeId()})); +} + +TEST(TypeInfo, Ostream) { + std::ostringstream out; + out << TypeInfo(); + EXPECT_THAT(out.str(), IsEmpty()); + out << cel::TypeId(); + auto string = out.str(); + EXPECT_THAT(string, Not(IsEmpty())); + EXPECT_THAT(string, SizeIs(std::strlen(string.c_str()))); +} + +TEST(TypeInfo, AbslStringify) { + EXPECT_THAT(absl::StrCat(TypeInfo()), IsEmpty()); + EXPECT_THAT(absl::StrCat(cel::TypeId()), Not(IsEmpty())); +} + +struct TestType {}; + +} // namespace + +template <> +struct NativeTypeTraits final { + static TypeInfo Id(const TestType&) { return cel::TypeId(); } +}; + +namespace { + +TEST(TypeInfo, Of) { + EXPECT_EQ(cel::TypeId(TestType()), cel::TypeId()); +} + +} // namespace + +} // namespace cel diff --git a/common/types/any_type.h b/common/types/any_type.h new file mode 100644 index 000000000..32a9fe3ce --- /dev/null +++ b/common/types/any_type.h @@ -0,0 +1,74 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_ANY_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_ANY_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `AnyType` is a special type which has no direct value representation. It is +// used to represent `google.protobuf.Any`, which never exists at runtime as +// a value. Its primary usage is for type checking and unpacking at runtime. +class AnyType final { + public: + static constexpr TypeKind kKind = TypeKind::kAny; + static constexpr absl::string_view kName = "google.protobuf.Any"; + + AnyType() = default; + AnyType(const AnyType&) = default; + AnyType(AnyType&&) = default; + AnyType& operator=(const AnyType&) = default; + AnyType& operator=(AnyType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(AnyType, AnyType) { return true; } + +inline constexpr bool operator!=(AnyType lhs, AnyType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, AnyType) { + // AnyType is really a singleton and all instances are equal. Nothing to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const AnyType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_ANY_TYPE_H_ diff --git a/common/types/any_type_test.cc b/common/types/any_type_test.cc new file mode 100644 index 000000000..5e0342a7d --- /dev/null +++ b/common/types/any_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(AnyType, Kind) { + EXPECT_EQ(AnyType().kind(), AnyType::kKind); + EXPECT_EQ(Type(AnyType()).kind(), AnyType::kKind); +} + +TEST(AnyType, Name) { + EXPECT_EQ(AnyType().name(), AnyType::kName); + EXPECT_EQ(Type(AnyType()).name(), AnyType::kName); +} + +TEST(AnyType, DebugString) { + { + std::ostringstream out; + out << AnyType(); + EXPECT_EQ(out.str(), AnyType::kName); + } + { + std::ostringstream out; + out << Type(AnyType()); + EXPECT_EQ(out.str(), AnyType::kName); + } +} + +TEST(AnyType, Hash) { + EXPECT_EQ(absl::HashOf(AnyType()), absl::HashOf(AnyType())); +} + +TEST(AnyType, Equal) { + EXPECT_EQ(AnyType(), AnyType()); + EXPECT_EQ(Type(AnyType()), AnyType()); + EXPECT_EQ(AnyType(), Type(AnyType())); + EXPECT_EQ(Type(AnyType()), Type(AnyType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/basic_struct_type.cc b/common/types/basic_struct_type.cc new file mode 100644 index 000000000..a3b31544c --- /dev/null +++ b/common/types/basic_struct_type.cc @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "common/type.h" + +namespace cel { + +bool IsWellKnownMessageType(absl::string_view name) { + static constexpr absl::string_view kPrefix = "google.protobuf."; + static constexpr std::array kNames = { + // clang-format off + // keep-sorted start + "Any", + "BoolValue", + "BytesValue", + "DoubleValue", + "Duration", + "FloatValue", + "Int32Value", + "Int64Value", + "ListValue", + "StringValue", + "Struct", + "Timestamp", + "UInt32Value", + "UInt64Value", + "Value", + // keep-sorted end + // clang-format on + }; + if (!absl::ConsumePrefix(&name, kPrefix)) { + return false; + } + return absl::c_binary_search(kNames, name); +} + +} // namespace cel diff --git a/common/types/basic_struct_type.h b/common/types/basic_struct_type.h new file mode 100644 index 000000000..74200dc17 --- /dev/null +++ b/common/types/basic_struct_type.h @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/types/struct_type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BASIC_STRUCT_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BASIC_STRUCT_TYPE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// Returns true if the given type name is one of the well known message types +// that CEL treats specially. +// +// For familiarity with textproto, these types may be created using the struct +// creation syntax, even though they are not considered a struct type in CEL. +bool IsWellKnownMessageType(absl::string_view name); + +namespace common_internal { + +class BasicStructType; +class BasicStructTypeField; + +// Constructs `BasicStructType` from a type name. The type name must not be one +// of the well known message types we treat specially, if it is behavior is +// undefined. The name must also outlive the resulting type. +BasicStructType MakeBasicStructType( + absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class BasicStructType final { + public: + static constexpr TypeKind kKind = TypeKind::kStruct; + + BasicStructType() = default; + BasicStructType(const BasicStructType&) = default; + BasicStructType(BasicStructType&&) = default; + BasicStructType& operator=(const BasicStructType&) = default; + BasicStructType& operator=(BasicStructType&&) = default; + + static TypeKind kind() { return kKind; } + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return name_; + } + + static TypeParameters GetParameters(); + + std::string DebugString() const { + return std::string(static_cast(*this) ? name() : absl::string_view()); + } + + explicit operator bool() const { return !name_.empty(); } + + private: + friend BasicStructType MakeBasicStructType( + absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND); + + explicit BasicStructType(absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND) + : name_(name) {} + + absl::string_view name_; +}; + +inline bool operator==(BasicStructType lhs, BasicStructType rhs) { + return static_cast(lhs) == static_cast(rhs) && + (!static_cast(lhs) || lhs.name() == rhs.name()); +} + +inline bool operator!=(BasicStructType lhs, BasicStructType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, BasicStructType type) { + ABSL_DCHECK(type); + return H::combine(std::move(state), static_cast(type) + ? type.name() + : absl::string_view()); +} + +inline std::ostream& operator<<(std::ostream& out, BasicStructType type) { + return out << type.DebugString(); +} + +inline BasicStructType MakeBasicStructType( + absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(!IsWellKnownMessageType(name)) << name; + return BasicStructType(name); +} + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BASIC_STRUCT_TYPE_H_ diff --git a/common/types/basic_struct_type_test.cc b/common/types/basic_struct_type_test.cc new file mode 100644 index 000000000..670c1f6e8 --- /dev/null +++ b/common/types/basic_struct_type_test.cc @@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" + +namespace cel::common_internal { +namespace { + +using ::testing::Eq; +using ::testing::IsEmpty; + +TEST(BasicStructType, Kind) { + EXPECT_EQ(BasicStructType::kind(), TypeKind::kStruct); +} + +TEST(BasicStructType, Default) { + BasicStructType type; + EXPECT_FALSE(type); + EXPECT_THAT(type.DebugString(), Eq("")); + EXPECT_EQ(type, BasicStructType()); +} + +TEST(BasicStructType, Name) { + BasicStructType type = MakeBasicStructType("test.Struct"); + EXPECT_TRUE(type); + EXPECT_THAT(type.name(), Eq("test.Struct")); + EXPECT_THAT(type.DebugString(), Eq("test.Struct")); + EXPECT_THAT(type.GetParameters(), IsEmpty()); + EXPECT_NE(type, BasicStructType()); + EXPECT_NE(BasicStructType(), type); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/types/bool_type.h b/common/types/bool_type.h new file mode 100644 index 000000000..545bc3c05 --- /dev/null +++ b/common/types/bool_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `BoolType` represents the primitive `bool` type. +class BoolType final { + public: + static constexpr TypeKind kKind = TypeKind::kBool; + static constexpr absl::string_view kName = "bool"; + + BoolType() = default; + BoolType(const BoolType&) = default; + BoolType(BoolType&&) = default; + BoolType& operator=(const BoolType&) = default; + BoolType& operator=(BoolType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(BoolType, BoolType) { return true; } + +inline constexpr bool operator!=(BoolType lhs, BoolType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, BoolType) { + // BoolType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const BoolType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_TYPE_H_ diff --git a/common/types/bool_type_test.cc b/common/types/bool_type_test.cc new file mode 100644 index 000000000..c9434caec --- /dev/null +++ b/common/types/bool_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(BoolType, Kind) { + EXPECT_EQ(BoolType().kind(), BoolType::kKind); + EXPECT_EQ(Type(BoolType()).kind(), BoolType::kKind); +} + +TEST(BoolType, Name) { + EXPECT_EQ(BoolType().name(), BoolType::kName); + EXPECT_EQ(Type(BoolType()).name(), BoolType::kName); +} + +TEST(BoolType, DebugString) { + { + std::ostringstream out; + out << BoolType(); + EXPECT_EQ(out.str(), BoolType::kName); + } + { + std::ostringstream out; + out << Type(BoolType()); + EXPECT_EQ(out.str(), BoolType::kName); + } +} + +TEST(BoolType, Hash) { + EXPECT_EQ(absl::HashOf(BoolType()), absl::HashOf(BoolType())); +} + +TEST(BoolType, Equal) { + EXPECT_EQ(BoolType(), BoolType()); + EXPECT_EQ(Type(BoolType()), BoolType()); + EXPECT_EQ(BoolType(), Type(BoolType())); + EXPECT_EQ(Type(BoolType()), Type(BoolType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/bool_wrapper_type.h b/common/types/bool_wrapper_type.h new file mode 100644 index 000000000..2149a59b7 --- /dev/null +++ b/common/types/bool_wrapper_type.h @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_WRAPPER_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_WRAPPER_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `BoolWrapperType` is a special type which has no direct value representation. +// It is used to represent `google.protobuf.BoolValue`, which never exists at +// runtime as a value. Its primary usage is for type checking and unpacking at +// runtime. +class BoolWrapperType final { + public: + static constexpr TypeKind kKind = TypeKind::kBoolWrapper; + static constexpr absl::string_view kName = "google.protobuf.BoolValue"; + + BoolWrapperType() = default; + BoolWrapperType(const BoolWrapperType&) = default; + BoolWrapperType(BoolWrapperType&&) = default; + BoolWrapperType& operator=(const BoolWrapperType&) = default; + BoolWrapperType& operator=(BoolWrapperType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(BoolWrapperType, BoolWrapperType) { + return true; +} + +inline constexpr bool operator!=(BoolWrapperType lhs, BoolWrapperType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, BoolWrapperType) { + // BoolWrapperType is really a singleton and all instances are equal. Nothing + // to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, + const BoolWrapperType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_WRAPPER_TYPE_H_ diff --git a/common/types/bool_wrapper_type_test.cc b/common/types/bool_wrapper_type_test.cc new file mode 100644 index 000000000..d66342982 --- /dev/null +++ b/common/types/bool_wrapper_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(BoolWrapperType, Kind) { + EXPECT_EQ(BoolWrapperType().kind(), BoolWrapperType::kKind); + EXPECT_EQ(Type(BoolWrapperType()).kind(), BoolWrapperType::kKind); +} + +TEST(BoolWrapperType, Name) { + EXPECT_EQ(BoolWrapperType().name(), BoolWrapperType::kName); + EXPECT_EQ(Type(BoolWrapperType()).name(), BoolWrapperType::kName); +} + +TEST(BoolWrapperType, DebugString) { + { + std::ostringstream out; + out << BoolWrapperType(); + EXPECT_EQ(out.str(), BoolWrapperType::kName); + } + { + std::ostringstream out; + out << Type(BoolWrapperType()); + EXPECT_EQ(out.str(), BoolWrapperType::kName); + } +} + +TEST(BoolWrapperType, Hash) { + EXPECT_EQ(absl::HashOf(BoolWrapperType()), absl::HashOf(BoolWrapperType())); +} + +TEST(BoolWrapperType, Equal) { + EXPECT_EQ(BoolWrapperType(), BoolWrapperType()); + EXPECT_EQ(Type(BoolWrapperType()), BoolWrapperType()); + EXPECT_EQ(BoolWrapperType(), Type(BoolWrapperType())); + EXPECT_EQ(Type(BoolWrapperType()), Type(BoolWrapperType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/bytes_type.h b/common/types/bytes_type.h new file mode 100644 index 000000000..eb56edb41 --- /dev/null +++ b/common/types/bytes_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `BoolType` represents the primitive `bytes` type. +class BytesType final { + public: + static constexpr TypeKind kKind = TypeKind::kBytes; + static constexpr absl::string_view kName = "bytes"; + + BytesType() = default; + BytesType(const BytesType&) = default; + BytesType(BytesType&&) = default; + BytesType& operator=(const BytesType&) = default; + BytesType& operator=(BytesType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(BytesType, BytesType) { return true; } + +inline constexpr bool operator!=(BytesType lhs, BytesType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, BytesType) { + // BytesType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const BytesType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_TYPE_H_ diff --git a/common/types/bytes_type_test.cc b/common/types/bytes_type_test.cc new file mode 100644 index 000000000..79346a34f --- /dev/null +++ b/common/types/bytes_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(BytesType, Kind) { + EXPECT_EQ(BytesType().kind(), BytesType::kKind); + EXPECT_EQ(Type(BytesType()).kind(), BytesType::kKind); +} + +TEST(BytesType, Name) { + EXPECT_EQ(BytesType().name(), BytesType::kName); + EXPECT_EQ(Type(BytesType()).name(), BytesType::kName); +} + +TEST(BytesType, DebugString) { + { + std::ostringstream out; + out << BytesType(); + EXPECT_EQ(out.str(), BytesType::kName); + } + { + std::ostringstream out; + out << Type(BytesType()); + EXPECT_EQ(out.str(), BytesType::kName); + } +} + +TEST(BytesType, Hash) { + EXPECT_EQ(absl::HashOf(BytesType()), absl::HashOf(BytesType())); +} + +TEST(BytesType, Equal) { + EXPECT_EQ(BytesType(), BytesType()); + EXPECT_EQ(Type(BytesType()), BytesType()); + EXPECT_EQ(BytesType(), Type(BytesType())); + EXPECT_EQ(Type(BytesType()), Type(BytesType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/bytes_wrapper_type.h b/common/types/bytes_wrapper_type.h new file mode 100644 index 000000000..7360fba8b --- /dev/null +++ b/common/types/bytes_wrapper_type.h @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_WRAPPER_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_WRAPPER_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `BytesWrapperType` is a special type which has no direct value +// representation. It is used to represent `google.protobuf.BytesValue`, which +// never exists at runtime as a value. Its primary usage is for type checking +// and unpacking at runtime. +class BytesWrapperType final { + public: + static constexpr TypeKind kKind = TypeKind::kBytesWrapper; + static constexpr absl::string_view kName = "google.protobuf.BytesValue"; + + BytesWrapperType() = default; + BytesWrapperType(const BytesWrapperType&) = default; + BytesWrapperType(BytesWrapperType&&) = default; + BytesWrapperType& operator=(const BytesWrapperType&) = default; + BytesWrapperType& operator=(BytesWrapperType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(BytesWrapperType, BytesWrapperType) { + return true; +} + +inline constexpr bool operator!=(BytesWrapperType lhs, BytesWrapperType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, BytesWrapperType) { + // BytesWrapperType is really a singleton and all instances are equal. Nothing + // to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, + const BytesWrapperType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_WRAPPER_TYPE_H_ diff --git a/common/types/bytes_wrapper_type_test.cc b/common/types/bytes_wrapper_type_test.cc new file mode 100644 index 000000000..eb14a16ad --- /dev/null +++ b/common/types/bytes_wrapper_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(BytesWrapperType, Kind) { + EXPECT_EQ(BytesWrapperType().kind(), BytesWrapperType::kKind); + EXPECT_EQ(Type(BytesWrapperType()).kind(), BytesWrapperType::kKind); +} + +TEST(BytesWrapperType, Name) { + EXPECT_EQ(BytesWrapperType().name(), BytesWrapperType::kName); + EXPECT_EQ(Type(BytesWrapperType()).name(), BytesWrapperType::kName); +} + +TEST(BytesWrapperType, DebugString) { + { + std::ostringstream out; + out << BytesWrapperType(); + EXPECT_EQ(out.str(), BytesWrapperType::kName); + } + { + std::ostringstream out; + out << Type(BytesWrapperType()); + EXPECT_EQ(out.str(), BytesWrapperType::kName); + } +} + +TEST(BytesWrapperType, Hash) { + EXPECT_EQ(absl::HashOf(BytesWrapperType()), absl::HashOf(BytesWrapperType())); +} + +TEST(BytesWrapperType, Equal) { + EXPECT_EQ(BytesWrapperType(), BytesWrapperType()); + EXPECT_EQ(Type(BytesWrapperType()), BytesWrapperType()); + EXPECT_EQ(BytesWrapperType(), Type(BytesWrapperType())); + EXPECT_EQ(Type(BytesWrapperType()), Type(BytesWrapperType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/double_type.h b/common/types/double_type.h new file mode 100644 index 000000000..73f904938 --- /dev/null +++ b/common/types/double_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `BoolType` represents the primitive `double` type. +class DoubleType final { + public: + static constexpr TypeKind kKind = TypeKind::kDouble; + static constexpr absl::string_view kName = "double"; + + DoubleType() = default; + DoubleType(const DoubleType&) = default; + DoubleType(DoubleType&&) = default; + DoubleType& operator=(const DoubleType&) = default; + DoubleType& operator=(DoubleType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(DoubleType, DoubleType) { return true; } + +inline constexpr bool operator!=(DoubleType lhs, DoubleType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, DoubleType) { + // DoubleType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const DoubleType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_TYPE_H_ diff --git a/common/types/double_type_test.cc b/common/types/double_type_test.cc new file mode 100644 index 000000000..9e708141e --- /dev/null +++ b/common/types/double_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(DoubleType, Kind) { + EXPECT_EQ(DoubleType().kind(), DoubleType::kKind); + EXPECT_EQ(Type(DoubleType()).kind(), DoubleType::kKind); +} + +TEST(DoubleType, Name) { + EXPECT_EQ(DoubleType().name(), DoubleType::kName); + EXPECT_EQ(Type(DoubleType()).name(), DoubleType::kName); +} + +TEST(DoubleType, DebugString) { + { + std::ostringstream out; + out << DoubleType(); + EXPECT_EQ(out.str(), DoubleType::kName); + } + { + std::ostringstream out; + out << Type(DoubleType()); + EXPECT_EQ(out.str(), DoubleType::kName); + } +} + +TEST(DoubleType, Hash) { + EXPECT_EQ(absl::HashOf(DoubleType()), absl::HashOf(DoubleType())); +} + +TEST(DoubleType, Equal) { + EXPECT_EQ(DoubleType(), DoubleType()); + EXPECT_EQ(Type(DoubleType()), DoubleType()); + EXPECT_EQ(DoubleType(), Type(DoubleType())); + EXPECT_EQ(Type(DoubleType()), Type(DoubleType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/double_wrapper_type.h b/common/types/double_wrapper_type.h new file mode 100644 index 000000000..fabaf322e --- /dev/null +++ b/common/types/double_wrapper_type.h @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_WRAPPER_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_WRAPPER_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `DoubleWrapperType` is a special type which has no direct value +// representation. It is used to represent `google.protobuf.DoubleValue`, which +// never exists at runtime as a value. Its primary usage is for type checking +// and unpacking at runtime. +class DoubleWrapperType final { + public: + static constexpr TypeKind kKind = TypeKind::kDoubleWrapper; + static constexpr absl::string_view kName = "google.protobuf.DoubleValue"; + + DoubleWrapperType() = default; + DoubleWrapperType(const DoubleWrapperType&) = default; + DoubleWrapperType(DoubleWrapperType&&) = default; + DoubleWrapperType& operator=(const DoubleWrapperType&) = default; + DoubleWrapperType& operator=(DoubleWrapperType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(DoubleWrapperType, DoubleWrapperType) { + return true; +} + +inline constexpr bool operator!=(DoubleWrapperType lhs, DoubleWrapperType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, DoubleWrapperType) { + // DoubleWrapperType is really a singleton and all instances are equal. + // Nothing to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, + const DoubleWrapperType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_WRAPPER_TYPE_H_ diff --git a/common/types/double_wrapper_type_test.cc b/common/types/double_wrapper_type_test.cc new file mode 100644 index 000000000..9b9a53b53 --- /dev/null +++ b/common/types/double_wrapper_type_test.cc @@ -0,0 +1,60 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(DoubleWrapperType, Kind) { + EXPECT_EQ(DoubleWrapperType().kind(), DoubleWrapperType::kKind); + EXPECT_EQ(Type(DoubleWrapperType()).kind(), DoubleWrapperType::kKind); +} + +TEST(DoubleWrapperType, Name) { + EXPECT_EQ(DoubleWrapperType().name(), DoubleWrapperType::kName); + EXPECT_EQ(Type(DoubleWrapperType()).name(), DoubleWrapperType::kName); +} + +TEST(DoubleWrapperType, DebugString) { + { + std::ostringstream out; + out << DoubleWrapperType(); + EXPECT_EQ(out.str(), DoubleWrapperType::kName); + } + { + std::ostringstream out; + out << Type(DoubleWrapperType()); + EXPECT_EQ(out.str(), DoubleWrapperType::kName); + } +} + +TEST(DoubleWrapperType, Hash) { + EXPECT_EQ(absl::HashOf(DoubleWrapperType()), + absl::HashOf(DoubleWrapperType())); +} + +TEST(DoubleWrapperType, Equal) { + EXPECT_EQ(DoubleWrapperType(), DoubleWrapperType()); + EXPECT_EQ(Type(DoubleWrapperType()), DoubleWrapperType()); + EXPECT_EQ(DoubleWrapperType(), Type(DoubleWrapperType())); + EXPECT_EQ(Type(DoubleWrapperType()), Type(DoubleWrapperType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/duration_type.h b/common/types/duration_type.h new file mode 100644 index 000000000..8d98137bf --- /dev/null +++ b/common/types/duration_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DURATION_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DURATION_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `DurationType` represents the primitive `duration` type. +class DurationType final { + public: + static constexpr TypeKind kKind = TypeKind::kDuration; + static constexpr absl::string_view kName = "google.protobuf.Duration"; + + DurationType() = default; + DurationType(const DurationType&) = default; + DurationType(DurationType&&) = default; + DurationType& operator=(const DurationType&) = default; + DurationType& operator=(DurationType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(DurationType, DurationType) { return true; } + +inline constexpr bool operator!=(DurationType lhs, DurationType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, DurationType) { + // DurationType is really a singleton and all instances are equal. + // Nothing to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const DurationType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DURATION_TYPE_H_ diff --git a/common/types/duration_type_test.cc b/common/types/duration_type_test.cc new file mode 100644 index 000000000..1a3b77d96 --- /dev/null +++ b/common/types/duration_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(DurationType, Kind) { + EXPECT_EQ(DurationType().kind(), DurationType::kKind); + EXPECT_EQ(Type(DurationType()).kind(), DurationType::kKind); +} + +TEST(DurationType, Name) { + EXPECT_EQ(DurationType().name(), DurationType::kName); + EXPECT_EQ(Type(DurationType()).name(), DurationType::kName); +} + +TEST(DurationType, DebugString) { + { + std::ostringstream out; + out << DurationType(); + EXPECT_EQ(out.str(), DurationType::kName); + } + { + std::ostringstream out; + out << Type(DurationType()); + EXPECT_EQ(out.str(), DurationType::kName); + } +} + +TEST(DurationType, Hash) { + EXPECT_EQ(absl::HashOf(DurationType()), absl::HashOf(DurationType())); +} + +TEST(DurationType, Equal) { + EXPECT_EQ(DurationType(), DurationType()); + EXPECT_EQ(Type(DurationType()), DurationType()); + EXPECT_EQ(DurationType(), Type(DurationType())); + EXPECT_EQ(Type(DurationType()), Type(DurationType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/dyn_type.h b/common/types/dyn_type.h new file mode 100644 index 000000000..68545a22d --- /dev/null +++ b/common/types/dyn_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DYN_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DYN_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `DynType` is a special type which represents any type and has no direct value +// representation. +class DynType final { + public: + static constexpr TypeKind kKind = TypeKind::kDyn; + static constexpr absl::string_view kName = "dyn"; + + DynType() = default; + DynType(const DynType&) = default; + DynType(DynType&&) = default; + DynType& operator=(const DynType&) = default; + DynType& operator=(DynType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(DynType, DynType) { return true; } + +inline constexpr bool operator!=(DynType lhs, DynType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, DynType) { + // DynType is really a singleton and all instances are equal. Nothing to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const DynType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DYN_TYPE_H_ diff --git a/common/types/dyn_type_test.cc b/common/types/dyn_type_test.cc new file mode 100644 index 000000000..acebead1c --- /dev/null +++ b/common/types/dyn_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(DynType, Kind) { + EXPECT_EQ(DynType().kind(), DynType::kKind); + EXPECT_EQ(Type(DynType()).kind(), DynType::kKind); +} + +TEST(DynType, Name) { + EXPECT_EQ(DynType().name(), DynType::kName); + EXPECT_EQ(Type(DynType()).name(), DynType::kName); +} + +TEST(DynType, DebugString) { + { + std::ostringstream out; + out << DynType(); + EXPECT_EQ(out.str(), DynType::kName); + } + { + std::ostringstream out; + out << Type(DynType()); + EXPECT_EQ(out.str(), DynType::kName); + } +} + +TEST(DynType, Hash) { + EXPECT_EQ(absl::HashOf(DynType()), absl::HashOf(DynType())); +} + +TEST(DynType, Equal) { + EXPECT_EQ(DynType(), DynType()); + EXPECT_EQ(Type(DynType()), DynType()); + EXPECT_EQ(DynType(), Type(DynType())); + EXPECT_EQ(Type(DynType()), Type(DynType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/enum_type.cc b/common/types/enum_type.cc new file mode 100644 index 000000000..2e358b53c --- /dev/null +++ b/common/types/enum_type.cc @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +using google::protobuf::EnumDescriptor; + +bool IsWellKnownEnumType(const EnumDescriptor* absl_nonnull descriptor) { + return descriptor->full_name() == "google.protobuf.NullValue"; +} + +std::string EnumType::DebugString() const { + if (ABSL_PREDICT_TRUE(static_cast(*this))) { + static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4, + "sizeof(void*) is neither 8 nor 4"); + return absl::StrCat(name(), "@0x", + absl::Hex(descriptor_, sizeof(descriptor_) == 8 + ? absl::PadSpec::kZeroPad16 + : absl::PadSpec::kZeroPad8)); + } + return std::string(); +} + +} // namespace cel diff --git a/common/types/enum_type.h b/common/types/enum_type.h new file mode 100644 index 000000000..60db1231d --- /dev/null +++ b/common/types/enum_type.h @@ -0,0 +1,128 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/type_kind.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class Type; +class TypeParameters; + +bool IsWellKnownEnumType(const google::protobuf::EnumDescriptor* absl_nonnull descriptor); + +class EnumType final { + public: + using element_type = const google::protobuf::EnumDescriptor; + + static constexpr TypeKind kKind = TypeKind::kEnum; + + // Constructs `EnumType` from a pointer to `google::protobuf::EnumDescriptor`. The + // `google::protobuf::EnumDescriptor` must not be one of the well known enum types we + // treat specially, if it is behavior is undefined. If you are unsure, you + // should use `Type::Enum`. + explicit EnumType(const google::protobuf::EnumDescriptor* absl_nullable descriptor) + : descriptor_(descriptor) { + ABSL_DCHECK(descriptor == nullptr || !IsWellKnownEnumType(descriptor)) + << descriptor->full_name(); + } + + EnumType() = default; + EnumType(const EnumType&) = default; + EnumType(EnumType&&) = default; + EnumType& operator=(const EnumType&) = default; + EnumType& operator=(EnumType&&) = default; + + static TypeKind kind() { return kKind; } + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return (*this)->full_name(); + } + + std::string DebugString() const; + + static TypeParameters GetParameters(); + + const google::protobuf::EnumDescriptor& operator*() const { + ABSL_DCHECK(*this); + return *descriptor_; + } + + const google::protobuf::EnumDescriptor* absl_nonnull operator->() const { + ABSL_DCHECK(*this); + return descriptor_; + } + + explicit operator bool() const { return descriptor_ != nullptr; } + + private: + friend struct std::pointer_traits; + + const google::protobuf::EnumDescriptor* absl_nullable descriptor_ = nullptr; +}; + +inline bool operator==(EnumType lhs, EnumType rhs) { + return static_cast(lhs) == static_cast(rhs) && + (!static_cast(lhs) || lhs.name() == rhs.name()); +} + +inline bool operator!=(EnumType lhs, EnumType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, EnumType enum_type) { + return H::combine(std::move(state), static_cast(enum_type) + ? enum_type.name() + : absl::string_view()); +} + +inline std::ostream& operator<<(std::ostream& out, EnumType type) { + return out << type.DebugString(); +} + +} // namespace cel + +namespace std { + +template <> +struct pointer_traits { + using pointer = cel::EnumType; + using element_type = typename cel::EnumType::element_type; + using difference_type = ptrdiff_t; + + static element_type* to_address(const pointer& p) noexcept { + return p.descriptor_; + } +}; + +} // namespace std + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ diff --git a/common/types/enum_type_test.cc b/common/types/enum_type_test.cc new file mode 100644 index 000000000..907740738 --- /dev/null +++ b/common/types/enum_type_test.cc @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/protobuf/descriptor.pb.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::NotNull; +using ::testing::StartsWith; + +TEST(EnumType, Kind) { EXPECT_EQ(EnumType::kind(), TypeKind::kEnum); } + +TEST(EnumType, Default) { + EnumType type; + EXPECT_FALSE(type); + EXPECT_THAT(type.DebugString(), Eq("")); + EXPECT_EQ(type, EnumType()); +} + +TEST(EnumType, Descriptor) { + google::protobuf::DescriptorPool pool; + { + google::protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.set_syntax("proto3"); + file_desc_proto.set_package("test"); + file_desc_proto.set_name("test/enum.proto"); + auto* enum_desc = file_desc_proto.add_enum_type(); + enum_desc->set_name("Enum"); + auto* enum_value_desc = enum_desc->add_value(); + enum_value_desc->set_number(0); + enum_value_desc->set_name("VALUE"); + ASSERT_THAT(pool.BuildFile(file_desc_proto), NotNull()); + } + const google::protobuf::EnumDescriptor* desc = pool.FindEnumTypeByName("test.Enum"); + ASSERT_THAT(desc, NotNull()); + EnumType type(desc); + EXPECT_TRUE(type); + EXPECT_THAT(type.name(), Eq("test.Enum")); + EXPECT_THAT(type.DebugString(), StartsWith("test.Enum@0x")); + EXPECT_THAT(type.GetParameters(), IsEmpty()); + EXPECT_NE(type, EnumType()); + EXPECT_NE(EnumType(), type); + EXPECT_EQ(cel::to_address(type), desc); +} + +} // namespace +} // namespace cel diff --git a/common/types/error_type.h b/common/types/error_type.h new file mode 100644 index 000000000..fdbf5fb36 --- /dev/null +++ b/common/types/error_type.h @@ -0,0 +1,75 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_ERROR_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_ERROR_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `ErrorType` is a special type which represents an error during type checking +// or an error value at runtime. See +// https://github.com/google/cel-spec/blob/master/doc/langdef.md#runtime-errors. +class ErrorType final { + public: + static constexpr TypeKind kKind = TypeKind::kError; + static constexpr absl::string_view kName = "*error*"; + + ErrorType() = default; + ErrorType(const ErrorType&) = default; + ErrorType(ErrorType&&) = default; + ErrorType& operator=(const ErrorType&) = default; + ErrorType& operator=(ErrorType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(ErrorType, ErrorType) { return true; } + +inline constexpr bool operator!=(ErrorType lhs, ErrorType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, ErrorType) { + // ErrorType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const ErrorType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_ERROR_TYPE_H_ diff --git a/common/types/error_type_test.cc b/common/types/error_type_test.cc new file mode 100644 index 000000000..f48c2966b --- /dev/null +++ b/common/types/error_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(ErrorType, Kind) { + EXPECT_EQ(ErrorType().kind(), ErrorType::kKind); + EXPECT_EQ(Type(ErrorType()).kind(), ErrorType::kKind); +} + +TEST(ErrorType, Name) { + EXPECT_EQ(ErrorType().name(), ErrorType::kName); + EXPECT_EQ(Type(ErrorType()).name(), ErrorType::kName); +} + +TEST(ErrorType, DebugString) { + { + std::ostringstream out; + out << ErrorType(); + EXPECT_EQ(out.str(), ErrorType::kName); + } + { + std::ostringstream out; + out << Type(ErrorType()); + EXPECT_EQ(out.str(), ErrorType::kName); + } +} + +TEST(ErrorType, Hash) { + EXPECT_EQ(absl::HashOf(ErrorType()), absl::HashOf(ErrorType())); +} + +TEST(ErrorType, Equal) { + EXPECT_EQ(ErrorType(), ErrorType()); + EXPECT_EQ(Type(ErrorType()), ErrorType()); + EXPECT_EQ(ErrorType(), Type(ErrorType())); + EXPECT_EQ(Type(ErrorType()), Type(ErrorType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/function_type.cc b/common/types/function_type.cc new file mode 100644 index 000000000..2e632b9cb --- /dev/null +++ b/common/types/function_type.cc @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "google/protobuf/arena.h" + +namespace cel { + +namespace { + +struct TypeFormatter { + void operator()(std::string* out, const Type& type) const { + out->append(type.DebugString()); + } +}; + +std::string FunctionDebugString(const Type& result, + absl::Span args) { + return absl::StrCat("(", absl::StrJoin(args, ", ", TypeFormatter{}), ") -> ", + result.DebugString()); +} + +} // namespace + +namespace common_internal { + +FunctionTypeData* absl_nonnull FunctionTypeData::Create( + google::protobuf::Arena* absl_nonnull arena, const Type& result, + absl::Span args) { + return ::new (arena->AllocateAligned( + offsetof(FunctionTypeData, args) + ((1 + args.size()) * sizeof(Type)), + alignof(FunctionTypeData))) FunctionTypeData(result, args); +} + +FunctionTypeData::FunctionTypeData(const Type& result, + absl::Span args) + : args_size(1 + args.size()) { + this->args[0] = result; + std::memcpy(this->args + 1, args.data(), args.size() * sizeof(Type)); +} + +} // namespace common_internal + +FunctionType::FunctionType(google::protobuf::Arena* absl_nonnull arena, + const Type& result, absl::Span args) + : FunctionType( + common_internal::FunctionTypeData::Create(arena, result, args)) {} + +std::string FunctionType::DebugString() const { + return FunctionDebugString(result(), args()); +} + +TypeParameters FunctionType::GetParameters() const { + ABSL_DCHECK(*this); + return TypeParameters(absl::MakeConstSpan(data_->args, data_->args_size)); +} + +const Type& FunctionType::result() const { + ABSL_DCHECK(*this); + return data_->args[0]; +} + +absl::Span FunctionType::args() const { + ABSL_DCHECK(*this); + return absl::MakeConstSpan(data_->args + 1, data_->args_size - 1); +} + +} // namespace cel diff --git a/common/types/function_type.h b/common/types/function_type.h new file mode 100644 index 000000000..a71c412aa --- /dev/null +++ b/common/types/function_type.h @@ -0,0 +1,91 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class Type; +class TypeParameters; + +namespace common_internal { +struct FunctionTypeData; +} // namespace common_internal + +class FunctionType final { + public: + static constexpr TypeKind kKind = TypeKind::kFunction; + static constexpr absl::string_view kName = "function"; + + FunctionType(google::protobuf::Arena* absl_nonnull arena, const Type& result, + absl::Span args); + + FunctionType() = default; + FunctionType(const FunctionType&) = default; + FunctionType(FunctionType&&) = default; + FunctionType& operator=(const FunctionType&) = default; + FunctionType& operator=(FunctionType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + std::string DebugString() const; + + const Type& result() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::Span args() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + explicit operator bool() const { return data_ != nullptr; } + + private: + explicit FunctionType( + const common_internal::FunctionTypeData* absl_nullable data) + : data_(data) {} + + const common_internal::FunctionTypeData* absl_nullable data_ = nullptr; +}; + +bool operator==(const FunctionType& lhs, const FunctionType& rhs); + +inline bool operator!=(const FunctionType& lhs, const FunctionType& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const FunctionType& type); + +inline std::ostream& operator<<(std::ostream& out, const FunctionType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_H_ diff --git a/common/types/function_type_pool.cc b/common/types/function_type_pool.cc new file mode 100644 index 000000000..451fa0647 --- /dev/null +++ b/common/types/function_type_pool.cc @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/types/function_type_pool.h" + +#include "absl/types/span.h" +#include "common/type.h" + +namespace cel::common_internal { + +FunctionType FunctionTypePool::InternFunctionType(const Type& result, + absl::Span args) { + return *function_types_.lazy_emplace( + AsTuple(result, args), + [&](const auto& ctor) { ctor(FunctionType(arena_, result, args)); }); +} + +} // namespace cel::common_internal diff --git a/common/types/function_type_pool.h b/common/types/function_type_pool.h new file mode 100644 index 000000000..8cac333da --- /dev/null +++ b/common/types/function_type_pool.h @@ -0,0 +1,102 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_POOL_H_ + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/log/die_if_null.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +// `FunctionTypePool` is a thread unsafe interning factory for `FunctionType`. +class FunctionTypePool final { + public: + explicit FunctionTypePool(google::protobuf::Arena* absl_nonnull arena) + : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK + + // Returns a `FunctionType` which has the provided parameters, interning as + // necessary. + FunctionType InternFunctionType(const Type& result, + absl::Span args); + + private: + using FunctionTypeTuple = + std::tuple, absl::Span>; + + static FunctionTypeTuple AsTuple(const FunctionType& function_type) { + return AsTuple(function_type.result(), function_type.args()); + } + + static FunctionTypeTuple AsTuple(const Type& result, + absl::Span args) { + return FunctionTypeTuple{std::cref(result), args}; + } + + struct Hasher { + using is_transparent = void; + + size_t operator()(const FunctionType& data) const { + return (*this)(AsTuple(data)); + } + + size_t operator()(const FunctionTypeTuple& tuple) const { + return absl::Hash{}(tuple); + } + }; + + struct Equaler { + using is_transparent = void; + + bool operator()(const FunctionType& lhs, const FunctionType& rhs) const { + return (*this)(AsTuple(lhs), AsTuple(rhs)); + } + + bool operator()(const FunctionType& lhs, + const FunctionTypeTuple& rhs) const { + return (*this)(AsTuple(lhs), rhs); + } + + bool operator()(const FunctionTypeTuple& lhs, + const FunctionType& rhs) const { + return (*this)(lhs, AsTuple(rhs)); + } + + bool operator()(const FunctionTypeTuple& lhs, + const FunctionTypeTuple& rhs) const { + return std::get<0>(lhs) == std::get<0>(rhs) && + absl::c_equal(std::get<1>(lhs), std::get<1>(rhs)); + } + }; + + google::protobuf::Arena* absl_nonnull const arena_; + absl::flat_hash_set function_types_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_POOL_H_ diff --git a/common/types/function_type_test.cc b/common/types/function_type_test.cc new file mode 100644 index 000000000..57aee1785 --- /dev/null +++ b/common/types/function_type_test.cc @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +TEST(FunctionType, Kind) { + google::protobuf::Arena arena; + EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}).kind(), + FunctionType::kKind); + EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})).kind(), + FunctionType::kKind); +} + +TEST(FunctionType, Name) { + google::protobuf::Arena arena; + EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}).name(), "function"); + EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})).name(), + "function"); +} + +TEST(FunctionType, DebugString) { + google::protobuf::Arena arena; + { + std::ostringstream out; + out << FunctionType(&arena, DynType{}, {BytesType()}); + EXPECT_EQ(out.str(), "(bytes) -> dyn"); + } + { + std::ostringstream out; + out << Type(FunctionType(&arena, DynType{}, {BytesType()})); + EXPECT_EQ(out.str(), "(bytes) -> dyn"); + } +} + +TEST(FunctionType, Hash) { + google::protobuf::Arena arena; + EXPECT_EQ(absl::HashOf(FunctionType(&arena, DynType{}, {BytesType()})), + absl::HashOf(FunctionType(&arena, DynType{}, {BytesType()}))); +} + +TEST(FunctionType, Equal) { + google::protobuf::Arena arena; + EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}), + FunctionType(&arena, DynType{}, {BytesType()})); + EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})), + FunctionType(&arena, DynType{}, {BytesType()})); + EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}), + Type(FunctionType(&arena, DynType{}, {BytesType()}))); + EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})), + Type(FunctionType(&arena, DynType{}, {BytesType()}))); +} + +} // namespace +} // namespace cel diff --git a/common/types/int_type.h b/common/types/int_type.h new file mode 100644 index 000000000..dfa4491c4 --- /dev/null +++ b/common/types/int_type.h @@ -0,0 +1,72 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `IntType` represents the primitive `int` type. +class IntType final { + public: + static constexpr TypeKind kKind = TypeKind::kInt; + static constexpr absl::string_view kName = "int"; + + IntType() = default; + IntType(const IntType&) = default; + IntType(IntType&&) = default; + IntType& operator=(const IntType&) = default; + IntType& operator=(IntType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(IntType, IntType) { return true; } + +inline constexpr bool operator!=(IntType lhs, IntType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, IntType) { + // IntType is really a singleton and all instances are equal. Nothing to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const IntType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_TYPE_H_ diff --git a/common/types/int_type_test.cc b/common/types/int_type_test.cc new file mode 100644 index 000000000..98e019491 --- /dev/null +++ b/common/types/int_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(IntType, Kind) { + EXPECT_EQ(IntType().kind(), IntType::kKind); + EXPECT_EQ(Type(IntType()).kind(), IntType::kKind); +} + +TEST(IntType, Name) { + EXPECT_EQ(IntType().name(), IntType::kName); + EXPECT_EQ(Type(IntType()).name(), IntType::kName); +} + +TEST(IntType, DebugString) { + { + std::ostringstream out; + out << IntType(); + EXPECT_EQ(out.str(), IntType::kName); + } + { + std::ostringstream out; + out << Type(IntType()); + EXPECT_EQ(out.str(), IntType::kName); + } +} + +TEST(IntType, Hash) { + EXPECT_EQ(absl::HashOf(IntType()), absl::HashOf(IntType())); +} + +TEST(IntType, Equal) { + EXPECT_EQ(IntType(), IntType()); + EXPECT_EQ(Type(IntType()), IntType()); + EXPECT_EQ(IntType(), Type(IntType())); + EXPECT_EQ(Type(IntType()), Type(IntType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/int_wrapper_type.h b/common/types/int_wrapper_type.h new file mode 100644 index 000000000..6e954b902 --- /dev/null +++ b/common/types/int_wrapper_type.h @@ -0,0 +1,78 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_WRAPPER_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_WRAPPER_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `IntWrapperType` is a special type which has no direct value +// representation. It is used to represent `google.protobuf.Int64Value`, which +// never exists at runtime as a value. Its primary usage is for type checking +// and unpacking at runtime. +class IntWrapperType final { + public: + static constexpr TypeKind kKind = TypeKind::kIntWrapper; + static constexpr absl::string_view kName = "google.protobuf.Int64Value"; + + IntWrapperType() = default; + IntWrapperType(const IntWrapperType&) = default; + IntWrapperType(IntWrapperType&&) = default; + IntWrapperType& operator=(const IntWrapperType&) = default; + IntWrapperType& operator=(IntWrapperType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(IntWrapperType, IntWrapperType) { + return true; +} + +inline constexpr bool operator!=(IntWrapperType lhs, IntWrapperType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, IntWrapperType) { + // IntWrapperType is really a singleton and all instances are equal. Nothing + // to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const IntWrapperType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_WRAPPER_TYPE_H_ diff --git a/common/types/int_wrapper_type_test.cc b/common/types/int_wrapper_type_test.cc new file mode 100644 index 000000000..d95715405 --- /dev/null +++ b/common/types/int_wrapper_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(IntWrapperType, Kind) { + EXPECT_EQ(IntWrapperType().kind(), IntWrapperType::kKind); + EXPECT_EQ(Type(IntWrapperType()).kind(), IntWrapperType::kKind); +} + +TEST(IntWrapperType, Name) { + EXPECT_EQ(IntWrapperType().name(), IntWrapperType::kName); + EXPECT_EQ(Type(IntWrapperType()).name(), IntWrapperType::kName); +} + +TEST(IntWrapperType, DebugString) { + { + std::ostringstream out; + out << IntWrapperType(); + EXPECT_EQ(out.str(), IntWrapperType::kName); + } + { + std::ostringstream out; + out << Type(IntWrapperType()); + EXPECT_EQ(out.str(), IntWrapperType::kName); + } +} + +TEST(IntWrapperType, Hash) { + EXPECT_EQ(absl::HashOf(IntWrapperType()), absl::HashOf(IntWrapperType())); +} + +TEST(IntWrapperType, Equal) { + EXPECT_EQ(IntWrapperType(), IntWrapperType()); + EXPECT_EQ(Type(IntWrapperType()), IntWrapperType()); + EXPECT_EQ(IntWrapperType(), Type(IntWrapperType())); + EXPECT_EQ(Type(IntWrapperType()), Type(IntWrapperType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/legacy_type_introspector.h b/common/types/legacy_type_introspector.h new file mode 100644 index 000000000..37118b685 --- /dev/null +++ b/common/types/legacy_type_introspector.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_INTROSPECTOR_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_INTROSPECTOR_H_ + +#include "common/type_introspector.h" + +namespace cel::common_internal { + +// `LegacyTypeIntrospector` is an implementation which should be used when +// converting between `cel::Value` and `google::api::expr::runtime::CelValue` +// and only then. +class LegacyTypeIntrospector : public virtual TypeIntrospector { + public: + LegacyTypeIntrospector() = default; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_INTROSPECTOR_H_ diff --git a/common/types/list_type.cc b/common/types/list_type.cc new file mode 100644 index 000000000..118ea15b0 --- /dev/null +++ b/common/types/list_type.cc @@ -0,0 +1,77 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +namespace common_internal { + +namespace { + +ABSL_CONST_INIT const ListTypeData kDynListTypeData; + +} // namespace + +ListTypeData* absl_nonnull ListTypeData::Create( + google::protobuf::Arena* absl_nonnull arena, const Type& element) { + return ::new (arena->AllocateAligned( + sizeof(ListTypeData), alignof(ListTypeData))) ListTypeData(element); +} + +ListTypeData::ListTypeData(const Type& element) : element(element) {} + +} // namespace common_internal + +ListType::ListType() : ListType(&common_internal::kDynListTypeData) {} + +ListType::ListType(google::protobuf::Arena* absl_nonnull arena, const Type& element) + : ListType(element.IsDyn() + ? &common_internal::kDynListTypeData + : common_internal::ListTypeData::Create(arena, element)) {} + +std::string ListType::DebugString() const { + return absl::StrCat("list<", TypeKindToString(GetElement().kind()), ">"); +} + +TypeParameters ListType::GetParameters() const { + return TypeParameters(GetElement()); +} + +Type ListType::GetElement() const { + ABSL_DCHECK_NE(data_, 0); + if ((data_ & kBasicBit) == kBasicBit) { + return reinterpret_cast(data_ & + kPointerMask) + ->element; + } + if ((data_ & kProtoBit) == kProtoBit) { + return common_internal::SingularMessageFieldType( + reinterpret_cast(data_ & kPointerMask)); + } + return Type(); +} + +Type ListType::element() const { return GetElement(); } + +} // namespace cel diff --git a/common/types/list_type.h b/common/types/list_type.h new file mode 100644 index 000000000..b42994d91 --- /dev/null +++ b/common/types/list_type.h @@ -0,0 +1,115 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/strings/string_view.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class Type; +class TypeParameters; + +namespace common_internal { +struct ListTypeData; +} // namespace common_internal + +class ListType final { + private: + static constexpr uintptr_t kBasicBit = 1; + static constexpr uintptr_t kProtoBit = 2; + static constexpr uintptr_t kBits = kBasicBit | kProtoBit; + static constexpr uintptr_t kPointerMask = ~kBits; + + public: + static constexpr TypeKind kKind = TypeKind::kList; + static constexpr absl::string_view kName = "list"; + + ListType(google::protobuf::Arena* absl_nonnull arena, const Type& element); + + // By default, this type is `list(dyn)`. Unless you can help it, you should + // use a more specific list type. + ListType(); + ListType(const ListType&) = default; + ListType(ListType&&) = default; + ListType& operator=(const ListType&) = default; + ListType& operator=(ListType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + std::string DebugString() const; + + TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_DEPRECATED("Use GetElement") + Type element() const; + + Type GetElement() const; + + private: + friend class Type; + + explicit ListType(const common_internal::ListTypeData* absl_nonnull data) + : data_(reinterpret_cast(data) | kBasicBit) { + ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(data)), 2) + << "alignment must be greater than 2"; + } + + explicit ListType(const google::protobuf::FieldDescriptor* absl_nonnull descriptor) + : data_(reinterpret_cast(descriptor) | kProtoBit) { + ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(descriptor)), + 2) + << "alignment must be greater than 2"; + ABSL_DCHECK(descriptor->is_repeated()); + ABSL_DCHECK(!descriptor->is_map()); + } + + uintptr_t data_; +}; + +bool operator==(const ListType& lhs, const ListType& rhs); + +inline bool operator!=(const ListType& lhs, const ListType& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const ListType& type); + +inline std::ostream& operator<<(std::ostream& out, const ListType& type) { + return out << type.DebugString(); +} + +inline ListType JsonListType() { return ListType(); } + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_H_ diff --git a/common/types/list_type_pool.cc b/common/types/list_type_pool.cc new file mode 100644 index 000000000..c76998ee5 --- /dev/null +++ b/common/types/list_type_pool.cc @@ -0,0 +1,29 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/types/list_type_pool.h" + +#include "common/type.h" + +namespace cel::common_internal { + +ListType ListTypePool::InternListType(const Type& element) { + if (element.IsDyn()) { + return ListType(); + } + return *list_types_.lazy_emplace( + element, [&](const auto& ctor) { ctor(ListType(arena_, element)); }); +} + +} // namespace cel::common_internal diff --git a/common/types/list_type_pool.h b/common/types/list_type_pool.h new file mode 100644 index 000000000..120627424 --- /dev/null +++ b/common/types/list_type_pool.h @@ -0,0 +1,80 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_POOL_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/log/die_if_null.h" +#include "common/type.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +// `ListTypePool` is a thread unsafe interning factory for `ListType`. +class ListTypePool final { + public: + explicit ListTypePool(google::protobuf::Arena* absl_nonnull arena) + : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK + + // Returns a `ListType` which has the provided parameters, interning as + // necessary. + ListType InternListType(const Type& element); + + private: + struct Hasher { + using is_transparent = void; + + size_t operator()(const ListType& list_type) const { + return (*this)(list_type.element()); + } + + size_t operator()(const Type& type) const { + return absl::Hash{}(type); + } + }; + + struct Equaler { + using is_transparent = void; + + bool operator()(const ListType& lhs, const ListType& rhs) const { + return (*this)(lhs.element(), rhs.element()); + } + + bool operator()(const ListType& lhs, const Type& rhs) const { + return (*this)(lhs.element(), rhs); + } + + bool operator()(const Type& lhs, const ListType& rhs) const { + return (*this)(lhs, rhs.element()); + } + + bool operator()(const Type& lhs, const Type& rhs) const { + return lhs == rhs; + } + }; + + google::protobuf::Arena* absl_nonnull const arena_; + absl::flat_hash_set list_types_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_POOL_H_ diff --git a/common/types/list_type_test.cc b/common/types/list_type_test.cc new file mode 100644 index 000000000..db40b1ff2 --- /dev/null +++ b/common/types/list_type_test.cc @@ -0,0 +1,72 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +TEST(ListType, Default) { + ListType list_type; + EXPECT_EQ(list_type.element(), DynType()); +} + +TEST(ListType, Kind) { + google::protobuf::Arena arena; + EXPECT_EQ(ListType(&arena, BoolType()).kind(), ListType::kKind); + EXPECT_EQ(Type(ListType(&arena, BoolType())).kind(), ListType::kKind); +} + +TEST(ListType, Name) { + google::protobuf::Arena arena; + EXPECT_EQ(ListType(&arena, BoolType()).name(), ListType::kName); + EXPECT_EQ(Type(ListType(&arena, BoolType())).name(), ListType::kName); +} + +TEST(ListType, DebugString) { + google::protobuf::Arena arena; + { + std::ostringstream out; + out << ListType(&arena, BoolType()); + EXPECT_EQ(out.str(), "list"); + } + { + std::ostringstream out; + out << Type(ListType(&arena, BoolType())); + EXPECT_EQ(out.str(), "list"); + } +} + +TEST(ListType, Hash) { + google::protobuf::Arena arena; + EXPECT_EQ(absl::HashOf(ListType(&arena, BoolType())), + absl::HashOf(ListType(&arena, BoolType()))); +} + +TEST(ListType, Equal) { + google::protobuf::Arena arena; + EXPECT_EQ(ListType(&arena, BoolType()), ListType(&arena, BoolType())); + EXPECT_EQ(Type(ListType(&arena, BoolType())), ListType(&arena, BoolType())); + EXPECT_EQ(ListType(&arena, BoolType()), Type(ListType(&arena, BoolType()))); + EXPECT_EQ(Type(ListType(&arena, BoolType())), + Type(ListType(&arena, BoolType()))); +} + +} // namespace +} // namespace cel diff --git a/common/types/map_type.cc b/common/types/map_type.cc new file mode 100644 index 000000000..bd294fc26 --- /dev/null +++ b/common/types/map_type.cc @@ -0,0 +1,122 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +namespace common_internal { + +namespace { + +ABSL_CONST_INIT const MapTypeData kDynDynMapTypeData = { + .key_and_value = {DynType(), DynType()}, +}; + +ABSL_CONST_INIT const MapTypeData kStringDynMapTypeData = { + .key_and_value = {StringType(), DynType()}, +}; + +} // namespace + +MapTypeData* absl_nonnull MapTypeData::Create(google::protobuf::Arena* absl_nonnull arena, + const Type& key, + const Type& value) { + MapTypeData* data = + ::new (arena->AllocateAligned(sizeof(MapTypeData), alignof(MapTypeData))) + MapTypeData; + data->key_and_value[0] = key; + data->key_and_value[1] = value; + return data; +} + +} // namespace common_internal + +MapType::MapType() : MapType(&common_internal::kDynDynMapTypeData) {} + +MapType::MapType(google::protobuf::Arena* absl_nonnull arena, const Type& key, + const Type& value) + : MapType(key.IsDyn() && value.IsDyn() + ? &common_internal::kDynDynMapTypeData + : common_internal::MapTypeData::Create(arena, key, value)) {} + +std::string MapType::DebugString() const { + return absl::StrCat("map<", TypeKindToString(key().kind()), ", ", + TypeKindToString(value().kind()), ">"); +} + +TypeParameters MapType::GetParameters() const { + ABSL_DCHECK_NE(data_, 0); + if ((data_ & kBasicBit) == kBasicBit) { + const auto* data = reinterpret_cast( + data_ & kPointerMask); + return TypeParameters(data->key_and_value[0], data->key_and_value[1]); + } + if ((data_ & kProtoBit) == kProtoBit) { + const auto* descriptor = + reinterpret_cast(data_ & kPointerMask); + return TypeParameters(Type::Field(descriptor->map_key()), + Type::Field(descriptor->map_value())); + } + return TypeParameters(Type(), Type()); +} + +Type MapType::GetKey() const { + ABSL_DCHECK_NE(data_, 0); + if ((data_ & kBasicBit) == kBasicBit) { + return reinterpret_cast(data_ & + kPointerMask) + ->key_and_value[0]; + } + if ((data_ & kProtoBit) == kProtoBit) { + return Type::Field( + reinterpret_cast(data_ & kPointerMask) + ->map_key()); + } + return Type(); +} + +Type MapType::key() const { return GetKey(); } + +Type MapType::GetValue() const { + ABSL_DCHECK_NE(data_, 0); + if ((data_ & kBasicBit) == kBasicBit) { + return reinterpret_cast(data_ & + kPointerMask) + ->key_and_value[1]; + } + if ((data_ & kProtoBit) == kProtoBit) { + return Type::Field( + reinterpret_cast(data_ & kPointerMask) + ->map_value()); + } + return Type(); +} + +Type MapType::value() const { return GetValue(); } + +MapType JsonMapType() { + return MapType(&common_internal::kStringDynMapTypeData); +} + +} // namespace cel diff --git a/common/types/map_type.h b/common/types/map_type.h new file mode 100644 index 000000000..1c198f991 --- /dev/null +++ b/common/types/map_type.h @@ -0,0 +1,124 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/strings/string_view.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class Type; +class TypeParameters; + +namespace common_internal { +struct MapTypeData; +} // namespace common_internal + +class MapType; + +MapType JsonMapType(); + +class MapType final { + private: + static constexpr uintptr_t kBasicBit = 1; + static constexpr uintptr_t kProtoBit = 2; + static constexpr uintptr_t kBits = kBasicBit | kProtoBit; + static constexpr uintptr_t kPointerMask = ~kBits; + + public: + static constexpr TypeKind kKind = TypeKind::kMap; + static constexpr absl::string_view kName = "map"; + + MapType(google::protobuf::Arena* absl_nonnull arena, const Type& key, + const Type& value); + + // By default, this type is `map(dyn, dyn)`. Unless you can help it, you + // should use a more specific map type. + MapType(); + MapType(const MapType&) = default; + MapType(MapType&&) = default; + MapType& operator=(const MapType&) = default; + MapType& operator=(MapType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + std::string DebugString() const; + + TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + ABSL_DEPRECATED("Use GetKey") + Type key() const; + + Type GetKey() const; + + ABSL_DEPRECATED("Use GetValue") + Type value() const; + + Type GetValue() const; + + private: + friend class Type; + friend MapType JsonMapType(); + + explicit MapType(const common_internal::MapTypeData* absl_nonnull data) + : data_(reinterpret_cast(data) | kBasicBit) { + ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(data)), 2) + << "alignment must be greater than 2"; + } + + explicit MapType(const google::protobuf::Descriptor* absl_nonnull descriptor) + : data_(reinterpret_cast(descriptor) | kProtoBit) { + ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(descriptor)), + 2) + << "alignment must be greater than 2"; + ABSL_DCHECK(descriptor->map_key() != nullptr); + ABSL_DCHECK(descriptor->map_value() != nullptr); + } + + uintptr_t data_; +}; + +bool operator==(const MapType& lhs, const MapType& rhs); + +inline bool operator!=(const MapType& lhs, const MapType& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const MapType& type); + +inline std::ostream& operator<<(std::ostream& out, const MapType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_H_ diff --git a/common/types/map_type_pool.cc b/common/types/map_type_pool.cc new file mode 100644 index 000000000..cc4a5fb09 --- /dev/null +++ b/common/types/map_type_pool.cc @@ -0,0 +1,30 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/types/map_type_pool.h" + +#include "common/type.h" + +namespace cel::common_internal { + +MapType MapTypePool::InternMapType(const Type& key, const Type& value) { + if (key.IsDyn() && value.IsDyn()) { + return MapType(); + } + return *map_types_.lazy_emplace(AsTuple(key, value), [&](const auto& ctor) { + ctor(MapType(arena_, key, value)); + }); +} + +} // namespace cel::common_internal diff --git a/common/types/map_type_pool.h b/common/types/map_type_pool.h new file mode 100644 index 000000000..461e880a6 --- /dev/null +++ b/common/types/map_type_pool.h @@ -0,0 +1,93 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/log/die_if_null.h" +#include "common/type.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +// `MapTypePool` is a thread unsafe interning factory for `MapType`. +class MapTypePool final { + public: + explicit MapTypePool(google::protobuf::Arena* absl_nonnull arena) + : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK + + // Returns a `MapType` which has the provided parameters, interning as + // necessary. + MapType InternMapType(const Type& key, const Type& value); + + private: + using MapTypeTuple = std::tuple, + std::reference_wrapper>; + + static MapTypeTuple AsTuple(const MapType& map_type) { + return AsTuple(map_type.key(), map_type.value()); + } + + static MapTypeTuple AsTuple(const Type& key, const Type& value) { + return MapTypeTuple{std::cref(key), std::cref(value)}; + } + + struct Hasher { + using is_transparent = void; + + size_t operator()(const MapType& map_type) const { + return (*this)(AsTuple(map_type)); + } + + size_t operator()(const MapTypeTuple& tuple) const { + return absl::Hash{}(tuple); + } + }; + + struct Equaler { + using is_transparent = void; + + bool operator()(const MapType& lhs, const MapType& rhs) const { + return (*this)(AsTuple(lhs), AsTuple(rhs)); + } + + bool operator()(const MapType& lhs, const MapTypeTuple& rhs) const { + return (*this)(AsTuple(lhs), rhs); + } + + bool operator()(const MapTypeTuple& lhs, const MapType& rhs) const { + return (*this)(lhs, AsTuple(rhs)); + } + + bool operator()(const MapTypeTuple& lhs, const MapTypeTuple& rhs) const { + return lhs == rhs; + } + }; + + google::protobuf::Arena* absl_nonnull const arena_; + absl::flat_hash_set map_types_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ diff --git a/common/types/map_type_test.cc b/common/types/map_type_test.cc new file mode 100644 index 000000000..0489ff67e --- /dev/null +++ b/common/types/map_type_test.cc @@ -0,0 +1,78 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +TEST(MapType, Default) { + MapType map_type; + EXPECT_EQ(map_type.key(), DynType()); + EXPECT_EQ(map_type.value(), DynType()); +} + +TEST(MapType, Kind) { + google::protobuf::Arena arena; + EXPECT_EQ(MapType(&arena, StringType(), BytesType()).kind(), MapType::kKind); + EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())).kind(), + MapType::kKind); +} + +TEST(MapType, Name) { + google::protobuf::Arena arena; + EXPECT_EQ(MapType(&arena, StringType(), BytesType()).name(), MapType::kName); + EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())).name(), + MapType::kName); +} + +TEST(MapType, DebugString) { + google::protobuf::Arena arena; + { + std::ostringstream out; + out << MapType(&arena, StringType(), BytesType()); + EXPECT_EQ(out.str(), "map"); + } + { + std::ostringstream out; + out << Type(MapType(&arena, StringType(), BytesType())); + EXPECT_EQ(out.str(), "map"); + } +} + +TEST(MapType, Hash) { + google::protobuf::Arena arena; + EXPECT_EQ(absl::HashOf(MapType(&arena, StringType(), BytesType())), + absl::HashOf(MapType(&arena, StringType(), BytesType()))); +} + +TEST(MapType, Equal) { + google::protobuf::Arena arena; + EXPECT_EQ(MapType(&arena, StringType(), BytesType()), + MapType(&arena, StringType(), BytesType())); + EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())), + MapType(&arena, StringType(), BytesType())); + EXPECT_EQ(MapType(&arena, StringType(), BytesType()), + Type(MapType(&arena, StringType(), BytesType()))); + EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())), + Type(MapType(&arena, StringType(), BytesType()))); +} + +} // namespace +} // namespace cel diff --git a/common/types/message_type.cc b/common/types/message_type.cc new file mode 100644 index 000000000..c5708cbbd --- /dev/null +++ b/common/types/message_type.cc @@ -0,0 +1,95 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +using google::protobuf::Descriptor; + +bool IsWellKnownMessageType(const Descriptor* absl_nonnull descriptor) { + switch (descriptor->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_ANY: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_DURATION: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_STRUCT: + return true; + default: + return false; + } +} + +std::string MessageType::DebugString() const { + if (ABSL_PREDICT_TRUE(static_cast(*this))) { + static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4, + "sizeof(void*) is neither 8 nor 4"); + return absl::StrCat(name(), "@0x", + absl::Hex(descriptor_, sizeof(descriptor_) == 8 + ? absl::PadSpec::kZeroPad16 + : absl::PadSpec::kZeroPad8)); + } + return std::string(); +} + +std::string MessageTypeField::DebugString() const { + if (ABSL_PREDICT_TRUE(static_cast(*this))) { + static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4, + "sizeof(void*) is neither 8 nor 4"); + return absl::StrCat("[", (*this)->number(), "]", (*this)->name(), "@0x", + absl::Hex(descriptor_, sizeof(descriptor_) == 8 + ? absl::PadSpec::kZeroPad16 + : absl::PadSpec::kZeroPad8)); + } + return std::string(); +} + +Type MessageTypeField::GetType() const { + ABSL_DCHECK(*this); + return Type::Field(descriptor_); +} + +} // namespace cel diff --git a/common/types/message_type.h b/common/types/message_type.h new file mode 100644 index 000000000..782af87aa --- /dev/null +++ b/common/types/message_type.h @@ -0,0 +1,200 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/types/struct_type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_MESSAGE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MESSAGE_TYPE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/type_kind.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class Type; +class TypeParameters; + +bool IsWellKnownMessageType(const google::protobuf::Descriptor* absl_nonnull descriptor); + +class MessageTypeField; + +class MessageType final { + public: + using element_type = const google::protobuf::Descriptor; + + static constexpr TypeKind kKind = TypeKind::kStruct; + + // Constructs `MessageType` from a pointer to `google::protobuf::Descriptor`. The + // `google::protobuf::Descriptor` must not be one of the well known message types we + // treat specially, if it is behavior is undefined. If you are unsure, you + // should use `Type::Message`. + explicit MessageType(const google::protobuf::Descriptor* absl_nullable descriptor) + : descriptor_(descriptor) { + ABSL_DCHECK(descriptor == nullptr || !IsWellKnownMessageType(descriptor)) + << descriptor->full_name(); + } + + // Constructs a `MessageType` in an empty state. + // + // Most operations on an empty `MessageType` result in undefined behavior. Use + // `operator bool` to test if a `MessageType` is empty. + MessageType() = default; + MessageType(const MessageType&) = default; + MessageType(MessageType&&) = default; + MessageType& operator=(const MessageType&) = default; + MessageType& operator=(MessageType&&) = default; + + static TypeKind kind() { return kKind; } + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return (*this)->full_name(); + } + + std::string DebugString() const; + + static TypeParameters GetParameters(); + + const google::protobuf::Descriptor& operator*() const { + ABSL_DCHECK(*this); + return *descriptor_; + } + + const google::protobuf::Descriptor* absl_nonnull operator->() const { + ABSL_DCHECK(*this); + return descriptor_; + } + + explicit operator bool() const { return descriptor_ != nullptr; } + + private: + friend struct std::pointer_traits; + + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; +}; + +inline bool operator==(MessageType lhs, MessageType rhs) { + return static_cast(lhs) == static_cast(rhs) && + (!static_cast(lhs) || lhs.name() == rhs.name()); +} + +inline bool operator!=(MessageType lhs, MessageType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, MessageType message_type) { + return H::combine(std::move(state), static_cast(message_type) + ? message_type.name() + : absl::string_view()); +} + +inline std::ostream& operator<<(std::ostream& out, MessageType type) { + return out << type.DebugString(); +} + +} // namespace cel + +namespace std { + +template <> +struct pointer_traits { + using pointer = cel::MessageType; + using element_type = typename cel::MessageType::element_type; + using difference_type = ptrdiff_t; + + static element_type* to_address(const pointer& p) noexcept { + return p.descriptor_; + } +}; + +} // namespace std + +namespace cel { + +class MessageTypeField final { + public: + using element_type = const google::protobuf::FieldDescriptor; + + explicit MessageTypeField( + const google::protobuf::FieldDescriptor* absl_nullable descriptor) + : descriptor_(descriptor) {} + + MessageTypeField() = default; + MessageTypeField(const MessageTypeField&) = default; + MessageTypeField(MessageTypeField&&) = default; + MessageTypeField& operator=(const MessageTypeField&) = default; + MessageTypeField& operator=(MessageTypeField&&) = default; + + std::string DebugString() const; + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return (*this)->name(); + } + + int32_t number() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return (*this)->number(); + } + + Type GetType() const; + + const google::protobuf::FieldDescriptor& operator*() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return *descriptor_; + } + + const google::protobuf::FieldDescriptor* absl_nonnull operator->() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return descriptor_; + } + + explicit operator bool() const { return descriptor_ != nullptr; } + + private: + friend struct std::pointer_traits; + + const google::protobuf::FieldDescriptor* absl_nullable descriptor_ = nullptr; +}; + +} // namespace cel + +namespace std { + +template <> +struct pointer_traits { + using pointer = cel::MessageTypeField; + using element_type = typename cel::MessageTypeField::element_type; + using difference_type = ptrdiff_t; + + static element_type* to_address(const pointer& p) noexcept { + return p.descriptor_; + } +}; + +} // namespace std + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_MESSAGE_TYPE_H_ diff --git a/common/types/message_type_test.cc b/common/types/message_type_test.cc new file mode 100644 index 000000000..497434e14 --- /dev/null +++ b/common/types/message_type_test.cc @@ -0,0 +1,102 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/protobuf/descriptor.pb.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::testing::An; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::NotNull; +using ::testing::Optional; +using ::testing::StartsWith; + +TEST(MessageType, Kind) { EXPECT_EQ(MessageType::kind(), TypeKind::kStruct); } + +TEST(MessageType, Default) { + MessageType type; + EXPECT_FALSE(type); + EXPECT_THAT(type.DebugString(), Eq("")); + EXPECT_EQ(type, MessageType()); +} + +TEST(MessageType, Descriptor) { + google::protobuf::DescriptorPool pool; + { + google::protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.set_syntax("proto3"); + file_desc_proto.set_package("test"); + file_desc_proto.set_name("test/struct.proto"); + file_desc_proto.add_message_type()->set_name("Struct"); + ASSERT_THAT(pool.BuildFile(file_desc_proto), NotNull()); + } + const google::protobuf::Descriptor* desc = pool.FindMessageTypeByName("test.Struct"); + ASSERT_THAT(desc, NotNull()); + MessageType type(desc); + EXPECT_TRUE(type); + EXPECT_THAT(type.name(), Eq("test.Struct")); + EXPECT_THAT(type.DebugString(), StartsWith("test.Struct@0x")); + EXPECT_THAT(type.GetParameters(), IsEmpty()); + EXPECT_NE(type, MessageType()); + EXPECT_NE(MessageType(), type); + EXPECT_EQ(cel::to_address(type), desc); +} + +TEST(MessageTypeField, Descriptor) { + google::protobuf::DescriptorPool pool; + { + google::protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.set_syntax("proto3"); + file_desc_proto.set_package("test"); + file_desc_proto.set_name("test/struct.proto"); + auto* message_type = file_desc_proto.add_message_type(); + message_type->set_name("Struct"); + auto* field = message_type->add_field(); + field->set_name("foo"); + field->set_json_name("foo"); + field->set_number(1); + field->set_type(google::protobuf::FieldDescriptorProto::TYPE_INT64); + field->set_label(google::protobuf::FieldDescriptorProto::LABEL_OPTIONAL); + ASSERT_THAT(pool.BuildFile(file_desc_proto), NotNull()); + } + const google::protobuf::Descriptor* desc = pool.FindMessageTypeByName("test.Struct"); + ASSERT_THAT(desc, NotNull()); + const google::protobuf::FieldDescriptor* field_desc = desc->FindFieldByName("foo"); + ASSERT_THAT(desc, NotNull()); + MessageTypeField message_type_field(field_desc); + EXPECT_TRUE(message_type_field); + EXPECT_THAT(message_type_field.name(), Eq("foo")); + EXPECT_THAT(message_type_field.DebugString(), StartsWith("[1]foo@0x")); + EXPECT_THAT(message_type_field.number(), Eq(1)); + EXPECT_THAT(message_type_field.GetType(), IntType()); + EXPECT_EQ(cel::to_address(message_type_field), field_desc); + StructTypeField struct_type_field = message_type_field; + EXPECT_TRUE(struct_type_field.IsMessage()); + EXPECT_THAT(struct_type_field.AsMessage(), Optional(An())); + EXPECT_THAT(static_cast(struct_type_field), + An()); + EXPECT_EQ(struct_type_field.name(), message_type_field.name()); + EXPECT_EQ(struct_type_field.number(), message_type_field.number()); + EXPECT_EQ(struct_type_field.GetType(), message_type_field.GetType()); +} + +} // namespace +} // namespace cel diff --git a/common/types/null_type.h b/common/types/null_type.h new file mode 100644 index 000000000..053cd9abb --- /dev/null +++ b/common/types/null_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_NULL_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_NULL_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `NullType` represents the primitive `null_type` type. +class NullType final { + public: + static constexpr TypeKind kKind = TypeKind::kNull; + static constexpr absl::string_view kName = "null_type"; + + NullType() = default; + NullType(const NullType&) = default; + NullType(NullType&&) = default; + NullType& operator=(const NullType&) = default; + NullType& operator=(NullType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(NullType, NullType) { return true; } + +inline constexpr bool operator!=(NullType lhs, NullType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, NullType) { + // NullType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const NullType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_NULL_TYPE_H_ diff --git a/common/types/null_type_test.cc b/common/types/null_type_test.cc new file mode 100644 index 000000000..66cd5fa05 --- /dev/null +++ b/common/types/null_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(NullType, Kind) { + EXPECT_EQ(NullType().kind(), NullType::kKind); + EXPECT_EQ(Type(NullType()).kind(), NullType::kKind); +} + +TEST(NullType, Name) { + EXPECT_EQ(NullType().name(), NullType::kName); + EXPECT_EQ(Type(NullType()).name(), NullType::kName); +} + +TEST(NullType, DebugString) { + { + std::ostringstream out; + out << NullType(); + EXPECT_EQ(out.str(), NullType::kName); + } + { + std::ostringstream out; + out << Type(NullType()); + EXPECT_EQ(out.str(), NullType::kName); + } +} + +TEST(NullType, Hash) { + EXPECT_EQ(absl::HashOf(NullType()), absl::HashOf(NullType())); +} + +TEST(NullType, Equal) { + EXPECT_EQ(NullType(), NullType()); + EXPECT_EQ(Type(NullType()), NullType()); + EXPECT_EQ(NullType(), Type(NullType())); + EXPECT_EQ(Type(NullType()), Type(NullType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/opaque_type.cc b/common/types/opaque_type.cc new file mode 100644 index 000000000..9c58e8289 --- /dev/null +++ b/common/types/opaque_type.cc @@ -0,0 +1,109 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/utility/utility.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" + +namespace cel { + +namespace { + +std::string OpaqueDebugString(absl::string_view name, + absl::Span parameters) { + if (parameters.empty()) { + return std::string(name); + } + return absl::StrCat(name, "<", + absl::StrJoin(parameters, ", ", + [](std::string* out, const Type& type) { + absl::StrAppend( + out, TypeKindToString(type.kind())); + }), + ">"); +} + +} // namespace + +namespace common_internal { + +OpaqueTypeData* absl_nonnull OpaqueTypeData::Create( + google::protobuf::Arena* absl_nonnull arena, absl::string_view name, + absl::Span parameters) { + return ::new (arena->AllocateAligned( + offsetof(OpaqueTypeData, parameters) + (parameters.size() * sizeof(Type)), + alignof(OpaqueTypeData))) OpaqueTypeData(name, parameters); +} + +OpaqueTypeData::OpaqueTypeData(absl::string_view name, + absl::Span parameters) + : name(name), parameters_size(parameters.size()) { + std::memcpy(this->parameters, parameters.data(), + parameters_size * sizeof(Type)); +} + +} // namespace common_internal + +OpaqueType::OpaqueType(google::protobuf::Arena* absl_nonnull arena, + absl::string_view name, + absl::Span parameters) + : OpaqueType( + common_internal::OpaqueTypeData::Create(arena, name, parameters)) {} + +std::string OpaqueType::DebugString() const { + ABSL_DCHECK(*this); + return OpaqueDebugString(name(), GetParameters()); +} + +absl::string_view OpaqueType::name() const { + ABSL_DCHECK(*this); + return data_->name; +} + +TypeParameters OpaqueType::GetParameters() const { + ABSL_DCHECK(*this); + return TypeParameters( + absl::MakeConstSpan(data_->parameters, data_->parameters_size)); +} + +bool OpaqueType::IsOptional() const { + return name() == OptionalType::kName && GetParameters().size() == 1; +} + +absl::optional OpaqueType::AsOptional() const { + if (IsOptional()) { + return OptionalType(absl::in_place, *this); + } + return std::nullopt; +} + +OptionalType OpaqueType::GetOptional() const { + ABSL_DCHECK(IsOptional()) << DebugString(); + return OptionalType(absl::in_place, *this); +} + +} // namespace cel diff --git a/common/types/opaque_type.h b/common/types/opaque_type.h new file mode 100644 index 000000000..2b4fe8185 --- /dev/null +++ b/common/types/opaque_type.h @@ -0,0 +1,118 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" +// IWYU pragma: friend "common/types/optional_type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class Type; +class OptionalType; +class TypeParameters; + +namespace common_internal { +struct OpaqueTypeData; +} // namespace common_internal + +class OpaqueType final { + public: + static constexpr TypeKind kKind = TypeKind::kOpaque; + + // `name` must outlive the instance. + OpaqueType(google::protobuf::Arena* absl_nonnull arena, absl::string_view name, + absl::Span parameters); + + // NOLINTNEXTLINE(google-explicit-constructor) + OpaqueType(OptionalType type); + + // NOLINTNEXTLINE(google-explicit-constructor) + OpaqueType& operator=(OptionalType type); + + OpaqueType() = default; + OpaqueType(const OpaqueType&) = default; + OpaqueType(OpaqueType&&) = default; + OpaqueType& operator=(const OpaqueType&) = default; + OpaqueType& operator=(OpaqueType&&) = default; + + static TypeKind kind() { return kKind; } + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + std::string DebugString() const; + + TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + explicit operator bool() const { return data_ != nullptr; } + + bool IsOptional() const; + + template + std::enable_if_t, bool> Is() const { + return IsOptional(); + } + + absl::optional AsOptional() const; + + template + std::enable_if_t, + absl::optional> + As() const; + + OptionalType GetOptional() const; + + template + std::enable_if_t, OptionalType> Get() const; + + private: + friend class OptionalType; + + constexpr explicit OpaqueType( + const common_internal::OpaqueTypeData* absl_nullable data) + : data_(data) {} + + const common_internal::OpaqueTypeData* absl_nullable data_ = nullptr; +}; + +bool operator==(const OpaqueType& lhs, const OpaqueType& rhs); + +inline bool operator!=(const OpaqueType& lhs, const OpaqueType& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const OpaqueType& type); + +inline std::ostream& operator<<(std::ostream& out, const OpaqueType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_H_ diff --git a/common/types/opaque_type_pool.cc b/common/types/opaque_type_pool.cc new file mode 100644 index 000000000..a4f86e656 --- /dev/null +++ b/common/types/opaque_type_pool.cc @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/types/opaque_type_pool.h" + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/type.h" + +namespace cel::common_internal { + +OpaqueType OpaqueTypePool::InternOpaqueType(absl::string_view name, + absl::Span parameters) { + if (name.empty() && parameters.empty()) { + return OpaqueType(); + } + return *opaque_types_.lazy_emplace( + AsTuple(name, parameters), + [&](const auto& ctor) { ctor(OpaqueType(arena_, name, parameters)); }); +} + +} // namespace cel::common_internal diff --git a/common/types/opaque_type_pool.h b/common/types/opaque_type_pool.h new file mode 100644 index 000000000..1d2d5be17 --- /dev/null +++ b/common/types/opaque_type_pool.h @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_POOL_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/log/die_if_null.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +// `OpaqueTypePool` is a thread unsafe interning factory for `OpaqueType`. +class OpaqueTypePool final { + public: + explicit OpaqueTypePool(google::protobuf::Arena* absl_nonnull arena) + : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK + + // Returns a `OpaqueType` which has the provided parameters, interning as + // necessary. + OpaqueType InternOpaqueType(absl::string_view name, + absl::Span parameters); + + private: + using OpaqueTypeTuple = std::tuple>; + + static OpaqueTypeTuple AsTuple(const OpaqueType& opaque_type) { + return AsTuple(opaque_type.name(), opaque_type.GetParameters()); + } + + static OpaqueTypeTuple AsTuple(absl::string_view name, + absl::Span parameters) { + return OpaqueTypeTuple{name, parameters}; + } + + struct Hasher { + using is_transparent = void; + + size_t operator()(const OpaqueType& data) const { + return (*this)(AsTuple(data)); + } + + size_t operator()(const OpaqueTypeTuple& tuple) const { + return absl::Hash{}(tuple); + } + }; + + struct Equaler { + using is_transparent = void; + + bool operator()(const OpaqueType& lhs, const OpaqueType& rhs) const { + return (*this)(AsTuple(lhs), AsTuple(rhs)); + } + + bool operator()(const OpaqueType& lhs, const OpaqueTypeTuple& rhs) const { + return (*this)(AsTuple(lhs), rhs); + } + + bool operator()(const OpaqueTypeTuple& lhs, const OpaqueType& rhs) const { + return (*this)(lhs, AsTuple(rhs)); + } + + bool operator()(const OpaqueTypeTuple& lhs, + const OpaqueTypeTuple& rhs) const { + return std::get<0>(lhs) == std::get<0>(rhs) && + absl::c_equal(std::get<1>(lhs), std::get<1>(rhs)); + } + }; + + google::protobuf::Arena* absl_nonnull const arena_; + absl::flat_hash_set opaque_types_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_POOL_H_ diff --git a/common/types/opaque_type_test.cc b/common/types/opaque_type_test.cc new file mode 100644 index 000000000..d34b6936c --- /dev/null +++ b/common/types/opaque_type_test.cc @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +TEST(OpaqueType, Kind) { + google::protobuf::Arena arena; + EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}).kind(), + OpaqueType::kKind); + EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})).kind(), + OpaqueType::kKind); +} + +TEST(OpaqueType, Name) { + google::protobuf::Arena arena; + EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}).name(), + "test.Opaque"); + EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})).name(), + "test.Opaque"); +} + +TEST(OpaqueType, DebugString) { + google::protobuf::Arena arena; + { + std::ostringstream out; + out << OpaqueType(&arena, "test.Opaque", {BytesType()}); + EXPECT_EQ(out.str(), "test.Opaque"); + } + { + std::ostringstream out; + out << Type(OpaqueType(&arena, "test.Opaque", {BytesType()})); + EXPECT_EQ(out.str(), "test.Opaque"); + } + { + std::ostringstream out; + out << OpaqueType(&arena, "test.Opaque", {}); + EXPECT_EQ(out.str(), "test.Opaque"); + } +} + +TEST(OpaqueType, Hash) { + google::protobuf::Arena arena; + EXPECT_EQ(absl::HashOf(OpaqueType(&arena, "test.Opaque", {BytesType()})), + absl::HashOf(OpaqueType(&arena, "test.Opaque", {BytesType()}))); +} + +TEST(OpaqueType, Equal) { + google::protobuf::Arena arena; + EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}), + OpaqueType(&arena, "test.Opaque", {BytesType()})); + EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})), + OpaqueType(&arena, "test.Opaque", {BytesType()})); + EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}), + Type(OpaqueType(&arena, "test.Opaque", {BytesType()}))); + EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})), + Type(OpaqueType(&arena, "test.Opaque", {BytesType()}))); +} + +} // namespace +} // namespace cel diff --git a/common/types/optional_type.cc b/common/types/optional_type.cc new file mode 100644 index 000000000..a37300bba --- /dev/null +++ b/common/types/optional_type.cc @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "common/type.h" + +namespace cel { + +namespace common_internal { + +namespace { + +struct OptionalTypeData final { + const absl::string_view name; + const size_t parameters_size; + const Type parameter; +}; + +// Here by dragons. In order to make `OptionalType` default constructible +// without some sort of dynamic static initializer, we perform some +// type-punning. `OptionalTypeData` and `OpaqueTypeData` must have the same +// layout, with the only exception being that `OptionalTypeData` as a single +// `Type` where `OpaqueTypeData` as a flexible array. +union DynOptionalTypeData final { + OptionalTypeData optional; + OpaqueTypeData opaque; +}; + +static_assert(offsetof(OptionalTypeData, name) == + offsetof(OpaqueTypeData, name)); +static_assert(offsetof(OptionalTypeData, parameters_size) == + offsetof(OpaqueTypeData, parameters_size)); +static_assert(offsetof(OptionalTypeData, parameter) == + offsetof(OpaqueTypeData, parameters)); + +ABSL_CONST_INIT const DynOptionalTypeData kDynOptionalTypeData = { + .optional = + { + .name = OptionalType::kName, + .parameters_size = 1, + .parameter = DynType(), + }, +}; + +} // namespace + +} // namespace common_internal + +OptionalType::OptionalType() + : opaque_(&common_internal::kDynOptionalTypeData.opaque) {} + +Type OptionalType::GetParameter() const { return GetParameters().front(); } + +} // namespace cel diff --git a/common/types/optional_type.h b/common/types/optional_type.h new file mode 100644 index 000000000..922e6372e --- /dev/null +++ b/common/types/optional_type.h @@ -0,0 +1,114 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPTIONAL_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPTIONAL_TYPE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/utility/utility.h" +#include "common/type_kind.h" +#include "common/types/opaque_type.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class Type; +class TypeParameters; + +class OptionalType final { + public: + static constexpr TypeKind kKind = TypeKind::kOpaque; + static constexpr absl::string_view kName = "optional_type"; + + // By default, this type is `optional(dyn)`. Unless you can help it, you + // should choose a more specific optional type. + OptionalType(); + + OptionalType(google::protobuf::Arena* absl_nonnull arena, const Type& parameter) + : OptionalType( + absl::in_place, + OpaqueType(arena, kName, absl::MakeConstSpan(¶meter, 1))) {} + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + std::string DebugString() const { return opaque_.DebugString(); } + + TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Type GetParameter() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + explicit operator bool() const { return static_cast(opaque_); } + + template + friend H AbslHashValue(H state, const OptionalType& type) { + return H::combine(std::move(state), type.opaque_); + } + + friend bool operator==(const OptionalType& lhs, const OptionalType& rhs) { + return lhs.opaque_ == rhs.opaque_; + } + + private: + friend class OpaqueType; + + OptionalType(absl::in_place_t, OpaqueType type) : opaque_(std::move(type)) {} + + OpaqueType opaque_; +}; + +inline bool operator!=(const OptionalType& lhs, const OptionalType& rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, const OptionalType& type) { + return out << type.DebugString(); +} + +inline OpaqueType::OpaqueType(OptionalType type) + : OpaqueType(std::move(type.opaque_)) {} + +inline OpaqueType& OpaqueType::operator=(OptionalType type) { + return *this = std::move(type.opaque_); +} + +template +inline std::enable_if_t, + absl::optional> +OpaqueType::As() const { + return AsOptional(); +} + +template +inline std::enable_if_t, OptionalType> +OpaqueType::Get() const { + return GetOptional(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPTIONAL_TYPE_H_ diff --git a/common/types/optional_type_test.cc b/common/types/optional_type_test.cc new file mode 100644 index 000000000..aa3a60385 --- /dev/null +++ b/common/types/optional_type_test.cc @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +TEST(OptionalType, Default) { + OptionalType optional_type; + EXPECT_EQ(optional_type.GetParameter(), DynType()); +} + +TEST(OptionalType, Kind) { + google::protobuf::Arena arena; + EXPECT_EQ(OptionalType(&arena, BoolType()).kind(), OptionalType::kKind); + EXPECT_EQ(Type(OptionalType(&arena, BoolType())).kind(), OptionalType::kKind); +} + +TEST(OptionalType, Name) { + google::protobuf::Arena arena; + EXPECT_EQ(OptionalType(&arena, BoolType()).name(), OptionalType::kName); + EXPECT_EQ(Type(OptionalType(&arena, BoolType())).name(), OptionalType::kName); +} + +TEST(OptionalType, DebugString) { + google::protobuf::Arena arena; + { + std::ostringstream out; + out << OptionalType(&arena, BoolType()); + EXPECT_EQ(out.str(), "optional_type"); + } + { + std::ostringstream out; + out << Type(OptionalType(&arena, BoolType())); + EXPECT_EQ(out.str(), "optional_type"); + } +} + +TEST(OptionalType, Parameter) { + google::protobuf::Arena arena; + EXPECT_EQ(OptionalType(&arena, BoolType()).GetParameter(), BoolType()); +} + +TEST(OptionalType, Hash) { + google::protobuf::Arena arena; + EXPECT_EQ(absl::HashOf(OptionalType(&arena, BoolType())), + absl::HashOf(OptionalType(&arena, BoolType()))); +} + +TEST(OptionalType, Equal) { + google::protobuf::Arena arena; + EXPECT_EQ(OptionalType(&arena, BoolType()), OptionalType(&arena, BoolType())); + EXPECT_EQ(Type(OptionalType(&arena, BoolType())), + OptionalType(&arena, BoolType())); + EXPECT_EQ(OptionalType(&arena, BoolType()), + Type(OptionalType(&arena, BoolType()))); + EXPECT_EQ(Type(OptionalType(&arena, BoolType())), + Type(OptionalType(&arena, BoolType()))); +} + +} // namespace +} // namespace cel diff --git a/common/types/string_type.h b/common/types/string_type.h new file mode 100644 index 000000000..4bb6963ed --- /dev/null +++ b/common/types/string_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `StringType` represents the primitive `string` type. +class StringType final { + public: + static constexpr TypeKind kKind = TypeKind::kString; + static constexpr absl::string_view kName = "string"; + + StringType() = default; + StringType(const StringType&) = default; + StringType(StringType&&) = default; + StringType& operator=(const StringType&) = default; + StringType& operator=(StringType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + std::string DebugString() const { return std::string(name()); } +}; + +inline constexpr bool operator==(StringType, StringType) { return true; } + +inline constexpr bool operator!=(StringType lhs, StringType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, StringType) { + // StringType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const StringType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_TYPE_H_ diff --git a/common/types/string_type_test.cc b/common/types/string_type_test.cc new file mode 100644 index 000000000..e668392d5 --- /dev/null +++ b/common/types/string_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(StringType, Kind) { + EXPECT_EQ(StringType().kind(), StringType::kKind); + EXPECT_EQ(Type(StringType()).kind(), StringType::kKind); +} + +TEST(StringType, Name) { + EXPECT_EQ(StringType().name(), StringType::kName); + EXPECT_EQ(Type(StringType()).name(), StringType::kName); +} + +TEST(StringType, DebugString) { + { + std::ostringstream out; + out << StringType(); + EXPECT_EQ(out.str(), StringType::kName); + } + { + std::ostringstream out; + out << Type(StringType()); + EXPECT_EQ(out.str(), StringType::kName); + } +} + +TEST(StringType, Hash) { + EXPECT_EQ(absl::HashOf(StringType()), absl::HashOf(StringType())); +} + +TEST(StringType, Equal) { + EXPECT_EQ(StringType(), StringType()); + EXPECT_EQ(Type(StringType()), StringType()); + EXPECT_EQ(StringType(), Type(StringType())); + EXPECT_EQ(Type(StringType()), Type(StringType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/string_wrapper_type.h b/common/types/string_wrapper_type.h new file mode 100644 index 000000000..530845a9d --- /dev/null +++ b/common/types/string_wrapper_type.h @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_WRAPPER_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_WRAPPER_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `StringWrapperType` is a special type which has no direct value +// representation. It is used to represent `google.protobuf.StringValue`, which +// never exists at runtime as a value. Its primary usage is for type checking +// and unpacking at runtime. +class StringWrapperType final { + public: + static constexpr TypeKind kKind = TypeKind::kStringWrapper; + static constexpr absl::string_view kName = "google.protobuf.StringValue"; + + StringWrapperType() = default; + StringWrapperType(const StringWrapperType&) = default; + StringWrapperType(StringWrapperType&&) = default; + StringWrapperType& operator=(const StringWrapperType&) = default; + StringWrapperType& operator=(StringWrapperType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } + + constexpr void swap(StringWrapperType&) noexcept {} +}; + +inline constexpr void swap(StringWrapperType& lhs, + StringWrapperType& rhs) noexcept { + lhs.swap(rhs); +} + +inline constexpr bool operator==(StringWrapperType, StringWrapperType) { + return true; +} + +inline constexpr bool operator!=(StringWrapperType lhs, StringWrapperType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, StringWrapperType) { + // StringWrapperType is really a singleton and all instances are equal. + // Nothing to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, + const StringWrapperType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_WRAPPER_TYPE_H_ diff --git a/common/types/string_wrapper_type_test.cc b/common/types/string_wrapper_type_test.cc new file mode 100644 index 000000000..a863177b3 --- /dev/null +++ b/common/types/string_wrapper_type_test.cc @@ -0,0 +1,60 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(StringWrapperType, Kind) { + EXPECT_EQ(StringWrapperType().kind(), StringWrapperType::kKind); + EXPECT_EQ(Type(StringWrapperType()).kind(), StringWrapperType::kKind); +} + +TEST(StringWrapperType, Name) { + EXPECT_EQ(StringWrapperType().name(), StringWrapperType::kName); + EXPECT_EQ(Type(StringWrapperType()).name(), StringWrapperType::kName); +} + +TEST(StringWrapperType, DebugString) { + { + std::ostringstream out; + out << StringWrapperType(); + EXPECT_EQ(out.str(), StringWrapperType::kName); + } + { + std::ostringstream out; + out << Type(StringWrapperType()); + EXPECT_EQ(out.str(), StringWrapperType::kName); + } +} + +TEST(StringWrapperType, Hash) { + EXPECT_EQ(absl::HashOf(StringWrapperType()), + absl::HashOf(StringWrapperType())); +} + +TEST(StringWrapperType, Equal) { + EXPECT_EQ(StringWrapperType(), StringWrapperType()); + EXPECT_EQ(Type(StringWrapperType()), StringWrapperType()); + EXPECT_EQ(StringWrapperType(), Type(StringWrapperType())); + EXPECT_EQ(Type(StringWrapperType()), Type(StringWrapperType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/struct_type.cc b/common/types/struct_type.cc new file mode 100644 index 000000000..69f531a2f --- /dev/null +++ b/common/types/struct_type.cc @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/type.h" +#include "common/types/types.h" + +namespace cel { + +absl::string_view StructType::name() const { + ABSL_DCHECK(*this); + return absl::visit( + absl::Overload([](std::monostate) { return absl::string_view(); }, + [](const common_internal::BasicStructType& alt) { + return alt.name(); + }, + [](const MessageType& alt) { return alt.name(); }), + variant_); +} + +TypeParameters StructType::GetParameters() const { + ABSL_DCHECK(*this); + return absl::visit( + absl::Overload( + [](std::monostate) { return TypeParameters(); }, + [](const common_internal::BasicStructType& alt) { + return alt.GetParameters(); + }, + [](const MessageType& alt) { return alt.GetParameters(); }), + variant_); +} + +std::string StructType::DebugString() const { + return absl::visit( + absl::Overload([](std::monostate) { return std::string(); }, + [](common_internal::BasicStructType alt) { + return alt.DebugString(); + }, + [](MessageType alt) { return alt.DebugString(); }), + variant_); +} + +absl::optional StructType::AsMessage() const { + if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { + return *alt; + } + return std::nullopt; +} + +MessageType StructType::GetMessage() const { + ABSL_DCHECK(IsMessage()) << DebugString(); + return absl::get(variant_); +} + +common_internal::TypeVariant StructType::ToTypeVariant() const { + return absl::visit( + absl::Overload( + [](std::monostate) { return common_internal::TypeVariant(); }, + [](common_internal::BasicStructType alt) { + return static_cast(alt) ? common_internal::TypeVariant(alt) + : common_internal::TypeVariant(); + }, + [](MessageType alt) { + return static_cast(alt) ? common_internal::TypeVariant(alt) + : common_internal::TypeVariant(); + }), + variant_); +} + +} // namespace cel diff --git a/common/types/struct_type.h b/common/types/struct_type.h new file mode 100644 index 000000000..6e20ea007 --- /dev/null +++ b/common/types/struct_type.h @@ -0,0 +1,158 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRUCT_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRUCT_TYPE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/type_kind.h" +#include "common/types/basic_struct_type.h" +#include "common/types/message_type.h" +#include "common/types/types.h" + +namespace cel { + +class Type; +class TypeParameters; + +class StructType final { + public: + static constexpr TypeKind kKind = TypeKind::kStruct; + + // NOLINTNEXTLINE(google-explicit-constructor) + StructType(MessageType other) : StructType() { + if (ABSL_PREDICT_TRUE(other)) { + variant_.emplace(other); + } + } + + // NOLINTNEXTLINE(google-explicit-constructor) + StructType(common_internal::BasicStructType other) : StructType() { + if (ABSL_PREDICT_TRUE(other)) { + variant_.emplace(other); + } + } + + // NOLINTNEXTLINE(google-explicit-constructor) + StructType& operator=(MessageType other) { + if (ABSL_PREDICT_TRUE(other)) { + variant_.emplace(other); + } else { + variant_.emplace(); + } + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + StructType& operator=(common_internal::BasicStructType other) { + if (ABSL_PREDICT_TRUE(other)) { + variant_.emplace(other); + } else { + variant_.emplace(); + } + return *this; + } + + StructType() = default; + StructType(const StructType&) = default; + StructType(StructType&&) = default; + StructType& operator=(const StructType&) = default; + StructType& operator=(StructType&&) = default; + + static TypeKind kind() { return kKind; } + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + std::string DebugString() const; + + bool IsMessage() const { + return absl::holds_alternative(variant_); + } + + template + std::enable_if_t, bool> Is() const { + return IsMessage(); + } + + absl::optional AsMessage() const; + + template + std::enable_if_t, absl::optional> + As() const { + return AsMessage(); + } + + MessageType GetMessage() const; + + template + std::enable_if_t, MessageType> Get() const { + return GetMessage(); + } + + explicit operator bool() const { + return !absl::holds_alternative(variant_); + } + + private: + friend class Type; + friend class MessageType; + friend class common_internal::BasicStructType; + + common_internal::TypeVariant ToTypeVariant() const; + + // The default state is well formed but invalid. It can be checked by using + // the explicit bool operator. This is to allow cases where you want to + // construct the type and later assign to it before using it. It is required + // that any instance returned from a function call or passed to a function + // call must not be in the default state. + common_internal::StructTypeVariant variant_; +}; + +inline bool operator==(const StructType& lhs, const StructType& rhs) { + return static_cast(lhs) == static_cast(rhs) && + (!static_cast(lhs) || lhs.name() == rhs.name()); +} + +inline bool operator!=(const StructType& lhs, const StructType& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const StructType& type) { + return H::combine(std::move(state), static_cast(type) + ? type.name() + : absl::string_view()); +} + +inline std::ostream& operator<<(std::ostream& out, const StructType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRUCT_TYPE_H_ diff --git a/common/types/struct_type_test.cc b/common/types/struct_type_test.cc new file mode 100644 index 000000000..f50a0a938 --- /dev/null +++ b/common/types/struct_type_test.cc @@ -0,0 +1,82 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/nullability.h" +#include "absl/hash/hash.h" +#include "absl/log/absl_check.h" +#include "absl/log/die_if_null.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::testing::Test; + +class StructTypeTest : public Test { + public: + void SetUp() override { + { + google::protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.set_syntax("proto3"); + file_desc_proto.set_package("test"); + file_desc_proto.set_name("test/struct.proto"); + file_desc_proto.add_message_type()->set_name("Struct"); + ABSL_CHECK(pool_.BuildFile(file_desc_proto) != nullptr); + } + } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + return ABSL_DIE_IF_NULL(pool_.FindMessageTypeByName("test.Struct")); + } + + MessageType GetMessageType() const { return MessageType(GetDescriptor()); } + + common_internal::BasicStructType GetBasicStructType() const { + return common_internal::MakeBasicStructType("test.Struct"); + } + + private: + google::protobuf::DescriptorPool pool_; +}; + +TEST(StructType, Kind) { EXPECT_EQ(StructType::kind(), TypeKind::kStruct); } + +TEST_F(StructTypeTest, Name) { + EXPECT_EQ(StructType(GetMessageType()).name(), GetMessageType().name()); + EXPECT_EQ(StructType(GetBasicStructType()).name(), + GetBasicStructType().name()); +} + +TEST_F(StructTypeTest, DebugString) { + EXPECT_EQ(StructType(GetMessageType()).DebugString(), + GetMessageType().DebugString()); + EXPECT_EQ(StructType(GetBasicStructType()).DebugString(), + GetBasicStructType().DebugString()); +} + +TEST_F(StructTypeTest, Hash) { + EXPECT_EQ(absl::HashOf(StructType(GetMessageType())), + absl::HashOf(StructType(GetBasicStructType()))); +} + +TEST_F(StructTypeTest, Equal) { + EXPECT_EQ(StructType(GetMessageType()), StructType(GetBasicStructType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/timestamp_type.h b/common/types/timestamp_type.h new file mode 100644 index 000000000..13cc8ca62 --- /dev/null +++ b/common/types/timestamp_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TIMESTAMP_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TIMESTAMP_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `TimestampType` represents the primitive `timestamp` type. +class TimestampType final { + public: + static constexpr TypeKind kKind = TypeKind::kTimestamp; + static constexpr absl::string_view kName = "google.protobuf.Timestamp"; + + TimestampType() = default; + TimestampType(const TimestampType&) = default; + TimestampType(TimestampType&&) = default; + TimestampType& operator=(const TimestampType&) = default; + TimestampType& operator=(TimestampType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(TimestampType, TimestampType) { return true; } + +inline constexpr bool operator!=(TimestampType lhs, TimestampType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, TimestampType) { + // TimestampType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const TimestampType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TIMESTAMP_TYPE_H_ diff --git a/common/types/timestamp_type_test.cc b/common/types/timestamp_type_test.cc new file mode 100644 index 000000000..648ba3df3 --- /dev/null +++ b/common/types/timestamp_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(TimestampType, Kind) { + EXPECT_EQ(TimestampType().kind(), TimestampType::kKind); + EXPECT_EQ(Type(TimestampType()).kind(), TimestampType::kKind); +} + +TEST(TimestampType, Name) { + EXPECT_EQ(TimestampType().name(), TimestampType::kName); + EXPECT_EQ(Type(TimestampType()).name(), TimestampType::kName); +} + +TEST(TimestampType, DebugString) { + { + std::ostringstream out; + out << TimestampType(); + EXPECT_EQ(out.str(), TimestampType::kName); + } + { + std::ostringstream out; + out << Type(TimestampType()); + EXPECT_EQ(out.str(), TimestampType::kName); + } +} + +TEST(TimestampType, Hash) { + EXPECT_EQ(absl::HashOf(TimestampType()), absl::HashOf(TimestampType())); +} + +TEST(TimestampType, Equal) { + EXPECT_EQ(TimestampType(), TimestampType()); + EXPECT_EQ(Type(TimestampType()), TimestampType()); + EXPECT_EQ(TimestampType(), Type(TimestampType())); + EXPECT_EQ(Type(TimestampType()), Type(TimestampType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/type_param_type.h b/common/types/type_param_type.h new file mode 100644 index 000000000..4fa8b9612 --- /dev/null +++ b/common/types/type_param_type.h @@ -0,0 +1,78 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_PARAM_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_PARAM_TYPE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +class TypeParamType final { + public: + static constexpr TypeKind kKind = TypeKind::kTypeParam; + + explicit TypeParamType(absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND) + : name_(name) {} + + TypeParamType() = default; + TypeParamType(const TypeParamType&) = default; + TypeParamType(TypeParamType&&) = default; + TypeParamType& operator=(const TypeParamType&) = default; + TypeParamType& operator=(TypeParamType&&) = default; + + static TypeKind kind() { return kKind; } + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } + + static TypeParameters GetParameters(); + + std::string DebugString() const { return std::string(name()); } + + private: + absl::string_view name_; +}; + +inline bool operator==(const TypeParamType& lhs, const TypeParamType& rhs) { + return lhs.name() == rhs.name(); +} + +inline bool operator!=(const TypeParamType& lhs, const TypeParamType& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const TypeParamType& type) { + return H::combine(std::move(state), type.name()); +} + +inline std::ostream& operator<<(std::ostream& out, const TypeParamType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_PARAM_TYPE_H_ diff --git a/common/types/type_param_type_test.cc b/common/types/type_param_type_test.cc new file mode 100644 index 000000000..69c902070 --- /dev/null +++ b/common/types/type_param_type_test.cc @@ -0,0 +1,60 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type.h" + +#include + +#include "absl/hash/hash.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(TypeParamType, Kind) { + EXPECT_EQ(TypeParamType("T").kind(), TypeParamType::kKind); + EXPECT_EQ(Type(TypeParamType("T")).kind(), TypeParamType::kKind); +} + +TEST(TypeParamType, Name) { + EXPECT_EQ(TypeParamType("T").name(), "T"); + EXPECT_EQ(Type(TypeParamType("T")).name(), "T"); +} + +TEST(TypeParamType, DebugString) { + { + std::ostringstream out; + out << TypeParamType("T"); + EXPECT_EQ(out.str(), "T"); + } + { + std::ostringstream out; + out << Type(TypeParamType("T")); + EXPECT_EQ(out.str(), "T"); + } +} + +TEST(TypeParamType, Hash) { + EXPECT_EQ(absl::HashOf(TypeParamType("T")), absl::HashOf(TypeParamType("T"))); +} + +TEST(TypeParamType, Equal) { + EXPECT_EQ(TypeParamType("T"), TypeParamType("T")); + EXPECT_EQ(Type(TypeParamType("T")), TypeParamType("T")); + EXPECT_EQ(TypeParamType("T"), Type(TypeParamType("T"))); + EXPECT_EQ(Type(TypeParamType("T")), Type(TypeParamType("T"))); +} + +} // namespace +} // namespace cel diff --git a/common/types/type_pool.cc b/common/types/type_pool.cc new file mode 100644 index 000000000..3db7ef288 --- /dev/null +++ b/common/types/type_pool.cc @@ -0,0 +1,96 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/types/type_pool.h" + +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "common/type.h" + +namespace cel::common_internal { + +StructType TypePool::MakeStructType(absl::string_view name) { + ABSL_DCHECK(!IsWellKnownMessageType(name)) << name; + if (ABSL_PREDICT_FALSE(name.empty())) { + return StructType(); + } + if (const auto* descriptor = descriptors_->FindMessageTypeByName(name); + descriptor != nullptr) { + return MessageType(descriptor); + } + return MakeBasicStructType(InternString(name)); +} + +FunctionType TypePool::MakeFunctionType(const Type& result, + absl::Span args) { + absl::MutexLock lock(functions_mutex_); + return functions_.InternFunctionType(result, args); +} + +ListType TypePool::MakeListType(const Type& element) { + if (element.IsDyn()) { + return ListType(); + } + absl::MutexLock lock(lists_mutex_); + return lists_.InternListType(element); +} + +MapType TypePool::MakeMapType(const Type& key, const Type& value) { + if (key.IsDyn() && value.IsDyn()) { + return MapType(); + } + if (key.IsString() && value.IsDyn()) { + return JsonMapType(); + } + absl::MutexLock lock(maps_mutex_); + return maps_.InternMapType(key, value); +} + +OpaqueType TypePool::MakeOpaqueType(absl::string_view name, + absl::Span parameters) { + if (name == OptionalType::kName) { + if (parameters.size() == 1 && parameters.front().IsDyn()) { + return OptionalType(); + } + name = OptionalType::kName; + } else { + name = InternString(name); + } + absl::MutexLock lock(opaques_mutex_); + return opaques_.InternOpaqueType(name, parameters); +} + +OptionalType TypePool::MakeOptionalType(const Type& parameter) { + return MakeOpaqueType(OptionalType::kName, absl::MakeConstSpan(¶meter, 1)) + .GetOptional(); +} + +TypeParamType TypePool::MakeTypeParamType(absl::string_view name) { + return TypeParamType(InternString(name)); +} + +TypeType TypePool::MakeTypeType(const Type& type) { + absl::MutexLock lock(types_mutex_); + return types_.InternTypeType(type); +} + +absl::string_view TypePool::InternString(absl::string_view string) { + absl::MutexLock lock(strings_mutex_); + return strings_.InternString(string); +} + +} // namespace cel::common_internal diff --git a/common/types/type_pool.h b/common/types/type_pool.h new file mode 100644 index 000000000..921bf9d07 --- /dev/null +++ b/common/types/type_pool.h @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/log/die_if_null.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/types/function_type_pool.h" +#include "common/types/list_type_pool.h" +#include "common/types/map_type_pool.h" +#include "common/types/opaque_type_pool.h" +#include "common/types/type_type_pool.h" +#include "internal/string_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel::common_internal { + +// `TypePool` is a thread safe interning factory for complex types. All types +// are allocated using the provided `google::protobuf::Arena`. +class TypePool final { + public: + TypePool(const google::protobuf::DescriptorPool* absl_nonnull descriptors + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : descriptors_(ABSL_DIE_IF_NULL(descriptors)), // Crash OK + arena_(ABSL_DIE_IF_NULL(arena)), // Crash OK + strings_(arena_), + functions_(arena_), + lists_(arena_), + maps_(arena_), + opaques_(arena_), + types_(arena_) {} + + TypePool(const TypePool&) = delete; + TypePool(TypePool&&) = delete; + TypePool& operator=(const TypePool&) = delete; + TypePool& operator=(TypePool&&) = delete; + + StructType MakeStructType(absl::string_view name); + + FunctionType MakeFunctionType(const Type& result, + absl::Span args); + + ListType MakeListType(const Type& element); + + MapType MakeMapType(const Type& key, const Type& value); + + OpaqueType MakeOpaqueType(absl::string_view name, + absl::Span parameters); + + OptionalType MakeOptionalType(const Type& parameter); + + TypeParamType MakeTypeParamType(absl::string_view name); + + TypeType MakeTypeType(const Type& type); + + private: + absl::string_view InternString(absl::string_view string); + + const google::protobuf::DescriptorPool* absl_nonnull const descriptors_; + google::protobuf::Arena* absl_nonnull const arena_; + absl::Mutex strings_mutex_; + internal::StringPool strings_ ABSL_GUARDED_BY(strings_mutex_); + absl::Mutex functions_mutex_; + FunctionTypePool functions_ ABSL_GUARDED_BY(functions_mutex_); + absl::Mutex lists_mutex_; + ListTypePool lists_ ABSL_GUARDED_BY(lists_mutex_); + absl::Mutex maps_mutex_; + MapTypePool maps_ ABSL_GUARDED_BY(maps_mutex_); + absl::Mutex opaques_mutex_; + OpaqueTypePool opaques_ ABSL_GUARDED_BY(opaques_mutex_); + absl::Mutex types_mutex_; + TypeTypePool types_ ABSL_GUARDED_BY(types_mutex_); +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ diff --git a/common/types/type_pool_test.cc b/common/types/type_pool_test.cc new file mode 100644 index 000000000..4d32113d0 --- /dev/null +++ b/common/types/type_pool_test.cc @@ -0,0 +1,94 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/types/type_pool.h" + +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { +namespace { + +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::_; + +TEST(TypePool, MakeStructType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeStructType("foo.Bar"), + MakeBasicStructType("foo.Bar")); + EXPECT_TRUE( + type_pool.MakeStructType("cel.expr.conformance.proto3.TestAllTypes") + .IsMessage()); + EXPECT_DEBUG_DEATH( + static_cast(type_pool.MakeStructType("google.protobuf.BoolValue")), + _); +} + +TEST(TypePool, MakeFunctionType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeFunctionType(BoolType(), {IntType(), IntType()}), + FunctionType(&arena, BoolType(), {IntType(), IntType()})); +} + +TEST(TypePool, MakeListType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeListType(DynType()), ListType()); + EXPECT_EQ(type_pool.MakeListType(DynType()), JsonListType()); + EXPECT_EQ(type_pool.MakeListType(StringType()), + ListType(&arena, StringType())); +} + +TEST(TypePool, MakeMapType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeMapType(DynType(), DynType()), MapType()); + EXPECT_EQ(type_pool.MakeMapType(StringType(), DynType()), JsonMapType()); + EXPECT_EQ(type_pool.MakeMapType(StringType(), StringType()), + MapType(&arena, StringType(), StringType())); +} + +TEST(TypePool, MakeOpaqueType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeOpaqueType("custom_type", {DynType(), DynType()}), + OpaqueType(&arena, "custom_type", {DynType(), DynType()})); +} + +TEST(TypePool, MakeOptionalType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeOptionalType(DynType()), OptionalType()); + EXPECT_EQ(type_pool.MakeOptionalType(StringType()), + OptionalType(&arena, StringType())); +} + +TEST(TypePool, MakeTypeParamType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeTypeParamType("T"), TypeParamType("T")); +} + +TEST(TypePool, MakeTypeType) { + google::protobuf::Arena arena; + TypePool type_pool(GetTestingDescriptorPool(), &arena); + EXPECT_EQ(type_pool.MakeTypeType(BoolType()), TypeType(&arena, BoolType())); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/types/type_type.cc b/common/types/type_type.cc new file mode 100644 index 000000000..831b8069b --- /dev/null +++ b/common/types/type_type.cc @@ -0,0 +1,74 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" + +namespace cel { + +namespace common_internal { + +struct TypeTypeData final { + static TypeTypeData* Create(google::protobuf::Arena* absl_nonnull arena, + const Type& type) { + return google::protobuf::Arena::Create(arena, type); + } + + explicit TypeTypeData(const Type& type) : type(type) {} + + TypeTypeData() = delete; + TypeTypeData(const TypeTypeData&) = delete; + TypeTypeData(TypeTypeData&&) = delete; + TypeTypeData& operator=(const TypeTypeData&) = delete; + TypeTypeData& operator=(TypeTypeData&&) = delete; + + const Type type; +}; + +} // namespace common_internal + +std::string TypeType::DebugString() const { + std::string s(name()); + if (!GetParameters().empty()) { + absl::StrAppend(&s, "(", TypeKindToString(GetParameters().front().kind()), + ")"); + } + return s; +} + +TypeType::TypeType(google::protobuf::Arena* absl_nonnull arena, const Type& parameter) + : TypeType(common_internal::TypeTypeData::Create(arena, parameter)) {} + +TypeParameters TypeType::GetParameters() const { + if (data_) { + return TypeParameters(absl::MakeConstSpan(&data_->type, 1)); + } + return {}; +} + +Type TypeType::GetType() const { + if (data_) { + return data_->type; + } + return Type(); +} + +} // namespace cel diff --git a/common/types/type_type.h b/common/types/type_type.h new file mode 100644 index 000000000..652f99008 --- /dev/null +++ b/common/types/type_type.h @@ -0,0 +1,92 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "common/type_kind.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class Type; +class TypeParameters; + +namespace common_internal { +struct TypeTypeData; +} // namespace common_internal + +// `TypeType` is a special type which represents the type of a type. +class TypeType final { + public: + static constexpr TypeKind kKind = TypeKind::kType; + static constexpr absl::string_view kName = "type"; + + TypeType(google::protobuf::Arena* absl_nonnull arena, const Type& parameter); + + TypeType() = default; + TypeType(const TypeType&) = default; + TypeType(TypeType&&) = default; + TypeType& operator=(const TypeType&) = default; + TypeType& operator=(TypeType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + std::string DebugString() const; + + Type GetType() const; + + private: + explicit TypeType(const common_internal::TypeTypeData* absl_nullable data) + : data_(data) {} + + const common_internal::TypeTypeData* absl_nullable data_ = nullptr; +}; + +inline constexpr bool operator==(const TypeType&, const TypeType&) { + return true; +} + +inline constexpr bool operator!=(const TypeType& lhs, const TypeType& rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const TypeType&) { + // TypeType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const TypeType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_H_ diff --git a/common/types/type_type_pool.cc b/common/types/type_type_pool.cc new file mode 100644 index 000000000..1d9238535 --- /dev/null +++ b/common/types/type_type_pool.cc @@ -0,0 +1,26 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/types/type_type_pool.h" + +#include "common/type.h" + +namespace cel::common_internal { + +TypeType TypeTypePool::InternTypeType(const Type& type) { + return *type_types_.lazy_emplace( + type, [&](const auto& ctor) { ctor(TypeType(arena_, type)); }); +} + +} // namespace cel::common_internal diff --git a/common/types/type_type_pool.h b/common/types/type_type_pool.h new file mode 100644 index 000000000..480ee6f7d --- /dev/null +++ b/common/types/type_type_pool.h @@ -0,0 +1,86 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_POOL_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/log/absl_check.h" +#include "absl/log/die_if_null.h" +#include "common/type.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +// `TypeTypePool` is a thread unsafe interning factory for `TypeType`. +class TypeTypePool final { + public: + explicit TypeTypePool(google::protobuf::Arena* absl_nonnull arena) + : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK + + // Returns a `TypeType` which has the provided parameters, interning as + // necessary. + TypeType InternTypeType(const Type& type); + + private: + struct Hasher { + using is_transparent = void; + + size_t operator()(const TypeType& type_type) const { + ABSL_DCHECK_EQ(type_type.GetParameters().size(), 1); + return (*this)(type_type.GetParameters().front()); + } + + size_t operator()(const Type& type) const { + return absl::Hash{}(type); + } + }; + + struct Equaler { + using is_transparent = void; + + bool operator()(const TypeType& lhs, const TypeType& rhs) const { + ABSL_DCHECK_EQ(lhs.GetParameters().size(), 1); + ABSL_DCHECK_EQ(rhs.GetParameters().size(), 1); + return (*this)(lhs.GetParameters().front(), rhs.GetParameters().front()); + } + + bool operator()(const TypeType& lhs, const Type& rhs) const { + ABSL_DCHECK_EQ(lhs.GetParameters().size(), 1); + return (*this)(lhs.GetParameters().front(), rhs); + } + + bool operator()(const Type& lhs, const TypeType& rhs) const { + ABSL_DCHECK_EQ(rhs.GetParameters().size(), 1); + return (*this)(lhs, rhs.GetParameters().front()); + } + + bool operator()(const Type& lhs, const Type& rhs) const { + return lhs == rhs; + } + }; + + google::protobuf::Arena* absl_nonnull const arena_; + absl::flat_hash_set type_types_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_POOL_H_ diff --git a/common/types/type_type_test.cc b/common/types/type_type_test.cc new file mode 100644 index 000000000..978027f98 --- /dev/null +++ b/common/types/type_type_test.cc @@ -0,0 +1,60 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type.h" + +#include + +#include "absl/hash/hash.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(TypeType, Kind) { + EXPECT_EQ(TypeType().kind(), TypeType::kKind); + EXPECT_EQ(Type(TypeType()).kind(), TypeType::kKind); +} + +TEST(TypeType, Name) { + EXPECT_EQ(TypeType().name(), TypeType::kName); + EXPECT_EQ(Type(TypeType()).name(), TypeType::kName); +} + +TEST(TypeType, DebugString) { + { + std::ostringstream out; + out << TypeType(); + EXPECT_EQ(out.str(), TypeType::kName); + } + { + std::ostringstream out; + out << Type(TypeType()); + EXPECT_EQ(out.str(), TypeType::kName); + } +} + +TEST(TypeType, Hash) { + EXPECT_EQ(absl::HashOf(TypeType()), absl::HashOf(TypeType())); +} + +TEST(TypeType, Equal) { + EXPECT_EQ(TypeType(), TypeType()); + EXPECT_EQ(Type(TypeType()), TypeType()); + EXPECT_EQ(TypeType(), Type(TypeType())); + EXPECT_EQ(Type(TypeType()), Type(TypeType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/types.h b/common/types/types.h new file mode 100644 index 000000000..50c1eefc8 --- /dev/null +++ b/common/types/types.h @@ -0,0 +1,99 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPES_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPES_H_ + +#include + +#include "absl/meta/type_traits.h" +#include "absl/types/variant.h" + +namespace cel { + +class Type; +class AnyType; +class BoolType; +class BoolWrapperType; +class BytesType; +class BytesWrapperType; +class DoubleType; +class DoubleWrapperType; +class DurationType; +class DynType; +class EnumType; +class ErrorType; +class FunctionType; +class IntType; +class IntWrapperType; +class ListType; +class MapType; +class NullType; +class OpaqueType; +class OptionalType; +class StringType; +class StringWrapperType; +class StructType; +class MessageType; +class TimestampType; +class TypeParamType; +class TypeType; +class UintType; +class UintWrapperType; +class UnknownType; + +namespace common_internal { + +class BasicStructType; + +template > +struct IsTypeAlternative + : std::bool_constant, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same>> {}; + +template +inline constexpr bool IsTypeAlternativeV = IsTypeAlternative::value; + +using TypeVariant = + absl::variant; + +using StructTypeVariant = + absl::variant; + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPES_H_ diff --git a/common/types/uint_type.h b/common/types/uint_type.h new file mode 100644 index 000000000..122ad77a9 --- /dev/null +++ b/common/types/uint_type.h @@ -0,0 +1,73 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `UintType` represents the primitive `uint` type. +class UintType final { + public: + static constexpr TypeKind kKind = TypeKind::kUint; + static constexpr absl::string_view kName = "uint"; + + UintType() = default; + UintType(const UintType&) = default; + UintType(UintType&&) = default; + UintType& operator=(const UintType&) = default; + UintType& operator=(UintType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(UintType, UintType) { return true; } + +inline constexpr bool operator!=(UintType lhs, UintType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, UintType) { + // UintType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const UintType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_TYPE_H_ diff --git a/common/types/uint_type_test.cc b/common/types/uint_type_test.cc new file mode 100644 index 000000000..2adea78d9 --- /dev/null +++ b/common/types/uint_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(UintType, Kind) { + EXPECT_EQ(UintType().kind(), UintType::kKind); + EXPECT_EQ(Type(UintType()).kind(), UintType::kKind); +} + +TEST(UintType, Name) { + EXPECT_EQ(UintType().name(), UintType::kName); + EXPECT_EQ(Type(UintType()).name(), UintType::kName); +} + +TEST(UintType, DebugString) { + { + std::ostringstream out; + out << UintType(); + EXPECT_EQ(out.str(), UintType::kName); + } + { + std::ostringstream out; + out << Type(UintType()); + EXPECT_EQ(out.str(), UintType::kName); + } +} + +TEST(UintType, Hash) { + EXPECT_EQ(absl::HashOf(UintType()), absl::HashOf(UintType())); +} + +TEST(UintType, Equal) { + EXPECT_EQ(UintType(), UintType()); + EXPECT_EQ(Type(UintType()), UintType()); + EXPECT_EQ(UintType(), Type(UintType())); + EXPECT_EQ(Type(UintType()), Type(UintType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/uint_wrapper_type.h b/common/types/uint_wrapper_type.h new file mode 100644 index 000000000..88ffb8e49 --- /dev/null +++ b/common/types/uint_wrapper_type.h @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_WRAPPER_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_WRAPPER_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `UintWrapperType` is a special type which has no direct value +// representation. It is used to represent `google.protobuf.UInt64Value`, which +// never exists at runtime as a value. Its primary usage is for type checking +// and unpacking at runtime. +class UintWrapperType final { + public: + static constexpr TypeKind kKind = TypeKind::kUintWrapper; + static constexpr absl::string_view kName = "google.protobuf.UInt64Value"; + + UintWrapperType() = default; + UintWrapperType(const UintWrapperType&) = default; + UintWrapperType(UintWrapperType&&) = default; + UintWrapperType& operator=(const UintWrapperType&) = default; + UintWrapperType& operator=(UintWrapperType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(UintWrapperType, UintWrapperType) { + return true; +} + +inline constexpr bool operator!=(UintWrapperType lhs, UintWrapperType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, UintWrapperType) { + // UintWrapperType is really a singleton and all instances are equal. Nothing + // to hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, + const UintWrapperType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_WRAPPER_TYPE_H_ diff --git a/common/types/uint_wrapper_type_test.cc b/common/types/uint_wrapper_type_test.cc new file mode 100644 index 000000000..a2fe47d8d --- /dev/null +++ b/common/types/uint_wrapper_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(UintWrapperType, Kind) { + EXPECT_EQ(UintWrapperType().kind(), UintWrapperType::kKind); + EXPECT_EQ(Type(UintWrapperType()).kind(), UintWrapperType::kKind); +} + +TEST(UintWrapperType, Name) { + EXPECT_EQ(UintWrapperType().name(), UintWrapperType::kName); + EXPECT_EQ(Type(UintWrapperType()).name(), UintWrapperType::kName); +} + +TEST(UintWrapperType, DebugString) { + { + std::ostringstream out; + out << UintWrapperType(); + EXPECT_EQ(out.str(), UintWrapperType::kName); + } + { + std::ostringstream out; + out << Type(UintWrapperType()); + EXPECT_EQ(out.str(), UintWrapperType::kName); + } +} + +TEST(UintWrapperType, Hash) { + EXPECT_EQ(absl::HashOf(UintWrapperType()), absl::HashOf(UintWrapperType())); +} + +TEST(UintWrapperType, Equal) { + EXPECT_EQ(UintWrapperType(), UintWrapperType()); + EXPECT_EQ(Type(UintWrapperType()), UintWrapperType()); + EXPECT_EQ(UintWrapperType(), Type(UintWrapperType())); + EXPECT_EQ(Type(UintWrapperType()), Type(UintWrapperType())); +} + +} // namespace +} // namespace cel diff --git a/common/types/unknown_type.h b/common/types/unknown_type.h new file mode 100644 index 000000000..5ea7d92aa --- /dev/null +++ b/common/types/unknown_type.h @@ -0,0 +1,74 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_UNKNOWN_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_UNKNOWN_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "common/type_kind.h" + +namespace cel { + +class Type; +class TypeParameters; + +// `UnknownType` is a special type which represents an unknown at runtime. It +// has no in-language representation. +class UnknownType final { + public: + static constexpr TypeKind kKind = TypeKind::kUnknown; + static constexpr absl::string_view kName = "*unknown*"; + + UnknownType() = default; + UnknownType(const UnknownType&) = default; + UnknownType(UnknownType&&) = default; + UnknownType& operator=(const UnknownType&) = default; + UnknownType& operator=(UnknownType&&) = default; + + static TypeKind kind() { return kKind; } + + static absl::string_view name() { return kName; } + + static TypeParameters GetParameters(); + + static std::string DebugString() { return std::string(name()); } +}; + +inline constexpr bool operator==(UnknownType, UnknownType) { return true; } + +inline constexpr bool operator!=(UnknownType lhs, UnknownType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, UnknownType) { + // UnknownType is really a singleton and all instances are equal. Nothing to + // hash. + return std::move(state); +} + +inline std::ostream& operator<<(std::ostream& out, const UnknownType& type) { + return out << type.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_UNKNOWN_TYPE_H_ diff --git a/common/types/unknown_type_test.cc b/common/types/unknown_type_test.cc new file mode 100644 index 000000000..2f105540d --- /dev/null +++ b/common/types/unknown_type_test.cc @@ -0,0 +1,59 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "common/type.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(UnknownType, Kind) { + EXPECT_EQ(UnknownType().kind(), UnknownType::kKind); + EXPECT_EQ(Type(UnknownType()).kind(), UnknownType::kKind); +} + +TEST(UnknownType, Name) { + EXPECT_EQ(UnknownType().name(), UnknownType::kName); + EXPECT_EQ(Type(UnknownType()).name(), UnknownType::kName); +} + +TEST(UnknownType, DebugString) { + { + std::ostringstream out; + out << UnknownType(); + EXPECT_EQ(out.str(), UnknownType::kName); + } + { + std::ostringstream out; + out << Type(UnknownType()); + EXPECT_EQ(out.str(), UnknownType::kName); + } +} + +TEST(UnknownType, Hash) { + EXPECT_EQ(absl::HashOf(UnknownType()), absl::HashOf(UnknownType())); +} + +TEST(UnknownType, Equal) { + EXPECT_EQ(UnknownType(), UnknownType()); + EXPECT_EQ(Type(UnknownType()), UnknownType()); + EXPECT_EQ(UnknownType(), Type(UnknownType())); + EXPECT_EQ(Type(UnknownType()), Type(UnknownType())); +} + +} // namespace +} // namespace cel diff --git a/common/unknown.cc b/common/unknown.cc deleted file mode 100644 index 98a74ece2..000000000 --- a/common/unknown.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include "common/unknown.h" -#include "common/macros.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -Unknown::Unknown(Id id) { ids_.insert(id); } - -Unknown::Unknown(absl::Span ids) { - assert(ids.begin() != ids.end()); - ids_.insert(ids.begin(), ids.end()); -} - -std::size_t Unknown::hash_code() const { - std::size_t code = internal::kIntegralTypeOffset; - for (const auto& id : ids_) { - internal::AccumulateHashNoOrder(id, &code); - } - return code; -} - -std::string Unknown::ToString() const { - internal::SequencePrinter printer; - return printer("Unknown", ids_); -} - -} // namespace common -} // namespace expr -} // namespace api -} // namespace google diff --git a/common/unknown.h b/common/unknown.h index 746d64803..1e0001879 100644 --- a/common/unknown.h +++ b/common/unknown.h @@ -1,49 +1,27 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_COMMON_UNKNOWN_H_ #define THIRD_PARTY_CEL_CPP_COMMON_UNKNOWN_H_ -#include "absl/types/span.h" -#include "common/id.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -/** An unknown CEL value. */ -class Unknown { - public: - explicit Unknown(Id id); - explicit Unknown(absl::Span ids); - - const std::set& ids() const { return ids_; } - - inline bool operator==(const Unknown& rhs) const { - return this == &rhs || ids_ == rhs.ids_; - } - - inline bool operator!=(const Unknown& rhs) const { return !(*this == rhs); } - - /** The hash code for this value. */ - std::size_t hash_code() const; - - /** - * A string useful for debugging. - * - * Format may change, and computation may be expensive. - */ - std::string ToString() const; +#include "base/internal/unknown_set.h" - private: - std::set ids_; -}; +namespace cel { -inline std::ostream& operator<<(std::ostream& os, const Unknown& value) { - return os << value.ToString(); -} +// `Unknown` is a collection of unknown attributes and function results. +using Unknown = base_internal::UnknownSet; -} // namespace common -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_UNKNOWN_H_ diff --git a/common/value.cc b/common/value.cc index 429d2b412..1cd3f54e1 100644 --- a/common/value.cc +++ b/common/value.cc @@ -1,397 +1,2790 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "common/value.h" -#include -#include +#include +#include +#include +#include +#include #include +#include -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/util/message_differencer.h" -#include "absl/strings/escaping.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" -#include "absl/utility/utility.h" -#include "common/macros.h" -#include "internal/cel_printer.h" -#include "internal/hash_util.h" -#include "internal/status_util.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/allocator.h" +#include "common/memory.h" +#include "common/optional_ref.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/list_value_builder.h" +#include "common/values/map_value_builder.h" +#include "common/values/struct_value_builder.h" +#include "common/values/values.h" +#include "internal/number.h" +#include "internal/protobuf_runtime_version.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +#undef GetMessage + +namespace cel { +namespace { + +google::protobuf::Arena* absl_nonnull MessageArenaOr( + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull or_arena) { + google::protobuf::Arena* absl_nullable arena = message->GetArena(); + if (arena == nullptr) { + arena = or_arena; + } + return arena; +} + +} // namespace + +Type Value::GetRuntimeType() const { + switch (kind()) { + case ValueKind::kNull: + return NullType(); + case ValueKind::kBool: + return BoolType(); + case ValueKind::kInt: + return IntType(); + case ValueKind::kUint: + return UintType(); + case ValueKind::kDouble: + return DoubleType(); + case ValueKind::kString: + return StringType(); + case ValueKind::kBytes: + return BytesType(); + case ValueKind::kStruct: + return this->GetStruct().GetRuntimeType(); + case ValueKind::kDuration: + return DurationType(); + case ValueKind::kTimestamp: + return TimestampType(); + case ValueKind::kList: + return ListType(); + case ValueKind::kMap: + return MapType(); + case ValueKind::kUnknown: + return UnknownType(); + case ValueKind::kType: + return TypeType(); + case ValueKind::kError: + return ErrorType(); + case ValueKind::kOpaque: + return this->GetOpaque().GetRuntimeType(); + default: + return cel::Type(); + } +} + +namespace { + +template +struct IsMonostate : std::is_same, std::monostate> {}; + +} // namespace + +absl::string_view Value::GetTypeName() const { + return variant_.Visit([](const auto& alternative) -> absl::string_view { + return alternative.GetTypeName(); + }); +} + +std::string Value::DebugString() const { + return variant_.Visit([](const auto& alternative) -> std::string { + return alternative.DebugString(); + }); +} + +absl::Status Value::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.SerializeTo(descriptor_pool, message_factory, output); + }); +} + +absl::Status Value::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + return variant_.Visit([descriptor_pool, message_factory, + json](const auto& alternative) -> absl::Status { + return alternative.ConvertToJson(descriptor_pool, message_factory, json); + }); +} + +absl::Status Value::ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + + return variant_.Visit(absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError("use of invalid Value"); + }, + [descriptor_pool, message_factory, json]( + const common_internal::LegacyListValue& alternative) -> absl::Status { + return alternative.ConvertToJsonArray(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const CustomListValue& alternative) -> absl::Status { + return alternative.ConvertToJsonArray(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const ParsedRepeatedFieldValue& alternative) -> absl::Status { + return alternative.ConvertToJsonArray(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const ParsedJsonListValue& alternative) -> absl::Status { + return alternative.ConvertToJsonArray(descriptor_pool, message_factory, + json); + }, + [](const auto& alternative) -> absl::Status { + return TypeConversionError(alternative.GetTypeName(), + "google.protobuf.ListValue") + .NativeValue(); + })); +} + +absl::Status Value::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + return variant_.Visit(absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError("use of invalid Value"); + }, + [descriptor_pool, message_factory, json]( + const common_internal::LegacyMapValue& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const CustomMapValue& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const ParsedMapFieldValue& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const ParsedJsonMapValue& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const common_internal::LegacyStructValue& alternative) + -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const CustomStructValue& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }, + [descriptor_pool, message_factory, + json](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }, + [](const auto& alternative) -> absl::Status { + return TypeConversionError(alternative.GetTypeName(), + "google.protobuf.Struct") + .NativeValue(); + })); +} + +absl::Status Value::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&other, descriptor_pool, message_factory, arena, + result](const auto& alternative) -> absl::Status { + return alternative.Equal(other, descriptor_pool, message_factory, arena, + result); + }); +} + +bool Value::IsZeroValue() const { + return variant_.Visit([](const auto& alternative) -> bool { + return alternative.IsZeroValue(); + }); +} + +namespace { + +template +struct HasCloneMethod : std::false_type {}; + +template +struct HasCloneMethod().Clone( + std::declval()))>> + : std::true_type {}; + +} // namespace + +Value Value::Clone(google::protobuf::Arena* absl_nonnull arena) const { + return variant_.Visit([arena](const auto& alternative) -> Value { + if constexpr (IsMonostate::value) { + return Value(); + } else if constexpr (HasCloneMethod>::value) { + return alternative.Clone(arena); + } else { + return alternative; + } + }); +} + +std::ostream& operator<<(std::ostream& out, const Value& value) { + return value.variant_.Visit([&out](const auto& alternative) -> std::ostream& { + return out << alternative; + }); +} + +namespace { + +Value NonNullEnumValue(const google::protobuf::EnumValueDescriptor* absl_nonnull value) { + ABSL_DCHECK(value != nullptr); + return IntValue(value->number()); +} + +Value NonNullEnumValue(const google::protobuf::EnumDescriptor* absl_nonnull type, + int32_t number) { + ABSL_DCHECK(type != nullptr); + if (type->is_closed()) { + if (ABSL_PREDICT_FALSE(type->FindValueByNumber(number) == nullptr)) { + return ErrorValue(absl::InvalidArgumentError(absl::StrCat( + "closed enum has no such value: ", type->full_name(), ".", number))); + } + } + return IntValue(number); +} + +} // namespace + +Value Value::Enum(const google::protobuf::EnumValueDescriptor* absl_nonnull value) { + ABSL_DCHECK(value != nullptr); + if (value->type()->full_name() == "google.protobuf.NullValue") { + ABSL_DCHECK_EQ(value->number(), 0); + return NullValue(); + } + return NonNullEnumValue(value); +} + +Value Value::Enum(const google::protobuf::EnumDescriptor* absl_nonnull type, + int32_t number) { + ABSL_DCHECK(type != nullptr); + if (type->full_name() == "google.protobuf.NullValue") { + ABSL_DCHECK_EQ(number, 0); + return NullValue(); + } + return NonNullEnumValue(type, number); +} + +namespace common_internal { + +namespace { + +void BoolMapFieldKeyAccessor(const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + *result = BoolValue(key.GetBoolValue()); +} + +void Int32MapFieldKeyAccessor(const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + *result = IntValue(key.GetInt32Value()); +} + +void Int64MapFieldKeyAccessor(const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + *result = IntValue(key.GetInt64Value()); +} + +void UInt32MapFieldKeyAccessor(const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + *result = UintValue(key.GetUInt32Value()); +} + +void UInt64MapFieldKeyAccessor(const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + *result = UintValue(key.GetUInt64Value()); +} + +void StringMapFieldKeyAccessor(const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + +#if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) + *result = StringValue(Borrower::Arena(MessageArenaOr(message, arena)), + key.GetStringValue()); +#else + *result = StringValue(arena, key.GetStringValue()); +#endif +} + +} // namespace + +absl::StatusOr MapFieldKeyAccessorFor( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + switch (field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return &BoolMapFieldKeyAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return &Int32MapFieldKeyAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return &Int64MapFieldKeyAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return &UInt32MapFieldKeyAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return &UInt64MapFieldKeyAccessor; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return &StringMapFieldKeyAccessor; + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected map key type: ", field->cpp_type_name())); + } +} + +namespace { + +void DoubleMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE); + + *result = DoubleValue(value.GetDoubleValue()); +} + +void FloatMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_FLOAT); + + *result = DoubleValue(value.GetFloatValue()); +} + +void Int64MapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT64); + + *result = IntValue(value.GetInt64Value()); +} + +void UInt64MapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT64); + + *result = UintValue(value.GetUInt64Value()); +} + +void Int32MapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT32); + + *result = IntValue(value.GetInt32Value()); +} + +void UInt32MapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT32); + + *result = UintValue(value.GetUInt32Value()); +} + +void BoolMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_BOOL); + + *result = BoolValue(value.GetBoolValue()); +} + +void StringMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_STRING); -namespace google { -namespace api { -namespace expr { -namespace common { + if (message->GetArena() == nullptr) { + *result = StringValue(arena, value.GetStringValue()); + } else { + *result = StringValue(Borrower::Arena(arena), value.GetStringValue()); + } +} + +void MessageMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); + + *result = Value::WrapMessage(&value.GetMessageValue(), descriptor_pool, + message_factory, arena); +} + +void BytesMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_BYTES); + + if (message->GetArena() == nullptr) { + *result = BytesValue(arena, value.GetStringValue()); + } else { + *result = BytesValue(Borrower::Arena(arena), value.GetStringValue()); + } +} + +void EnumMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_ENUM); + + *result = NonNullEnumValue(field->enum_type(), value.GetEnumValue()); +} + +void NullMapFieldValueAccessor( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK(field->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM && + field->enum_type()->full_name() == "google.protobuf.NullValue"); + + *result = NullValue(); +} + +} // namespace + +absl::StatusOr MapFieldValueAccessorFor( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + switch (field->type()) { + case google::protobuf::FieldDescriptor::TYPE_DOUBLE: + return &DoubleMapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_FLOAT: + return &FloatMapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT64: + return &Int64MapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT64: + return &UInt64MapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT32: + return &Int32MapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_BOOL: + return &BoolMapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_STRING: + return &StringMapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_MESSAGE: + return &MessageMapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_BYTES: + return &BytesMapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT32: + return &UInt32MapFieldValueAccessor; + case google::protobuf::FieldDescriptor::TYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return &NullMapFieldValueAccessor; + } + return &EnumMapFieldValueAccessor; + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected protocol buffer message field type: ", + field->type_name())); + } +} + +namespace { + +void DoubleRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = DoubleValue(reflection->GetRepeatedDouble(*message, field, index)); +} + +void FloatRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_FLOAT); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = DoubleValue(reflection->GetRepeatedFloat(*message, field, index)); +} + +void Int64RepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT64); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = IntValue(reflection->GetRepeatedInt64(*message, field, index)); +} + +void UInt64RepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT64); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = UintValue(reflection->GetRepeatedUInt64(*message, field, index)); +} + +void Int32RepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT32); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = IntValue(reflection->GetRepeatedInt32(*message, field, index)); +} + +void UInt32RepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT32); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = UintValue(reflection->GetRepeatedUInt32(*message, field, index)); +} + +void BoolRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_BOOL); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = BoolValue(reflection->GetRepeatedBool(*message, field, index)); +} + +void StringRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_STRING); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + std::string scratch; + absl::visit( + absl::Overload( + [&](absl::string_view string) { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + *result = StringValue(arena, std::move(scratch)); + } else { + if (message->GetArena() == nullptr) { + *result = StringValue(arena, string); + } else { + *result = StringValue(Borrower::Arena(arena), string); + } + } + }, + [&](absl::Cord&& cord) { *result = StringValue(std::move(cord)); }), + well_known_types::AsVariant(well_known_types::GetRepeatedStringField( + *message, field, index, scratch))); +} + +void MessageRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = Value::WrapMessage( + &reflection->GetRepeatedMessage(*message, field, index), descriptor_pool, + message_factory, arena); +} + +void BytesRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_BYTES); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + std::string scratch; + absl::visit( + absl::Overload( + [&](absl::string_view string) { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + *result = BytesValue(arena, std::move(scratch)); + } else { + if (message->GetArena() == nullptr) { + *result = BytesValue(arena, string); + } else { + *result = BytesValue(Borrower::Arena(arena), string); + } + } + }, + [&](absl::Cord&& cord) { *result = BytesValue(std::move(cord)); }), + well_known_types::AsVariant(well_known_types::GetRepeatedBytesField( + *message, field, index, scratch))); +} + +void EnumRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_ENUM); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = NonNullEnumValue( + field->enum_type(), + reflection->GetRepeatedEnumValue(*message, field, index)); +} + +void NullRepeatedFieldAccessor( + int index, const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(reflection != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK_EQ(reflection, message->GetReflection()); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(field->is_repeated()); + ABSL_DCHECK(field->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM && + field->enum_type()->full_name() == "google.protobuf.NullValue"); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); + + *result = NullValue(); +} + +} // namespace + +absl::StatusOr RepeatedFieldAccessorFor( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + switch (field->type()) { + case google::protobuf::FieldDescriptor::TYPE_DOUBLE: + return &DoubleRepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_FLOAT: + return &FloatRepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT64: + return &Int64RepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT64: + return &UInt64RepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT32: + return &Int32RepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_BOOL: + return &BoolRepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_STRING: + return &StringRepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_MESSAGE: + return &MessageRepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_BYTES: + return &BytesRepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT32: + return &UInt32RepeatedFieldAccessor; + case google::protobuf::FieldDescriptor::TYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return &NullRepeatedFieldAccessor; + } + return &EnumRepeatedFieldAccessor; + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected protocol buffer message field type: ", + field->type_name())); + } +} -using ::google::api::expr::internal::NotFoundError; +} // namespace common_internal namespace { -static constexpr const Value::Kind kIndexToKind[] = { - Value::Kind::kNull, // null value - Value::Kind::kBool, // bool value - Value::Kind::kInt, // int value - Value::Kind::kUInt, // uint value - Value::Kind::kDouble, // double value - Value::Kind::kEnum, // enum_value - Value::Kind::kType, // type value - Value::Kind::kType, // type value - Value::Kind::kType, // type value - Value::Kind::kUnknown, // unknown - Value::Kind::kString, // string value - Value::Kind::kString, // string value - Value::Kind::kString, // string value - Value::Kind::kBytes, // bytes value - Value::Kind::kBytes, // bytes value - Value::Kind::kBytes, // bytes value - Value::Kind::kMap, // map value - Value::Kind::kMap, // map value - Value::Kind::kList, // list value - Value::Kind::kList, // list value - Value::Kind::kObject, // object value - Value::Kind::kObject, // object value - Value::Kind::kDuration, // duration - Value::Kind::kTime, // time - Value::Kind::kEnum, // enum_value - Value::Kind::kType, // type value - Value::Kind::kError, // error - Value::Kind::kUnknown, // unknown +// Overloads for `well_known_types::Value` which handles the primitive values +// which require no special handling based on allocators. +Value VistWellKnownTypeValue(std::nullptr_t) { return NullValue(); } + +Value VistWellKnownTypeValue(bool value) { return BoolValue(value); } + +Value VistWellKnownTypeValue(int32_t value) { return IntValue(value); } + +Value VistWellKnownTypeValue(int64_t value) { return IntValue(value); } + +Value VistWellKnownTypeValue(uint32_t value) { return UintValue(value); } + +Value VistWellKnownTypeValue(uint64_t value) { return UintValue(value); } + +Value VistWellKnownTypeValue(float value) { return DoubleValue(value); } + +Value VistWellKnownTypeValue(double value) { return DoubleValue(value); } + +Value VistWellKnownTypeValue(absl::Duration value) { + return DurationValue(value); +} + +Value VistWellKnownTypeValue(absl::Time value) { return TimestampValue(value); } + +struct OwningWellKnownTypesValueVisitor { + google::protobuf::Arena* absl_nullable arena; + std::string* absl_nonnull scratch; + + Value operator()(well_known_types::BytesValue&& value) const { + return absl::visit(absl::Overload( + [&](absl::string_view string) -> BytesValue { + if (string.empty()) { + return BytesValue(); + } + if (scratch->data() == string.data() && + scratch->size() == string.size()) { + return BytesValue(arena, std::move(*scratch)); + } + return BytesValue(arena, string); + }, + [&](absl::Cord&& cord) -> BytesValue { + if (cord.empty()) { + return BytesValue(); + } + return BytesValue(arena, cord); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(well_known_types::StringValue&& value) const { + return absl::visit(absl::Overload( + [&](absl::string_view string) -> StringValue { + if (string.empty()) { + return StringValue(); + } + if (scratch->data() == string.data() && + scratch->size() == string.size()) { + return StringValue(arena, std::move(*scratch)); + } + return StringValue(arena, string); + }, + [&](absl::Cord&& cord) -> StringValue { + if (cord.empty()) { + return StringValue(); + } + return StringValue(arena, cord); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(well_known_types::ListValue&& value) const { + return absl::visit( + absl::Overload( + [&](well_known_types::ListValueConstRef value) -> ListValue { + auto* cloned = value.get().New(arena); + cloned->CopyFrom(value.get()); + return ParsedJsonListValue(cloned, arena); + }, + [&](well_known_types::ListValuePtr value) -> ListValue { + if (value->GetArena() != arena) { + auto* cloned = value->New(arena); + cloned->CopyFrom(*value); + return ParsedJsonListValue(cloned, arena); + } + return ParsedJsonListValue(value.release(), arena); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(well_known_types::Struct&& value) const { + return absl::visit( + absl::Overload( + [&](well_known_types::StructConstRef value) -> MapValue { + auto* cloned = value.get().New(arena); + cloned->CopyFrom(value.get()); + return ParsedJsonMapValue(cloned, arena); + }, + [&](well_known_types::StructPtr value) -> MapValue { + if (value.arena() != arena) { + auto* cloned = value->New(arena); + cloned->CopyFrom(*value); + return ParsedJsonMapValue(cloned, arena); + } + return ParsedJsonMapValue(value.release(), arena); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(Unique value) const { + if (value->GetArena() != arena) { + auto* cloned = value->New(arena); + cloned->CopyFrom(*value); + return ParsedMessageValue(cloned, arena); + } + return ParsedMessageValue(value.release(), arena); + } + + template + Value operator()(T t) const { + return VistWellKnownTypeValue(t); + } +}; + +struct BorrowingWellKnownTypesValueVisitor { + const google::protobuf::Message* absl_nonnull message; + google::protobuf::Arena* absl_nonnull arena; + std::string* absl_nonnull scratch; + + Value operator()(well_known_types::BytesValue&& value) const { + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> BytesValue { + if (string.data() == scratch->data() && + string.size() == scratch->size()) { + return BytesValue(arena, std::move(*scratch)); + } else { + return BytesValue( + Borrower::Arena(MessageArenaOr(message, arena)), string); + } + }, + [&](absl::Cord&& cord) -> BytesValue { + return BytesValue(std::move(cord)); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(well_known_types::StringValue&& value) const { + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> StringValue { + if (string.data() == scratch->data() && + string.size() == scratch->size()) { + return StringValue(arena, std::move(*scratch)); + } else { + return StringValue( + Borrower::Arena(MessageArenaOr(message, arena)), string); + } + }, + [&](absl::Cord&& cord) -> StringValue { + return StringValue(std::move(cord)); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(well_known_types::ListValue&& value) const { + return absl::visit( + absl::Overload( + [&](well_known_types::ListValueConstRef value) + -> ParsedJsonListValue { + return ParsedJsonListValue(&value.get(), + MessageArenaOr(&value.get(), arena)); + }, + [&](well_known_types::ListValuePtr value) -> ParsedJsonListValue { + if (value->GetArena() != arena) { + auto* cloned = value->New(arena); + cloned->CopyFrom(*value); + return ParsedJsonListValue(cloned, arena); + } + return ParsedJsonListValue(value.release(), arena); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(well_known_types::Struct&& value) const { + return absl::visit( + absl::Overload( + [&](well_known_types::StructConstRef value) -> ParsedJsonMapValue { + return ParsedJsonMapValue(&value.get(), + MessageArenaOr(&value.get(), arena)); + }, + [&](well_known_types::StructPtr value) -> ParsedJsonMapValue { + if (value->GetArena() != arena) { + auto* cloned = value->New(arena); + cloned->CopyFrom(*value); + return ParsedJsonMapValue(cloned, arena); + } + return ParsedJsonMapValue(value.release(), arena); + }), + well_known_types::AsVariant(std::move(value))); + } + + Value operator()(Unique&& value) const { + if (value->GetArena() != arena) { + auto* cloned = value->New(arena); + cloned->CopyFrom(*value); + return ParsedMessageValue(cloned, arena); + } + return ParsedMessageValue(value.release(), arena); + } + + template + Value operator()(T t) const { + return VistWellKnownTypeValue(t); + } }; } // namespace -Value::~Value() { - // All the enums should line up. - static_assert(28 == absl::variant_size::value, "size mismatch"); - static_assert(DATA_SIZE == absl::variant_size::value, - "size mismatch"); - static_assert( - ABSL_ARRAYSIZE(kIndexToKind) == absl::variant_size::value, - "size mismatch"); +Value Value::FromMessage( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); - static_assert(static_cast(Kind::DO_NOT_USE) == - internal::types_size::value, - "size mismatch"); + std::string scratch; + auto status_or_adapted = well_known_types::AdaptFromMessage( + arena, message, descriptor_pool, message_factory, scratch); + if (ABSL_PREDICT_FALSE(!status_or_adapted.ok())) { + return ErrorValue(std::move(status_or_adapted).status()); + } + return absl::visit( + absl::Overload(OwningWellKnownTypesValueVisitor{ + /* .arena = */ arena, /* .scratch = */ &scratch}, + [&](std::monostate) -> Value { + auto* cloned = message.New(arena); + cloned->CopyFrom(message); + return ParsedMessageValue(cloned, arena); + }), + std::move(status_or_adapted).value()); +} - // Value size should not exceed the size of a basic variant type. - static_assert(sizeof(Value) <= sizeof(absl::variant), - "Value oversized"); +Value Value::FromMessage( + google::protobuf::Message&& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + std::string scratch; + auto status_or_adapted = well_known_types::AdaptFromMessage( + arena, message, descriptor_pool, message_factory, scratch); + if (ABSL_PREDICT_FALSE(!status_or_adapted.ok())) { + return ErrorValue(std::move(status_or_adapted).status()); + } + return absl::visit( + absl::Overload(OwningWellKnownTypesValueVisitor{ + /* .arena = */ arena, /* .scratch = */ &scratch}, + [&](std::monostate) -> Value { + auto* cloned = message.New(arena); + cloned->GetReflection()->Swap(cloned, &message); + return ParsedMessageValue(cloned, arena); + }), + std::move(status_or_adapted).value()); } -Value::Kind Value::kind() const { return kIndexToKind[data_.index()]; } +Value Value::WrapMessage( + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); -bool Value::owns_value() const { - if (is_inline() || data_.index() >= kOptionalOwnershipEnd) { - return true; + std::string scratch; + absl::StatusOr adapted_value = + well_known_types::AdaptFromMessage(arena, *message, descriptor_pool, + message_factory, scratch); + if (ABSL_PREDICT_FALSE(!adapted_value.ok())) { + return ErrorValue(std::move(adapted_value).status()); + } + return absl::visit( + absl::Overload(BorrowingWellKnownTypesValueVisitor{ + /* .message = */ message, /* .arena = */ arena, + /* .scratch = */ &scratch}, + [&](std::monostate) -> Value { + if (message->GetArena() != arena) { + auto* cloned = message->New(arena); + cloned->CopyFrom(*message); + return ParsedMessageValue(cloned, arena); + } + return ParsedMessageValue(message, arena); + }), + std::move(adapted_value).value()); +} + +Value Value::WrapMessageUnsafe( + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + std::string scratch; + absl::StatusOr adapted_value = + well_known_types::AdaptFromMessage(arena, *message, descriptor_pool, + message_factory, scratch); + if (ABSL_PREDICT_FALSE(!adapted_value.ok())) { + return ErrorValue(std::move(adapted_value).status()); } + return absl::visit( + absl::Overload(BorrowingWellKnownTypesValueVisitor{ + /* .message = */ message, /* .arena = */ arena, + /* .scratch = */ &scratch}, + [&](std::monostate) -> Value { + if (message->GetArena() != arena) { + return UnsafeParsedMessageValue(message); + } + return ParsedMessageValue(message, arena); + }), + std::move(adapted_value).value()); +} + +namespace { - switch (data_.index()) { - case kBytes: - case kStr: +bool IsWellKnownMessageWrapperType( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + switch (descriptor->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: return true; - case kList: - return absl::get(data_)->owns_value(); - case kMap: - return absl::get(data_)->owns_value(); - case kObject: - return absl::get(data_)->owns_value(); default: return false; } } -Value Value::GetType() const { - switch (kind()) { - // Basic values - case Kind::kNull: - return Value::FromType(BasicTypeValue::kNull); - case Kind::kBool: - return Value::FromType(BasicTypeValue::kBool); - case Kind::kInt: - return Value::FromType(BasicTypeValue::kInt); - case Kind::kUInt: - return Value::FromType(BasicTypeValue::kUint); - case Kind::kDouble: - return Value::FromType(BasicTypeValue::kDouble); - case Kind::kString: - return Value::FromType(BasicTypeValue::kString); - case Kind::kBytes: - return Value::FromType(BasicTypeValue::kBytes); - case Kind::kType: - return Value::FromType(BasicTypeValue::kType); - case Kind::kMap: - return Value::FromType(BasicTypeValue::kMap); - case Kind::kList: - return Value::FromType(BasicTypeValue::kList); +template +Value WrapFieldImpl( + ProtoWrapperTypeOptions wrapper_type_options, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK_EQ(message->GetDescriptor(), field->containing_type()); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(!IsWellKnownMessageType(message->GetDescriptor())); + + const auto* reflection = message->GetReflection(); + if (field->is_map()) { + if (reflection->FieldSize(*message, field) == 0) { + return MapValue(); + } + if constexpr (Unsafe::value) { + return UnsafeParsedMapFieldValue(message, field); + } else { + return ParsedMapFieldValue(message, field, + MessageArenaOr(message, arena)); + } + } + if (field->is_repeated()) { + if (reflection->FieldSize(*message, field) == 0) { + return ListValue(); + } + if constexpr (Unsafe::value) { + return UnsafeParsedRepeatedFieldValue(message, field); + } else { + return ParsedRepeatedFieldValue(message, field, + MessageArenaOr(message, arena)); + } + } + switch (field->type()) { + case google::protobuf::FieldDescriptor::TYPE_DOUBLE: + return DoubleValue(reflection->GetDouble(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_FLOAT: + return DoubleValue(reflection->GetFloat(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_INT64: + return IntValue(reflection->GetInt64(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_UINT64: + return UintValue(reflection->GetUInt64(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_INT32: + return IntValue(reflection->GetInt32(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_FIXED64: + return UintValue(reflection->GetUInt64(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_FIXED32: + return UintValue(reflection->GetUInt32(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_BOOL: + return BoolValue(reflection->GetBool(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_STRING: { + std::string scratch; + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> StringValue { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + return StringValue(arena, std::move(scratch)); + } + if constexpr (Unsafe::value) { + return StringValue::WrapUnsafe(string); + } else { + return StringValue( + Borrower::Arena(MessageArenaOr(message, arena)), string); + } + }, + [&](absl::Cord&& cord) -> StringValue { + return StringValue(std::move(cord)); + }), + well_known_types::AsVariant( + well_known_types::GetStringField(*message, field, scratch))); + } + case google::protobuf::FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_MESSAGE: + if (wrapper_type_options == ProtoWrapperTypeOptions::kUnsetNull && + IsWellKnownMessageWrapperType(field->message_type()) && + !reflection->HasField(*message, field)) { + return NullValue(); + } + if constexpr (Unsafe::value) { + return Value::WrapMessageUnsafe( + &reflection->GetMessage(*message, field), descriptor_pool, + message_factory, arena); + } else { + return Value::WrapMessage(&reflection->GetMessage(*message, field), + descriptor_pool, message_factory, arena); + } + case google::protobuf::FieldDescriptor::TYPE_BYTES: { + std::string scratch; + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> BytesValue { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + return BytesValue(arena, std::move(scratch)); + } + if constexpr (Unsafe::value) { + return BytesValue::WrapUnsafe(string); + } else { + return BytesValue( + Borrower::Arena(MessageArenaOr(message, arena)), string); + } + }, + [&](absl::Cord&& cord) -> BytesValue { + return BytesValue(std::move(cord)); + }), + well_known_types::AsVariant( + well_known_types::GetBytesField(*message, field, scratch))); + } + case google::protobuf::FieldDescriptor::TYPE_UINT32: + return UintValue(reflection->GetUInt32(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_ENUM: + return Value::Enum(field->enum_type(), + reflection->GetEnumValue(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_SFIXED32: + return IntValue(reflection->GetInt32(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_SFIXED64: + return IntValue(reflection->GetInt64(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_SINT32: + return IntValue(reflection->GetInt32(*message, field)); + case google::protobuf::FieldDescriptor::TYPE_SINT64: + return IntValue(reflection->GetInt64(*message, field)); + default: + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("unexpected protocol buffer message field type: ", + field->type_name()))); + } +} - // Enum - case Kind::kEnum: - return Value::FromType(enum_value().type()); +template +Value WrapRepeatedFieldImpl( + int index, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_GE(index, 0); + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); - // Objects. - case Kind::kObject: - return Value::FromType(object_value().object_type()); - case Kind::kDuration: - return Value::FromType( - ObjectType(google::protobuf::Duration::descriptor())); - case Kind::kTime: - return Value::FromType( - ObjectType(google::protobuf::Timestamp::descriptor())); + const auto* reflection = message->GetReflection(); + const int size = reflection->FieldSize(*message, field); + if (ABSL_PREDICT_FALSE(index < 0 || index >= size)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("index out of bounds: ", index))); + } + switch (field->type()) { + case google::protobuf::FieldDescriptor::TYPE_DOUBLE: + return DoubleValue(reflection->GetRepeatedDouble(*message, field, index)); + case google::protobuf::FieldDescriptor::TYPE_FLOAT: + return DoubleValue(reflection->GetRepeatedFloat(*message, field, index)); + case google::protobuf::FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT64: + return IntValue(reflection->GetRepeatedInt64(*message, field, index)); + case google::protobuf::FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT64: + return UintValue(reflection->GetRepeatedUInt64(*message, field, index)); + case google::protobuf::FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT32: + return IntValue(reflection->GetRepeatedInt32(*message, field, index)); + case google::protobuf::FieldDescriptor::TYPE_BOOL: + return BoolValue(reflection->GetRepeatedBool(*message, field, index)); + case google::protobuf::FieldDescriptor::TYPE_STRING: { + std::string scratch; + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> StringValue { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + return StringValue(arena, std::move(scratch)); + } + if constexpr (Unsafe::value) { + return StringValue::WrapUnsafe(string); + } else { + return StringValue( + Borrower::Arena(MessageArenaOr(message, arena)), string); + } + }, + [&](absl::Cord&& cord) -> StringValue { + return StringValue(std::move(cord)); + }), + well_known_types::AsVariant(well_known_types::GetRepeatedStringField( + reflection, *message, field, index, scratch))); + } + case google::protobuf::FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_MESSAGE: + if constexpr (Unsafe::value) { + return Value::WrapMessageUnsafe( + &reflection->GetRepeatedMessage(*message, field, index), + descriptor_pool, message_factory, arena); + } else { + return Value::WrapMessage( + &reflection->GetRepeatedMessage(*message, field, index), + descriptor_pool, message_factory, arena); + } + case google::protobuf::FieldDescriptor::TYPE_BYTES: { + std::string scratch; + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> BytesValue { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + return BytesValue(arena, std::move(scratch)); + } + if constexpr (Unsafe::value) { + return BytesValue::WrapUnsafe(string); + } else { + return BytesValue( + Borrower::Arena(MessageArenaOr(message, arena)), string); + } + }, + [&](absl::Cord&& cord) -> BytesValue { + return BytesValue(std::move(cord)); + }), + well_known_types::AsVariant(well_known_types::GetRepeatedBytesField( + reflection, *message, field, index, scratch))); + } + case google::protobuf::FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT32: + return UintValue(reflection->GetRepeatedUInt32(*message, field, index)); + case google::protobuf::FieldDescriptor::TYPE_ENUM: + return Value::Enum(field->enum_type(), reflection->GetRepeatedEnumValue( + *message, field, index)); + default: + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("unexpected message field type: ", field->type_name()))); + } +} - // Non-values. - case Kind::kError: - case Kind::kUnknown: - return *this; +template +Value WrapMapFieldValueImpl( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK_EQ(field->containing_type()->containing_type(), + message->GetDescriptor()); + ABSL_DCHECK(!field->is_map() && !field->is_repeated()); + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); - // Cause a compiler error if this switch isn't complete. - case Kind::DO_NOT_USE: - assert(false); + switch (field->type()) { + case google::protobuf::FieldDescriptor::TYPE_DOUBLE: + return DoubleValue(value.GetDoubleValue()); + case google::protobuf::FieldDescriptor::TYPE_FLOAT: + return DoubleValue(value.GetFloatValue()); + case google::protobuf::FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT64: + return IntValue(value.GetInt64Value()); + case google::protobuf::FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT64: + return UintValue(value.GetUInt64Value()); + case google::protobuf::FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_INT32: + return IntValue(value.GetInt32Value()); + case google::protobuf::FieldDescriptor::TYPE_BOOL: + return BoolValue(value.GetBoolValue()); + case google::protobuf::FieldDescriptor::TYPE_STRING: + if constexpr (Unsafe::value) { + return StringValue::WrapUnsafe(value.GetStringValue()); + } else { + return StringValue(Borrower::Arena(MessageArenaOr(message, arena)), + value.GetStringValue()); + } + case google::protobuf::FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_MESSAGE: + if constexpr (Unsafe::value) { + return Value::WrapMessageUnsafe( + &value.GetMessageValue(), descriptor_pool, message_factory, arena); + } else { + return Value::WrapMessage(&value.GetMessageValue(), descriptor_pool, + message_factory, arena); + } + case google::protobuf::FieldDescriptor::TYPE_BYTES: + if constexpr (Unsafe::value) { + return BytesValue::WrapUnsafe(value.GetStringValue()); + } else { + return BytesValue(Borrower::Arena(MessageArenaOr(message, arena)), + value.GetStringValue()); + } + case google::protobuf::FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::FieldDescriptor::TYPE_UINT32: + return UintValue(value.GetUInt32Value()); + case google::protobuf::FieldDescriptor::TYPE_ENUM: + return Value::Enum(field->enum_type(), value.GetEnumValue()); + default: + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("unexpected message field type: ", field->type_name()))); } - // Should never happen. - return Value::FromError( - internal::InternalError(absl::StrCat("Bad value kind: ", kind()))); } -bool Value::operator==(const Value& rhs) const { - return kind() == rhs.kind() && - Value::visit(internal::StrictEqVisitor(), *this, rhs); +} // namespace + +Value Value::WrapField( + ProtoWrapperTypeOptions wrapper_type_options, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + using Unsafe = std::false_type; + return WrapFieldImpl(wrapper_type_options, message, field, + descriptor_pool, message_factory, arena); +} + +Value Value::WrapFieldUnsafe( + ProtoWrapperTypeOptions wrapper_type_options, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + using Unsafe = std::true_type; + return WrapFieldImpl(wrapper_type_options, message, field, + descriptor_pool, message_factory, arena); +} + +Value Value::WrapRepeatedField( + int index, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + using Unsafe = std::false_type; + return WrapRepeatedFieldImpl(index, message, field, descriptor_pool, + message_factory, arena); +} + +Value Value::WrapRepeatedFieldUnsafe( + int index, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + using Unsafe = std::true_type; + return WrapRepeatedFieldImpl(index, message, field, descriptor_pool, + message_factory, arena); } -std::size_t Value::hash_code() const { - return internal::MixHash(internal::Hash(kind()), visit(internal::Hasher())); +StringValue Value::WrapMapFieldKeyString( + const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK_EQ(key.type(), google::protobuf::FieldDescriptor::CPPTYPE_STRING); + +#if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) + return StringValue(Borrower::Arena(MessageArenaOr(message, arena)), + key.GetStringValue()); +#else + return StringValue(arena, key.GetStringValue()); +#endif +} + +Value Value::WrapMapFieldValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + using Unsafe = std::false_type; + return WrapMapFieldValueImpl(value, message, field, descriptor_pool, + message_factory, arena); +} + +Value Value::WrapMapFieldValueUnsafe( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + using Unsafe = std::true_type; + return WrapMapFieldValueImpl(value, message, field, descriptor_pool, + message_factory, arena); } -std::string Value::ToString() const { - switch (data_.index()) { - case kId: - return unknown_value().ToString(); - case kBytes: - case kBytesView: - case kBytesPtr: - return absl::StrCat("b", internal::ToString(bytes_value())); +optional_ref Value::AsBytes() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return visit(internal::CelPrinter()); + return absl::nullopt; } -Value Value::ForString(absl::string_view value, - const absl::optional& parent) { - if (parent == absl::nullopt) { - return Value::FromString(value); - } else if (!parent->RequiresReference()) { - return Create(value); +absl::optional Value::AsBytes() && { + if (auto* alternative = variant_.As(); alternative != nullptr) { + return std::move(*alternative); } - return Create(parent->GetRef(), value); + return absl::nullopt; } -Value Value::ForBytes(absl::string_view value, const ParentRef& parent) { - if (parent == absl::nullopt) { - return FromBytes(value); - } else if (parent->RequiresReference()) { - return Create(parent->GetRef(), value); +absl::optional Value::AsDouble() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return Create(value); + return absl::nullopt; } -Value Value::FromEnum(const EnumValue& value) { - if (value.is_named()) { - return FromEnum(value.named_value()); +absl::optional Value::AsDuration() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return FromEnum(value.unnamed_value()); + return absl::nullopt; } -Value Value::FromType(const Type& value) { - if (value.is_basic()) { - return FromType(value.basic_type()); +optional_ref Value::AsError() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - if (value.is_enum()) { - return FromType(value.enum_type()); + return absl::nullopt; +} + +absl::optional Value::AsError() && { + if (auto* alternative = variant_.As(); alternative != nullptr) { + return std::move(*alternative); } - if (value.is_object()) { - return FromType(value.object_type()); + return absl::nullopt; +} + +absl::optional Value::AsInt() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return Create(value.full_name()); + return absl::nullopt; } -Value Value::FromUnknown(const Unknown& value) { - if (value.ids().size() == 1) { - // Only has a single id, so it can be stored inline. - return FromUnknown(*value.ids().begin()); +absl::optional Value::AsList() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return Create(value); + return absl::nullopt; } -Value Value::FromUnknown(Unknown&& value) { - if (value.ids().size() == 1) { - return FromUnknown(*value.ids().begin()); +absl::optional Value::AsList() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); } - return Create(std::move(value)); + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -std::size_t Container::hash_code() const { - return internal::LazyComputeHash([this]() { return ComputeHash(); }, - &hash_code_); +absl::optional Value::AsMap() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -Container::Container() : hash_code_(internal::kNoHash) {} +absl::optional Value::AsMap() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} -Value Container::GetToContainsResult(const Value& get_result) { - if (get_result.kind() == Value::Kind::kUnknown) { - return get_result; +absl::optional Value::AsMessage() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return Value::FromBool(get_result.is_value()); + return absl::nullopt; } -Value List::Get(const Value& value) const { - RETURN_IF_NOT_VALUE(value); - auto index = value.get_if(); - if (!index) { - return Value::FromError( - internal::UnexpectedType(value.GetType().ToString(), "list index")); +absl::optional Value::AsMessage() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); } - return Get(*index); + return absl::nullopt; } -Value List::Contains(const Value& value) const { - RETURN_IF_NOT_VALUE(value); - return ContainsImpl(value); +absl::optional Value::AsNull() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -google::rpc::Status List::ForEach( - const std::function& call) const { - for (std::size_t i = 0; i < size(); ++i) { - RETURN_IF_STATUS_ERROR(call(Get(i))); +optional_ref Value::AsOpaque() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return internal::OkStatus(); + return absl::nullopt; } -Value List::ContainsImpl(const Value& value) const { - return Value::FromBool( - !ForEach([&value](const Value& elem) { return value != elem; })); +absl::optional Value::AsOpaque() && { + if (auto* alternative = variant_.As(); alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -bool List::operator==(const List& rhs) const { - if (this == &rhs) { - return true; +optional_ref Value::AsOptional() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr && alternative->IsOptional()) { + return static_cast(*alternative); } - if (size() != rhs.size() || hash_code() != rhs.hash_code()) { - return false; + return absl::nullopt; +} + +absl::optional Value::AsOptional() && { + if (auto* alternative = variant_.As(); + alternative != nullptr && alternative->IsOptional()) { + return static_cast(*alternative); } - for (std::size_t i = 0; i < size(); ++i) { - if (Get(i) != rhs.Get(i)) { - return false; - } + return absl::nullopt; +} + +optional_ref Value::AsParsedJsonList() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - return true; + return absl::nullopt; } -std::size_t List::ComputeHash() const { - std::size_t code = 0; - ForEach( - [&code](const Value& value) { internal::AccumulateHash(value, &code); }); - return code; +absl::optional Value::AsParsedJsonList() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -std::string List::ToString() const { - internal::SequenceBuilder builder; - ForEach([&builder](const Value& elem) { builder.Add(elem); }); - return builder.Build(); +optional_ref Value::AsParsedJsonMap() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -Value Map::Get(const Value& key) const { - RETURN_IF_NOT_VALUE(key); - return GetImpl(key); +absl::optional Value::AsParsedJsonMap() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -Value Map::ContainsKey(const Value& key) const { - RETURN_IF_NOT_VALUE(key); - return ContainsKeyImpl(key); +optional_ref Value::AsCustomList() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -Value Map::ContainsKeyImpl(const Value& key) const { - return GetToContainsResult(GetImpl(key)); +absl::optional Value::AsCustomList() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -Value Map::ContainsValue(const Value& value) const { - RETURN_IF_NOT_VALUE(value); - return ContainsValueImpl(value); +optional_ref Value::AsCustomMap() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -Value Map::ContainsValueImpl(const Value& value) const { - return Value::FromBool( - !ForEach([&value](const Value& key, const Value& stored_value) { - return stored_value != value; - })); +absl::optional Value::AsCustomMap() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -Value Map::GetImpl(const Value& key) const { - Value value; - bool found = !ForEach( - [&key, &value](const Value& stored_key, const Value& stored_value) { - if (stored_key == key) { - value = stored_value; - return false; - } - return true; - }); - return found ? value : Value::FromError(internal::NoSuchKey(key.ToString())); +optional_ref Value::AsParsedMapField() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -std::size_t Map::ComputeHash() const { - std::size_t code = 0; - ForEach([&code](const Value& key, const Value& value) { - internal::AccumulateHashNoOrder(internal::Hash(key, value), &code); - }); - return code; +absl::optional Value::AsParsedMapField() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -bool Map::operator==(const Map& rhs) const { - if (this == &rhs) { - return true; +optional_ref Value::AsParsedMessage() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; } - if (size() != rhs.size() || hash_code() != rhs.hash_code()) { - return false; + return absl::nullopt; +} + +absl::optional Value::AsParsedMessage() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); } - return rhs.ForEach([this](const Value& key, const Value& value) { - return Get(key) == value; - }); + return absl::nullopt; } -std::string Map::ToString() const { - internal::SequenceBuilder builder; - ForEach([&builder](const Value& key, const Value& value) { - builder.Add(key, value); - }); - return builder.Build(); +optional_ref Value::AsParsedRepeatedField() + const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -Value Object::ContainsMember(absl::string_view name) const { - return GetToContainsResult(GetMember(name)); +absl::optional Value::AsParsedRepeatedField() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; } -std::string Object::ToString() const { - internal::SequenceBuilder builder; - ForEach([&builder](absl::string_view name, Value value) { - builder.Add(name, value); - }); - return builder.Build(object_type().full_name()); +optional_ref Value::AsCustomStruct() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; } -bool Object::operator==(const Object& rhs) const { - if (this == &rhs) { - return true; +absl::optional Value::AsCustomStruct() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +optional_ref Value::AsString() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional Value::AsString() && { + if (auto* alternative = variant_.As(); alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +absl::optional Value::AsStruct() const& { + if (const auto* alternative = + variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional Value::AsStruct() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +absl::optional Value::AsTimestamp() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +optional_ref Value::AsType() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional Value::AsType() && { + if (auto* alternative = variant_.As(); alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +absl::optional Value::AsUint() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +optional_ref Value::AsUnknown() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional Value::AsUnknown() && { + if (auto* alternative = variant_.As(); alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +const BytesValue& Value::GetBytes() const& { + ABSL_DCHECK(IsBytes()) << *this; + return variant_.Get(); +} + +BytesValue Value::GetBytes() && { + ABSL_DCHECK(IsBytes()) << *this; + return std::move(variant_).Get(); +} + +DoubleValue Value::GetDouble() const { + ABSL_DCHECK(IsDouble()) << *this; + return variant_.Get(); +} + +DurationValue Value::GetDuration() const { + ABSL_DCHECK(IsDuration()) << *this; + return variant_.Get(); +} + +const ErrorValue& Value::GetError() const& { + ABSL_DCHECK(IsError()) << *this; + return variant_.Get(); +} + +ErrorValue Value::GetError() && { + ABSL_DCHECK(IsError()) << *this; + return std::move(variant_).Get(); +} + +IntValue Value::GetInt() const { + ABSL_DCHECK(IsInt()) << *this; + return variant_.Get(); +} + +#ifdef ABSL_HAVE_EXCEPTIONS +#define CEL_VALUE_THROW_BAD_VARIANT_ACCESS() throw absl::bad_variant_access() +#else +#define CEL_VALUE_THROW_BAD_VARIANT_ACCESS() \ + ABSL_LOG(FATAL) << absl::bad_variant_access().what() /* Crash OK */ +#endif + +ListValue Value::GetList() const& { + ABSL_DCHECK(IsList()) << *this; + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); +} + +ListValue Value::GetList() && { + ABSL_DCHECK(IsList()) << *this; + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); +} + +MapValue Value::GetMap() const& { + ABSL_DCHECK(IsMap()) << *this; + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); +} + +MapValue Value::GetMap() && { + ABSL_DCHECK(IsMap()) << *this; + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); +} + +MessageValue Value::GetMessage() const& { + ABSL_DCHECK(IsMessage()) << *this; + return variant_.Get(); +} + +MessageValue Value::GetMessage() && { + ABSL_DCHECK(IsMessage()) << *this; + return std::move(variant_).Get(); +} + +NullValue Value::GetNull() const { + ABSL_DCHECK(IsNull()) << *this; + return variant_.Get(); +} + +const OpaqueValue& Value::GetOpaque() const& { + ABSL_DCHECK(IsOpaque()) << *this; + return variant_.Get(); +} + +OpaqueValue Value::GetOpaque() && { + ABSL_DCHECK(IsOpaque()) << *this; + return std::move(variant_).Get(); +} + +const OptionalValue& Value::GetOptional() const& { + ABSL_DCHECK(IsOptional()) << *this; + return static_cast(variant_.Get()); +} + +OptionalValue Value::GetOptional() && { + ABSL_DCHECK(IsOptional()) << *this; + return static_cast(std::move(variant_).Get()); +} + +const ParsedJsonListValue& Value::GetParsedJsonList() const& { + ABSL_DCHECK(IsParsedJsonList()) << *this; + return variant_.Get(); +} + +ParsedJsonListValue Value::GetParsedJsonList() && { + ABSL_DCHECK(IsParsedJsonList()) << *this; + return std::move(variant_).Get(); +} + +const ParsedJsonMapValue& Value::GetParsedJsonMap() const& { + ABSL_DCHECK(IsParsedJsonMap()) << *this; + return variant_.Get(); +} + +ParsedJsonMapValue Value::GetParsedJsonMap() && { + ABSL_DCHECK(IsParsedJsonMap()) << *this; + return std::move(variant_).Get(); +} + +const CustomListValue& Value::GetCustomList() const& { + ABSL_DCHECK(IsCustomList()) << *this; + return variant_.Get(); +} + +CustomListValue Value::GetCustomList() && { + ABSL_DCHECK(IsCustomList()) << *this; + return std::move(variant_).Get(); +} + +const CustomMapValue& Value::GetCustomMap() const& { + ABSL_DCHECK(IsCustomMap()) << *this; + return variant_.Get(); +} + +CustomMapValue Value::GetCustomMap() && { + ABSL_DCHECK(IsCustomMap()) << *this; + return std::move(variant_).Get(); +} + +const ParsedMapFieldValue& Value::GetParsedMapField() const& { + ABSL_DCHECK(IsParsedMapField()) << *this; + return variant_.Get(); +} + +ParsedMapFieldValue Value::GetParsedMapField() && { + ABSL_DCHECK(IsParsedMapField()) << *this; + return std::move(variant_).Get(); +} + +const ParsedMessageValue& Value::GetParsedMessage() const& { + ABSL_DCHECK(IsParsedMessage()) << *this; + return variant_.Get(); +} + +ParsedMessageValue Value::GetParsedMessage() && { + ABSL_DCHECK(IsParsedMessage()) << *this; + return std::move(variant_).Get(); +} + +const ParsedRepeatedFieldValue& Value::GetParsedRepeatedField() const& { + ABSL_DCHECK(IsParsedRepeatedField()) << *this; + return variant_.Get(); +} + +ParsedRepeatedFieldValue Value::GetParsedRepeatedField() && { + ABSL_DCHECK(IsParsedRepeatedField()) << *this; + return std::move(variant_).Get(); +} + +const CustomStructValue& Value::GetCustomStruct() const& { + ABSL_DCHECK(IsCustomStruct()) << *this; + return variant_.Get(); +} + +CustomStructValue Value::GetCustomStruct() && { + ABSL_DCHECK(IsCustomStruct()) << *this; + return std::move(variant_).Get(); +} + +const StringValue& Value::GetString() const& { + ABSL_DCHECK(IsString()) << *this; + return variant_.Get(); +} + +StringValue Value::GetString() && { + ABSL_DCHECK(IsString()) << *this; + return std::move(variant_).Get(); +} + +StructValue Value::GetStruct() const& { + ABSL_DCHECK(IsStruct()) << *this; + if (const auto* alternative = + variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); +} + +StructValue Value::GetStruct() && { + ABSL_DCHECK(IsStruct()) << *this; + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); +} + +TimestampValue Value::GetTimestamp() const { + ABSL_DCHECK(IsTimestamp()) << *this; + return variant_.Get(); +} + +const TypeValue& Value::GetType() const& { + ABSL_DCHECK(IsType()) << *this; + return variant_.Get(); +} + +TypeValue Value::GetType() && { + ABSL_DCHECK(IsType()) << *this; + return std::move(variant_).Get(); +} + +UintValue Value::GetUint() const { + ABSL_DCHECK(IsUint()) << *this; + return variant_.Get(); +} + +const UnknownValue& Value::GetUnknown() const& { + ABSL_DCHECK(IsUnknown()) << *this; + return variant_.Get(); +} + +UnknownValue Value::GetUnknown() && { + ABSL_DCHECK(IsUnknown()) << *this; + return std::move(variant_).Get(); +} + +namespace { + +class EmptyValueIterator final : public ValueIterator { + public: + bool HasNext() override { return false; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return absl::FailedPreconditionError( + "`ValueIterator::Next` called after `ValueIterator::HasNext` returned " + "false"); } - if (object_type() != rhs.object_type() || hash_code() != rhs.hash_code()) { + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + return false; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + return false; } - return EqualsImpl(rhs); +}; + +} // namespace + +absl_nonnull std::unique_ptr NewEmptyValueIterator() { + return std::make_unique(); } -std::size_t Object::ComputeHash() const { - std::size_t hash_code = 0; - ForEach([&hash_code](absl::string_view key, const Value& value) { - internal::AccumulateHashNoOrder(internal::Hash(key, value), &hash_code); - }); - return hash_code; +absl_nonnull ListValueBuilderPtr +NewListValueBuilder(google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + return common_internal::NewListValueBuilder(arena); +} + +absl_nonnull MapValueBuilderPtr +NewMapValueBuilder(google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + return common_internal::NewMapValueBuilder(arena); +} + +absl_nullable StructValueBuilderPtr NewStructValueBuilder( + google::protobuf::Arena* absl_nonnull arena, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + absl::string_view name) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return common_internal::NewStructValueBuilder(arena, descriptor_pool, + message_factory, name); +} + +bool operator==(IntValue lhs, UintValue rhs) { + return internal::Number::FromInt64(lhs.NativeValue()) == + internal::Number::FromUint64(rhs.NativeValue()); +} + +bool operator==(UintValue lhs, IntValue rhs) { + return internal::Number::FromUint64(lhs.NativeValue()) == + internal::Number::FromInt64(rhs.NativeValue()); +} + +bool operator==(IntValue lhs, DoubleValue rhs) { + return internal::Number::FromInt64(lhs.NativeValue()) == + internal::Number::FromDouble(rhs.NativeValue()); +} + +bool operator==(DoubleValue lhs, IntValue rhs) { + return internal::Number::FromDouble(lhs.NativeValue()) == + internal::Number::FromInt64(rhs.NativeValue()); +} + +bool operator==(UintValue lhs, DoubleValue rhs) { + return internal::Number::FromUint64(lhs.NativeValue()) == + internal::Number::FromDouble(rhs.NativeValue()); +} + +bool operator==(DoubleValue lhs, UintValue rhs) { + return internal::Number::FromDouble(lhs.NativeValue()) == + internal::Number::FromUint64(rhs.NativeValue()); +} + +absl::StatusOr ValueIterator::Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull value) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(value != nullptr); + + if (HasNext()) { + CEL_RETURN_IF_ERROR(Next(descriptor_pool, message_factory, arena, value)); + return true; + } + return false; } -} // namespace common -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel diff --git a/common/value.h b/common/value.h index 777021a4b..34b4714a7 100644 --- a/common/value.h +++ b/common/value.h @@ -1,844 +1,2947 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_CEL_VALUE_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_CEL_VALUE_H_ - -#include -#include -#include +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUE_H_ + +#include +#include +#include #include +#include #include -#include +#include #include -#include "google/protobuf/any.pb.h" -#include "google/rpc/status.pb.h" -#include "absl/memory/memory.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "absl/types/variant.h" -#include "common/enum.h" -#include "common/error.h" -#include "common/id.h" -#include "common/parent_ref.h" +#include "absl/utility/utility.h" +#include "base/attribute.h" +#include "common/arena.h" +#include "common/native_type.h" +#include "common/optional_ref.h" #include "common/type.h" -#include "common/unknown.h" -#include "internal/hash_util.h" -#include "internal/holder.h" -#include "internal/ref_countable.h" -#include "internal/status_util.h" -#include "internal/value_internal.h" -#include "internal/visitor_util.h" - -namespace google { -namespace api { -namespace expr { -namespace common { - -/** - * A CEL Value. - * - * Instances of this class can always be cheaply copied. - * - * The value held by a Value is either inlined, owned, or unowned. If - * inlined, the value is embedded directly in the Value object. If owned, the - * value is held via a shared pointer. If unowned, the value is held via a raw - * pointer and the creator of the Value is responsible for insuring the - * referenced value lives longer than the Value or any of its copies. - * - * Value provides a unified interface for all three cases. - * However, `Value::is_inline()` and `Value::owns_value()` can be - * used to distinguish these cases. - * - * Three types of constructor functions are provided: - * - From*: The resulting Value always owns the value being held. Depending - * on the type, the value may be inlined or owned. - * - For*: The resulting Value might not own its value, instead only holding a - * pointer to its value. If so, the provided pointer argument must live longer - * than the resulting Value (or any copy of the resulting Value). - * - Make*: Helper functions that behave like std::make_unique. The resulting - * Value always owns its value. - * - * List, Map and Object base are not implemented directly, but rather base - * classes for custom implementations. This allows for optimizations specific - * to various contexts or data. For example, a homogenious map might use a - * native c++ map to hold its data. - */ -class Value final : public internal::BaseValue { +#include "common/typeinfo.h" +#include "common/value_kind.h" +#include "common/values/bool_value.h" // IWYU pragma: export +#include "common/values/bytes_value.h" // IWYU pragma: export +#include "common/values/bytes_value_input_stream.h" // IWYU pragma: export +#include "common/values/bytes_value_output_stream.h" // IWYU pragma: export +#include "common/values/custom_list_value.h" // IWYU pragma: export +#include "common/values/custom_map_value.h" // IWYU pragma: export +#include "common/values/custom_struct_value.h" // IWYU pragma: export +#include "common/values/double_value.h" // IWYU pragma: export +#include "common/values/duration_value.h" // IWYU pragma: export +#include "common/values/enum_value.h" // IWYU pragma: export +#include "common/values/error_value.h" // IWYU pragma: export +#include "common/values/int_value.h" // IWYU pragma: export +#include "common/values/list_value.h" // IWYU pragma: export +#include "common/values/map_value.h" // IWYU pragma: export +#include "common/values/message_value.h" // IWYU pragma: export +#include "common/values/null_value.h" // IWYU pragma: export +#include "common/values/opaque_value.h" // IWYU pragma: export +#include "common/values/optional_value.h" // IWYU pragma: export +#include "common/values/parsed_json_list_value.h" // IWYU pragma: export +#include "common/values/parsed_json_map_value.h" // IWYU pragma: export +#include "common/values/parsed_map_field_value.h" // IWYU pragma: export +#include "common/values/parsed_message_value.h" // IWYU pragma: export +#include "common/values/parsed_repeated_field_value.h" // IWYU pragma: export +#include "common/values/string_value.h" // IWYU pragma: export +#include "common/values/struct_value.h" // IWYU pragma: export +#include "common/values/timestamp_value.h" // IWYU pragma: export +#include "common/values/type_value.h" // IWYU pragma: export +#include "common/values/uint_value.h" // IWYU pragma: export +#include "common/values/unknown_value.h" // IWYU pragma: export +#include "common/values/value_variant.h" +#include "common/values/values.h" +#include "internal/status_macros.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/generated_enum_reflection.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +#pragma push_macro("GetMessage") +#ifdef GetMessage +// GetMessage in windows API headers might be defined as a macro. Depending on +// ordering, might cause issues with Value::GetMessage or +// google::protobuf::Reflection::GetMessage. +#undef GetMessage +#endif + +namespace cel { + +// `Value` is a composition type which encompasses all values supported by the +// Common Expression Language. When default constructed or moved, `Value` is in +// a known but invalid state. Any attempt to use it from then on, without +// assigning another type, is undefined behavior. In debug builds, we do our +// best to fail. +class Value final : private common_internal::ValueMixin { public: - enum class Kind { - // Values - kNull, - kBool, - kInt, - kUInt, - kDouble, - kString, - kBytes, - kType, - kMap, - kList, - kObject, - kEnum, - - // Objects that have well-known c++ representations. - kDuration, - kTime, - - // Non-values. - kError, - kUnknown, - - // Special value to require 'default' case in switch statements. - DO_NOT_USE - }; - - // Constructors from primitive types. - static inline Value NullValue() { return Create(); } - static inline Value FromBool(bool value) { return Create(value); } - static inline Value TrueValue() { return FromBool(true); } - static inline Value FalseValue() { return FromBool(false); } - static inline Value FromInt(int64_t value) { return Create(value); } - static inline Value FromUInt(uint64_t value) { return Create(value); } - static inline Value FromDouble(double value); - - // Constructors from well-known c++ types. - static inline Value FromString(absl::string_view value); - static inline Value FromString(const std::string& value); - static inline Value FromString(std::string&& value); - // For string literals (e.g. const char*) use ForString. - - static inline Value FromBytes(absl::string_view value); - static inline Value FromBytes(const std::string& value); - static inline Value FromBytes(std::string&& value); - // For byte literals (e.g. const char*) use ForBytes. - - static inline Value FromDuration(absl::Duration value); - static inline Value FromTime(absl::Time value); - - static Value FromEnum(const EnumValue& value); - static inline Value FromEnum(NamedEnumValue value); - static inline Value FromEnum(const UnnamedEnumValue& value); - - static Value FromType(const Type& value); - static inline Value FromType(BasicType value); - static inline Value FromType(ObjectType value); - static inline Value FromType(EnumType value); - static inline Value FromType(BasicTypeValue value); - static inline Value FromType(const char* value); - static inline Value FromType(absl::string_view full_name); - static inline Value FromType(const std::string& full_name); - - static inline Value FromError(const google::rpc::Status& value); - static inline Value FromError(google::rpc::Status&& value); - static inline Value FromError(const Error& value); - static inline Value FromError(Error&& value); - - static Value FromUnknown(const Unknown& value); - static Value FromUnknown(Unknown&& value); - static inline Value FromUnknown(Id value) { return Create(value); } - - // Constructors *for* well-known c++ types. - // - // If no parent is passed in, any value referenced by the provided arguments - // must live longer than the resulting Value (or any copy of the resulting - // value). - static Value ForString(absl::string_view value, - const ParentRef& parent = NoParent()); - static Value ForBytes(absl::string_view value, - const ParentRef& parent = NoParent()); - - // Constructors from container types. - static inline Value FromMap(std::unique_ptr value); - static inline Value FromList(std::unique_ptr value); - static inline Value FromObject(std::unique_ptr value); - - // Constructors *for* Container types. - // - // The arguments passed in must live longer than any Value referencing - // them. - static inline Value ForMap(const Map* value); - static inline Value ForList(const List* value); - static inline Value ForObject(const Object* value); - - // Inplace constructors - template - static Value MakeList(Args&&... args); - template - static Value MakeMap(Args&&... args); - template - static Value MakeObject(Args&&... args); - - private: - template - struct KindHelper; - - template - using GetIfKType = - GetIfType(K), KindToType>>; - - template - using GetKType = - GetType(K), KindToType>>; + // Returns an appropriate `Value` for the dynamic protobuf enum. For open + // enums, returns `cel::IntValue`. For closed enums, returns `cel::ErrorValue` + // if the value is not present in the enum otherwise returns `cel::IntValue`. + static Value Enum(const google::protobuf::EnumValueDescriptor* absl_nonnull value); + static Value Enum(const google::protobuf::EnumDescriptor* absl_nonnull type, + int32_t number); + + // SFINAE overload for generated protobuf enums which are not well-known. + // Always returns `cel::IntValue`. + template + static common_internal::EnableIfGeneratedEnum Enum(T value) { + return IntValue(value); + } - public: - template - static Value From(T&& value) { - return KindHelper::From(std::forward(value)); + // SFINAE overload for google::protobuf::NullValue. Always returns + // `cel::NullValue`. + template + static common_internal::EnableIfWellKnownEnum + Enum(T) { + return NullValue(); } - template - static Value For(T&& value, const ParentRef& parent = NoParent()) { - return KindHelper::For(std::forward(value), parent); + // Returns an appropriate `Value` for the dynamic protobuf message. If + // `message` is the well known type `google.protobuf.Any`, `descriptor_pool` + // and `message_factory` will be used to unpack the value. Both must outlive + // the resulting value and any of its shallow copies. Otherwise the message is + // copied using `arena`. + static Value FromMessage( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + static Value FromMessage( + google::protobuf::Message&& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Value` for the dynamic protobuf message. If + // `message` is the well known type `google.protobuf.Any`, `descriptor_pool` + // and `message_factory` will be used to unpack the value. Both must outlive + // the resulting value and any of its shallow copies. Otherwise the message is + // borrowed (no copying). If the message is on an arena, that arena will be + // attributed as the owner. Otherwise `arena` is used. + static Value WrapMessage( + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Value` for the dynamic protobuf message. If + // `message` is the well known type `google.protobuf.Any`, `descriptor_pool` + // and `message_factory` will be used to unpack the value. Both must outlive + // the resulting value and any of its shallow copies. Otherwise the message is + // borrowed (no copying). This function does not attempt to validate arena + // ownership of a dynamic message that was not unpacked from a well known + // type. Caller is responsible for ensuring the resulting value and any + // derived values do not outlive the input message. + static Value WrapMessageUnsafe( + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Value` for the dynamic protobuf message field. If + // `field` in `message` is the well known type `google.protobuf.Any`, + // `descriptor_pool` and `message_factory` will be used to unpack the value. + // Both must outlive the resulting value and any of its shallow copies. + // Otherwise the field is borrowed (no copying). If the message is on an + // arena, that arena will be attributed as the owner. Otherwise `arena` is + // used. + static Value WrapField( + ProtoWrapperTypeOptions wrapper_type_options, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + static Value WrapField( + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return WrapField(ProtoWrapperTypeOptions::kUnsetNull, message, field, + descriptor_pool, message_factory, arena); } - Value() = default; - ~Value(); + // Returns an appropriate `Value` for the dynamic protobuf message field. If + // `field` in `message` is the well known type `google.protobuf.Any`, + // `descriptor_pool` and `message_factory` will be used to unpack the value. + // Both must outlive the resulting value and any of its shallow copies. + // Otherwise the field is borrowed (no copying). Caller is responsible for + // ensuring the resulting value and any derived values do not outlive the + // input message. + static Value WrapFieldUnsafe( + ProtoWrapperTypeOptions wrapper_type_options, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Value` for the dynamic protobuf message repeated + // field. If `field` in `message` is the well known type + // `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be used + // to unpack the value. Both must outlive the resulting value and any of its + // shallow copies. + static Value WrapRepeatedField( + int index, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Value` for the dynamic protobuf message repeated + // field. If `field` in `message` is the well known type + // `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be used + // to unpack the value. Both must outlive the resulting value and any of its + // shallow copies. + static Value WrapRepeatedFieldUnsafe( + int index, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `StringValue` for the dynamic protobuf message map + // field key. The map field key must be a string or the behavior is undefined. + static StringValue WrapMapFieldKeyString( + const google::protobuf::MapKey& key, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Value` for the dynamic protobuf message map + // field value. If `field` in `message`, which is `value`, is the well known + // type `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be + // used to unpack the value. Both must outlive the resulting value and any of + // its shallow copies. + static Value WrapMapFieldValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + + // Returns an appropriate `Value` for the dynamic protobuf message map + // field value. If `field` in `message`, which is `value`, is the well known + // type `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be + // used to unpack the value. Both must outlive the resulting value and any of + // its shallow copies. Caller is responsible for ensuring the resulting value + // and any derived values do not outlive the input message. + static Value WrapMapFieldValueUnsafe( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); - // Copy and move constructable. + Value() = default; Value(const Value&) = default; - Value(Value&) = default; - Value(Value&&) = default; Value& operator=(const Value&) = default; + Value(Value&& other) = default; Value& operator=(Value&&) = default; - // Accessors. - // Dies with absl::bad_variant_access if the wrong kind is accessed. - - // Mutable accessors for inlined values. - inline bool& bool_value() { return *absl::get(data_); } - inline int64_t& int_value() { return *absl::get(data_); } - inline uint64_t& uint_value() { return *absl::get(data_); } - inline double& double_value() { return *absl::get(data_); } - - // Const accessors for all value. - inline bool is_null() const { return data_.index() == kNull; } - inline const bool& bool_value() const; - inline const int64_t& int_value() const; - inline const uint64_t& uint_value() const; - inline const double& double_value() const; - inline const absl::Duration& duration() const; - inline const absl::Time& time() const; - inline const Error& error_value() const; - inline const Map& map_value() const; - inline const List& list_value() const; - inline const Object& object_value() const; - inline absl::string_view string_value() const; - inline absl::string_view bytes_value() const; - inline Type type_value() const; - inline EnumValue enum_value() const; - inline Unknown unknown_value() const; - - // Value metadata. - /** If the value is stored directly in the Value */ - inline bool is_inline() const { return data_.index() < kInlineEnd; } - - /** If the value is a real 'value', and not an Error or Unknown. */ - inline bool is_value() const; - - /** If the value is considered an 'object'. */ - inline bool is_object() const { return index_in(kObject, kObjectEnd); } - - /** If the value is owned, or if an external value is being referenced. */ - bool owns_value() const; - - /** The kind of value stored. */ - Kind kind() const; - - /** - * The hash code of this value. - * - * Cached internally when appropriate. - */ - std::size_t hash_code() const; - - /** - * Pure representation equality, e.g. Value(NaN) == Value(NaN) => true. - */ - bool operator==(const Value& rhs) const; - inline bool operator!=(const Value& rhs) const { return !(*this == rhs); } - - /** Applies the visitor to the given Values and returns the result. */ - template - static VisitType visit(V&& vis, F&& value, R&&... rest); - - /** Applies the visitor and returns the result. */ - template - inline VisitType visit(V&& vis) const; - - template - inline GetType get() const; - template - inline GetKType get() const; - - template - inline GetIfType get_if() const; - template - inline GetIfKType get_if() const; - - /** Returns the type of the stored value. */ - Value GetType() const; - - /** - * Returns a canonical cel expression for the value. - * - * Computation may be expensive. - */ - std::string ToString() const; + // NOLINTNEXTLINE(google-explicit-constructor) + Value(const ListValue& value) : variant_(value.ToValueVariant()) {} - private: - template - static Value Create(Args&&... args); + // NOLINTNEXTLINE(google-explicit-constructor) + Value(ListValue&& value) : variant_(std::move(value).ToValueVariant()) {} - template - explicit Value(Args&&... args) : data_(std::forward(args)...) {} + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(const ListValue& value) { + variant_ = value.ToValueVariant(); + return *this; + } - inline bool index_in(std::size_t start, std::size_t end) const; + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(ListValue&& value) { + variant_ = std::move(value).ToValueVariant(); + return *this; + } - ValueData data_; -}; + // NOLINTNEXTLINE(google-explicit-constructor) + Value(const MapValue& value) : variant_(value.ToValueVariant()) {} -/** A base class for shared values that may contain other values. */ -class Container : public SharedValue { - public: - /** The hash code for this value. Cached after the first call. */ - std::size_t hash_code() const; + // NOLINTNEXTLINE(google-explicit-constructor) + Value(MapValue&& value) : variant_(std::move(value).ToValueVariant()) {} - protected: - Container(); + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(const MapValue& value) { + variant_ = value.ToValueVariant(); + return *this; + } - /** - * The hash computation function. - * - * The result of this function is automatically cached. - */ - virtual std::size_t ComputeHash() const = 0; + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(MapValue&& value) { + variant_ = std::move(value).ToValueVariant(); + return *this; + } - // A helper alias that resolves to R iff the lambda type T returns a value - // of type R when called. This is (annoyingly) needed as c++ does not match - // overloads based on lambda return types by default. - template - using ForEachReturnType = internal::specialize_if_returns; + // NOLINTNEXTLINE(google-explicit-constructor) + Value(const StructValue& value) : variant_(value.ToValueVariant()) {} - // A helper function to convert a get result to a contains result. - static Value GetToContainsResult(const Value& get_result); + // NOLINTNEXTLINE(google-explicit-constructor) + Value(StructValue&& value) : variant_(std::move(value).ToValueVariant()) {} - // Helpers to Convert a contained value into a cel Value. - template - Value GetValue(V& value) { - return Value::For(&value, SelfRefProvider()); + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(const StructValue& value) { + variant_ = value.ToValueVariant(); + return *this; } - template - static Value GetValue(V&& value) { - return Value::From(std::move(value)); + + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(StructValue&& value) { + variant_ = std::move(value).ToValueVariant(); + return *this; } - private: - mutable std::atomic hash_code_; -}; + // NOLINTNEXTLINE(google-explicit-constructor) + Value(const MessageValue& value) : variant_(value.ToValueVariant()) {} -/** The base class for a CEL list value. */ -class List : public Container { - public: - /** The number of elements in the list */ - virtual std::size_t size() const = 0; + // NOLINTNEXTLINE(google-explicit-constructor) + Value(MessageValue&& value) : variant_(std::move(value).ToValueVariant()) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(const MessageValue& value) { + variant_ = value.ToValueVariant(); + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(MessageValue&& value) { + variant_ = std::move(value).ToValueVariant(); + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + Value(const OptionalValue& value) + : variant_(absl::in_place_type, + static_cast(value)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Value(OptionalValue&& value) + : variant_(absl::in_place_type, + static_cast(value)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(const OptionalValue& value) { + variant_.Assign(static_cast(value)); + return *this; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(OptionalValue&& value) { + variant_.Assign(static_cast(value)); + return *this; + } - /** Returns the value at the index stored in value, or an error. */ - Value Get(const Value& value) const; + template >>> + // NOLINTNEXTLINE(google-explicit-constructor) + Value(T&& alternative) noexcept + : variant_(absl::in_place_type>, + std::forward(alternative)) {} + + template >>> + // NOLINTNEXTLINE(google-explicit-constructor) + Value& operator=(T&& alternative) noexcept { + variant_.Assign(std::forward(alternative)); + return *this; + } + + ValueKind kind() const { return variant_.kind(); } + + Type GetRuntimeType() const; + + absl::string_view GetTypeName() const; + + std::string DebugString() const; + + // `SerializeTo` serializes this value to `output`. If an error is returned, + // `output` is in a valid but unspecified state. If this value does not + // support serialization, `FAILED_PRECONDITION` is returned. + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // `ConvertToJson` converts this value to its JSON representation. The + // argument `json` **MUST** be an instance of `google.protobuf.Value` which is + // can either be the generated message or a dynamic message. The descriptor + // pool `descriptor_pool` and message factory `message_factory` are used to + // deal with serialized messages and a few corners cases. + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // `ConvertToJsonArray` converts this value to its JSON representation if and + // only if it can be represented as an array. The argument `json` **MUST** be + // an instance of `google.protobuf.ListValue` which is can either be the + // generated message or a dynamic message. The descriptor pool + // `descriptor_pool` and message factory `message_factory` are used to deal + // with serialized messages and a few corners cases. + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // `ConvertToJsonArray` converts this value to its JSON representation if and + // only if it can be represented as an object. The argument `json` **MUST** be + // an instance of `google.protobuf.Struct` which is can either be the + // generated message or a dynamic message. The descriptor pool + // `descriptor_pool` and message factory `message_factory` are used to deal + // with serialized messages and a few corners cases. + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const; + + // Clones the value to another arena, if necessary, such that the lifetime of + // the value is tied to the arena. + Value Clone(google::protobuf::Arena* absl_nonnull arena) const; + + friend void swap(Value& lhs, Value& rhs) noexcept { + using std::swap; + swap(lhs.variant_, rhs.variant_); + } + + friend std::ostream& operator<<(std::ostream& out, const Value& value); + + ABSL_DEPRECATED("Just use operator.()") + Value* operator->() { return this; } + + ABSL_DEPRECATED("Just use operator.()") + const Value* operator->() const { return this; } + + // Returns `true` if this value is an instance of a bool value. + bool IsBool() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a bool value and true. + bool IsTrue() const { return IsBool() && GetBool().NativeValue(); } + + // Returns `true` if this value is an instance of a bool value and false. + bool IsFalse() const { return IsBool() && !GetBool().NativeValue(); } + + // Returns `true` if this value is an instance of a bytes value. + bool IsBytes() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a double value. + bool IsDouble() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a duration value. + bool IsDuration() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of an error value. + bool IsError() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of an int value. + bool IsInt() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a list value. + bool IsList() const { + return variant_.Is() || + variant_.Is() || + variant_.Is() || + variant_.Is(); + } - /** Returns if the list contains the given value. */ - Value Contains(const Value& value) const; + // Returns `true` if this value is an instance of a map value. + bool IsMap() const { + return variant_.Is() || + variant_.Is() || + variant_.Is() || + variant_.Is(); + } + + // Returns `true` if this value is an instance of a message value. If `true` + // is returned, it is implied that `IsStruct()` would also return true. + bool IsMessage() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a null value. + bool IsNull() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of an opaque value. + bool IsOpaque() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of an optional value. If `true` + // is returned, it is implied that `IsOpaque()` would also return true. + bool IsOptional() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return alternative->IsOptional(); + } + return false; + } + + // Returns `true` if this value is an instance of a parsed JSON list value. If + // `true` is returned, it is implied that `IsList()` would also return + // true. + bool IsParsedJsonList() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a parsed JSON map value. If + // `true` is returned, it is implied that `IsMap()` would also return + // true. + bool IsParsedJsonMap() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a custom list value. If + // `true` is returned, it is implied that `IsList()` would also return + // true. + bool IsCustomList() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a custom map value. If + // `true` is returned, it is implied that `IsMap()` would also return + // true. + bool IsCustomMap() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a parsed map field value. If + // `true` is returned, it is implied that `IsMap()` would also return + // true. + bool IsParsedMapField() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a parsed message value. If + // `true` is returned, it is implied that `IsMessage()` would also return + // true. + bool IsParsedMessage() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a parsed repeated field + // value. If `true` is returned, it is implied that `IsList()` would also + // return true. + bool IsParsedRepeatedField() const { + return variant_.Is(); + } - /** Returns the value at the given index, or an error. */ - virtual Value Get(std::size_t index) const = 0; + // Returns `true` if this value is an instance of a custom struct value. If + // `true` is returned, it is implied that `IsStruct()` would also return + // true. + bool IsCustomStruct() const { return variant_.Is(); } - /** - * Calls the provided function with every element in the list, in order. - * - * If the provided function returns an error, iteration is stopped and that - * error is returned to the caller immediately. - */ - virtual google::rpc::Status ForEach( - const std::function& call) const; + // Returns `true` if this value is an instance of a string value. + bool IsString() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a struct value. + bool IsStruct() const { + return variant_.Is() || + variant_.Is() || + variant_.Is(); + } + + // Returns `true` if this value is an instance of a timestamp value. + bool IsTimestamp() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a type value. + bool IsType() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of a uint value. + bool IsUint() const { return variant_.Is(); } + + // Returns `true` if this value is an instance of an unknown value. + bool IsUnknown() const { return variant_.Is(); } + + // Convenience method for use with template metaprogramming. See + // `IsBool()`. + template + std::enable_if_t, bool> Is() const { + return IsBool(); + } - /** - * ForEach helper method for a boolean return type. - * - * Iteration stops when false is returned by `call`. - */ + // Convenience method for use with template metaprogramming. See + // `IsBytes()`. template - ForEachReturnType ForEach(T&& call) const; + std::enable_if_t, bool> Is() const { + return IsBytes(); + } - /** - * ForEach helper method for a void return type. - */ + // Convenience method for use with template metaprogramming. See + // `IsDouble()`. template - ForEachReturnType ForEach(T&& call) const; + std::enable_if_t, bool> Is() const { + return IsDouble(); + } - bool operator==(const List& rhs) const; - inline bool operator!=(const List& rhs) { return !(*this == rhs); } + // Convenience method for use with template metaprogramming. See + // `IsDuration()`. + template + std::enable_if_t, bool> Is() const { + return IsDuration(); + } - std::string ToString() const override; + // Convenience method for use with template metaprogramming. See + // `IsError()`. + template + std::enable_if_t, bool> Is() const { + return IsError(); + } - protected: - std::size_t ComputeHash() const override; - virtual Value ContainsImpl(const Value& value) const; -}; + // Convenience method for use with template metaprogramming. See + // `IsInt()`. + template + std::enable_if_t, bool> Is() const { + return IsInt(); + } -/** The base class for a CEL map value. */ -class Map : public Container { - public: - /** The number of elements in the map. */ - virtual std::size_t size() const = 0; - - /** Returns the value for the given key, or an error. */ - Value Get(const Value& key) const; - - /** Returns if the map contains the given key. */ - Value ContainsKey(const Value& key) const; - - /** Returns if the map contains the given value. */ - Value ContainsValue(const Value& value) const; - - /** - * Calls the provided function with every entry in the map. - * - * If the provided function returns an error, iteration is stopped and that - * error is returned to the caller immediately. - */ - virtual google::rpc::Status ForEach( - const std::function& - call) const = 0; - /** - * ForEach specialization for a boolean return type. - * - * Iteration stops if false is returned by `call`. - */ - template - ForEachReturnType ForEach( - T&& call) const; - - /** - * ForEach specialization for a void return type. - */ - template - ForEachReturnType ForEach( - T&& call) const; - - bool operator==(const Map& rhs) const; - inline bool operator!=(const Map& rhs) const { return !(*this == rhs); } - - /** - * Returns a cel expression for the value. - * - * Computation may be expensive. - */ - std::string ToString() const override; - - protected: - virtual Value ContainsKeyImpl(const Value& key) const; - virtual Value ContainsValueImpl(const Value& value) const; - - // An implementation is provided, but not efficient. - virtual Value GetImpl(const Value& key) const = 0; - - std::size_t ComputeHash() const override; -}; + // Convenience method for use with template metaprogramming. See + // `IsList()`. + template + std::enable_if_t, bool> Is() const { + return IsList(); + } -/** The base class for a CEL object value. */ -class Object : public Container { - public: - /** Returns the value of a field or an error. */ - virtual Value GetMember(absl::string_view name) const = 0; - virtual Value ContainsMember(absl::string_view name) const; - - /** They object type for this value. */ - virtual Type object_type() const = 0; - - /** Serialize the object to protobuf any. */ - virtual void To(google::protobuf::Any* value) const = 0; - - bool operator==(const Object& rhs) const; - inline bool operator!=(const Object& rhs) const { return !(*this == rhs); } - - /** - * Returns a canonical cel expression for the value. - * - * Computation may be expensive. - */ - std::string ToString() const override; - - /** - * Loop over members. - * - * Order of calls must be stable for ToString() to produce - * a deterministic result. - */ - virtual google::rpc::Status ForEach( - const std::function& - call) const = 0; - - /** - * ForEach specialization for a boolean return type. - * - * Iteration stops when false is returned by `call`. - */ - template - ForEachReturnType ForEach( - T&& call) const; - - /** - * ForEach specialization for a void return type. - */ - template - ForEachReturnType ForEach( - T&& call) const; - - protected: - /** - * The equal implementation. - * - * @param same_type a Object of the same type as the current Object. - */ - virtual bool EqualsImpl(const Object& same_type) const = 0; - std::size_t ComputeHash() const override; -}; + // Convenience method for use with template metaprogramming. See + // `IsMap()`. + template + std::enable_if_t, bool> Is() const { + return IsMap(); + } -Value Value::FromDouble(double value) { return Create(value); } -Value Value::FromString(absl::string_view value) { return Create(value); } -Value Value::FromString(const std::string& value) { - return Create(value); -} -Value Value::FromString(std::string&& value) { - return Create(std::move(value)); -} -Value Value::FromBytes(absl::string_view value) { - return Create(value); -} -Value Value::FromBytes(const std::string& value) { - return Create(value); -} -Value Value::FromBytes(std::string&& value) { - return Create(std::move(value)); -} -Value Value::FromDuration(absl::Duration value) { - return Create(value); -} -Value Value::FromTime(absl::Time value) { return Create(value); } + // Convenience method for use with template metaprogramming. See + // `IsMessage()`. + template + std::enable_if_t, bool> Is() const { + return IsMessage(); + } -Value Value::FromEnum(NamedEnumValue value) { - return Create(value); -} -Value Value::FromEnum(const UnnamedEnumValue& value) { - return Create(value); -} -Value Value::FromType(BasicType value) { return Create(value); } -Value Value::FromType(ObjectType value) { return Create(value); } -Value Value::FromType(EnumType value) { return Create(value); } -Value Value::FromType(BasicTypeValue value) { - return Create(value); -} -Value Value::FromType(const char* value) { - return FromType(absl::string_view(value)); -} -Value Value::FromType(absl::string_view full_name) { - return FromType(Type(full_name)); -} -Value Value::FromType(const std::string& full_name) { - return FromType(Type(full_name)); -} -Value Value::FromError(const google::rpc::Status& value) { - return Create(value); -} -Value Value::FromError(google::rpc::Status&& value) { - return Create(std::move(value)); -} -Value Value::FromError(const Error& value) { return Create(value); } -Value Value::FromError(Error&& value) { - return Create(std::move(value)); -} -Value Value::FromMap(std::unique_ptr value) { - assert(value != nullptr); - return Create(std::move(value)); -} -Value Value::FromList(std::unique_ptr value) { - assert(value != nullptr); - return Create(std::move(value)); -} -Value Value::FromObject(std::unique_ptr value) { - assert(value != nullptr); - return Create(std::move(value)); -} -Value Value::ForMap(const Map* value) { - assert(value != nullptr); - return Create(value); -} -Value Value::ForList(const List* value) { - assert(value != nullptr); - return Create(value); -} -Value Value::ForObject(const Object* value) { - assert(value != nullptr); - return Create(value); -} -Type Value::type_value() const { return get(); } -const Map& Value::map_value() const { return get(); } -const List& Value::list_value() const { return get(); } -EnumValue Value::enum_value() const { return get(); } -const Object& Value::object_value() const { return get(); } -absl::string_view Value::string_value() const { return GetStr(data_); } -absl::string_view Value::bytes_value() const { return GetStr(data_); } -const absl::Duration& Value::duration() const { return get(); } -const absl::Time& Value::time() const { return get(); } -const Error& Value::error_value() const { return get(); } -Unknown Value::unknown_value() const { return get(); } - -template -Value Value::MakeList(Args&&... args) { - return FromList(absl::make_unique(std::forward(args)...)); -} + // Convenience method for use with template metaprogramming. See + // `IsNull()`. + template + std::enable_if_t, bool> Is() const { + return IsNull(); + } -template -Value Value::MakeMap(Args&&... args) { - return FromMap(absl::make_unique(std::forward(args)...)); -} + // Convenience method for use with template metaprogramming. See + // `IsOpaque()`. + template + std::enable_if_t, bool> Is() const { + return IsOpaque(); + } -template -Value Value::MakeObject(Args&&... args) { - return FromObject(absl::make_unique(std::forward(args)...)); -} + // Convenience method for use with template metaprogramming. See + // `IsOptional()`. + template + std::enable_if_t, bool> Is() const { + return IsOptional(); + } -bool Value::is_value() const { - return data_.index() < kValueEnd && data_.index() != kId; -} + // Convenience method for use with template metaprogramming. See + // `IsParsedJsonList()`. + template + std::enable_if_t, bool> Is() const { + return IsParsedJsonList(); + } -#define EXPR_INTERNAL_KIND_HELPER_BASE(name, type) \ - static Value::GetKType get(const Value* value) { \ - return value->get(); \ - } \ - static Value::GetIfKType get_if(const Value* value) { \ - return value->get_if(); \ - } \ - template \ - static Value From(T&& value) { \ - return Value::From##name(std::forward(value)); \ - } - -#define EXPR_INTERNAL_KIND_HELPER(name, type) \ - template <> \ - struct Value::KindHelper { \ - EXPR_INTERNAL_KIND_HELPER_BASE(name, type) \ - template \ - static Value For(T* value, const ParentRef& parent) { \ - return Value::From##name(std::forward(*value)); \ - } \ - }; - -#define EXPR_INTERNAL_KIND_HELPER_WITH_FOR(name, type) \ - template <> \ - struct Value::KindHelper { \ - EXPR_INTERNAL_KIND_HELPER_BASE(name, type) \ - template \ - static Value For(T&& value, const ParentRef& parent) { \ - return Value::For##name(std::forward(value), parent); \ - } \ - }; - -EXPR_INTERNAL_KIND_HELPER(Bool, bool); -EXPR_INTERNAL_KIND_HELPER(Int, int64_t); -EXPR_INTERNAL_KIND_HELPER(UInt, uint64_t); -EXPR_INTERNAL_KIND_HELPER(Double, double); -EXPR_INTERNAL_KIND_HELPER(Type, Type); -EXPR_INTERNAL_KIND_HELPER_WITH_FOR(Map, Map); -EXPR_INTERNAL_KIND_HELPER_WITH_FOR(List, List); -EXPR_INTERNAL_KIND_HELPER_WITH_FOR(Object, Object); -EXPR_INTERNAL_KIND_HELPER(Enum, EnumValue); -EXPR_INTERNAL_KIND_HELPER(Duration, absl::Duration); -EXPR_INTERNAL_KIND_HELPER(Time, absl::Time); -EXPR_INTERNAL_KIND_HELPER(Error, Error); -EXPR_INTERNAL_KIND_HELPER(Unknown, Unknown); + // Convenience method for use with template metaprogramming. See + // `IsParsedJsonMap()`. + template + std::enable_if_t, bool> Is() const { + return IsParsedJsonMap(); + } -template <> -struct Value::KindHelper { - static Value::GetKType get(const Value* value) { - return value->get(); + // Convenience method for use with template metaprogramming. See + // `IsCustomList()`. + template + std::enable_if_t, bool> Is() const { + return IsCustomList(); } - static Value::GetIfKType get_if(const Value* value) { - return value->get_if(); + + // Convenience method for use with template metaprogramming. See + // `IsCustomMap()`. + template + std::enable_if_t, bool> Is() const { + return IsCustomMap(); } - static Value From(std::nullptr_t) { return Value::NullValue(); } + + // Convenience method for use with template metaprogramming. See + // `IsParsedMapField()`. template - static Value For(const std::nullptr_t* value, const ParentRef& parent) { - return Value::NullValue(); + std::enable_if_t, bool> Is() const { + return IsParsedMapField(); } -}; -template <> -struct Value::KindHelper { - static absl::string_view get(const Value* value) { - return value->string_value(); + // Convenience method for use with template metaprogramming. See + // `IsParsedMessage()`. + template + std::enable_if_t, bool> Is() const { + return IsParsedMessage(); } - static absl::optional get_if(const Value* value) { - if (value->kind() == Value::Kind::kString) return value->string_value(); - return absl::nullopt; + + // Convenience method for use with template metaprogramming. See + // `IsParsedRepeatedField()`. + template + std::enable_if_t, bool> Is() + const { + return IsParsedRepeatedField(); } + + // Convenience method for use with template metaprogramming. See + // `IsParsedStruct()`. template - static Value From(T&& value) { - return Value::FromString(std::forward(value)); + std::enable_if_t, bool> Is() const { + return IsCustomStruct(); } + + // Convenience method for use with template metaprogramming. See + // `IsString()`. template - static Value For(T&& value, const ParentRef& parent) { - return Value::ForString(std::forward(value), parent); + std::enable_if_t, bool> Is() const { + return IsString(); } -}; -template <> -struct Value::KindHelper { - static absl::string_view get(const Value* value) { - return value->bytes_value(); + // Convenience method for use with template metaprogramming. See + // `IsStruct()`. + template + std::enable_if_t, bool> Is() const { + return IsStruct(); } - static absl::optional get_if(const Value* value) { - if (value->kind() == Value::Kind::kBytes) return value->bytes_value(); - return absl::nullopt; + + // Convenience method for use with template metaprogramming. See + // `IsTimestamp()`. + template + std::enable_if_t, bool> Is() const { + return IsTimestamp(); } + + // Convenience method for use with template metaprogramming. See + // `IsType()`. template - static Value From(T&& value) { - return Value::FromBytes(std::forward(value)); + std::enable_if_t, bool> Is() const { + return IsType(); } + + // Convenience method for use with template metaprogramming. See + // `IsUint()`. template - static Value For(T&& value, const ParentRef& parent) { - return Value::ForBytes(std::forward(value), parent); + std::enable_if_t, bool> Is() const { + return IsUint(); } -}; -#undef EXPR_INTERNAL_KIND_HELPER_BASE -#undef EXPR_INTERNAL_KIND_HELPER -#undef EXPR_INTERNAL_KIND_HELPER_WITH_FOR + // Convenience method for use with template metaprogramming. See + // `IsUnknown()`. + template + std::enable_if_t, bool> Is() const { + return IsUnknown(); + } -template -Value::VisitType Value::visit(V&& vis, F&& value, R&&... rest) { - return absl::visit(AdaptedVisitor(std::forward(vis)), - std::forward(value).data_, - std::forward(rest).data_...); -} + // Performs a checked cast from a value to a bool value, + // returning a non-empty optional with either a value or reference to the + // bool value. Otherwise an empty optional is returned. + absl::optional AsBool() const { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; + } -/** Applies the visitor and returns the result. */ -template -Value::VisitType Value::visit(V&& vis) const { - return absl::visit(AdaptedVisitor(std::forward(vis)), data_); -} + // Performs a checked cast from a value to a bytes value, + // returning a non-empty optional with either a value or reference to the + // bytes value. Otherwise an empty optional is returned. + optional_ref AsBytes() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsBytes(); + } + optional_ref AsBytes() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsBytes() &&; + absl::optional AsBytes() const&& { + return common_internal::AsOptional(AsBytes()); + } -template -inline internal::BaseValue::GetType Value::get() const { - return TypeHelper::get(data_); -} + // Performs a checked cast from a value to a double value, + // returning a non-empty optional with either a value or reference to the + // double value. Otherwise an empty optional is returned. + absl::optional AsDouble() const; + + // Performs a checked cast from a value to a duration value, + // returning a non-empty optional with either a value or reference to the + // duration value. Otherwise an empty optional is returned. + absl::optional AsDuration() const; + + // Performs a checked cast from a value to an error value, + // returning a non-empty optional with either a value or reference to the + // error value. Otherwise an empty optional is returned. + optional_ref AsError() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsError(); + } + optional_ref AsError() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsError() &&; + absl::optional AsError() const&& { + return common_internal::AsOptional(AsError()); + } -template -inline Value::GetKType Value::get() const { - return KindHelper::get(this); -} + // Performs a checked cast from a value to an int value, + // returning a non-empty optional with either a value or reference to the + // int value. Otherwise an empty optional is returned. + absl::optional AsInt() const; + + // Performs a checked cast from a value to a list value, + // returning a non-empty optional with either a value or reference to the + // list value. Otherwise an empty optional is returned. + absl::optional AsList() & { return std::as_const(*this).AsList(); } + absl::optional AsList() const&; + absl::optional AsList() &&; + absl::optional AsList() const&& { + return common_internal::AsOptional(AsList()); + } -template -inline internal::BaseValue::GetIfType Value::get_if() const { - return TypeHelper::get_if(&data_); -} + // Performs a checked cast from a value to a map value, + // returning a non-empty optional with either a value or reference to the + // map value. Otherwise an empty optional is returned. + absl::optional AsMap() & { return std::as_const(*this).AsMap(); } + absl::optional AsMap() const&; + absl::optional AsMap() &&; + absl::optional AsMap() const&& { + return common_internal::AsOptional(AsMap()); + } -template -inline Value::GetIfKType Value::get_if() const { - return KindHelper::get_if(this); -} + // Performs a checked cast from a value to a message value, + // returning a non-empty optional with either a value or reference to the + // message value. Otherwise an empty optional is returned. + absl::optional AsMessage() & { + return std::as_const(*this).AsMessage(); + } + absl::optional AsMessage() const&; + absl::optional AsMessage() &&; + absl::optional AsMessage() const&& { + return common_internal::AsOptional(AsMessage()); + } -template -Value Value::Create(Args&&... args) { - return Value(absl::in_place_index_t(), std::forward(args)...); -} + // Performs a checked cast from a value to a null value, + // returning a non-empty optional with either a value or reference to the + // null value. Otherwise an empty optional is returned. + absl::optional AsNull() const; -inline bool Value::index_in(std::size_t start, std::size_t end) const { - return data_.index() >= start && data_.index() < end; -} + // Performs a checked cast from a value to an opaque value, + // returning a non-empty optional with either a value or reference to the + // opaque value. Otherwise an empty optional is returned. + optional_ref AsOpaque() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsOpaque(); + } + optional_ref AsOpaque() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsOpaque() &&; + absl::optional AsOpaque() const&& { + return common_internal::AsOptional(AsOpaque()); + } -template -Map::ForEachReturnType Map::ForEach( - T&& call) const { - return internal::IsOk(ForEach([&call](const Value& key, const Value& value) { - return call(key, value) ? internal::OkStatus() : internal::CancelledError(); - })); -} + // Performs a checked cast from a value to an optional value, + // returning a non-empty optional with either a value or reference to the + // optional value. Otherwise an empty optional is returned. + optional_ref AsOptional() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsOptional(); + } + optional_ref AsOptional() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsOptional() &&; + absl::optional AsOptional() const&& { + return common_internal::AsOptional(AsOptional()); + } -template -Map::ForEachReturnType Map::ForEach( - T&& call) const { - ForEach([&call](const Value& key, const Value& value) { - call(key, value); - return internal::OkStatus(); - }); -} + // Performs a checked cast from a value to a parsed JSON list value, + // returning a non-empty optional with either a value or reference to the + // parsed message value. Otherwise an empty optional is returned. + optional_ref AsParsedJsonList() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsParsedJsonList(); + } + optional_ref AsParsedJsonList() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsParsedJsonList() &&; + absl::optional AsParsedJsonList() const&& { + return common_internal::AsOptional(AsParsedJsonList()); + } -template -List::ForEachReturnType List::ForEach(T&& call) const { - return internal::IsOk(ForEach([&call](const Value& value) { - return call(value) ? internal::OkStatus() : internal::CancelledError(); - })); -} + // Performs a checked cast from a value to a parsed JSON map value, + // returning a non-empty optional with either a value or reference to the + // parsed message value. Otherwise an empty optional is returned. + optional_ref AsParsedJsonMap() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsParsedJsonMap(); + } + optional_ref AsParsedJsonMap() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsParsedJsonMap() &&; + absl::optional AsParsedJsonMap() const&& { + return common_internal::AsOptional(AsParsedJsonMap()); + } -template -List::ForEachReturnType List::ForEach(T&& call) const { - ForEach([&call](const Value& value) { - call(value); - return internal::OkStatus(); - }); -} + // Performs a checked cast from a value to a custom list value, + // returning a non-empty optional with either a value or reference to the + // custom list value. Otherwise an empty optional is returned. + optional_ref AsCustomList() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsCustomList(); + } + optional_ref AsCustomList() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsCustomList() &&; + absl::optional AsCustomList() const&& { + return common_internal::AsOptional(AsCustomList()); + } -template -Object::ForEachReturnType -Object::ForEach(T&& call) const { - return internal::IsOk( - ForEach([&call](absl::string_view name, const Value& value) { - return call(name, value) ? internal::OkStatus() - : internal::CancelledError(); - })); -} + // Performs a checked cast from a value to a custom map value, + // returning a non-empty optional with either a value or reference to the + // custom map value. Otherwise an empty optional is returned. + optional_ref AsCustomMap() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsCustomMap(); + } + optional_ref AsCustomMap() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsCustomMap() &&; + absl::optional AsCustomMap() const&& { + return common_internal::AsOptional(AsCustomMap()); + } -template -Object::ForEachReturnType -Object::ForEach(T&& call) const { - ForEach([&call](absl::string_view name, const Value& value) { - call(name, value); - return internal::OkStatus(); - }); -} + // Performs a checked cast from a value to a parsed map field value, + // returning a non-empty optional with either a value or reference to the + // parsed map field value. Otherwise an empty optional is returned. + optional_ref AsParsedMapField() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsParsedMapField(); + } + optional_ref AsParsedMapField() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsParsedMapField() &&; + absl::optional AsParsedMapField() const&& { + return common_internal::AsOptional(AsParsedMapField()); + } -// Overloads for printing. -inline std::ostream& operator<<(std::ostream& os, const Value& value) { - return os << value.ToString(); -} + // Performs a checked cast from a value to a parsed message value, + // returning a non-empty optional with either a value or reference to the + // parsed message value. Otherwise an empty optional is returned. + optional_ref AsParsedMessage() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsParsedMessage(); + } + optional_ref AsParsedMessage() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsParsedMessage() &&; + absl::optional AsParsedMessage() const&& { + return common_internal::AsOptional(AsParsedMessage()); + } -inline std::ostream& operator<<(std::ostream& os, const Object& value) { - return os << value.ToString(); -} + // Performs a checked cast from a value to a parsed repeated field value, + // returning a non-empty optional with either a value or reference to the + // parsed repeated field value. Otherwise an empty optional is returned. + optional_ref AsParsedRepeatedField() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsParsedRepeatedField(); + } + optional_ref AsParsedRepeatedField() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsParsedRepeatedField() &&; + absl::optional AsParsedRepeatedField() const&& { + return common_internal::AsOptional(AsParsedRepeatedField()); + } -inline std::ostream& operator<<(std::ostream& os, const List& value) { - return os << value.ToString(); -} -inline std::ostream& operator<<(std::ostream& os, const Map& value) { - return os << value.ToString(); -} + // Performs a checked cast from a value to a custom struct value, + // returning a non-empty optional with either a value or reference to the + // custom struct value. Otherwise an empty optional is returned. + optional_ref AsCustomStruct() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsCustomStruct(); + } + optional_ref AsCustomStruct() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsCustomStruct() &&; + absl::optional AsCustomStruct() const&& { + return common_internal::AsOptional(AsCustomStruct()); + } + + // Performs a checked cast from a value to a string value, + // returning a non-empty optional with either a value or reference to the + // string value. Otherwise an empty optional is returned. + optional_ref AsString() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsString(); + } + optional_ref AsString() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsString() &&; + absl::optional AsString() const&& { + return common_internal::AsOptional(AsString()); + } + + // Performs a checked cast from a value to a struct value, + // returning a non-empty optional with either a value or reference to the + // struct value. Otherwise an empty optional is returned. + absl::optional AsStruct() & { + return std::as_const(*this).AsStruct(); + } + absl::optional AsStruct() const&; + absl::optional AsStruct() &&; + absl::optional AsStruct() const&& { + return common_internal::AsOptional(AsStruct()); + } -} // namespace common -} // namespace expr -} // namespace api -} // namespace google + // Performs a checked cast from a value to a timestamp value, + // returning a non-empty optional with either a value or reference to the + // timestamp value. Otherwise an empty optional is returned. + absl::optional AsTimestamp() const; -// Custom specialization of std::hash for Value. -namespace std { -template <> -struct hash { - typedef google::api::expr::common::Value argument_type; - typedef std::size_t result_type; - result_type operator()(argument_type const& value) const noexcept { - return value.hash_code(); + // Performs a checked cast from a value to a type value, + // returning a non-empty optional with either a value or reference to the + // type value. Otherwise an empty optional is returned. + optional_ref AsType() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsType(); } -}; -} // namespace std + optional_ref AsType() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsType() &&; + absl::optional AsType() const&& { + return common_internal::AsOptional(AsType()); + } + + // Performs a checked cast from a value to an uint value, + // returning a non-empty optional with either a value or reference to the + // uint value. Otherwise an empty optional is returned. + absl::optional AsUint() const; + + // Performs a checked cast from a value to an unknown value, + // returning a non-empty optional with either a value or reference to the + // unknown value. Otherwise an empty optional is returned. + optional_ref AsUnknown() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsUnknown(); + } + optional_ref AsUnknown() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsUnknown() &&; + absl::optional AsUnknown() const&& { + return common_internal::AsOptional(AsUnknown()); + } + + // Convenience method for use with template metaprogramming. See + // `AsBool()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsBool(); + } + template + std::enable_if_t, absl::optional> As() + const& { + return AsBool(); + } + template + std::enable_if_t, absl::optional> + As() && { + return AsBool(); + } + template + std::enable_if_t, absl::optional> As() + const&& { + return AsBool(); + } + + // Convenience method for use with template metaprogramming. See + // `AsBytes()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsBytes(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsBytes(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsBytes(); + } + template + std::enable_if_t, absl::optional> + As() const&& { + return std::move(*this).AsBytes(); + } + + // Convenience method for use with template metaprogramming. See + // `AsDouble()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsDouble(); + } + template + std::enable_if_t, absl::optional> + As() const& { + return AsDouble(); + } + template + std::enable_if_t, absl::optional> + As() && { + return AsDouble(); + } + template + std::enable_if_t, absl::optional> + As() const&& { + return AsDouble(); + } + + // Convenience method for use with template metaprogramming. See + // `AsDuration()`. + template + std::enable_if_t, + absl::optional> + As() & { + return AsDuration(); + } + template + std::enable_if_t, + absl::optional> + As() const& { + return AsDuration(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return AsDuration(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return AsDuration(); + } + + // Convenience method for use with template metaprogramming. See + // `AsError()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsError(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsError(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsError(); + } + template + std::enable_if_t, absl::optional> + As() const&& { + return std::move(*this).AsError(); + } + + // Convenience method for use with template metaprogramming. See + // `AsInt()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsInt(); + } + template + std::enable_if_t, absl::optional> As() + const& { + return AsInt(); + } + template + std::enable_if_t, absl::optional> + As() && { + return AsInt(); + } + template + std::enable_if_t, absl::optional> As() + const&& { + return AsInt(); + } + + // Convenience method for use with template metaprogramming. See + // `AsList()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsList(); + } + template + std::enable_if_t, absl::optional> As() + const& { + return AsList(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsList(); + } + template + std::enable_if_t, absl::optional> As() + const&& { + return std::move(*this).AsList(); + } + + // Convenience method for use with template metaprogramming. See + // `AsMap()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsMap(); + } + template + std::enable_if_t, absl::optional> As() + const& { + return AsMap(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsMap(); + } + template + std::enable_if_t, absl::optional> As() + const&& { + return std::move(*this).AsMap(); + } + + // Convenience method for use with template metaprogramming. See + // `AsMessage()`. + template + std::enable_if_t, + absl::optional> + As() & { + return AsMessage(); + } + template + std::enable_if_t, + absl::optional> + As() const& { + return AsMessage(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsMessage(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsMessage(); + } + + // Convenience method for use with template metaprogramming. See + // `AsNull()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsNull(); + } + template + std::enable_if_t, absl::optional> As() + const& { + return AsNull(); + } + template + std::enable_if_t, absl::optional> + As() && { + return AsNull(); + } + template + std::enable_if_t, absl::optional> As() + const&& { + return AsNull(); + } + + // Convenience method for use with template metaprogramming. See + // `AsOpaque()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsOpaque(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsOpaque(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsOpaque(); + } + template + std::enable_if_t, absl::optional> + As() const&& { + return std::move(*this).AsOpaque(); + } + + // Convenience method for use with template metaprogramming. See + // `AsOptional()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsOptional(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsOptional(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsOptional(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsOptional(); + } + + // Convenience method for use with template metaprogramming. See + // `AsParsedJsonList()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedJsonList(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedJsonList(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsParsedJsonList(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsParsedJsonList(); + } + + // Convenience method for use with template metaprogramming. See + // `AsParsedJsonMap()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedJsonMap(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedJsonMap(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsParsedJsonMap(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsParsedJsonMap(); + } + + // Convenience method for use with template metaprogramming. See + // `AsCustomList()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustomList(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustomList(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsCustomList(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsCustomList(); + } + + // Convenience method for use with template metaprogramming. See + // `AsCustomMap()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustomMap(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustomMap(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsCustomMap(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsCustomMap(); + } + + // Convenience method for use with template metaprogramming. See + // `AsParsedMapField()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedMapField(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedMapField(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsParsedMapField(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsParsedMapField(); + } + + // Convenience method for use with template metaprogramming. See + // `AsParsedMessage()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedMessage(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedMessage(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsParsedMessage(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsParsedMessage(); + } + + // Convenience method for use with template metaprogramming. See + // `AsParsedRepeatedField()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedRepeatedField(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedRepeatedField(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsParsedRepeatedField(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsParsedRepeatedField(); + } + + // Convenience method for use with template metaprogramming. See + // `AsCustomStruct()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustomStruct(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustomStruct(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsCustomStruct(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsCustomStruct(); + } + + // Convenience method for use with template metaprogramming. See + // `AsString()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsString(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsString(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsString(); + } + template + std::enable_if_t, absl::optional> + As() const&& { + return std::move(*this).AsString(); + } + + // Convenience method for use with template metaprogramming. See + // `AsStruct()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsStruct(); + } + template + std::enable_if_t, absl::optional> + As() const& { + return AsStruct(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsStruct(); + } + template + std::enable_if_t, absl::optional> + As() const&& { + return std::move(*this).AsStruct(); + } + + // Convenience method for use with template metaprogramming. See + // `AsTimestamp()`. + template + std::enable_if_t, + absl::optional> + As() & { + return AsTimestamp(); + } + template + std::enable_if_t, + absl::optional> + As() const& { + return AsTimestamp(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return AsTimestamp(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return AsTimestamp(); + } + + // Convenience method for use with template metaprogramming. See + // `AsType()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsType(); + } + template + std::enable_if_t, optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsType(); + } + template + std::enable_if_t, absl::optional> + As() && { + return std::move(*this).AsType(); + } + template + std::enable_if_t, absl::optional> As() + const&& { + return std::move(*this).AsType(); + } + + // Convenience method for use with template metaprogramming. See + // `AsUint()`. + template + std::enable_if_t, absl::optional> + As() & { + return AsUint(); + } + template + std::enable_if_t, absl::optional> As() + const& { + return AsUint(); + } + template + std::enable_if_t, absl::optional> + As() && { + return AsUint(); + } + template + std::enable_if_t, absl::optional> As() + const&& { + return AsUint(); + } + + // Convenience method for use with template metaprogramming. See + // `AsUnknown()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsUnknown(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsUnknown(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsUnknown(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsUnknown(); + } + + // Performs an unchecked cast from a value to a bool value. In + // debug builds a best effort is made to crash. If `IsBool()` would return + // false, calling this method is undefined behavior. + BoolValue GetBool() const { + ABSL_DCHECK(IsBool()) << *this; + return variant_.Get(); + } + + // Performs an unchecked cast from a value to a bytes value. In + // debug builds a best effort is made to crash. If `IsBytes()` would return + // false, calling this method is undefined behavior. + const BytesValue& GetBytes() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetBytes(); + } + const BytesValue& GetBytes() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + BytesValue GetBytes() &&; + BytesValue GetBytes() const&& { return GetBytes(); } + + // Performs an unchecked cast from a value to a double value. In + // debug builds a best effort is made to crash. If `IsDouble()` would return + // false, calling this method is undefined behavior. + DoubleValue GetDouble() const; + + // Performs an unchecked cast from a value to a duration value. In + // debug builds a best effort is made to crash. If `IsDuration()` would return + // false, calling this method is undefined behavior. + DurationValue GetDuration() const; + + // Performs an unchecked cast from a value to an error value. In + // debug builds a best effort is made to crash. If `IsError()` would return + // false, calling this method is undefined behavior. + const ErrorValue& GetError() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetError(); + } + const ErrorValue& GetError() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ErrorValue GetError() &&; + ErrorValue GetError() const&& { return GetError(); } + + // Performs an unchecked cast from a value to an int value. In + // debug builds a best effort is made to crash. If `IsInt()` would return + // false, calling this method is undefined behavior. + IntValue GetInt() const; + + // Performs an unchecked cast from a value to a list value. In + // debug builds a best effort is made to crash. If `IsList()` would return + // false, calling this method is undefined behavior. + ListValue GetList() & { return std::as_const(*this).GetList(); } + ListValue GetList() const&; + ListValue GetList() &&; + ListValue GetList() const&& { return GetList(); } + + // Performs an unchecked cast from a value to a map value. In + // debug builds a best effort is made to crash. If `IsMap()` would return + // false, calling this method is undefined behavior. + MapValue GetMap() & { return std::as_const(*this).GetMap(); } + MapValue GetMap() const&; + MapValue GetMap() &&; + MapValue GetMap() const&& { return GetMap(); } + + // Performs an unchecked cast from a value to a message value. In + // debug builds a best effort is made to crash. If `IsMessage()` would return + // false, calling this method is undefined behavior. + MessageValue GetMessage() & { return std::as_const(*this).GetMessage(); } + MessageValue GetMessage() const&; + MessageValue GetMessage() &&; + MessageValue GetMessage() const&& { return GetMessage(); } + + // Performs an unchecked cast from a value to a null value. In + // debug builds a best effort is made to crash. If `IsNull()` would return + // false, calling this method is undefined behavior. + NullValue GetNull() const; + + // Performs an unchecked cast from a value to an opaque value. In + // debug builds a best effort is made to crash. If `IsOpaque()` would return + // false, calling this method is undefined behavior. + const OpaqueValue& GetOpaque() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetOpaque(); + } + const OpaqueValue& GetOpaque() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + OpaqueValue GetOpaque() &&; + OpaqueValue GetOpaque() const&& { return GetOpaque(); } + + // Performs an unchecked cast from a value to an optional value. In + // debug builds a best effort is made to crash. If `IsOptional()` would return + // false, calling this method is undefined behavior. + const OptionalValue& GetOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetOptional(); + } + const OptionalValue& GetOptional() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + OptionalValue GetOptional() &&; + OptionalValue GetOptional() const&& { return GetOptional(); } + + // Performs an unchecked cast from a value to a parsed message value. In + // debug builds a best effort is made to crash. If `IsParsedJsonList()` would + // return false, calling this method is undefined behavior. + const ParsedJsonListValue& GetParsedJsonList() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetParsedJsonList(); + } + const ParsedJsonListValue& GetParsedJsonList() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ParsedJsonListValue GetParsedJsonList() &&; + ParsedJsonListValue GetParsedJsonList() const&& { + return GetParsedJsonList(); + } + + // Performs an unchecked cast from a value to a parsed message value. In + // debug builds a best effort is made to crash. If `IsParsedJsonMap()` would + // return false, calling this method is undefined behavior. + const ParsedJsonMapValue& GetParsedJsonMap() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetParsedJsonMap(); + } + const ParsedJsonMapValue& GetParsedJsonMap() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ParsedJsonMapValue GetParsedJsonMap() &&; + ParsedJsonMapValue GetParsedJsonMap() const&& { return GetParsedJsonMap(); } + + // Performs an unchecked cast from a value to a custom list value. In + // debug builds a best effort is made to crash. If `IsCustomList()` would + // return false, calling this method is undefined behavior. + const CustomListValue& GetCustomList() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetCustomList(); + } + const CustomListValue& GetCustomList() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + CustomListValue GetCustomList() &&; + CustomListValue GetCustomList() const&& { return GetCustomList(); } + + // Performs an unchecked cast from a value to a custom map value. In + // debug builds a best effort is made to crash. If `IsCustomMap()` would + // return false, calling this method is undefined behavior. + const CustomMapValue& GetCustomMap() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetCustomMap(); + } + const CustomMapValue& GetCustomMap() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + CustomMapValue GetCustomMap() &&; + CustomMapValue GetCustomMap() const&& { return GetCustomMap(); } + + // Performs an unchecked cast from a value to a parsed map field value. In + // debug builds a best effort is made to crash. If `IsParsedMapField()` would + // return false, calling this method is undefined behavior. + const ParsedMapFieldValue& GetParsedMapField() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetParsedMapField(); + } + const ParsedMapFieldValue& GetParsedMapField() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ParsedMapFieldValue GetParsedMapField() &&; + ParsedMapFieldValue GetParsedMapField() const&& { + return GetParsedMapField(); + } + + // Performs an unchecked cast from a value to a parsed message value. In + // debug builds a best effort is made to crash. If `IsParsedMessage()` would + // return false, calling this method is undefined behavior. + const ParsedMessageValue& GetParsedMessage() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetParsedMessage(); + } + const ParsedMessageValue& GetParsedMessage() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ParsedMessageValue GetParsedMessage() &&; + ParsedMessageValue GetParsedMessage() const&& { return GetParsedMessage(); } + + // Performs an unchecked cast from a value to a parsed repeated field value. + // In debug builds a best effort is made to crash. If + // `IsParsedRepeatedField()` would return false, calling this method is + // undefined behavior. + const ParsedRepeatedFieldValue& GetParsedRepeatedField() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetParsedRepeatedField(); + } + const ParsedRepeatedFieldValue& GetParsedRepeatedField() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ParsedRepeatedFieldValue GetParsedRepeatedField() &&; + ParsedRepeatedFieldValue GetParsedRepeatedField() const&& { + return GetParsedRepeatedField(); + } + + // Performs an unchecked cast from a value to a custom struct value. In + // debug builds a best effort is made to crash. If `IsCustomStruct()` would + // return false, calling this method is undefined behavior. + const CustomStructValue& GetCustomStruct() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetCustomStruct(); + } + const CustomStructValue& GetCustomStruct() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + CustomStructValue GetCustomStruct() &&; + CustomStructValue GetCustomStruct() const&& { return GetCustomStruct(); } + + // Performs an unchecked cast from a value to a string value. In + // debug builds a best effort is made to crash. If `IsString()` would return + // false, calling this method is undefined behavior. + const StringValue& GetString() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetString(); + } + const StringValue& GetString() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + StringValue GetString() &&; + StringValue GetString() const&& { return GetString(); } + + // Performs an unchecked cast from a value to a struct value. In + // debug builds a best effort is made to crash. If `IsStruct()` would return + // false, calling this method is undefined behavior. + StructValue GetStruct() & { return std::as_const(*this).GetStruct(); } + StructValue GetStruct() const&; + StructValue GetStruct() &&; + StructValue GetStruct() const&& { return GetStruct(); } + + // Performs an unchecked cast from a value to a timestamp value. In + // debug builds a best effort is made to crash. If `IsTimestamp()` would + // return false, calling this method is undefined behavior. + TimestampValue GetTimestamp() const; + + // Performs an unchecked cast from a value to a type value. In + // debug builds a best effort is made to crash. If `IsType()` would return + // false, calling this method is undefined behavior. + const TypeValue& GetType() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetType(); + } + const TypeValue& GetType() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + TypeValue GetType() &&; + TypeValue GetType() const&& { return GetType(); } + + // Performs an unchecked cast from a value to an uint value. In + // debug builds a best effort is made to crash. If `IsUint()` would return + // false, calling this method is undefined behavior. + UintValue GetUint() const; + + // Performs an unchecked cast from a value to an unknown value. In + // debug builds a best effort is made to crash. If `IsUnknown()` would return + // false, calling this method is undefined behavior. + const UnknownValue& GetUnknown() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetUnknown(); + } + const UnknownValue& GetUnknown() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + UnknownValue GetUnknown() &&; + UnknownValue GetUnknown() const&& { return GetUnknown(); } + + // Convenience method for use with template metaprogramming. See + // `GetBool()`. + template + std::enable_if_t, BoolValue> Get() & { + return GetBool(); + } + template + std::enable_if_t, BoolValue> Get() const& { + return GetBool(); + } + template + std::enable_if_t, BoolValue> Get() && { + return GetBool(); + } + template + std::enable_if_t, BoolValue> Get() const&& { + return GetBool(); + } + + // Convenience method for use with template metaprogramming. See + // `GetBytes()`. + template + std::enable_if_t, const BytesValue&> Get() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetBytes(); + } + template + std::enable_if_t, const BytesValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetBytes(); + } + template + std::enable_if_t, BytesValue> Get() && { + return std::move(*this).GetBytes(); + } + template + std::enable_if_t, BytesValue> Get() const&& { + return std::move(*this).GetBytes(); + } + + // Convenience method for use with template metaprogramming. See + // `GetDouble()`. + template + std::enable_if_t, DoubleValue> Get() & { + return GetDouble(); + } + template + std::enable_if_t, DoubleValue> Get() const& { + return GetDouble(); + } + template + std::enable_if_t, DoubleValue> Get() && { + return GetDouble(); + } + template + std::enable_if_t, DoubleValue> Get() const&& { + return GetDouble(); + } + + // Convenience method for use with template metaprogramming. See + // `GetDuration()`. + template + std::enable_if_t, DurationValue> Get() & { + return GetDuration(); + } + template + std::enable_if_t, DurationValue> Get() + const& { + return GetDuration(); + } + template + std::enable_if_t, DurationValue> Get() && { + return GetDuration(); + } + template + std::enable_if_t, DurationValue> Get() + const&& { + return GetDuration(); + } + + // Convenience method for use with template metaprogramming. See + // `GetError()`. + template + std::enable_if_t, const ErrorValue&> Get() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetError(); + } + template + std::enable_if_t, const ErrorValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetError(); + } + template + std::enable_if_t, ErrorValue> Get() && { + return std::move(*this).GetError(); + } + template + std::enable_if_t, ErrorValue> Get() const&& { + return std::move(*this).GetError(); + } + + // Convenience method for use with template metaprogramming. See + // `GetInt()`. + template + std::enable_if_t, IntValue> Get() & { + return GetInt(); + } + template + std::enable_if_t, IntValue> Get() const& { + return GetInt(); + } + template + std::enable_if_t, IntValue> Get() && { + return GetInt(); + } + template + std::enable_if_t, IntValue> Get() const&& { + return GetInt(); + } + + // Convenience method for use with template metaprogramming. See + // `GetList()`. + template + std::enable_if_t, ListValue> Get() & { + return GetList(); + } + template + std::enable_if_t, ListValue> Get() const& { + return GetList(); + } + template + std::enable_if_t, ListValue> Get() && { + return std::move(*this).GetList(); + } + template + std::enable_if_t, ListValue> Get() const&& { + return std::move(*this).GetList(); + } + + // Convenience method for use with template metaprogramming. See + // `GetMap()`. + template + std::enable_if_t, MapValue> Get() & { + return GetMap(); + } + template + std::enable_if_t, MapValue> Get() const& { + return GetMap(); + } + template + std::enable_if_t, MapValue> Get() && { + return std::move(*this).GetMap(); + } + template + std::enable_if_t, MapValue> Get() const&& { + return std::move(*this).GetMap(); + } + + // Convenience method for use with template metaprogramming. See + // `GetMessage()`. + template + std::enable_if_t, MessageValue> Get() & { + return GetMessage(); + } + template + std::enable_if_t, MessageValue> Get() const& { + return GetMessage(); + } + template + std::enable_if_t, MessageValue> Get() && { + return std::move(*this).GetMessage(); + } + template + std::enable_if_t, MessageValue> Get() + const&& { + return std::move(*this).GetMessage(); + } + + // Convenience method for use with template metaprogramming. See + // `GetNull()`. + template + std::enable_if_t, NullValue> Get() & { + return GetNull(); + } + template + std::enable_if_t, NullValue> Get() const& { + return GetNull(); + } + template + std::enable_if_t, NullValue> Get() && { + return GetNull(); + } + template + std::enable_if_t, NullValue> Get() const&& { + return GetNull(); + } + + // Convenience method for use with template metaprogramming. See + // `GetOpaque()`. + template + std::enable_if_t, const OpaqueValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetOpaque(); + } + template + std::enable_if_t, const OpaqueValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetOpaque(); + } + template + std::enable_if_t, OpaqueValue> Get() && { + return std::move(*this).GetOpaque(); + } + template + std::enable_if_t, OpaqueValue> Get() const&& { + return std::move(*this).GetOpaque(); + } + + // Convenience method for use with template metaprogramming. See + // `GetOptional()`. + template + std::enable_if_t, const OptionalValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetOptional(); + } + template + std::enable_if_t, const OptionalValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetOptional(); + } + template + std::enable_if_t, OptionalValue> Get() && { + return std::move(*this).GetOptional(); + } + template + std::enable_if_t, OptionalValue> Get() + const&& { + return std::move(*this).GetOptional(); + } + + // Convenience method for use with template metaprogramming. See + // `GetParsedJsonList()`. + template + std::enable_if_t, + const ParsedJsonListValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedJsonList(); + } + template + std::enable_if_t, + const ParsedJsonListValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedJsonList(); + } + template + std::enable_if_t, ParsedJsonListValue> + Get() && { + return std::move(*this).GetParsedJsonList(); + } + template + std::enable_if_t, ParsedJsonListValue> + Get() const&& { + return std::move(*this).GetParsedJsonList(); + } + + // Convenience method for use with template metaprogramming. See + // `GetParsedJsonMap()`. + template + std::enable_if_t, + const ParsedJsonMapValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedJsonMap(); + } + template + std::enable_if_t, + const ParsedJsonMapValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedJsonMap(); + } + template + std::enable_if_t, ParsedJsonMapValue> + Get() && { + return std::move(*this).GetParsedJsonMap(); + } + template + std::enable_if_t, ParsedJsonMapValue> + Get() const&& { + return std::move(*this).GetParsedJsonMap(); + } + + // Convenience method for use with template metaprogramming. See + // `GetCustomList()`. + template + std::enable_if_t, + const CustomListValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustomList(); + } + template + std::enable_if_t, const CustomListValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustomList(); + } + template + std::enable_if_t, CustomListValue> + Get() && { + return std::move(*this).GetCustomList(); + } + template + std::enable_if_t, CustomListValue> Get() + const&& { + return std::move(*this).GetCustomList(); + } + + // Convenience method for use with template metaprogramming. See + // `GetCustomMap()`. + template + std::enable_if_t, const CustomMapValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustomMap(); + } + template + std::enable_if_t, const CustomMapValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustomMap(); + } + template + std::enable_if_t, CustomMapValue> Get() && { + return std::move(*this).GetCustomMap(); + } + template + std::enable_if_t, CustomMapValue> Get() + const&& { + return std::move(*this).GetCustomMap(); + } + + // Convenience method for use with template metaprogramming. See + // `GetParsedMapField()`. + template + std::enable_if_t, + const ParsedMapFieldValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedMapField(); + } + template + std::enable_if_t, + const ParsedMapFieldValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedMapField(); + } + template + std::enable_if_t, ParsedMapFieldValue> + Get() && { + return std::move(*this).GetParsedMapField(); + } + template + std::enable_if_t, ParsedMapFieldValue> + Get() const&& { + return std::move(*this).GetParsedMapField(); + } + + // Convenience method for use with template metaprogramming. See + // `GetParsedMessage()`. + template + std::enable_if_t, + const ParsedMessageValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedMessage(); + } + template + std::enable_if_t, + const ParsedMessageValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedMessage(); + } + template + std::enable_if_t, ParsedMessageValue> + Get() && { + return std::move(*this).GetParsedMessage(); + } + template + std::enable_if_t, ParsedMessageValue> + Get() const&& { + return std::move(*this).GetParsedMessage(); + } + + // Convenience method for use with template metaprogramming. See + // `GetParsedRepeatedField()`. + template + std::enable_if_t, + const ParsedRepeatedFieldValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedRepeatedField(); + } + template + std::enable_if_t, + const ParsedRepeatedFieldValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedRepeatedField(); + } + template + std::enable_if_t, + ParsedRepeatedFieldValue> + Get() && { + return std::move(*this).GetParsedRepeatedField(); + } + template + std::enable_if_t, + ParsedRepeatedFieldValue> + Get() const&& { + return std::move(*this).GetParsedRepeatedField(); + } + + // Convenience method for use with template metaprogramming. See + // `GetCustomStruct()`. + template + std::enable_if_t, + const CustomStructValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustomStruct(); + } + template + std::enable_if_t, + const CustomStructValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustomStruct(); + } + template + std::enable_if_t, CustomStructValue> + Get() && { + return std::move(*this).GetCustomStruct(); + } + template + std::enable_if_t, CustomStructValue> + Get() const&& { + return std::move(*this).GetCustomStruct(); + } + + // Convenience method for use with template metaprogramming. See + // `GetString()`. + template + std::enable_if_t, const StringValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetString(); + } + template + std::enable_if_t, const StringValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetString(); + } + template + std::enable_if_t, StringValue> Get() && { + return std::move(*this).GetString(); + } + template + std::enable_if_t, StringValue> Get() const&& { + return std::move(*this).GetString(); + } + + // Convenience method for use with template metaprogramming. See + // `GetStruct()`. + template + std::enable_if_t, StructValue> Get() & { + return GetStruct(); + } + template + std::enable_if_t, StructValue> Get() const& { + return GetStruct(); + } + template + std::enable_if_t, StructValue> Get() && { + return std::move(*this).GetStruct(); + } + template + std::enable_if_t, StructValue> Get() const&& { + return std::move(*this).GetStruct(); + } + + // Convenience method for use with template metaprogramming. See + // `GetTimestamp()`. + template + std::enable_if_t, TimestampValue> Get() & { + return GetTimestamp(); + } + template + std::enable_if_t, TimestampValue> Get() + const& { + return GetTimestamp(); + } + template + std::enable_if_t, TimestampValue> Get() && { + return GetTimestamp(); + } + template + std::enable_if_t, TimestampValue> Get() + const&& { + return GetTimestamp(); + } + + // Convenience method for use with template metaprogramming. See + // `GetType()`. + template + std::enable_if_t, const TypeValue&> Get() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetType(); + } + template + std::enable_if_t, const TypeValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetType(); + } + template + std::enable_if_t, TypeValue> Get() && { + return std::move(*this).GetType(); + } + template + std::enable_if_t, TypeValue> Get() const&& { + return std::move(*this).GetType(); + } + + // Convenience method for use with template metaprogramming. See + // `GetUint()`. + template + std::enable_if_t, UintValue> Get() & { + return GetUint(); + } + template + std::enable_if_t, UintValue> Get() const& { + return GetUint(); + } + template + std::enable_if_t, UintValue> Get() && { + return GetUint(); + } + template + std::enable_if_t, UintValue> Get() const&& { + return GetUint(); + } + + // Convenience method for use with template metaprogramming. See + // `GetUnknown()`. + template + std::enable_if_t, const UnknownValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetUnknown(); + } + template + std::enable_if_t, const UnknownValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetUnknown(); + } + template + std::enable_if_t, UnknownValue> Get() && { + return std::move(*this).GetUnknown(); + } + template + std::enable_if_t, UnknownValue> Get() + const&& { + return std::move(*this).GetUnknown(); + } + + // When `Value` is default constructed, it is in a valid but undefined state. + // Any attempt to use it invokes undefined behavior. This mention can be used + // to test whether this value is valid. + explicit operator bool() const { return true; } + + private: + friend struct NativeTypeTraits; + friend bool common_internal::IsLegacyListValue(const Value& value); + friend common_internal::LegacyListValue common_internal::GetLegacyListValue( + const Value& value); + friend bool common_internal::IsLegacyMapValue(const Value& value); + friend common_internal::LegacyMapValue common_internal::GetLegacyMapValue( + const Value& value); + friend bool common_internal::IsLegacyStructValue(const Value& value); + friend common_internal::LegacyStructValue + common_internal::GetLegacyStructValue(const Value& value); + friend class common_internal::ValueMixin; + friend struct ArenaTraits; + + common_internal::ValueVariant variant_; +}; + +// Overloads for heterogeneous equality of numeric values. +bool operator==(IntValue lhs, UintValue rhs); +bool operator==(UintValue lhs, IntValue rhs); +bool operator==(IntValue lhs, DoubleValue rhs); +bool operator==(DoubleValue lhs, IntValue rhs); +bool operator==(UintValue lhs, DoubleValue rhs); +bool operator==(DoubleValue lhs, UintValue rhs); +inline bool operator!=(IntValue lhs, UintValue rhs) { + return !operator==(lhs, rhs); +} +inline bool operator!=(UintValue lhs, IntValue rhs) { + return !operator==(lhs, rhs); +} +inline bool operator!=(IntValue lhs, DoubleValue rhs) { + return !operator==(lhs, rhs); +} +inline bool operator!=(DoubleValue lhs, IntValue rhs) { + return !operator==(lhs, rhs); +} +inline bool operator!=(UintValue lhs, DoubleValue rhs) { + return !operator==(lhs, rhs); +} +inline bool operator!=(DoubleValue lhs, UintValue rhs) { + return !operator==(lhs, rhs); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const Value& value) { + return value.variant_.Visit([](const auto& alternative) -> NativeTypeId { + return NativeTypeId::Of(alternative); + }); + } +}; + +template <> +struct ArenaTraits { + static bool trivially_destructible(const Value& value) { + return value.variant_.Visit([](const auto& alternative) -> bool { + return ArenaTraits<>::trivially_destructible(alternative); + }); + } +}; + +// Statically assert some expectations. +static_assert(sizeof(Value) <= 32); +static_assert(alignof(Value) <= alignof(std::max_align_t)); +static_assert(std::is_default_constructible_v); +static_assert(std::is_copy_constructible_v); +static_assert(std::is_copy_assignable_v); +static_assert(std::is_nothrow_move_constructible_v); +static_assert(std::is_nothrow_move_assignable_v); +static_assert(std::is_nothrow_swappable_v); + +inline common_internal::ImplicitlyConvertibleStatus +ErrorValueAssign::operator()(absl::Status status) const { + *value_ = ErrorValue(std::move(status)); + return common_internal::ImplicitlyConvertibleStatus(); +} + +namespace common_internal { + +template +absl::StatusOr ValueMixin::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->Equal( + other, descriptor_pool, message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr ListValueMixin::Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->Get( + index, descriptor_pool, message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr ListValueMixin::Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->Contains( + other, descriptor_pool, message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr MapValueMixin::Get( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->Get( + key, descriptor_pool, message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr> MapValueMixin::Find( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_ASSIGN_OR_RETURN( + bool found, static_cast(this)->Find( + other, descriptor_pool, message_factory, arena, &result)); + if (found) { + return result; + } + return absl::nullopt; +} + +template +absl::StatusOr MapValueMixin::Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->Has( + key, descriptor_pool, message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr MapValueMixin::ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + ListValue result; + CEL_RETURN_IF_ERROR(static_cast(this)->ListKeys( + descriptor_pool, message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr StructValueMixin::GetFieldByName( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByName( + name, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, + message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr StructValueMixin::GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByName( + name, unboxing_options, descriptor_pool, message_factory, arena, + &result)); + return result; +} + +template +absl::StatusOr StructValueMixin::GetFieldByNumber( + int64_t number, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByNumber( + number, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, + message_factory, arena, &result)); + return result; +} + +template +absl::StatusOr StructValueMixin::GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByNumber( + number, unboxing_options, descriptor_pool, message_factory, arena, + &result)); + return result; +} + +template +absl::StatusOr> StructValueMixin::Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK_GT(qualifiers.size(), 0); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + int count; + CEL_RETURN_IF_ERROR(static_cast(this)->Qualify( + qualifiers, presence_test, descriptor_pool, message_factory, arena, + &result, &count)); + return std::pair{std::move(result), count}; +} + +} // namespace common_internal + +using ValueIteratorPtr = std::unique_ptr; + +inline absl::StatusOr ValueIterator::Next( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(Next(descriptor_pool, message_factory, arena, &result)); + return result; +} + +inline absl::StatusOr> ValueIterator::Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value key_or_value; + CEL_ASSIGN_OR_RETURN( + bool ok, Next1(descriptor_pool, message_factory, arena, &key_or_value)); + if (!ok) { + return absl::nullopt; + } + return key_or_value; +} + +inline absl::StatusOr>> +ValueIterator::Next2(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value key; + Value value; + CEL_ASSIGN_OR_RETURN( + bool ok, Next2(descriptor_pool, message_factory, arena, &key, &value)); + if (!ok) { + return absl::nullopt; + } + return std::pair{std::move(key), std::move(value)}; +} + +absl_nonnull std::unique_ptr NewEmptyValueIterator(); + +class ValueBuilder { + public: + virtual ~ValueBuilder() = default; + + virtual absl::StatusOr> SetFieldByName( + absl::string_view name, Value value) = 0; + + virtual absl::StatusOr> SetFieldByNumber( + int64_t number, Value value) = 0; + + virtual absl::StatusOr Build() && = 0; +}; + +using ValueBuilderPtr = std::unique_ptr; + +absl_nonnull ListValueBuilderPtr +NewListValueBuilder(google::protobuf::Arena* absl_nonnull arena); + +absl_nonnull MapValueBuilderPtr +NewMapValueBuilder(google::protobuf::Arena* absl_nonnull arena); + +// Returns a new `StructValueBuilder`. Returns `nullptr` if there is no such +// message type with the name `name` in `descriptor_pool`. Returns an error if +// `message_factory` is unable to provide a prototype for the descriptor +// returned from `descriptor_pool`. +absl_nullable StructValueBuilderPtr NewStructValueBuilder( + google::protobuf::Arena* absl_nonnull arena, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + absl::string_view name); + +using ListValueBuilderInterface = ListValueBuilder; +using MapValueBuilderInterface = MapValueBuilder; +using StructValueBuilderInterface = StructValueBuilder; + +// Now that Value is complete, we can define various parts of list, map, opaque, +// and struct which depend on Value. + +namespace common_internal { + +using MapFieldKeyAccessor = void (*)(const google::protobuf::MapKey&, + const google::protobuf::Message* absl_nonnull, + google::protobuf::Arena* absl_nonnull, + Value* absl_nonnull); + +absl::StatusOr MapFieldKeyAccessorFor( + const google::protobuf::FieldDescriptor* absl_nonnull field); + +using MapFieldValueAccessor = void (*)( + const google::protobuf::MapValueConstRef&, const google::protobuf::Message* absl_nonnull, + const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull, + Value* absl_nonnull); + +absl::StatusOr MapFieldValueAccessorFor( + const google::protobuf::FieldDescriptor* absl_nonnull field); + +using RepeatedFieldAccessor = + void (*)(int, const google::protobuf::Message* absl_nonnull, + const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull, + Value* absl_nonnull); + +absl::StatusOr RepeatedFieldAccessorFor( + const google::protobuf::FieldDescriptor* absl_nonnull field); + +} // namespace common_internal + +} // namespace cel + +#pragma pop_macro("GetMessage") -#endif // THIRD_PARTY_CEL_CPP_COMMON_CEL_VALUE_H_ +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUE_H_ diff --git a/common/value_kind.h b/common/value_kind.h new file mode 100644 index 000000000..6bf60bcd4 --- /dev/null +++ b/common/value_kind.h @@ -0,0 +1,104 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUE_KIND_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUE_KIND_H_ + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "common/kind.h" + +namespace cel { + +// `ValueKind` is a subset of `Kind`, representing all valid `Kind` for `Value`. +// All `ValueKind` are valid `Kind`, but it is not guaranteed that all `Kind` +// are valid `ValueKind`. +enum class ValueKind : std::underlying_type_t { + kNull = static_cast(Kind::kNull), + kBool = static_cast(Kind::kBool), + kInt = static_cast(Kind::kInt), + kUint = static_cast(Kind::kUint), + kDouble = static_cast(Kind::kDouble), + kString = static_cast(Kind::kString), + kBytes = static_cast(Kind::kBytes), + kStruct = static_cast(Kind::kStruct), + kDuration = static_cast(Kind::kDuration), + kTimestamp = static_cast(Kind::kTimestamp), + kList = static_cast(Kind::kList), + kMap = static_cast(Kind::kMap), + kUnknown = static_cast(Kind::kUnknown), + kType = static_cast(Kind::kType), + kError = static_cast(Kind::kError), + kOpaque = static_cast(Kind::kOpaque), + + // Legacy aliases, deprecated do not use. + kNullType = kNull, + kInt64 = kInt, + kUint64 = kUint, + kMessage = kStruct, + kUnknownSet = kUnknown, + kCelType = kType, + + // INTERNAL: Do not exceed 63. Implementation details rely on the fact that + // we can store `Kind` using 6 bits. + kNotForUseWithExhaustiveSwitchStatements = + static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), +}; + +constexpr Kind ValueKindToKind(ValueKind kind) { + return static_cast( + static_cast>(kind)); +} + +constexpr bool KindIsValueKind(Kind kind) { + return kind != Kind::kBoolWrapper && kind != Kind::kIntWrapper && + kind != Kind::kUintWrapper && kind != Kind::kDoubleWrapper && + kind != Kind::kStringWrapper && kind != Kind::kBytesWrapper && + kind != Kind::kDyn && kind != Kind::kAny && kind != Kind::kTypeParam && + kind != Kind::kFunction; +} + +constexpr bool operator==(Kind lhs, ValueKind rhs) { + return lhs == ValueKindToKind(rhs); +} + +constexpr bool operator==(ValueKind lhs, Kind rhs) { + return ValueKindToKind(lhs) == rhs; +} + +constexpr bool operator!=(Kind lhs, ValueKind rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(ValueKind lhs, Kind rhs) { + return !operator==(lhs, rhs); +} + +inline absl::string_view ValueKindToString(ValueKind kind) { + // All ValueKind are valid Kind. + return KindToString(ValueKindToKind(kind)); +} + +constexpr ValueKind KindToValueKind(Kind kind) { + ABSL_ASSERT(KindIsValueKind(kind)); + return static_cast( + static_cast>(kind)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUE_KIND_H_ diff --git a/common/value_test.cc b/common/value_test.cc index 3a6a60c3a..fb346423b 100644 --- a/common/value_test.cc +++ b/common/value_test.cc @@ -1,800 +1,998 @@ -#include "common/value.h" +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include -#include +#include "common/value.h" -#include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/rpc/status.pb.h" -#include "google/type/money.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "common/custom_object.h" -#include "internal/status_util.h" -#include "internal/types.h" -#include "internal/value_internal.h" -#include "testutil/util.h" - -namespace google { -namespace api { -namespace expr { -namespace common { +#include "google/protobuf/type.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/attributes.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/value_testing.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/generated_enum_reflection.h" + +namespace cel { namespace { -using testutil::ExpectSameType; - -class DummyObject final : public OpaqueObject { - public: - Type object_type() const override { - return Type(ObjectType(google::type::Money::descriptor())); - } - - void To(google::protobuf::Any* value) const override {} - - std::string ToString() const override { return Object::ToString(); } - - protected: - inline bool EqualsImpl(const Object& rhs) const override { return true; } - inline std::size_t ComputeHash() const override { return 0; } -}; - -class DummyMap final : public Map { - public: - std::size_t size() const override { return 1; } - - google::rpc::Status ForEach( - const std::function& - call) const override { - return call(Value::ForString("key"), Value::ForString("value")); - } - - bool owns_value() const override { return true; } - - protected: - Value GetImpl(const Value& key) const override { - if (key == Value::ForString("key")) return Value::ForString("value"); - return Value::FromError(internal::NotFoundError("")); - } -}; - -class DummyList final : public List { - public: - std::size_t size() const override { return 1; } - bool owns_value() const override { return true; } - - google::rpc::Status ForEach( - const std::function& call) - const override { - return call(Value::ForString("elem")); - } - - Value Get(std::size_t index) const override { - if (index == 0) { - return Value::ForString("elem"); - } - return Value::FromError(internal::OutOfRangeError(index, 1)); - } -}; - -struct ValueTestCase { - Value value; - Value value_copy; - Value::Kind kind; - Value type; - std::string debug_string; - std::string type_debug_string; - bool is_inline; - bool is_value; - bool owns_value; - - static Value::Kind GetKind(BasicTypeValue type) { - switch (type) { - case BasicTypeValue::kNull: - return Value::Kind::kNull; - case BasicTypeValue::kBool: - return Value::Kind::kBool; - case BasicTypeValue::kInt: - return Value::Kind::kInt; - case BasicTypeValue::kUint: - return Value::Kind::kUInt; - case BasicTypeValue::kDouble: - return Value::Kind::kDouble; - case BasicTypeValue::kString: - return Value::Kind::kString; - case BasicTypeValue::kBytes: - return Value::Kind::kBytes; - case BasicTypeValue::kList: - return Value::Kind::kList; - case BasicTypeValue::kMap: - return Value::Kind::kMap; - case BasicTypeValue::kType: - return Value::Kind::kType; - case BasicTypeValue::DO_NOT_USE: // Force compiler error if switch is not - // completed. - EXPECT_TRUE(false) << "not a basic type"; - } - return Value::Kind::kNull; - } - - static ValueTestCase ForInline(const Value& value, - const std::string& debug_string, - const std::string& type_debug_string) { - return ValueTestCase{value, value, value.kind(), - value, debug_string, type_debug_string, - true, false, true}; - } - - static ValueTestCase ForInline(const Value& value, EnumType type, - const std::string& debug_string, - const std::string& type_debug_string) { - return ValueTestCase{value, - value, - Value::Kind::kEnum, - Value::FromType(type), - debug_string, - type_debug_string, - true, - true, - true}; - } - - static ValueTestCase ForInline(const Value& value, BasicTypeValue type, - const std::string& debug_string, - const std::string& type_debug_string) { - return ValueTestCase{value, value, - GetKind(type), Value::FromType(type), - debug_string, type_debug_string, - true, true, - true}; - } - - static ValueTestCase ForNonInline(Value value, Value copy, - BasicTypeValue type, - const std::string& debug_string, - const std::string& type_debug_string, - bool owns_value = true) { - return ValueTestCase{std::move(value), - std::move(copy), - GetKind(type), - Value::FromType(type), - debug_string, - type_debug_string, - false, - true, - owns_value}; - } - - static ValueTestCase ForNonInline(Value value, Value copy, EnumType type, - const std::string& debug_string, - const std::string& type_debug_string) { - return ValueTestCase{std::move(value), - std::move(copy), - Value::Kind::kEnum, - Value::FromType(type), - debug_string, - type_debug_string, - false, - true, - true}; - } - static ValueTestCase ForNonInline(Value value, Value copy, Value::Kind kind, - absl::string_view full_name, - absl::string_view args) { - auto type = Value::FromType(value.GetType().type_value().object_type()); - return ValueTestCase{std::move(value), - std::move(copy), - kind, - type, - absl::StrCat(full_name, args), - std::string(full_name), - false, - true, - true}; - } - - static ValueTestCase ForNonValue(Value value, const Value& copy, - Value::Kind kind, - const std::string& debug_string) { - return ValueTestCase{std::move(value), copy, kind, copy, debug_string, - debug_string, false, false, true}; - } -}; - -std::ostream& operator<<(std::ostream& os, const ValueTestCase& test_case) { - return os << test_case.value; -} +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::DynamicParseTextProto; +using ::cel::internal::GetTestingDescriptorPool; +using ::cel::internal::GetTestingMessageFactory; +using ::testing::An; +using ::testing::Eq; +using ::testing::NotNull; +using ::testing::Optional; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +TEST(Value, GeneratedEnum) { + EXPECT_EQ(Value::Enum(google::protobuf::NULL_VALUE), NullValue()); + EXPECT_EQ(Value::Enum(google::protobuf::SYNTAX_EDITIONS), IntValue(2)); +} + +TEST(Value, DynamicEnum) { + EXPECT_THAT( + Value::Enum(google::protobuf::GetEnumDescriptor(), 0), + test::IsNullValue()); + EXPECT_THAT( + Value::Enum(google::protobuf::GetEnumDescriptor() + ->FindValueByNumber(0)), + test::IsNullValue()); + EXPECT_THAT( + Value::Enum(google::protobuf::GetEnumDescriptor(), 2), + test::IntValueIs(2)); + EXPECT_THAT(Value::Enum(google::protobuf::GetEnumDescriptor() + ->FindValueByNumber(2)), + test::IntValueIs(2)); +} + +TEST(Value, DynamicClosedEnum) { + google::protobuf::FileDescriptorProto file_descriptor; + file_descriptor.set_name("test/closed_enum.proto"); + file_descriptor.set_package("test"); + file_descriptor.set_syntax("editions"); + file_descriptor.set_edition(google::protobuf::EDITION_2023); + { + auto* enum_descriptor = file_descriptor.add_enum_type(); + enum_descriptor->set_name("ClosedEnum"); + enum_descriptor->mutable_options()->mutable_features()->set_enum_type( + google::protobuf::FeatureSet::CLOSED); + auto* enum_value_descriptor = enum_descriptor->add_value(); + enum_value_descriptor->set_number(1); + enum_value_descriptor->set_name("FOO"); + enum_value_descriptor = enum_descriptor->add_value(); + enum_value_descriptor->set_number(2); + enum_value_descriptor->set_name("BAR"); + } + google::protobuf::DescriptorPool pool; + ASSERT_THAT(pool.BuildFile(file_descriptor), NotNull()); + const auto* enum_descriptor = pool.FindEnumTypeByName("test.ClosedEnum"); + ASSERT_THAT(enum_descriptor, NotNull()); + EXPECT_THAT(Value::Enum(enum_descriptor, 0), + test::ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))); +} + +TEST(Value, Is) { + google::protobuf::Arena arena; + + EXPECT_TRUE(Value(BoolValue()).Is()); + EXPECT_TRUE(Value(BoolValue(true)).IsTrue()); + EXPECT_TRUE(Value(BoolValue(false)).IsFalse()); + + EXPECT_TRUE(Value(BytesValue()).Is()); + + EXPECT_TRUE(Value(DoubleValue()).Is()); + + EXPECT_TRUE(Value(DurationValue()).Is()); + + EXPECT_TRUE(Value(ErrorValue()).Is()); + + EXPECT_TRUE(Value(IntValue()).Is()); + + EXPECT_TRUE(Value(ListValue()).Is()); + EXPECT_TRUE(Value(CustomListValue()).Is()); + EXPECT_TRUE(Value(CustomListValue()).Is()); + EXPECT_TRUE(Value(ParsedJsonListValue()).Is()); + EXPECT_TRUE(Value(ParsedJsonListValue()).Is()); + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("repeated_int32")); + EXPECT_TRUE(Value(ParsedRepeatedFieldValue(message, field, &arena)) + .Is()); + EXPECT_TRUE(Value(ParsedRepeatedFieldValue(message, field, &arena)) + .Is()); + } -class ValueTest : public ::testing::TestWithParam { - public: - static absl::optional CreateRef(const Value& value) { - switch (value.kind()) { - case Value::Kind::kNull: - case Value::Kind::kBool: - case Value::Kind::kInt: - case Value::Kind::kUInt: - case Value::Kind::kDouble: - // Inline values cannot be referenced. - break; - - case Value::Kind::kEnum: - case Value::Kind::kDuration: - case Value::Kind::kTime: - case Value::Kind::kError: - case Value::Kind::kUnknown: - case Value::Kind::kType: - // These value types are always copied. - break; - - case Value::Kind::kString: - return Value::ForString(value.string_value()); - case Value::Kind::kBytes: - return Value::ForBytes(value.bytes_value()); - case Value::Kind::kMap: - return Value::ForMap(&value.map_value()); - case Value::Kind::kList: - return Value::ForList(&value.list_value()); - case Value::Kind::kObject: - return Value::ForObject(&value.object_value()); - case Value::Kind::DO_NOT_USE: // Force a compiler error if this enum - // is not complete. - assert(false); - break; - } - return absl::nullopt; - } - - template - void TestKind(); - - template - void TestCustomKind(); -}; - -TEST_P(ValueTest, Kind) { EXPECT_EQ(GetParam().value.kind(), GetParam().kind); } - -TEST_P(ValueTest, GetType) { - EXPECT_EQ(GetParam().value.GetType(), GetParam().type); -} + EXPECT_TRUE(Value(MapValue()).Is()); + EXPECT_TRUE(Value(CustomMapValue()).Is()); + EXPECT_TRUE(Value(CustomMapValue()).Is()); + EXPECT_TRUE(Value(ParsedJsonMapValue()).Is()); + EXPECT_TRUE(Value(ParsedJsonMapValue()).Is()); + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("map_int32_int32")); + EXPECT_TRUE( + Value(ParsedMapFieldValue(message, field, &arena)).Is()); + EXPECT_TRUE(Value(ParsedMapFieldValue(message, field, &arena)) + .Is()); + } + + EXPECT_TRUE(Value(NullValue()).Is()); + + EXPECT_TRUE(Value(OptionalValue()).Is()); + EXPECT_TRUE(Value(OptionalValue()).Is()); + + EXPECT_TRUE(Value(ParsedMessageValue()).Is()); + EXPECT_TRUE(Value(ParsedMessageValue()).Is()); + EXPECT_TRUE(Value(ParsedMessageValue()).Is()); + + EXPECT_TRUE(Value(StringValue()).Is()); + + EXPECT_TRUE(Value(TimestampValue()).Is()); -TEST_P(ValueTest, TypeString) { - EXPECT_EQ(GetParam().value.GetType().ToString(), - GetParam().type_debug_string); + EXPECT_TRUE(Value(TypeValue(StringType())).Is()); + + EXPECT_TRUE(Value(UintValue()).Is()); + + EXPECT_TRUE(Value(UnknownValue()).Is()); } -TEST_P(ValueTest, ToString) { - std::ostringstream os; - os << GetParam().value; - EXPECT_EQ(os.str(), GetParam().value.ToString()); - EXPECT_EQ(GetParam().value.ToString(), GetParam().debug_string); +template +constexpr T& AsLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return t; } -TEST_P(ValueTest, Inlined) { - EXPECT_EQ(GetParam().value.is_inline(), GetParam().is_inline); +template +constexpr const T& AsConstLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return t; } -TEST_P(ValueTest, IsValue) { - EXPECT_EQ(GetParam().value.is_value(), GetParam().is_value); +template +constexpr T&& AsRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return static_cast(t); } -TEST_P(ValueTest, IsObjet) { - bool is_object = GetParam().type.kind() == Value::Kind::kType && - GetParam().type.type_value().is_object(); - EXPECT_EQ(is_object, GetParam().value.is_object()); +template +constexpr const T&& AsConstRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return static_cast(t); } -TEST_P(ValueTest, OwnsValue) { - EXPECT_EQ(GetParam().value.owns_value(), GetParam().owns_value); - auto ref = CreateRef(GetParam().value); - if (ref.has_value()) { - EXPECT_FALSE(ref.value().owns_value()); +TEST(Value, As) { + google::protobuf::Arena arena; + + EXPECT_THAT(Value(BoolValue()).As(), Optional(An())); + EXPECT_THAT(Value(BoolValue()).As(), Eq(absl::nullopt)); + + { + Value value(BytesValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); } -} -TEST_P(ValueTest, For) { - EXPECT_EQ(GetParam().value.owns_value(), GetParam().owns_value); - auto ref = CreateRef(GetParam().value); - if (ref.has_value()) { - EXPECT_FALSE(ref.value().owns_value()); + EXPECT_THAT(Value(DoubleValue()).As(), + Optional(An())); + EXPECT_THAT(Value(DoubleValue()).As(), Eq(absl::nullopt)); + + EXPECT_THAT(Value(DurationValue()).As(), + Optional(An())); + EXPECT_THAT(Value(DurationValue()).As(), Eq(absl::nullopt)); + + { + Value value(ErrorValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(ErrorValue()).As(), Eq(absl::nullopt)); } -} -template -void ValueTest::TestKind() { - using Kind = Value::Kind; - Value value = GetParam().value; - Kind kind = GetParam().kind; - EXPECT_EQ(static_cast(value.get_if()), kind == K); - EXPECT_EQ(static_cast(value.get_if()), kind == K); - if (kind == K) { - EXPECT_EQ(value, Value::From(value.get())); - EXPECT_EQ(value, Value::From(value.get())); + EXPECT_THAT(Value(IntValue()).As(), Optional(An())); + EXPECT_THAT(Value(IntValue()).As(), Eq(absl::nullopt)); + + { + Value value(ListValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(ListValue()).As(), Eq(absl::nullopt)); } -} -TEST_P(ValueTest, Null) { TestKind(); } + { + Value value(ParsedJsonListValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(ListValue()).As(), Eq(absl::nullopt)); + } -TEST_P(ValueTest, Bool) { TestKind(); } + { + Value value(ParsedJsonListValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } -TEST_P(ValueTest, Int) { TestKind(); } + { + Value value(CustomListValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(ListValue()).As(), Eq(absl::nullopt)); + } -TEST_P(ValueTest, UInt) { TestKind(); } + { + Value value(CustomListValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } -TEST_P(ValueTest, Double) { TestKind(); } + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("repeated_int32")); + Value value(ParsedRepeatedFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } -TEST_P(ValueTest, Time) { TestKind(); } + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("repeated_int32")); + Value value(ParsedRepeatedFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT( + AsConstRValueRef(other_value).As(), + Optional(An())); + } -TEST_P(ValueTest, Type) { TestKind(); } + { + Value value(MapValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(MapValue()).As(), Eq(absl::nullopt)); + } -TEST_P(ValueTest, Error) { TestKind(); } + { + Value value(ParsedJsonMapValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(MapValue()).As(), Eq(absl::nullopt)); + } -TEST_P(ValueTest, Unknown) { TestKind(); } + { + Value value(ParsedJsonMapValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } -TEST_P(ValueTest, String) { - using Kind = Value::Kind; - Value value = GetParam().value; - Kind kind = GetParam().kind; - EXPECT_EQ(value.get_if(), - kind == Kind::kString ? absl::make_optional(value.string_value()) - : absl::nullopt); - if (kind == Kind::kString) { - EXPECT_EQ(value, Value::From(value.get())); - EXPECT_EQ(value, Value::For(value.get())); + { + Value value(CustomMapValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(MapValue()).As(), Eq(absl::nullopt)); } -} -TEST_P(ValueTest, Bytes) { - using Kind = Value::Kind; - Value value = GetParam().value; - Kind kind = GetParam().kind; - EXPECT_EQ(value.get_if(), - kind == Kind::kBytes ? absl::make_optional(value.bytes_value()) - : absl::nullopt); - if (kind == Kind::kBytes) { - EXPECT_EQ(value, Value::From(value.get())); - EXPECT_EQ(value, Value::For(value.get())); + { + Value value(CustomMapValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); } -} -TEST_P(ValueTest, TypeRoundTrip) { - if (GetParam().type.kind() != Value::Kind::kType) { - return; - } - Type expected = GetParam().type.type_value(); - Type actual(expected.full_name()); - EXPECT_EQ(expected, actual) << expected.full_name(); - EXPECT_EQ(expected.full_name(), actual.full_name()); - EXPECT_EQ(expected.is_object(), actual.is_object()); - EXPECT_EQ(expected.is_unrecognized(), actual.is_unrecognized()); - EXPECT_EQ(expected.is_enum(), actual.is_enum()); - EXPECT_EQ(expected.is_basic(), actual.is_basic()); -} + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("map_int32_int32")); + Value value(ParsedMapFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } -template -void ValueTest::TestCustomKind() { - using Kind = Value::Kind; - Value value = GetParam().value; - Kind kind = GetParam().kind; - EXPECT_EQ(value.get_if() != nullptr, kind == K); - EXPECT_EQ(value.get_if() != nullptr, kind == K); - EXPECT_EQ(value.get_if() != nullptr, kind == K); - if (kind == K) { - EXPECT_EQ(value.get_if(), &value.get()); - EXPECT_EQ(value.get_if(), &value.get()); - EXPECT_EQ(value.get_if(), &value.get()); + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("map_int32_int32")); + Value value(ParsedMapFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); } -} -TEST_P(ValueTest, Object) { - TestCustomKind(); -} + { + Value value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(ParsedMessageValue{ + DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}) + .As(), + Eq(absl::nullopt)); + } -TEST_P(ValueTest, Map) { TestCustomKind(); } + EXPECT_THAT(Value(NullValue()).As(), Optional(An())); + EXPECT_THAT(Value(NullValue()).As(), Eq(absl::nullopt)); + + { + Value value(OptionalValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(OpaqueValue(OptionalValue())).As(), + Eq(absl::nullopt)); + } -TEST_P(ValueTest, List) { - TestCustomKind(); -} + { + Value value(OptionalValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(OptionalValue()).As(), Eq(absl::nullopt)); + } -TEST_P(ValueTest, Equal) { - EXPECT_EQ(GetParam().value, GetParam().value); - EXPECT_FALSE(GetParam().value != GetParam().value); - EXPECT_EQ(GetParam().value, GetParam().value_copy); - EXPECT_FALSE(GetParam().value != GetParam().value_copy); - auto ref = CreateRef(GetParam().value); - if (ref) { - EXPECT_EQ(*ref, GetParam().value); - EXPECT_FALSE(*ref != GetParam().value); + { + OpaqueValue value(OptionalValue{}); + OpaqueValue other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } + + { + Value value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } + + { + Value value(StringValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(StringValue()).As(), Eq(absl::nullopt)); + } + + { + Value value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } + + EXPECT_THAT(Value(TimestampValue()).As(), + Optional(An())); + EXPECT_THAT(Value(TimestampValue()).As(), Eq(absl::nullopt)); + + { + Value value(TypeValue(StringType{})); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(TypeValue(StringType())).As(), + Eq(absl::nullopt)); } -} -TEST_P(ValueTest, HashCode) { - EXPECT_EQ(GetParam().value.hash_code(), std::hash()(GetParam().value)); - EXPECT_EQ(GetParam().value.hash_code(), GetParam().value_copy.hash_code()); - auto ref = CreateRef(GetParam().value); - if (ref) { - EXPECT_EQ(ref->hash_code(), GetParam().value.hash_code()); + EXPECT_THAT(Value(UintValue()).As(), Optional(An())); + EXPECT_THAT(Value(UintValue()).As(), Eq(absl::nullopt)); + + { + Value value(UnknownValue{}); + Value other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + EXPECT_THAT(Value(UnknownValue()).As(), Eq(absl::nullopt)); } } -INSTANTIATE_TEST_SUITE_P( - Inlined, ValueTest, - ::testing::Values( - ValueTestCase::ForInline(Value::NullValue(), BasicTypeValue::kNull, - "null", "null_type"), - ValueTestCase::ForInline(Value::FromBool(true), BasicTypeValue::kBool, - "true", "bool"), - ValueTestCase::ForInline(Value::FromInt(1), BasicTypeValue::kInt, "1", - "int"), - ValueTestCase::ForInline(Value::FromUInt(1), BasicTypeValue::kUint, - "1u", "uint"), - ValueTestCase::ForInline(Value::FromDouble(1), BasicTypeValue::kDouble, - "1.0", "double"), - ValueTestCase::ForInline(Value::FromType(BasicTypeValue::kInt), - BasicTypeValue::kType, "int", "type"), - ValueTestCase::ForInline( - Value::FromType(ObjectType::For()), - BasicTypeValue::kType, "google.protobuf.Timestamp", "type"), - ValueTestCase::ForInline( - Value::FromType(EnumType(google::protobuf::NullValue_descriptor())), - BasicTypeValue::kType, "google.protobuf.NullValue", "type"), - ValueTestCase::ForInline( - Value::FromEnum(NamedEnumValue( - google::protobuf::NullValue_descriptor()->value(0))), - EnumType(google::protobuf::NullValue_descriptor()), - "google.protobuf.NULL_VALUE", "google.protobuf.NullValue"), - ValueTestCase::ForInline(Value::FromUnknown(Id(1)), "Unknown{Id(1)}", - "Unknown{Id(1)}"))); - -INSTANTIATE_TEST_SUITE_P( - NonInlined, ValueTest, - ::testing::Values( - ValueTestCase::ForNonInline(Value::ForString("hi"), - Value::ForString("hi"), - BasicTypeValue::kString, "\"hi\"", "string", - false), - ValueTestCase::ForNonInline( - Value::FromBytes(absl::string_view("h\000i", 3)), - Value::FromBytes(absl::string_view("h\000i", 3)), - BasicTypeValue::kBytes, "b\"h\\000i\"", "bytes"), - ValueTestCase::ForNonInline(Value::MakeList(), - Value::MakeList(), - BasicTypeValue::kList, "[\"elem\"]", - "list"), - ValueTestCase::ForNonInline(Value::MakeMap(), - Value::MakeMap(), - BasicTypeValue::kMap, - "{\"key\": \"value\"}", "map"), - ValueTestCase::ForNonInline(Value::MakeObject(), - Value::MakeObject(), - Value::Kind::kObject, "google.type.Money", - "{}"), - ValueTestCase::ForNonInline( - Value::FromDuration(absl::Hours(1) + absl::Nanoseconds(1)), - Value::FromDuration(absl::Hours(1) + absl::Nanoseconds(1)), - Value::Kind::kDuration, "google.protobuf.Duration", - "(\"1h0.000000001s\")"), - ValueTestCase::ForNonInline(Value::FromTime(absl::FromUnixNanos(1)), - Value::FromTime(absl::FromUnixNanos(1)), - Value::Kind::kTime, - "google.protobuf.Timestamp", - "(\"1970-01-01T00:00:00.000000001Z\")"), - ValueTestCase::ForNonInline( - Value::FromEnum(EnumValue( - EnumType(google::protobuf::NullValue_descriptor()), -1)), - Value::FromEnum(EnumValue( - EnumType(google::protobuf::NullValue_descriptor()), -1)), - EnumType(google::protobuf::NullValue_descriptor()), - "google.protobuf.NullValue(-1)", "google.protobuf.NullValue"), - ValueTestCase::ForNonInline(Value::FromType("bad type"), - Value::FromType("bad type"), - BasicTypeValue::kType, "type(\"bad type\")", - "type"))); - -INSTANTIATE_TEST_SUITE_P( - NonValue, ValueTest, - ::testing::Values( - ValueTestCase::ForNonValue( - Value::FromError(internal::NotFoundError("hi")), - Value::FromError(internal::NotFoundError("hi")), - Value::Kind::kError, "Error{NOT_FOUND}"), - ValueTestCase::ForNonValue( - Value::FromError(Error({internal::NotFoundError("hi"), - internal::OutOfRangeError("bye")})), - Value::FromError(Error({internal::NotFoundError("hi"), - internal::OutOfRangeError("bye")})), - Value::Kind::kError, "Error{NOT_FOUND, OUT_OF_RANGE}"), - ValueTestCase::ForNonValue(Value::FromUnknown(Unknown({Id(1), Id(2)})), - Value::FromUnknown(Unknown({Id(1), Id(2)})), - Value::Kind::kUnknown, - "Unknown{Id(1), Id(2)}"))); - -TEST(ValueTest, NotEqual) { - EXPECT_NE(Value::ForString("hi"), Value::ForBytes("hi")); - EXPECT_NE(Value::ForString("hi").hash_code(), - Value::ForBytes("hi").hash_code()); - EXPECT_NE(Value::FromInt(1), Value::FromUInt(1)); - EXPECT_NE(Value::FromInt(1).hash_code(), Value::FromUInt(1).hash_code()); +template +decltype(auto) DoGet(From&& from) { + return std::forward(from).template Get(); } -template -void TestGetIfNum(Value value, bool matches) { - ExpectSameType, decltype(value.get_if())>(); - ExpectSameType())>(); - if (matches) { - EXPECT_EQ(value.get(), *value.get_if()); - EXPECT_EQ(value.get(), value.get()); - } else { - EXPECT_EQ(absl::nullopt, value.get_if()); +TEST(Value, Get) { + google::protobuf::Arena arena; + + EXPECT_THAT(DoGet(Value(BoolValue())), An()); + + { + Value value(BytesValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); } -} -TEST(ValueTest, GetIfInt) { - Value max_int = Value::FromInt(std::numeric_limits::max()); - EXPECT_EQ(std::numeric_limits::max(), max_int.get()); - EXPECT_EQ(std::numeric_limits::max(), *max_int.get_if()); + EXPECT_THAT(DoGet(Value(DoubleValue())), An()); - TestGetIfNum(max_int, false); - TestGetIfNum(max_int, false); - TestGetIfNum(max_int, false); + EXPECT_THAT(DoGet(Value(DurationValue())), + An()); - Value one = Value::FromInt(1); - EXPECT_EQ(1, one.get()); - EXPECT_EQ(1, *one.get_if()); - EXPECT_EQ(nullptr, one.get_if()); + { + Value value(ErrorValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } - TestGetIfNum(one, true); - TestGetIfNum(one, true); - TestGetIfNum(one, true); -} + EXPECT_THAT(DoGet(Value(IntValue())), An()); + + { + Value value(ListValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(ValueTest, GetIfUInt) { - Value max_uint = Value::FromUInt(std::numeric_limits::max()); - EXPECT_EQ(std::numeric_limits::max(), max_uint.get()); - EXPECT_EQ(std::numeric_limits::max(), *max_uint.get_if()); + { + Value value(ParsedJsonListValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } - TestGetIfNum(max_uint, false); - TestGetIfNum(max_uint, false); - TestGetIfNum(max_uint, false); + { + Value value(ParsedJsonListValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT( + DoGet(AsConstRValueRef(other_value)), + An()); + } - Value one = Value::FromUInt(1); - EXPECT_EQ(1, one.get()); - EXPECT_EQ(1, *one.get_if()); - EXPECT_EQ(nullptr, one.get_if()); + { + Value value(CustomListValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } - TestGetIfNum(one, true); - TestGetIfNum(one, true); - TestGetIfNum(one, true); -} + { + Value value(CustomListValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(ValueTest, GetIfDouble) { - Value max_double = Value::FromDouble(std::numeric_limits::max()); - EXPECT_EQ(std::numeric_limits::max(), max_double.get()); - EXPECT_EQ(std::numeric_limits::max(), *max_double.get_if()); + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("repeated_int32")); + Value value(ParsedRepeatedFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } - TestGetIfNum(max_double, true); - TestGetIfNum(max_double, false); + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("repeated_int32")); + Value value(ParsedRepeatedFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT( + DoGet(AsConstRValueRef(other_value)), + An()); + } - Value one = Value::FromDouble(1); - EXPECT_EQ(1, one.get()); - EXPECT_EQ(1, *one.get_if()); + { + Value value(MapValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } - TestGetIfNum(one, true); - TestGetIfNum(one, true); + { + Value value(ParsedJsonMapValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } - Value inf = Value::FromDouble(std::numeric_limits::infinity()); - EXPECT_TRUE(std::isinf(inf.get())); - EXPECT_TRUE(std::isinf(*inf.get_if())); - EXPECT_TRUE(std::isinf(inf.get())); - EXPECT_TRUE(std::isinf(*inf.get_if())); - EXPECT_TRUE(std::isinf(inf.get())); - EXPECT_TRUE(std::isinf(*inf.get_if())); -} + { + Value value(ParsedJsonMapValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(ErrorTest, FromStatus) { - Error value(internal::OutOfRangeError("hi")); - EXPECT_EQ(value.errors().size(), 1); - EXPECT_THAT(value.errors(), ::testing::Contains(testutil::EqualsProto( - internal::OutOfRangeError("hi")))); -} + { + Value value(CustomMapValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(ErrorTest, FromInitList) { - Error value( - {internal::OutOfRangeError("hi"), internal::OutOfRangeError("bye")}); - EXPECT_EQ(value.errors().size(), 2); - EXPECT_THAT(value.errors(), ::testing::Contains(testutil::EqualsProto( - internal::OutOfRangeError("hi")))); - EXPECT_THAT(value.errors(), ::testing::Contains(testutil::EqualsProto( - internal::OutOfRangeError("bye")))); -} + { + Value value(CustomMapValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(ErrorTest, FromRepeated) { - Error value( - {internal::OutOfRangeError("hi"), internal::OutOfRangeError("bye")}); - EXPECT_EQ(value.errors().size(), 2); - EXPECT_THAT(value.errors(), ::testing::Contains(testutil::EqualsProto( - internal::OutOfRangeError("hi")))); - EXPECT_THAT(value.errors(), ::testing::Contains(testutil::EqualsProto( - internal::OutOfRangeError("bye")))); -} + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("map_int32_int32")); + Value value(ParsedMapFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(ErrorTest, OrderAgnostic) { - Error v1({internal::OutOfRangeError("hi"), internal::OutOfRangeError("bye")}); - Error v2({internal::OutOfRangeError("bye"), internal::OutOfRangeError("hi")}); - Error v3({internal::OutOfRangeError("hi"), internal::OutOfRangeError("hi")}); - EXPECT_EQ(v1, v2); - EXPECT_EQ(v1.hash_code(), v2.hash_code()); - EXPECT_NE(v1, v3); - // Technically could be equal (but likely shouldn't be); - EXPECT_NE(v1.hash_code(), v3.hash_code()); -} + { + auto message = DynamicParseTextProto( + &arena, R"pb()pb", GetTestingDescriptorPool(), + GetTestingMessageFactory()); + const auto* field = ABSL_DIE_IF_NULL( + message->GetDescriptor()->FindFieldByName("map_int32_int32")); + Value value(ParsedMapFieldValue{message, field, &arena}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT( + DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(UnknownTest, FromId) { - Unknown value(Id(1)); - EXPECT_EQ(value.ids(), std::set({Id(1)})); -} + { + Value value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(UnknownTest, FromInitList) { - Unknown value({Id(1), Id(2)}); - EXPECT_EQ(value.ids(), std::set({Id(1), Id(2)})); -} + EXPECT_THAT(DoGet(Value(NullValue())), An()); + + { + Value value(OptionalValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(UnknownTest, OrderAgnostic) { - Unknown v1({Id(1), Id(2), Id(3)}); - Unknown v2({Id(3), Id(1), Id(2)}); - Unknown v3({Id(3), Id(1)}); - EXPECT_EQ(v1, v2); - EXPECT_EQ(v1.hash_code(), v2.hash_code()); - EXPECT_NE(v1, v3); - // Technically could be equal (but likely shouldn't be); - EXPECT_NE(v1.hash_code(), v3.hash_code()); -} + { + Value value(OptionalValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(TypeTest, BasicType) { - Type value((BasicType(BasicTypeValue::kString))); - EXPECT_TRUE(value.is_basic()); - EXPECT_FALSE(value.is_object()); - EXPECT_FALSE(value.is_enum()); - EXPECT_EQ(value.basic_type().value(), BasicTypeValue::kString); - EXPECT_EQ("string", value.full_name()); -} + { + OpaqueValue value(OptionalValue{}); + OpaqueValue other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT( + DoGet(AsConstRValueRef(other_value)), + An()); + } -TEST(TypeTest, ObjectType) { - ObjectType type(google::type::Money::descriptor()); - Type value(type); - EXPECT_FALSE(value.is_basic()); - EXPECT_TRUE(value.is_object()); - EXPECT_FALSE(value.is_enum()); - EXPECT_EQ(value.object_type(), type); - EXPECT_EQ(value.full_name(), "google.type.Money"); - EXPECT_EQ(value.object_type().value()->full_name(), "google.type.Money"); + { + Value value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } + + { + Value value(StringValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } + + { + Value value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } + + EXPECT_THAT(DoGet(Value(TimestampValue())), + An()); + + { + Value value(TypeValue(StringType{})); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } + + EXPECT_THAT(DoGet(Value(UintValue())), An()); + + { + Value value(UnknownValue{}); + Value other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } } -TEST(TypeTest, EnumType) { - EnumType type(google::protobuf::NullValue_descriptor()); - Type value(type); - EXPECT_FALSE(value.is_basic()); - EXPECT_FALSE(value.is_object()); - EXPECT_TRUE(value.is_enum()); - EXPECT_EQ(value.enum_type(), type); - EXPECT_EQ(value.full_name(), "google.protobuf.NullValue"); - EXPECT_EQ(value.enum_type().value()->full_name(), - "google.protobuf.NullValue"); +TEST(Value, NumericHeterogeneousEquality) { + EXPECT_EQ(IntValue(1), UintValue(1)); + EXPECT_EQ(UintValue(1), IntValue(1)); + EXPECT_EQ(IntValue(1), DoubleValue(1)); + EXPECT_EQ(DoubleValue(1), IntValue(1)); + EXPECT_EQ(UintValue(1), DoubleValue(1)); + EXPECT_EQ(DoubleValue(1), UintValue(1)); + + EXPECT_NE(IntValue(1), UintValue(2)); + EXPECT_NE(UintValue(1), IntValue(2)); + EXPECT_NE(IntValue(1), DoubleValue(2)); + EXPECT_NE(DoubleValue(1), IntValue(2)); + EXPECT_NE(UintValue(1), DoubleValue(2)); + EXPECT_NE(DoubleValue(1), UintValue(2)); } -} // namespace -} // namespace common +using ValueIteratorTest = common_internal::ValueTest<>; -namespace internal { +TEST_F(ValueIteratorTest, Empty) { + auto iterator = NewEmptyValueIterator(); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} -using testutil::ExpectSameType; +TEST_F(ValueIteratorTest, Empty1) { + auto iterator = NewEmptyValueIterator(); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} -class ValueVisitTest : public ::testing::Test { - public: - template - using ValueAdapterType = - decltype(MaybeAdapt(BaseValue::ValueAdapter(), inst_of())); - - template - using GetPtrVisitorType = - decltype(BaseValue::GetPtrVisitor, T>()(inst_of())); - - template - using GetVisitorType = - decltype(BaseValue::GetVisitor, R, T>()(inst_of())); - - using OwnedStr = BaseValue::OwnedStr; - using UnownedStr = BaseValue::UnownedStr; - using ParentOwnedStr = BaseValue::ParentOwnedStr; - - template - void TestGetPtrVisitor() { - using T = remove_reference_t())>; - ExpectSameType>(); - // Return by const ref. - ExpectSameType>(); - } - - template - void TestGetVisitor() { - using T = remove_reference_t())>; - // Return by optional. - ExpectSameType, - GetVisitorType, T>>(); - // Return by value. - ExpectSameType>(); - } - - template - void TestValueAdapter() { - using T = remove_reference_t())>; - ExpectSameType>(); - ExpectSameType>(); - } -}; - -TEST_F(ValueVisitTest, Types) { - TestValueAdapter>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - TestGetVisitor, common::EnumValue>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - TestGetVisitor, common::Type>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - TestGetVisitor, common::Type>(); - - TestValueAdapter>(); - TestGetPtrVisitor>(); - TestGetVisitor, common::Type>(); - - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - - TestValueAdapter>(); - TestValueAdapter>(); - TestGetPtrVisitor>(); - TestGetPtrVisitor>(); - - TestValueAdapter>(); - TestValueAdapter>(); - TestGetPtrVisitor>(); - TestGetPtrVisitor>(); - - TestValueAdapter>(); - TestValueAdapter>(); - TestGetPtrVisitor>(); - TestGetPtrVisitor>(); - - TestValueAdapter>(); - TestValueAdapter>(); - TestValueAdapter>(); - TestValueAdapter>(); - TestValueAdapter>(); - TestValueAdapter>(); +TEST_F(ValueIteratorTest, Empty2) { + auto iterator = NewEmptyValueIterator(); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); } -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google +} // namespace +} // namespace cel diff --git a/common/value_testing.cc b/common/value_testing.cc new file mode 100644 index 000000000..52240905b --- /dev/null +++ b/common/value_testing.cc @@ -0,0 +1,246 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/value_testing.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "internal/testing.h" + +namespace cel { + +void PrintTo(const Value& value, std::ostream* os) { *os << value << "\n"; } + +namespace test { +namespace { + +using ::testing::Matcher; + +template +constexpr ValueKind ToValueKind() { + if constexpr (std::is_same_v) { + return ValueKind::kBool; + } else if constexpr (std::is_same_v) { + return ValueKind::kInt; + } else if constexpr (std::is_same_v) { + return ValueKind::kUint; + } else if constexpr (std::is_same_v) { + return ValueKind::kDouble; + } else if constexpr (std::is_same_v) { + return ValueKind::kString; + } else if constexpr (std::is_same_v) { + return ValueKind::kBytes; + } else if constexpr (std::is_same_v) { + return ValueKind::kDuration; + } else if constexpr (std::is_same_v) { + return ValueKind::kTimestamp; + } else if constexpr (std::is_same_v) { + return ValueKind::kError; + } else if constexpr (std::is_same_v) { + return ValueKind::kMap; + } else if constexpr (std::is_same_v) { + return ValueKind::kList; + } else if constexpr (std::is_same_v) { + return ValueKind::kStruct; + } else if constexpr (std::is_same_v) { + return ValueKind::kOpaque; + } else { + // Otherwise, unspecified (uninitialized value) + return ValueKind::kError; + } +} + +template +class SimpleTypeMatcherImpl : public testing::MatcherInterface { + public: + using MatcherType = Matcher; + + explicit SimpleTypeMatcherImpl(MatcherType&& matcher) + : matcher_(std::forward(matcher)) {} + + bool MatchAndExplain(const Value& v, + testing::MatchResultListener* listener) const override { + return v.Is() && + matcher_.MatchAndExplain(v.Get().NativeValue(), listener); + } + + void DescribeTo(std::ostream* os) const override { + *os << absl::StrCat("kind is ", ValueKindToString(ToValueKind()), + " and "); + matcher_.DescribeTo(os); + } + + private: + MatcherType matcher_; +}; + +template +class StringTypeMatcherImpl : public testing::MatcherInterface { + public: + using MatcherType = Matcher; + + explicit StringTypeMatcherImpl(MatcherType matcher) + : matcher_((std::move(matcher))) {} + + bool MatchAndExplain(const Value& v, + testing::MatchResultListener* listener) const override { + return v.Is() && matcher_.Matches(v.Get().ToString()); + } + + void DescribeTo(std::ostream* os) const override { + *os << absl::StrCat("kind is ", ValueKindToString(ToValueKind()), + " and "); + matcher_.DescribeTo(os); + } + + private: + MatcherType matcher_; +}; + +template +class AbstractTypeMatcherImpl : public testing::MatcherInterface { + public: + using MatcherType = Matcher; + + explicit AbstractTypeMatcherImpl(MatcherType&& matcher) + : matcher_(std::forward(matcher)) {} + + bool MatchAndExplain(const Value& v, + testing::MatchResultListener* listener) const override { + return v.Is() && matcher_.Matches(v.template Get()); + } + + void DescribeTo(std::ostream* os) const override { + *os << absl::StrCat("kind is ", ValueKindToString(ToValueKind()), + " and "); + matcher_.DescribeTo(os); + } + + private: + MatcherType matcher_; +}; + +class OptionalValueMatcherImpl + : public testing::MatcherInterface { + public: + explicit OptionalValueMatcherImpl(ValueMatcher matcher) + : matcher_(std::move(matcher)) {} + + bool MatchAndExplain(const Value& v, + testing::MatchResultListener* listener) const override { + if (!v.IsOptional()) { + *listener << "wanted OptionalValue, got " << ValueKindToString(v.kind()); + return false; + } + const auto& optional_value = v.GetOptional(); + if (!optional_value.HasValue()) { + *listener << "OptionalValue is not engaged"; + return false; + } + return matcher_.MatchAndExplain(optional_value.Value(), listener); + } + + void DescribeTo(std::ostream* os) const override { + *os << "is OptionalValue that is engaged with value whose "; + matcher_.DescribeTo(os); + } + + private: + ValueMatcher matcher_; +}; + +MATCHER(OptionalValueIsEmptyImpl, "is empty OptionalValue") { + const Value& v = arg; + if (!v.IsOptional()) { + *result_listener << "wanted OptionalValue, got " + << ValueKindToString(v.kind()); + return false; + } + const auto& optional_value = v.GetOptional(); + *result_listener << (optional_value.HasValue() ? "is not empty" : "is empty"); + return !optional_value.HasValue(); +} + +} // namespace + +ValueMatcher BoolValueIs(Matcher m) { + return ValueMatcher(new SimpleTypeMatcherImpl(std::move(m))); +} + +ValueMatcher IntValueIs(Matcher m) { + return ValueMatcher( + new SimpleTypeMatcherImpl(std::move(m))); +} + +ValueMatcher UintValueIs(Matcher m) { + return ValueMatcher( + new SimpleTypeMatcherImpl(std::move(m))); +} + +ValueMatcher DoubleValueIs(Matcher m) { + return ValueMatcher( + new SimpleTypeMatcherImpl(std::move(m))); +} + +ValueMatcher TimestampValueIs(Matcher m) { + return ValueMatcher( + new SimpleTypeMatcherImpl(std::move(m))); +} + +ValueMatcher DurationValueIs(Matcher m) { + return ValueMatcher( + new SimpleTypeMatcherImpl(std::move(m))); +} + +ValueMatcher ErrorValueIs(Matcher m) { + return ValueMatcher( + new SimpleTypeMatcherImpl(std::move(m))); +} + +ValueMatcher StringValueIs(Matcher m) { + return ValueMatcher(new StringTypeMatcherImpl(std::move(m))); +} + +ValueMatcher BytesValueIs(Matcher m) { + return ValueMatcher(new StringTypeMatcherImpl(std::move(m))); +} + +ValueMatcher MapValueIs(Matcher m) { + return ValueMatcher(new AbstractTypeMatcherImpl(std::move(m))); +} + +ValueMatcher ListValueIs(Matcher m) { + return ValueMatcher(new AbstractTypeMatcherImpl(std::move(m))); +} + +ValueMatcher StructValueIs(Matcher m) { + return ValueMatcher(new AbstractTypeMatcherImpl(std::move(m))); +} + +ValueMatcher OptionalValueIs(ValueMatcher m) { + return ValueMatcher(new OptionalValueMatcherImpl(std::move(m))); +} + +ValueMatcher OptionalValueIsEmpty() { return OptionalValueIsEmptyImpl(); } + +} // namespace test + +} // namespace cel diff --git a/common/value_testing.h b/common/value_testing.h new file mode 100644 index 000000000..f870712b9 --- /dev/null +++ b/common/value_testing.h @@ -0,0 +1,307 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUE_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUE_TESTING_H_ + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "internal/equals_text_proto.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +// GTest Printer +void PrintTo(const Value& value, std::ostream* os); + +namespace test { + +using ValueMatcher = testing::Matcher; + +MATCHER_P(ValueKindIs, m, "") { + return ExplainMatchResult(m, arg.kind(), result_listener); +} + +// Returns a matcher for CEL null value. +inline ValueMatcher IsNullValue() { return ValueKindIs(ValueKind::kNull); } + +// Returns a matcher for CEL bool values. +ValueMatcher BoolValueIs(testing::Matcher m); + +// Returns a matcher for CEL int values. +ValueMatcher IntValueIs(testing::Matcher m); + +// Returns a matcher for CEL uint values. +ValueMatcher UintValueIs(testing::Matcher m); + +// Returns a matcher for CEL double values. +ValueMatcher DoubleValueIs(testing::Matcher m); + +// Returns a matcher for CEL duration values. +ValueMatcher DurationValueIs(testing::Matcher m); + +// Returns a matcher for CEL timestamp values. +ValueMatcher TimestampValueIs(testing::Matcher m); + +// Returns a matcher for CEL error values. +ValueMatcher ErrorValueIs(testing::Matcher m); + +// Returns a matcher for CEL string values. +ValueMatcher StringValueIs(testing::Matcher m); + +// Returns a matcher for CEL bytes values. +ValueMatcher BytesValueIs(testing::Matcher m); + +// Returns a matcher for CEL map values. +ValueMatcher MapValueIs(testing::Matcher m); + +// Returns a matcher for CEL list values. +ValueMatcher ListValueIs(testing::Matcher m); + +// Returns a matcher for CEL struct values. +ValueMatcher StructValueIs(testing::Matcher m); + +// Returns a matcher for CEL struct values. +ValueMatcher OptionalValueIsEmpty(); + +// Returns a matcher for CEL struct values. +ValueMatcher OptionalValueIs(ValueMatcher m); + +// Returns a Matcher that tests the value of a CEL struct's field. +// ValueManager* mgr must remain valid for the lifetime of the matcher. +MATCHER_P5(StructValueFieldIs, name, m, descriptor_pool, message_factory, arena, + "") { + auto wrapped_m = ::absl_testing::IsOkAndHolds(m); + + return ExplainMatchResult(wrapped_m, + cel::StructValue(arg).GetFieldByName( + name, descriptor_pool, message_factory, arena), + result_listener); +} + +// Returns a Matcher that tests the presence of a CEL struct's field. +// ValueManager* mgr must remain valid for the lifetime of the matcher. +MATCHER_P2(StructValueFieldHas, name, m, "") { + auto wrapped_m = ::absl_testing::IsOkAndHolds(m); + + return ExplainMatchResult( + wrapped_m, cel::StructValue(arg).HasFieldByName(name), result_listener); +} + +class ListValueElementsMatcher { + public: + using is_gtest_matcher = void; + + explicit ListValueElementsMatcher( + testing::Matcher>&& m, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : m_(std::move(m)), + descriptor_pool_(ABSL_DIE_IF_NULL(descriptor_pool)), // Crash OK + message_factory_(ABSL_DIE_IF_NULL(message_factory)), // Crash OK + arena_(ABSL_DIE_IF_NULL(arena)) // Crash OK + {} + + bool MatchAndExplain(const ListValue& arg, + testing::MatchResultListener* result_listener) const { + std::vector elements; + absl::Status s = arg.ForEach( + [&](const Value& v) -> absl::StatusOr { + elements.push_back(v); + return true; + }, + descriptor_pool_, message_factory_, arena_); + if (!s.ok()) { + *result_listener << "cannot convert to list of values: " << s; + return false; + } + return m_.MatchAndExplain(elements, result_listener); + } + + void DescribeTo(std::ostream* os) const { *os << m_; } + void DescribeNegationTo(std::ostream* os) const { *os << m_; } + + private: + testing::Matcher> m_; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull message_factory_; + google::protobuf::Arena* absl_nonnull arena_; +}; + +// Returns a matcher that tests the elements of a cel::ListValue on a given +// matcher as if they were a std::vector. +// ValueManager* mgr must remain valid for the lifetime of the matcher. +inline ListValueElementsMatcher ListValueElements( + testing::Matcher>&& m, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return ListValueElementsMatcher(std::move(m), descriptor_pool, + message_factory, arena); +} + +class MapValueElementsMatcher { + public: + using is_gtest_matcher = void; + + explicit MapValueElementsMatcher( + testing::Matcher>>&& m, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : m_(std::move(m)), + descriptor_pool_(ABSL_DIE_IF_NULL(descriptor_pool)), // Crash OK + message_factory_(ABSL_DIE_IF_NULL(message_factory)), // Crash OK + arena_(ABSL_DIE_IF_NULL(arena)) // Crash OK + {} + + bool MatchAndExplain(const MapValue& arg, + testing::MatchResultListener* result_listener) const { + std::vector> elements; + absl::Status s = arg.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + elements.push_back({key, value}); + return true; + }, + descriptor_pool_, message_factory_, arena_); + if (!s.ok()) { + *result_listener << "cannot convert to list of values: " << s; + return false; + } + return m_.MatchAndExplain(elements, result_listener); + } + + void DescribeTo(std::ostream* os) const { *os << m_; } + void DescribeNegationTo(std::ostream* os) const { *os << m_; } + + private: + testing::Matcher>> m_; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull message_factory_; + google::protobuf::Arena* absl_nonnull arena_; +}; + +// Returns a matcher that tests the elements of a cel::MapValue on a given +// matcher as if they were a std::vector>. +// ValueManager* mgr must remain valid for the lifetime of the matcher. +inline MapValueElementsMatcher MapValueElements( + testing::Matcher>>&& m, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return MapValueElementsMatcher(std::move(m), descriptor_pool, message_factory, + arena); +} + +} // namespace test + +} // namespace cel + +namespace cel::common_internal { + +template +class ValueTest : public ::testing::TestWithParam> { + public: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return ::cel::internal::GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return ::cel::internal::GetTestingMessageFactory(); + } + + google::protobuf::Message* absl_nonnull NewArenaValueMessage() { + return ABSL_DIE_IF_NULL( // Crash OK + message_factory()->GetPrototype(ABSL_DIE_IF_NULL( // Crash OK + descriptor_pool()->FindMessageTypeByName( + "google.protobuf.Value")))) + ->New(arena()); + } + + template + auto GeneratedParseTextProto(absl::string_view text = "") { + return ::cel::internal::GeneratedParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + template + auto DynamicParseTextProto(absl::string_view text = "") { + return ::cel::internal::DynamicParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + template + auto EqualsTextProto(absl::string_view text) { + return ::cel::internal::EqualsTextProto(arena(), text, descriptor_pool(), + message_factory()); + } + + auto EqualsValueTextProto(absl::string_view text) { + return EqualsTextProto(text); + } + + template + const google::protobuf::FieldDescriptor* absl_nonnull DynamicGetField( + absl::string_view name) { + return ABSL_DIE_IF_NULL( // Crash OK + ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( // Crash OK + internal::MessageTypeNameFor())) + ->FindFieldByName(name)); + } + + template + ParsedMessageValue MakeParsedMessage(absl::string_view text = R"pb()pb") { + return ParsedMessageValue(DynamicParseTextProto(text), arena()); + } + + private: + google::protobuf::Arena arena_; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUE_TESTING_H_ diff --git a/common/value_testing_test.cc b/common/value_testing_test.cc new file mode 100644 index 000000000..d7a7a4c07 --- /dev/null +++ b/common/value_testing_test.cc @@ -0,0 +1,279 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/value_testing.h" + +#include + +#include "gtest/gtest-spi.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "internal/testing.h" + +namespace cel::test { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Truly; +using ::testing::UnorderedElementsAre; + +TEST(BoolValueIs, Match) { EXPECT_THAT(BoolValue(true), BoolValueIs(true)); } + +TEST(BoolValueIs, NoMatch) { + EXPECT_THAT(BoolValue(false), Not(BoolValueIs(true))); + EXPECT_THAT(IntValue(2), Not(BoolValueIs(true))); +} + +TEST(BoolValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { EXPECT_THAT(IntValue(42), BoolValueIs(true)); }(), + "kind is bool and is equal to true"); +} + +TEST(IntValueIs, Match) { EXPECT_THAT(IntValue(42), IntValueIs(42)); } + +TEST(IntValueIs, NoMatch) { + EXPECT_THAT(IntValue(-42), Not(IntValueIs(42))); + EXPECT_THAT(UintValue(2), Not(IntValueIs(42))); +} + +TEST(IntValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { EXPECT_THAT(UintValue(42), IntValueIs(42)); }(), + "kind is int and is equal to 42"); +} + +TEST(UintValueIs, Match) { EXPECT_THAT(UintValue(42), UintValueIs(42)); } + +TEST(UintValueIs, NoMatch) { + EXPECT_THAT(UintValue(41), Not(UintValueIs(42))); + EXPECT_THAT(IntValue(2), Not(UintValueIs(42))); +} + +TEST(UintValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { EXPECT_THAT(IntValue(42), UintValueIs(42)); }(), + "kind is uint and is equal to 42"); +} + +TEST(DoubleValueIs, Match) { + EXPECT_THAT(DoubleValue(1.2), DoubleValueIs(1.2)); +} + +TEST(DoubleValueIs, NoMatch) { + EXPECT_THAT(DoubleValue(41), Not(DoubleValueIs(1.2))); + EXPECT_THAT(IntValue(2), Not(DoubleValueIs(1.2))); +} + +TEST(DoubleValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { EXPECT_THAT(IntValue(42), DoubleValueIs(1.2)); }(), + "kind is double and is equal to 1.2"); +} + +TEST(DurationValueIs, Match) { + EXPECT_THAT(DurationValue(absl::Minutes(2)), + DurationValueIs(absl::Minutes(2))); +} + +TEST(DurationValueIs, NoMatch) { + EXPECT_THAT(DurationValue(absl::Minutes(5)), + Not(DurationValueIs(absl::Minutes(2)))); + EXPECT_THAT(IntValue(2), Not(DurationValueIs(absl::Minutes(2)))); +} + +TEST(DurationValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { EXPECT_THAT(IntValue(42), DurationValueIs(absl::Minutes(2))); }(), + "kind is duration and is equal to 2m"); +} + +TEST(TimestampValueIs, Match) { + EXPECT_THAT(TimestampValue(absl::UnixEpoch() + absl::Minutes(2)), + TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2))); +} + +TEST(TimestampValueIs, NoMatch) { + EXPECT_THAT(TimestampValue(absl::UnixEpoch()), + Not(TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2)))); + EXPECT_THAT(IntValue(2), + Not(TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2)))); +} + +TEST(TimestampValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { + EXPECT_THAT(IntValue(42), + TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2))); + }(), + "kind is timestamp and is equal to 19"); +} + +TEST(StringValueIs, Match) { + EXPECT_THAT(StringValue("hello!"), StringValueIs("hello!")); +} + +TEST(StringValueIs, NoMatch) { + EXPECT_THAT(StringValue("hello!"), Not(StringValueIs("goodbye!"))); + EXPECT_THAT(IntValue(2), Not(StringValueIs("goodbye!"))); +} + +TEST(StringValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { EXPECT_THAT(IntValue(42), StringValueIs("hello!")); }(), + "kind is string and is equal to \"hello!\""); +} + +TEST(BytesValueIs, Match) { + EXPECT_THAT(BytesValue("hello!"), BytesValueIs("hello!")); +} + +TEST(BytesValueIs, NoMatch) { + EXPECT_THAT(BytesValue("hello!"), Not(BytesValueIs("goodbye!"))); + EXPECT_THAT(IntValue(2), Not(BytesValueIs("goodbye!"))); +} + +TEST(BytesValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { EXPECT_THAT(IntValue(42), BytesValueIs("hello!")); }(), + "kind is bytes and is equal to \"hello!\""); +} + +TEST(ErrorValueIs, Match) { + EXPECT_THAT(ErrorValue(absl::InternalError("test")), + ErrorValueIs(StatusIs(absl::StatusCode::kInternal, "test"))); +} + +TEST(ErrorValueIs, NoMatch) { + EXPECT_THAT(ErrorValue(absl::UnknownError("test")), + Not(ErrorValueIs(StatusIs(absl::StatusCode::kInternal, "test")))); + EXPECT_THAT(IntValue(2), Not(ErrorValueIs(_))); +} + +TEST(ErrorValueIs, NonMatchMessage) { + EXPECT_NONFATAL_FAILURE( + []() { + EXPECT_THAT(IntValue(42), ErrorValueIs(StatusIs( + absl::StatusCode::kInternal, "test"))); + }(), + "kind is *error* and"); +} + +using ValueMatcherTest = common_internal::ValueTest<>; + +TEST_F(ValueMatcherTest, OptionalValueIsMatch) { + EXPECT_THAT(OptionalValue::Of(IntValue(42), arena()), + OptionalValueIs(IntValueIs(42))); +} + +TEST_F(ValueMatcherTest, OptionalValueIsHeldValueDifferent) { + EXPECT_NONFATAL_FAILURE( + [&]() { + EXPECT_THAT(OptionalValue::Of(IntValue(-42), arena()), + OptionalValueIs(IntValueIs(42))); + }(), + "is OptionalValue that is engaged with value whose kind is int and is " + "equal to 42"); +} + +TEST_F(ValueMatcherTest, OptionalValueIsNotEngaged) { + EXPECT_NONFATAL_FAILURE( + [&]() { + EXPECT_THAT(OptionalValue::None(), OptionalValueIs(IntValueIs(42))); + }(), + "is not engaged"); +} + +TEST_F(ValueMatcherTest, OptionalValueIsNotAnOptional) { + EXPECT_NONFATAL_FAILURE( + [&]() { EXPECT_THAT(IntValue(42), OptionalValueIs(IntValueIs(42))); }(), + "wanted OptionalValue, got int"); +} + +TEST_F(ValueMatcherTest, OptionalValueIsEmptyMatch) { + EXPECT_THAT(OptionalValue::None(), OptionalValueIsEmpty()); +} + +TEST_F(ValueMatcherTest, OptionalValueIsEmptyNotEmpty) { + EXPECT_NONFATAL_FAILURE( + [&]() { + EXPECT_THAT(OptionalValue::Of(IntValue(42), arena()), + OptionalValueIsEmpty()); + }(), + "is not empty"); +} + +TEST_F(ValueMatcherTest, OptionalValueIsEmptyNotOptional) { + EXPECT_NONFATAL_FAILURE( + [&]() { EXPECT_THAT(IntValue(42), OptionalValueIsEmpty()); }(), + "wanted OptionalValue, got int"); +} + +TEST_F(ValueMatcherTest, ListMatcherBasic) { + auto builder = NewListValueBuilder(arena()); + + ASSERT_OK(builder->Add(IntValue(42))); + + Value list_value = std::move(*builder).Build(); + + EXPECT_THAT(list_value, ListValueIs(Truly([](const ListValue& v) { + auto size = v.Size(); + return size.ok() && *size == 1; + }))); +} + +TEST_F(ValueMatcherTest, ListMatcherMatchesElements) { + auto builder = NewListValueBuilder(arena()); + ASSERT_OK(builder->Add(IntValue(42))); + ASSERT_OK(builder->Add(IntValue(1337))); + ASSERT_OK(builder->Add(IntValue(42))); + ASSERT_OK(builder->Add(IntValue(100))); + EXPECT_THAT(std::move(*builder).Build(), + ListValueIs(ListValueElements( + ElementsAre(IntValueIs(42), IntValueIs(1337), IntValueIs(42), + IntValueIs(100)), + descriptor_pool(), message_factory(), arena()))); +} + +TEST_F(ValueMatcherTest, MapMatcherBasic) { + auto builder = NewMapValueBuilder(arena()); + + ASSERT_OK(builder->Put(IntValue(42), IntValue(42))); + + Value map_value = std::move(*builder).Build(); + + EXPECT_THAT(map_value, MapValueIs(Truly([](const MapValue& v) { + auto size = v.Size(); + return size.ok() && *size == 1; + }))); +} + +TEST_F(ValueMatcherTest, MapMatcherMatchesElements) { + auto builder = NewMapValueBuilder(arena()); + + ASSERT_OK(builder->Put(IntValue(42), StringValue("answer"))); + ASSERT_OK(builder->Put(IntValue(1337), StringValue("leet"))); + EXPECT_THAT( + std::move(*builder).Build(), + MapValueIs(MapValueElements( + UnorderedElementsAre(Pair(IntValueIs(42), StringValueIs("answer")), + Pair(IntValueIs(1337), StringValueIs("leet"))), + descriptor_pool(), message_factory(), arena()))); +} + +} // namespace +} // namespace cel::test diff --git a/common/values/bool_value.cc b/common/values/bool_value.cc new file mode 100644 index 000000000..07854e0f5 --- /dev/null +++ b/common/values/bool_value.cc @@ -0,0 +1,97 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +std::string BoolDebugString(bool value) { return value ? "true" : "false"; } + +} // namespace + +std::string BoolValue::DebugString() const { + return BoolDebugString(NativeValue()); +} + +absl::Status BoolValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::BoolValue message; + message.set_value(NativeValue()); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status BoolValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.SetBoolValue(json, NativeValue()); + + return absl::OkStatus(); +} + +absl::Status BoolValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsBool(); other_value.has_value()) { + *result = BoolValue{NativeValue() == other_value->NativeValue()}; + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/bool_value.h b/common/values/bool_value.h new file mode 100644 index 000000000..58fb26ebc --- /dev/null +++ b/common/values/bool_value.h @@ -0,0 +1,111 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BOOL_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BOOL_VALUE_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class BoolValue; + +// `BoolValue` represents values of the primitive `bool` type. +class BoolValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kBool; + + BoolValue() = default; + BoolValue(const BoolValue&) = default; + BoolValue(BoolValue&&) = default; + BoolValue& operator=(const BoolValue&) = default; + BoolValue& operator=(BoolValue&&) = default; + + explicit BoolValue(bool value) noexcept : value_(value) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + operator bool() const noexcept { return value_; } + + ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return BoolType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return NativeValue() == false; } + + bool NativeValue() const { return static_cast(*this); } + + friend void swap(BoolValue& lhs, BoolValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + private: + friend class common_internal::ValueMixin; + + bool value_ = false; +}; + +template +H AbslHashValue(H state, BoolValue value) { + return H::combine(std::move(state), value.NativeValue()); +} + +inline std::ostream& operator<<(std::ostream& out, BoolValue value) { + return out << value.DebugString(); +} + +inline BoolValue FalseValue() noexcept { return BoolValue(false); } + +inline BoolValue TrueValue() noexcept { return BoolValue(true); } + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BOOL_VALUE_H_ diff --git a/common/values/bool_value_test.cc b/common/values/bool_value_test.cc new file mode 100644 index 000000000..5f679627c --- /dev/null +++ b/common/values/bool_value_test.cc @@ -0,0 +1,80 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/hash/hash.h" +#include "absl/status/status_matchers.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; + +using BoolValueTest = common_internal::ValueTest<>; + +TEST_F(BoolValueTest, Kind) { + EXPECT_EQ(BoolValue(true).kind(), BoolValue::kKind); + EXPECT_EQ(Value(BoolValue(true)).kind(), BoolValue::kKind); +} + +TEST_F(BoolValueTest, DebugString) { + { + std::ostringstream out; + out << BoolValue(true); + EXPECT_EQ(out.str(), "true"); + } + { + std::ostringstream out; + out << Value(BoolValue(true)); + EXPECT_EQ(out.str(), "true"); + } +} + +TEST_F(BoolValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(BoolValue(false).ConvertToJson(descriptor_pool(), + message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(bool_value: false)pb")); +} + +TEST_F(BoolValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(BoolValue(true)), NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(BoolValue(true))), + NativeTypeId::For()); +} + +TEST_F(BoolValueTest, HashValue) { + EXPECT_EQ(absl::HashOf(BoolValue(true)), absl::HashOf(true)); +} + +TEST_F(BoolValueTest, Equality) { + EXPECT_NE(BoolValue(false), true); + EXPECT_NE(true, BoolValue(false)); + EXPECT_NE(BoolValue(false), BoolValue(true)); +} + +TEST_F(BoolValueTest, LessThan) { + EXPECT_LT(BoolValue(false), true); + EXPECT_LT(false, BoolValue(true)); + EXPECT_LT(BoolValue(false), BoolValue(true)); +} + +} // namespace +} // namespace cel diff --git a/common/values/bytes_value.cc b/common/values/bytes_value.cc new file mode 100644 index 000000000..c9fc32ac2 --- /dev/null +++ b/common/values/bytes_value.cc @@ -0,0 +1,194 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/internal/byte_string.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +template +std::string BytesDebugString(const Bytes& value) { + return value.NativeValue(absl::Overload( + [](absl::string_view string) -> std::string { + return internal::FormatBytesLiteral(string); + }, + [](const absl::Cord& cord) -> std::string { + if (auto flat = cord.TryFlat(); flat.has_value()) { + return internal::FormatBytesLiteral(*flat); + } + return internal::FormatBytesLiteral(static_cast(cord)); + })); +} + +} // namespace + +BytesValue BytesValue::Concat(const BytesValue& lhs, const BytesValue& rhs, + google::protobuf::Arena* absl_nonnull arena) { + return BytesValue( + common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena)); +} + +std::string BytesValue::DebugString() const { return BytesDebugString(*this); } + +absl::Status BytesValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::BytesValue message; + message.set_value(NativeString()); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status BytesValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + NativeValue([&](const auto& value) { + value_reflection.SetStringValueFromBytes(json, value); + }); + + return absl::OkStatus(); +} + +absl::Status BytesValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsBytes(); other_value.has_value()) { + *result = NativeValue([other_value](const auto& value) -> BoolValue { + return other_value->NativeValue( + [&value](const auto& other_value) -> BoolValue { + return BoolValue{value == other_value}; + }); + }); + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +BytesValue BytesValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { + return BytesValue(value_.Clone(arena)); +} + +size_t BytesValue::Size() const { + return NativeValue( + [](const auto& alternative) -> size_t { return alternative.size(); }); +} + +bool BytesValue::IsEmpty() const { + return NativeValue( + [](const auto& alternative) -> bool { return alternative.empty(); }); +} + +bool BytesValue::Equals(absl::string_view bytes) const { + return NativeValue([bytes](const auto& alternative) -> bool { + return alternative == bytes; + }); +} + +bool BytesValue::Equals(const absl::Cord& bytes) const { + return NativeValue([&bytes](const auto& alternative) -> bool { + return alternative == bytes; + }); +} + +bool BytesValue::Equals(const BytesValue& bytes) const { + return bytes.NativeValue( + [this](const auto& alternative) -> bool { return Equals(alternative); }); +} + +namespace { + +int CompareImpl(absl::string_view lhs, absl::string_view rhs) { + return lhs.compare(rhs); +} + +int CompareImpl(absl::string_view lhs, const absl::Cord& rhs) { + return -rhs.Compare(lhs); +} + +int CompareImpl(const absl::Cord& lhs, absl::string_view rhs) { + return lhs.Compare(rhs); +} + +int CompareImpl(const absl::Cord& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs); +} + +} // namespace + +int BytesValue::Compare(absl::string_view bytes) const { + return NativeValue([bytes](const auto& alternative) -> int { + return CompareImpl(alternative, bytes); + }); +} + +int BytesValue::Compare(const absl::Cord& bytes) const { + return NativeValue([&bytes](const auto& alternative) -> int { + return CompareImpl(alternative, bytes); + }); +} + +int BytesValue::Compare(const BytesValue& bytes) const { + return bytes.NativeValue( + [this](const auto& alternative) -> int { return Compare(alternative); }); +} + +} // namespace cel diff --git a/common/values/bytes_value.h b/common/values/bytes_value.h new file mode 100644 index 000000000..c18381a6a --- /dev/null +++ b/common/values/bytes_value.h @@ -0,0 +1,338 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/arena.h" +#include "common/internal/byte_string.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class BytesValue; +class BytesValueInputStream; +class BytesValueOutputStream; + +namespace common_internal { +absl::string_view LegacyBytesValue(const BytesValue& value, bool stable, + google::protobuf::Arena* absl_nonnull arena); +} // namespace common_internal + +// `BytesValue` represents values of the primitive `bytes` type. +class BytesValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kBytes; + + static BytesValue From(const char* absl_nullable value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + static BytesValue From(absl::string_view value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + static BytesValue From(const absl::Cord& value); + static BytesValue From(std::string&& value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + static BytesValue Wrap(absl::string_view value, + google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + static BytesValue Wrap(absl::string_view value) = delete; + static BytesValue Wrap(const absl::Cord& value); + static BytesValue Wrap(std::string&& value) = delete; + static BytesValue Wrap(std::string&& value, + google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) = delete; + + // Returns a BytesValue that aliases the provided string. Caller must ensure + // the provided string outlives the use of the returned BytesValue. + static BytesValue WrapUnsafe(absl::string_view value); + + static BytesValue Concat(const BytesValue& lhs, const BytesValue& rhs, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + ABSL_DEPRECATED("Use From") + explicit BytesValue(const char* absl_nullable value) : value_(value) {} + + ABSL_DEPRECATED("Use From") + explicit BytesValue(absl::string_view value) : value_(value) {} + + ABSL_DEPRECATED("Use From") + explicit BytesValue(const absl::Cord& value) : value_(value) {} + + ABSL_DEPRECATED("Use From") + explicit BytesValue(std::string&& value) : value_(std::move(value)) {} + + ABSL_DEPRECATED("Use From") + BytesValue(Allocator<> allocator, const char* absl_nullable value) + : value_(allocator, value) {} + + ABSL_DEPRECATED("Use From") + BytesValue(Allocator<> allocator, absl::string_view value) + : value_(allocator, value) {} + + ABSL_DEPRECATED("Use From") + BytesValue(Allocator<> allocator, const absl::Cord& value) + : value_(allocator, value) {} + + ABSL_DEPRECATED("Use From") + BytesValue(Allocator<> allocator, std::string&& value) + : value_(allocator, std::move(value)) {} + + ABSL_DEPRECATED("Use Wrap") + BytesValue(Borrower borrower, absl::string_view value) + : value_(borrower, value) {} + + ABSL_DEPRECATED("Use Wrap") + BytesValue(Borrower borrower, const absl::Cord& value) + : value_(borrower, value) {} + + BytesValue() = default; + BytesValue(const BytesValue&) = default; + BytesValue(BytesValue&&) = default; + BytesValue& operator=(const BytesValue&) = default; + BytesValue& operator=(BytesValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return BytesType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { + return NativeValue([](const auto& value) -> bool { return value.empty(); }); + } + + BytesValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + ABSL_DEPRECATED("Use ToString()") + std::string NativeString() const { return value_.ToString(); } + + ABSL_DEPRECATED("Use ToStringView()") + absl::string_view NativeString( + std::string& scratch + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_.ToStringView(&scratch); + } + + ABSL_DEPRECATED("Use ToCord()") + absl::Cord NativeCord() const { return value_.ToCord(); } + + template + ABSL_DEPRECATED("Use TryFlat()") + std::common_type_t< + std::invoke_result_t, + std::invoke_result_t> NativeValue(Visitor&& + visitor) + const { + return value_.Visit(std::forward(visitor)); + } + + void swap(BytesValue& other) noexcept { + using std::swap; + swap(value_, other.value_); + } + + size_t Size() const; + + bool IsEmpty() const; + + bool Equals(absl::string_view bytes) const; + bool Equals(const absl::Cord& bytes) const; + bool Equals(const BytesValue& bytes) const; + + int Compare(absl::string_view bytes) const; + int Compare(const absl::Cord& bytes) const; + int Compare(const BytesValue& bytes) const; + + absl::optional TryFlat() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_.TryFlat(); + } + + std::string ToString() const { return value_.ToString(); } + + void CopyToString(std::string* absl_nonnull out) const { + value_.CopyToString(out); + } + + void AppendToString(std::string* absl_nonnull out) const { + value_.AppendToString(out); + } + + absl::Cord ToCord() const { return value_.ToCord(); } + + void CopyToCord(absl::Cord* absl_nonnull out) const { + value_.CopyToCord(out); + } + + void AppendToCord(absl::Cord* absl_nonnull out) const { + value_.AppendToCord(out); + } + + absl::string_view ToStringView( + std::string* absl_nonnull scratch + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_.ToStringView(scratch); + } + + friend bool operator<(const BytesValue& lhs, const BytesValue& rhs) { + return lhs.value_ < rhs.value_; + } + + private: + friend class common_internal::ValueMixin; + friend class BytesValueInputStream; + friend class BytesValueOutputStream; + friend absl::string_view common_internal::LegacyBytesValue( + const BytesValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena); + friend struct ArenaTraits; + + explicit BytesValue(common_internal::ByteString value) noexcept + : value_(std::move(value)) {} + + common_internal::ByteString value_; +}; + +inline void swap(BytesValue& lhs, BytesValue& rhs) noexcept { lhs.swap(rhs); } + +inline std::ostream& operator<<(std::ostream& out, const BytesValue& value) { + return out << value.DebugString(); +} + +inline bool operator==(const BytesValue& lhs, absl::string_view rhs) { + return lhs.Equals(rhs); +} + +inline bool operator==(absl::string_view lhs, const BytesValue& rhs) { + return rhs == lhs; +} + +inline bool operator!=(const BytesValue& lhs, absl::string_view rhs) { + return !lhs.Equals(rhs); +} + +inline bool operator!=(absl::string_view lhs, const BytesValue& rhs) { + return rhs != lhs; +} + +inline BytesValue BytesValue::From(const char* absl_nullable value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return From(absl::NullSafeStringView(value), arena); +} + +inline BytesValue BytesValue::From(absl::string_view value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(arena != nullptr); + + return BytesValue(arena, value); +} + +inline BytesValue BytesValue::From(const absl::Cord& value) { + return BytesValue(value); +} + +inline BytesValue BytesValue::From(std::string&& value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(arena != nullptr); + + return BytesValue(arena, std::move(value)); +} + +inline BytesValue BytesValue::Wrap(absl::string_view value, + google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(arena != nullptr); + + return BytesValue(Borrower::Arena(arena), value); +} + +inline BytesValue BytesValue::WrapUnsafe(absl::string_view value) { + return BytesValue(common_internal::ByteString::FromExternal(value)); +} + +inline BytesValue BytesValue::Wrap(const absl::Cord& value) { + return BytesValue(value); +} + +namespace common_internal { + +inline absl::string_view LegacyBytesValue(const BytesValue& value, bool stable, + google::protobuf::Arena* absl_nonnull arena) { + return LegacyByteString(value.value_, stable, arena); +} + +} // namespace common_internal + +template <> +struct ArenaTraits { + using constructible = std::true_type; + + static bool trivially_destructible(const BytesValue& value) { + return ArenaTraits<>::trivially_destructible(value.value_); + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_H_ diff --git a/common/values/bytes_value_input_stream.h b/common/values/bytes_value_input_stream.h new file mode 100644 index 000000000..c4224f30d --- /dev/null +++ b/common/values/bytes_value_input_stream.h @@ -0,0 +1,133 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" +#include "common/internal/byte_string.h" +#include "common/values/bytes_value.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { + +class BytesValueInputStream final : public google::protobuf::io::ZeroCopyInputStream { + public: + explicit BytesValueInputStream( + const BytesValue* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND) { + Construct(value); + } + + ~BytesValueInputStream() override { AsVariant().~variant(); } + + bool Next(const void** data, int* size) override { + return absl::visit( + [&data, &size](auto& alternative) -> bool { + return alternative.Next(data, size); + }, + AsVariant()); + } + + void BackUp(int count) override { + absl::visit( + [&count](auto& alternative) -> void { alternative.BackUp(count); }, + AsVariant()); + } + + bool Skip(int count) override { + return absl::visit( + [&count](auto& alternative) -> bool { return alternative.Skip(count); }, + AsVariant()); + } + + int64_t ByteCount() const override { + return absl::visit( + [](const auto& alternative) -> int64_t { + return alternative.ByteCount(); + }, + AsVariant()); + } + + bool ReadCord(absl::Cord* cord, int count) override { + return absl::visit( + [&cord, &count](auto& alternative) -> bool { + return alternative.ReadCord(cord, count); + }, + AsVariant()); + } + + private: + using Variant = + absl::variant; + + void Construct(const BytesValue* absl_nonnull value) { + ABSL_DCHECK(value != nullptr); + + switch (value->value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: + Construct(value->value_.GetSmall()); + break; + case common_internal::ByteStringKind::kMedium: + Construct(value->value_.GetMedium()); + break; + case common_internal::ByteStringKind::kLarge: + Construct(&value->value_.GetLarge()); + break; + } + } + + void Construct(absl::string_view value) { + ABSL_DCHECK_LE(value.size(), + static_cast(std::numeric_limits::max())); + ::new (static_cast(&impl_[0])) + Variant(absl::in_place_type, value.data(), + static_cast(value.size())); + } + + void Construct(const absl::Cord* absl_nonnull value) { + ::new (static_cast(&impl_[0])) + Variant(absl::in_place_type, value); + } + + void Destruct() { AsVariant().~variant(); } + + Variant& AsVariant() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + const Variant& AsVariant() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + alignas(Variant) char impl_[sizeof(Variant)]; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ diff --git a/common/values/bytes_value_output_stream.h b/common/values/bytes_value_output_stream.h new file mode 100644 index 000000000..0773e40e7 --- /dev/null +++ b/common/values/bytes_value_output_stream.h @@ -0,0 +1,176 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" +#include "common/internal/byte_string.h" +#include "common/values/bytes_value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { + +class BytesValueOutputStream final : public google::protobuf::io::ZeroCopyOutputStream { + public: + explicit BytesValueOutputStream(const BytesValue& value) + : BytesValueOutputStream(value, /*arena=*/nullptr) {} + + BytesValueOutputStream(const BytesValue& value, + google::protobuf::Arena* absl_nullable arena) { + Construct(value, arena); + } + + bool Next(void** data, int* size) override { + return absl::visit(absl::Overload( + [&data, &size](String& string) -> bool { + return string.stream.Next(data, size); + }, + [&data, &size](Cord& cord) -> bool { + return cord.Next(data, size); + }), + AsVariant()); + } + + void BackUp(int count) override { + absl::visit( + absl::Overload( + [&count](String& string) -> void { string.stream.BackUp(count); }, + [&count](Cord& cord) -> void { cord.BackUp(count); }), + AsVariant()); + } + + int64_t ByteCount() const override { + return absl::visit( + absl::Overload( + [](const String& string) -> int64_t { + return string.stream.ByteCount(); + }, + [](const Cord& cord) -> int64_t { return cord.ByteCount(); }), + AsVariant()); + } + + bool WriteAliasedRaw(const void* data, int size) override { + return absl::visit(absl::Overload( + [&data, &size](String& string) -> bool { + return string.stream.WriteAliasedRaw(data, size); + }, + [&data, &size](Cord& cord) -> bool { + return cord.WriteAliasedRaw(data, size); + }), + AsVariant()); + } + + bool AllowsAliasing() const override { + return absl::visit( + absl::Overload( + [](const String& string) -> bool { + return string.stream.AllowsAliasing(); + }, + [](const Cord& cord) -> bool { return cord.AllowsAliasing(); }), + AsVariant()); + } + + bool WriteCord(const absl::Cord& out) override { + return absl::visit( + absl::Overload( + [&out](String& string) -> bool { + return string.stream.WriteCord(out); + }, + [&out](Cord& cord) -> bool { return cord.WriteCord(out); }), + AsVariant()); + } + + BytesValue Consume() && { + return absl::visit(absl::Overload( + [](String& string) -> BytesValue { + return BytesValue(string.arena, + std::move(string.target)); + }, + [](Cord& cord) -> BytesValue { + return BytesValue(cord.Consume()); + }), + AsVariant()); + } + + private: + struct String final { + String(absl::string_view target, google::protobuf::Arena* absl_nullable arena) + : target(target), stream(&this->target), arena(arena) {} + + std::string target; + google::protobuf::io::StringOutputStream stream; + google::protobuf::Arena* absl_nullable arena; + }; + + using Cord = google::protobuf::io::CordOutputStream; + + using Variant = absl::variant; + + void Construct(const BytesValue& value, google::protobuf::Arena* absl_nullable arena) { + switch (value.value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: + Construct(value.value_.GetSmall(), arena); + break; + case common_internal::ByteStringKind::kMedium: + Construct(value.value_.GetMedium(), arena); + break; + case common_internal::ByteStringKind::kLarge: + Construct(value.value_.GetLarge()); + break; + } + } + + void Construct(absl::string_view value, google::protobuf::Arena* absl_nullable arena) { + ::new (static_cast(&impl_[0])) + Variant(absl::in_place_type, value, arena); + } + + void Construct(const absl::Cord& value) { + ::new (static_cast(&impl_[0])) + Variant(absl::in_place_type, value); + } + + void Destruct() { AsVariant().~variant(); } + + Variant& AsVariant() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + const Variant& AsVariant() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + alignas(Variant) char impl_[sizeof(Variant)]; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ diff --git a/common/values/bytes_value_test.cc b/common/values/bytes_value_test.cc new file mode 100644 index 000000000..58219e3a4 --- /dev/null +++ b/common/values/bytes_value_test.cc @@ -0,0 +1,256 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/cord.h" +#include "absl/strings/cord_test_helpers.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::An; +using ::testing::Eq; +using ::testing::NotNull; +using ::testing::Optional; + +using BytesValueTest = common_internal::ValueTest<>; + +TEST_F(BytesValueTest, Kind) { + EXPECT_EQ(BytesValue("foo").kind(), BytesValue::kKind); + EXPECT_EQ(Value(BytesValue(absl::Cord("foo"))).kind(), BytesValue::kKind); +} + +TEST_F(BytesValueTest, DebugString) { + { + std::ostringstream out; + out << BytesValue("foo"); + EXPECT_EQ(out.str(), "b\"foo\""); + } + { + std::ostringstream out; + out << BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})); + EXPECT_EQ(out.str(), "b\"foo\""); + } + { + std::ostringstream out; + out << Value(BytesValue(absl::Cord("foo"))); + EXPECT_EQ(out.str(), "b\"foo\""); + } +} + +TEST_F(BytesValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(BytesValue("foo").ConvertToJson(descriptor_pool(), + message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(string_value: "Zm9v")pb")); +} + +TEST_F(BytesValueTest, NativeValue) { + std::string scratch; + EXPECT_EQ(BytesValue("foo").NativeString(), "foo"); + EXPECT_EQ(BytesValue("foo").NativeString(scratch), "foo"); + EXPECT_EQ(BytesValue("foo").NativeCord(), "foo"); +} + +TEST_F(BytesValueTest, TryFlat) { + EXPECT_THAT(BytesValue("foo").TryFlat(), Optional(Eq("foo"))); + EXPECT_THAT( + BytesValue(absl::MakeFragmentedCord({"Hello, World!", "World, Hello!"})) + .TryFlat(), + Eq(absl::nullopt)); +} + +TEST_F(BytesValueTest, ToString) { + EXPECT_EQ(BytesValue("foo").ToString(), "foo"); + EXPECT_EQ(BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToString(), + "foo"); +} + +TEST_F(BytesValueTest, CopyToString) { + std::string out; + BytesValue("foo").CopyToString(&out); + EXPECT_EQ(out, "foo"); + BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToString(&out); + EXPECT_EQ(out, "foo"); +} + +TEST_F(BytesValueTest, AppendToString) { + std::string out; + BytesValue("foo").AppendToString(&out); + EXPECT_EQ(out, "foo"); + BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToString(&out); + EXPECT_EQ(out, "foofoo"); +} + +TEST_F(BytesValueTest, ToCord) { + EXPECT_EQ(BytesValue("foo").ToCord(), "foo"); + EXPECT_EQ(BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToCord(), + "foo"); +} + +TEST_F(BytesValueTest, CopyToCord) { + absl::Cord out; + BytesValue("foo").CopyToCord(&out); + EXPECT_EQ(out, "foo"); + BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToCord(&out); + EXPECT_EQ(out, "foo"); +} + +TEST_F(BytesValueTest, AppendToCord) { + absl::Cord out; + BytesValue("foo").AppendToCord(&out); + EXPECT_EQ(out, "foo"); + BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToCord(&out); + EXPECT_EQ(out, "foofoo"); +} + +TEST_F(BytesValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(BytesValue("foo")), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(BytesValue(absl::Cord("foo")))), + NativeTypeId::For()); +} + +TEST_F(BytesValueTest, StringViewEquality) { + // NOLINTBEGIN(readability/check) + EXPECT_TRUE(BytesValue("foo") == "foo"); + EXPECT_FALSE(BytesValue("foo") == "bar"); + + EXPECT_TRUE("foo" == BytesValue("foo")); + EXPECT_FALSE("bar" == BytesValue("foo")); + // NOLINTEND(readability/check) +} + +TEST_F(BytesValueTest, StringViewInequality) { + // NOLINTBEGIN(readability/check) + EXPECT_FALSE(BytesValue("foo") != "foo"); + EXPECT_TRUE(BytesValue("foo") != "bar"); + + EXPECT_FALSE("foo" != BytesValue("foo")); + EXPECT_TRUE("bar" != BytesValue("foo")); + // NOLINTEND(readability/check) +} + +TEST_F(BytesValueTest, Comparison) { + EXPECT_LT(BytesValue("bar"), BytesValue("foo")); + EXPECT_FALSE(BytesValue("foo") < BytesValue("foo")); + EXPECT_FALSE(BytesValue("foo") < BytesValue("bar")); +} + +TEST_F(BytesValueTest, StringInputStream) { + BytesValue value = BytesValue("foo"); + BytesValueInputStream stream(&value); + const void* data; + int size; + absl::Cord cord; + ASSERT_TRUE(stream.Next(&data, &size)); + EXPECT_THAT(data, NotNull()); + EXPECT_EQ(size, 3); + EXPECT_EQ(stream.ByteCount(), 3); + stream.BackUp(size); + ASSERT_TRUE(stream.Skip(3)); + EXPECT_FALSE(stream.ReadCord(&cord, 3)); + EXPECT_FALSE(stream.Next(&data, &size)); +} + +TEST_F(BytesValueTest, CordInputStream) { + BytesValue value = BytesValue(absl::Cord("foo")); + BytesValueInputStream stream(&value); + const void* data; + int size; + absl::Cord cord; + ASSERT_TRUE(stream.Next(&data, &size)); + EXPECT_THAT(data, NotNull()); + EXPECT_EQ(size, 3); + EXPECT_EQ(stream.ByteCount(), 3); + stream.BackUp(size); + ASSERT_TRUE(stream.Skip(3)); + EXPECT_FALSE(stream.ReadCord(&cord, 3)); + EXPECT_FALSE(stream.Next(&data, &size)); +} + +TEST_F(BytesValueTest, ArenaStringOutputStream) { + BytesValue value = BytesValue(""); + { + BytesValueOutputStream stream(value, arena()); + EXPECT_THAT(stream.AllowsAliasing(), An()); + EXPECT_EQ(stream.ByteCount(), 0); + google::protobuf::Value value_proto; + auto* struct_proto = value_proto.mutable_struct_value(); + (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); + (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); + ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); + EXPECT_EQ(std::move(stream).Consume(), + value_proto.SerializePartialAsString()); + } + { + BytesValueOutputStream stream(value); + EXPECT_EQ(std::move(stream).Consume(), ""); + } +} + +TEST_F(BytesValueTest, StringOutputStream) { + BytesValue value = BytesValue(""); + { + BytesValueOutputStream stream(value); + EXPECT_THAT(stream.AllowsAliasing(), An()); + EXPECT_EQ(stream.ByteCount(), 0); + google::protobuf::Value value_proto; + auto* struct_proto = value_proto.mutable_struct_value(); + (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); + (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); + ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); + EXPECT_EQ(std::move(stream).Consume(), + value_proto.SerializePartialAsString()); + } + { + BytesValueOutputStream stream(value); + EXPECT_EQ(std::move(stream).Consume(), ""); + } +} + +TEST_F(BytesValueTest, CordOutputStream) { + BytesValue value = BytesValue(absl::Cord()); + { + BytesValueOutputStream stream(value); + EXPECT_THAT(stream.AllowsAliasing(), An()); + EXPECT_EQ(stream.ByteCount(), 0); + google::protobuf::Value value_proto; + auto* struct_proto = value_proto.mutable_struct_value(); + (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); + (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); + ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); + EXPECT_EQ(std::move(stream).Consume(), + value_proto.SerializePartialAsString()); + } + { + BytesValueOutputStream stream(value); + EXPECT_EQ(std::move(stream).Consume(), ""); + } +} + +} // namespace +} // namespace cel diff --git a/common/values/custom_list_value.cc b/common/values/custom_list_value.cc new file mode 100644 index 000000000..fbba38cfa --- /dev/null +++ b/common/values/custom_list_value.cc @@ -0,0 +1,614 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/casting.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/values/list_value_builder.h" +#include "common/values/values.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ListValueReflection; +using ::cel::well_known_types::ValueReflection; +using ::google::api::expr::runtime::CelValue; + +class EmptyListValue final : public common_internal::CompatListValue { + public: + static const EmptyListValue& Get() { + static const absl::NoDestructor empty; + return *empty; + } + + EmptyListValue() = default; + + std::string DebugString() const override { return "[]"; } + + bool IsEmpty() const override { return true; } + + size_t Size() const override { return 0; } + + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + + json->Clear(); + return absl::OkStatus(); + } + + CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + return CustomListValue(&EmptyListValue::Get(), arena); + } + + int size() const override { return 0; } + + CelValue operator[](int index) const override { + static const absl::NoDestructor error( + absl::InvalidArgumentError("index out of bounds")); + return CelValue::CreateError(&*error); + } + + CelValue Get(google::protobuf::Arena* arena, int index) const override { + if (arena == nullptr) { + return (*this)[index]; + } + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, absl::InvalidArgumentError("index out of bounds"))); + } + + private: + absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull, + Value* absl_nonnull result) const override { + *result = IndexOutOfBoundsError(index); + return absl::OkStatus(); + } +}; + +} // namespace + +namespace common_internal { + +const CompatListValue* absl_nonnull EmptyCompatListValue() { + return &EmptyListValue::Get(); +} + +} // namespace common_internal + +class CustomListValueInterfaceIterator final : public ValueIterator { + public: + explicit CustomListValueInterfaceIterator( + const CustomListValueInterface& interface) + : interface_(interface), size_(interface_.Size()) {} + + bool HasNext() override { return index_ < size_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "ValueIterator::Next() called when " + "ValueIterator::HasNext() returns false"); + } + CEL_RETURN_IF_ERROR(interface_.Get(index_, descriptor_pool, message_factory, + arena, result)); + ++index_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (index_ >= size_) { + return false; + } + CEL_RETURN_IF_ERROR(interface_.Get(index_, descriptor_pool, message_factory, + arena, key_or_value)); + ++index_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (index_ >= size_) { + return false; + } + if (value != nullptr) { + CEL_RETURN_IF_ERROR(interface_.Get(index_, descriptor_pool, + message_factory, arena, value)); + } + *key = IntValue(index_); + ++index_; + return true; + } + + private: + const CustomListValueInterface& interface_; + const size_t size_; + size_t index_ = 0; +}; + +namespace { + +class CustomListValueDispatcherIterator final : public ValueIterator { + public: + explicit CustomListValueDispatcherIterator( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, size_t size) + : dispatcher_(dispatcher), content_(content), size_(size) {} + + bool HasNext() override { return index_ < size_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "ValueIterator::Next() called when " + "ValueIterator::HasNext() returns false"); + } + CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index_, + descriptor_pool, message_factory, + arena, result)); + ++index_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (index_ >= size_) { + return false; + } + CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index_, + descriptor_pool, message_factory, + arena, key_or_value)); + ++index_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (index_ >= size_) { + return false; + } + if (value != nullptr) { + CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index_, + descriptor_pool, message_factory, + arena, value)); + } + *key = IntValue(index_); + ++index_; + return true; + } + + private: + const CustomListValueDispatcher* absl_nonnull const dispatcher_; + const CustomListValueContent content_; + const size_t size_; + size_t index_ = 0; +}; + +} // namespace + +absl::Status CustomListValueInterface::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + ListValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor_pool)); + const google::protobuf::Message* prototype = + message_factory->GetPrototype(reflection.GetDescriptor()); + if (prototype == nullptr) { + return absl::UnknownError( + absl::StrCat("failed to get message prototype: ", + reflection.GetDescriptor()->full_name())); + } + google::protobuf::Arena arena; + google::protobuf::Message* message = prototype->New(&arena); + CEL_RETURN_IF_ERROR( + ConvertToJsonArray(descriptor_pool, message_factory, message)); + if (!message->SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.ListValue"); + } + return absl::OkStatus(); +} + +absl::Status CustomListValueInterface::ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + const size_t size = Size(); + for (size_t index = 0; index < size; ++index) { + Value element; + CEL_RETURN_IF_ERROR( + Get(index, descriptor_pool, message_factory, arena, &element)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(index, element)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +absl::StatusOr +CustomListValueInterface::NewIterator() const { + return std::make_unique(*this); +} + +absl::Status CustomListValueInterface::Equal( + const ListValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return ListValueEqual(*this, other, descriptor_pool, message_factory, arena, + result); +} + +absl::Status CustomListValueInterface::Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + Value outcome = BoolValue(false); + Value equal; + CEL_RETURN_IF_ERROR(ForEach( + [&](size_t index, const Value& element) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(element.Equal(other, descriptor_pool, + message_factory, arena, &equal)); + if (auto bool_result = As(equal); + bool_result.has_value() && bool_result->NativeValue()) { + outcome = BoolValue(true); + return false; + } + return true; + }, + descriptor_pool, message_factory, arena)); + *result = outcome; + return absl::OkStatus(); +} + +CustomListValue::CustomListValue() { + content_ = CustomListValueContent::From(CustomListValueInterface::Content{ + .interface = &EmptyListValue::Get(), .arena = nullptr}); +} + +NativeTypeId CustomListValue::GetTypeId() const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetNativeTypeId(); + } + return dispatcher_->get_type_id(dispatcher_, content_); +} + +absl::string_view CustomListValue::GetTypeName() const { return "list"; } + +std::string CustomListValue::DebugString() const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->DebugString(); + } + if (dispatcher_->debug_string != nullptr) { + return dispatcher_->debug_string(dispatcher_, content_); + } + return "list"; +} + +absl::Status CustomListValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->SerializeTo(descriptor_pool, message_factory, + output); + } + if (dispatcher_->serialize_to != nullptr) { + return dispatcher_->serialize_to(dispatcher_, content_, descriptor_pool, + message_factory, output); + } + return absl::UnimplementedError( + absl::StrCat(GetTypeName(), " is unserializable")); +} + +absl::Status CustomListValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + google::protobuf::Message* json_array = value_reflection.MutableListValue(json); + + return ConvertToJsonArray(descriptor_pool, message_factory, json_array); +} + +absl::Status CustomListValue::ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->ConvertToJsonArray(descriptor_pool, + message_factory, json); + } + if (dispatcher_->convert_to_json_array != nullptr) { + return dispatcher_->convert_to_json_array( + dispatcher_, content_, descriptor_pool, message_factory, json); + } + return absl::UnimplementedError( + absl::StrCat(GetTypeName(), " is not convertable to JSON")); +} + +absl::Status CustomListValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_list_value = other.AsList(); other_list_value) { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Equal(*other_list_value, descriptor_pool, + message_factory, arena, result); + } + if (dispatcher_->equal != nullptr) { + return dispatcher_->equal(dispatcher_, content_, *other_list_value, + descriptor_pool, message_factory, arena, + result); + } + return common_internal::ListValueEqual(*this, *other_list_value, + descriptor_pool, message_factory, + arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +bool CustomListValue::IsZeroValue() const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->IsZeroValue(); + } + return dispatcher_->is_zero_value(dispatcher_, content_); +} + +CustomListValue CustomListValue::Clone( + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + if (content.arena != arena) { + return content.interface->Clone(arena); + } + return *this; + } + return dispatcher_->clone(dispatcher_, content_, arena); +} + +bool CustomListValue::IsEmpty() const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->IsEmpty(); + } + if (dispatcher_->is_empty != nullptr) { + return dispatcher_->is_empty(dispatcher_, content_); + } + return dispatcher_->size(dispatcher_, content_) == 0; +} + +size_t CustomListValue::Size() const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Size(); + } + return dispatcher_->size(dispatcher_, content_); +} + +absl::Status CustomListValue::Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Get(index, descriptor_pool, message_factory, + arena, result); + } + return dispatcher_->get(dispatcher_, content_, index, descriptor_pool, + message_factory, arena, result); +} + +absl::Status CustomListValue::ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->ForEach(callback, descriptor_pool, + message_factory, arena); + } + if (dispatcher_->for_each != nullptr) { + return dispatcher_->for_each(dispatcher_, content_, callback, + descriptor_pool, message_factory, arena); + } + const size_t size = dispatcher_->size(dispatcher_, content_); + for (size_t index = 0; index < size; ++index) { + Value element; + CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index, + descriptor_pool, message_factory, + arena, &element)); + CEL_ASSIGN_OR_RETURN(auto ok, callback(index, element)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +absl::StatusOr CustomListValue::NewIterator() + const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->NewIterator(); + } + if (dispatcher_->new_iterator != nullptr) { + return dispatcher_->new_iterator(dispatcher_, content_); + } + return std::make_unique( + dispatcher_, content_, dispatcher_->size(dispatcher_, content_)); +} + +absl::Status CustomListValue::Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (dispatcher_ == nullptr) { + CustomListValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Contains(other, descriptor_pool, message_factory, + arena, result); + } + if (dispatcher_->contains != nullptr) { + return dispatcher_->contains(dispatcher_, content_, other, descriptor_pool, + message_factory, arena, result); + } + Value outcome = BoolValue(false); + Value equal; + CEL_RETURN_IF_ERROR(ForEach( + [&](size_t index, const Value& element) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(element.Equal(other, descriptor_pool, + message_factory, arena, &equal)); + if (auto bool_result = As(equal); + bool_result.has_value() && bool_result->NativeValue()) { + outcome = BoolValue(true); + return false; + } + return true; + }, + descriptor_pool, message_factory, arena)); + *result = outcome; + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/custom_list_value.h b/common/values/custom_list_value.h new file mode 100644 index 000000000..e66eece43 --- /dev/null +++ b/common/values/custom_list_value.h @@ -0,0 +1,423 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +// `CustomListValue` represents values of the primitive `list` type. +// `CustomListValueView` is a non-owning view of `CustomListValue`. +// `CustomListValueInterface` is the abstract base class of implementations. +// `CustomListValue` and `CustomListValueView` act as smart pointers to +// `CustomListValueInterface`. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_LIST_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_LIST_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/function_ref.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/native_type.h" +#include "common/value_kind.h" +#include "common/values/custom_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class CustomListValueInterface; +class CustomListValueInterfaceIterator; +class CustomListValue; +struct CustomListValueDispatcher; +using CustomListValueContent = CustomValueContent; + +struct CustomListValueDispatcher { + using GetTypeId = + NativeTypeId (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content); + + using GetArena = google::protobuf::Arena* absl_nullable (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content); + + using DebugString = + std::string (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content); + + using SerializeTo = absl::Status (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output); + + using ConvertToJsonArray = absl::Status (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json); + + using Equal = absl::Status (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, const ListValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using IsZeroValue = + bool (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content); + + using IsEmpty = + bool (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content); + + using Size = + size_t (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content); + + using Get = absl::Status (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using ForEach = absl::Status (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, + absl::FunctionRef(size_t, const Value&)> callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + + using NewIterator = absl::StatusOr (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content); + + using Contains = absl::Status (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using Clone = CustomListValue (*)( + const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, google::protobuf::Arena* absl_nonnull arena); + + absl_nonnull GetTypeId get_type_id; + + absl_nonnull GetArena get_arena; + + // If null, simply returns "list". + absl_nullable DebugString debug_string = nullptr; + + // If null, attempts to serialize results in an UNIMPLEMENTED error. + absl_nullable SerializeTo serialize_to = nullptr; + + // If null, attempts to convert to JSON results in an UNIMPLEMENTED error. + absl_nullable ConvertToJsonArray convert_to_json_array = nullptr; + + // If null, an nonoptimal fallback implementation for equality is used. + absl_nullable Equal equal = nullptr; + + absl_nonnull IsZeroValue is_zero_value; + + // If null, `size(...) == 0` is used. + absl_nullable IsEmpty is_empty = nullptr; + + absl_nonnull Size size; + + absl_nonnull Get get; + + // If null, a fallback implementation using `size` and `get` is used. + absl_nullable ForEach for_each = nullptr; + + // If null, a fallback implementation using `size` and `get` is used. + absl_nullable NewIterator new_iterator = nullptr; + + // If null, a fallback implementation is used. + absl_nullable Contains contains = nullptr; + + absl_nonnull Clone clone; +}; + +class CustomListValueInterface { + public: + CustomListValueInterface() = default; + CustomListValueInterface(const CustomListValueInterface&) = delete; + CustomListValueInterface(CustomListValueInterface&&) = delete; + + virtual ~CustomListValueInterface() = default; + + CustomListValueInterface& operator=(const CustomListValueInterface&) = delete; + CustomListValueInterface& operator=(CustomListValueInterface&&) = delete; + + using ForEachCallback = absl::FunctionRef(const Value&)>; + + using ForEachWithIndexCallback = + absl::FunctionRef(size_t, const Value&)>; + + private: + friend class CustomListValueInterfaceIterator; + friend class CustomListValue; + friend absl::Status common_internal::ListValueEqual( + const CustomListValueInterface& lhs, const ListValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + virtual std::string DebugString() const = 0; + + virtual absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + virtual absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const = 0; + + virtual absl::Status Equal( + const ListValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + + virtual bool IsZeroValue() const { return IsEmpty(); } + + virtual bool IsEmpty() const { return Size() == 0; } + + virtual size_t Size() const = 0; + + virtual absl::Status Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; + + virtual absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + virtual absl::StatusOr NewIterator() const; + + virtual absl::Status Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + + virtual CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; + + virtual NativeTypeId GetNativeTypeId() const = 0; + + struct Content { + const CustomListValueInterface* absl_nonnull interface; + const google::protobuf::Arena* absl_nullable arena; + }; +}; + +// Creates a custom list value from a manual dispatch table `dispatcher` and +// opaque data `content` whose format is only know to functions in the manual +// dispatch table. The dispatch table should probably be valid for the lifetime +// of the process, but at a minimum must outlive all instances of the resulting +// value. +// +// IMPORTANT: This approach to implementing CustomListValue should only be +// used when you know exactly what you are doing. When in doubt, just implement +// CustomListValueInterface. +CustomListValue UnsafeCustomListValue( + const CustomListValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomListValueContent content); + +class CustomListValue final + : private common_internal::ListValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kList; + + // Constructs a custom list value from an implementation of + // `CustomListValueInterface` `interface` whose lifetime is tied to that of + // the arena `arena`. + CustomListValue(const CustomListValueInterface* absl_nonnull + interface ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(interface != nullptr); + ABSL_DCHECK(arena != nullptr); + content_ = CustomListValueContent::From(CustomListValueInterface::Content{ + .interface = interface, .arena = arena}); + } + + CustomListValue(); + CustomListValue(const CustomListValue&) = default; + CustomListValue(CustomListValue&&) = default; + CustomListValue& operator=(const CustomListValue&) = default; + CustomListValue& operator=(CustomListValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + NativeTypeId GetTypeId() const; + + absl::string_view GetTypeName() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonArray(). + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Equal; + + bool IsZeroValue() const; + + CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + bool IsEmpty() const; + + size_t Size() const; + + // See ListValueInterface::Get for documentation. + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Get; + + using ForEachCallback = typename CustomListValueInterface::ForEachCallback; + + using ForEachWithIndexCallback = + typename CustomListValueInterface::ForEachWithIndexCallback; + + absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + using ListValueMixin::ForEach; + + absl::StatusOr NewIterator() const; + + absl::Status Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using ListValueMixin::Contains; + + const CustomListValueDispatcher* absl_nullable dispatcher() const { + return dispatcher_; + } + + CustomListValueContent content() const { + ABSL_DCHECK(dispatcher_ != nullptr); + return content_; + } + + const CustomListValueInterface* absl_nullable interface() const { + if (dispatcher_ == nullptr) { + return content_.To().interface; + } + return nullptr; + } + + friend void swap(CustomListValue& lhs, CustomListValue& rhs) noexcept { + using std::swap; + swap(lhs.dispatcher_, rhs.dispatcher_); + swap(lhs.content_, rhs.content_); + } + + private: + friend class common_internal::ValueMixin; + friend class common_internal::ListValueMixin; + friend CustomListValue UnsafeCustomListValue( + const CustomListValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomListValueContent content); + + CustomListValue(const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content) + : dispatcher_(dispatcher), content_(content) { + ABSL_DCHECK(dispatcher != nullptr); + ABSL_DCHECK(dispatcher->get_type_id != nullptr); + ABSL_DCHECK(dispatcher->get_arena != nullptr); + ABSL_DCHECK(dispatcher->is_zero_value != nullptr); + ABSL_DCHECK(dispatcher->size != nullptr); + ABSL_DCHECK(dispatcher->get != nullptr); + ABSL_DCHECK(dispatcher->clone != nullptr); + } + + const CustomListValueDispatcher* absl_nullable dispatcher_ = nullptr; + CustomListValueContent content_ = CustomListValueContent::Zero(); +}; + +inline std::ostream& operator<<(std::ostream& out, + const CustomListValue& type) { + return out << type.DebugString(); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const CustomListValue& type) { + return type.GetTypeId(); + } +}; + +inline CustomListValue UnsafeCustomListValue( + const CustomListValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomListValueContent content) { + return CustomListValue(dispatcher, content); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_LIST_VALUE_H_ diff --git a/common/values/custom_list_value_test.cc b/common/values/custom_list_value_test.cc new file mode 100644 index 000000000..79c3f2419 --- /dev/null +++ b/common/values/custom_list_value_test.cc @@ -0,0 +1,548 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::IsNull; +using ::testing::Not; +using ::testing::NotNull; +using ::testing::Optional; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class CustomListValueTest; + +struct CustomListValueTestContent { + google::protobuf::Arena* absl_nonnull arena; +}; + +class CustomListValueInterfaceTest final : public CustomListValueInterface { + public: + std::string DebugString() const override { return "[true, 1]"; } + + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const override { + google::protobuf::Value json; + google::protobuf::ListValue* json_array = json.mutable_list_value(); + json_array->add_values()->set_bool_value(true); + json_array->add_values()->set_number_value(1.0); + if (!json.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Value"); + } + return absl::OkStatus(); + } + + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + google::protobuf::ListValue json_array; + json_array.add_values()->set_bool_value(true); + json_array.add_values()->set_number_value(1.0); + absl::Cord serialized; + if (!json_array.SerializePartialToString(&serialized)) { + return absl::UnknownError( + "failed to serialize google.protobuf.ListValue"); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError("failed to parse google.protobuf.ListValue"); + } + return absl::OkStatus(); + } + + size_t Size() const override { return 2; } + + CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + return CustomListValue( + (::new (arena->AllocateAligned(sizeof(CustomListValueInterfaceTest), + alignof(CustomListValueInterfaceTest))) + CustomListValueInterfaceTest()), + arena); + } + + private: + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + if (index == 0) { + *result = TrueValue(); + return absl::OkStatus(); + } + if (index == 1) { + *result = IntValue(1); + return absl::OkStatus(); + } + *result = IndexOutOfBoundsError(index); + return absl::OkStatus(); + } + + NativeTypeId GetNativeTypeId() const override { + return NativeTypeId::For(); + } +}; + +class CustomListValueTest : public common_internal::ValueTest<> { + public: + CustomListValue MakeInterface() { + return CustomListValue( + (::new (arena()->AllocateAligned(sizeof(CustomListValueInterfaceTest), + alignof(CustomListValueInterfaceTest))) + CustomListValueInterfaceTest()), + arena()); + } + + CustomListValue MakeDispatcher() { + return UnsafeCustomListValue( + &test_dispatcher_, CustomValueContent::From( + CustomListValueTestContent{.arena = arena()})); + } + + protected: + CustomListValueDispatcher test_dispatcher_ = { + .get_type_id = + [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content) -> NativeTypeId { + return NativeTypeId::For(); + }, + .get_arena = + [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content) -> google::protobuf::Arena* absl_nullable { + return content.To().arena; + }, + .debug_string = + [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content) -> std::string { + return "[true, 1]"; + }, + .serialize_to = + [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) + -> absl::Status { + google::protobuf::Value json; + google::protobuf::Struct* json_object = json.mutable_struct_value(); + (*json_object->mutable_fields())["foo"].set_bool_value(true); + (*json_object->mutable_fields())["bar"].set_number_value(1.0); + if (!json.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Value"); + } + return absl::OkStatus(); + }, + .convert_to_json_array = + [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) -> absl::Status { + { + google::protobuf::ListValue json_array; + json_array.add_values()->set_bool_value(true); + json_array.add_values()->set_number_value(1.0); + absl::Cord serialized; + if (!json_array.SerializePartialToString(&serialized)) { + return absl::UnknownError( + "failed to serialize google.protobuf.ListValue"); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError( + "failed to parse google.protobuf.ListValue"); + } + return absl::OkStatus(); + } + }, + .is_zero_value = + [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content) -> bool { return false; }, + .size = [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content) -> size_t { return 2; }, + .get = [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) -> absl::Status { + if (index == 0) { + *result = TrueValue(); + return absl::OkStatus(); + } + if (index == 1) { + *result = IntValue(1); + return absl::OkStatus(); + } + *result = IndexOutOfBoundsError(index); + return absl::OkStatus(); + }, + .clone = [](const CustomListValueDispatcher* absl_nonnull dispatcher, + CustomListValueContent content, + google::protobuf::Arena* absl_nonnull arena) -> CustomListValue { + return UnsafeCustomListValue( + dispatcher, CustomValueContent::From( + CustomListValueTestContent{.arena = arena})); + }, + }; +}; + +TEST_F(CustomListValueTest, Kind) { + EXPECT_EQ(CustomListValue::kind(), CustomListValue::kKind); +} + +TEST_F(CustomListValueTest, Dispatcher_GetTypeId) { + EXPECT_EQ(MakeDispatcher().GetTypeId(), + NativeTypeId::For()); +} + +TEST_F(CustomListValueTest, Interface_GetTypeId) { + EXPECT_EQ(MakeInterface().GetTypeId(), + NativeTypeId::For()); +} + +TEST_F(CustomListValueTest, Dispatcher_GetTypeName) { + EXPECT_EQ(MakeDispatcher().GetTypeName(), "list"); +} + +TEST_F(CustomListValueTest, Interface_GetTypeName) { + EXPECT_EQ(MakeInterface().GetTypeName(), "list"); +} + +TEST_F(CustomListValueTest, Dispatcher_DebugString) { + EXPECT_EQ(MakeDispatcher().DebugString(), "[true, 1]"); +} + +TEST_F(CustomListValueTest, Interface_DebugString) { + EXPECT_EQ(MakeInterface().DebugString(), "[true, 1]"); +} + +TEST_F(CustomListValueTest, Dispatcher_IsZeroValue) { + EXPECT_FALSE(MakeDispatcher().IsZeroValue()); +} + +TEST_F(CustomListValueTest, Interface_IsZeroValue) { + EXPECT_FALSE(MakeInterface().IsZeroValue()); +} + +TEST_F(CustomListValueTest, Dispatcher_SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(MakeDispatcher().SerializeTo(descriptor_pool(), message_factory(), + &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); +} + +TEST_F(CustomListValueTest, Interface_SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(MakeInterface().SerializeTo(descriptor_pool(), message_factory(), + &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); +} + +TEST_F(CustomListValueTest, Dispatcher_ConvertToJson) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeDispatcher().ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + list_value: { + values: { bool_value: true } + values: { number_value: 1.0 } + } + )pb")); +} + +TEST_F(CustomListValueTest, Interface_ConvertToJson) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeInterface().ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + list_value: { + values: { bool_value: true } + values: { number_value: 1.0 } + } + )pb")); +} + +TEST_F(CustomListValueTest, Dispatcher_ConvertToJsonArray) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeDispatcher().ConvertToJsonArray(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + values: { bool_value: true } + values: { number_value: 1.0 } + )pb")); +} + +TEST_F(CustomListValueTest, Interface_ConvertToJsonArray) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeInterface().ConvertToJsonArray(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + values: { bool_value: true } + values: { number_value: 1.0 } + )pb")); +} + +TEST_F(CustomListValueTest, Dispatcher_IsEmpty) { + EXPECT_FALSE(MakeDispatcher().IsEmpty()); +} + +TEST_F(CustomListValueTest, Interface_IsEmpty) { + EXPECT_FALSE(MakeInterface().IsEmpty()); +} + +TEST_F(CustomListValueTest, Dispatcher_Size) { + EXPECT_EQ(MakeDispatcher().Size(), 2); +} + +TEST_F(CustomListValueTest, Interface_Size) { + EXPECT_EQ(MakeInterface().Size(), 2); +} + +TEST_F(CustomListValueTest, Dispatcher_Get) { + CustomListValue list = MakeDispatcher(); + ASSERT_THAT(list.Get(0, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(list.Get(1, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); + ASSERT_THAT( + list.Get(2, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); +} + +TEST_F(CustomListValueTest, Interface_Get) { + CustomListValue list = MakeInterface(); + ASSERT_THAT(list.Get(0, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(list.Get(1, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); + ASSERT_THAT( + list.Get(2, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); +} + +TEST_F(CustomListValueTest, Dispatcher_ForEach) { + std::vector> fields; + EXPECT_THAT( + MakeDispatcher().ForEach( + [&](size_t index, const Value& value) -> absl::StatusOr { + fields.push_back(std::pair{index, value}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(fields, UnorderedElementsAre(Pair(0, BoolValueIs(true)), + Pair(1, IntValueIs(1)))); +} + +TEST_F(CustomListValueTest, Interface_ForEach) { + std::vector> fields; + EXPECT_THAT( + MakeInterface().ForEach( + [&](size_t index, const Value& value) -> absl::StatusOr { + fields.push_back(std::pair{index, value}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(fields, UnorderedElementsAre(Pair(0, BoolValueIs(true)), + Pair(1, IntValueIs(1)))); +} + +TEST_F(CustomListValueTest, Dispatcher_NewIterator) { + CustomListValue list = MakeDispatcher(); + ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(CustomListValueTest, Interface_NewIterator) { + CustomListValue list = MakeInterface(); + ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(CustomListValueTest, Dispatcher_NewIterator1) { + CustomListValue list = MakeDispatcher(); + ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(IntValueIs(1)))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomListValueTest, Interface_NewIterator1) { + CustomListValue list = MakeInterface(); + ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(IntValueIs(1)))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomListValueTest, Dispatcher_NewIterator2) { + CustomListValue list = MakeDispatcher(); + ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(0), BoolValueIs(true))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(1), IntValueIs(1))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomListValueTest, Interface_NewIterator2) { + CustomListValue list = MakeInterface(); + ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(0), BoolValueIs(true))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(1), IntValueIs(1))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomListValueTest, Dispatcher_Contains) { + CustomListValue list = MakeDispatcher(); + EXPECT_THAT( + list.Contains(TrueValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + list.Contains(IntValue(1), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(list.Contains(UintValue(1u), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(list.Contains(DoubleValue(1.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(list.Contains(FalseValue(), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT( + list.Contains(IntValue(0), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(list.Contains(UintValue(0u), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(list.Contains(DoubleValue(0.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(CustomListValueTest, Interface_Contains) { + CustomListValue list = MakeInterface(); + EXPECT_THAT( + list.Contains(TrueValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + list.Contains(IntValue(1), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(list.Contains(UintValue(1u), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(list.Contains(DoubleValue(1.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(list.Contains(FalseValue(), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT( + list.Contains(IntValue(0), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(list.Contains(UintValue(0u), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(list.Contains(DoubleValue(0.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(CustomListValueTest, Dispatcher) { + EXPECT_THAT(MakeDispatcher().dispatcher(), NotNull()); + EXPECT_THAT(MakeDispatcher().interface(), IsNull()); +} + +TEST_F(CustomListValueTest, Interface) { + EXPECT_THAT(MakeInterface().dispatcher(), IsNull()); + EXPECT_THAT(MakeInterface().interface(), NotNull()); +} + +} // namespace +} // namespace cel diff --git a/common/values/custom_map_value.cc b/common/values/custom_map_value.cc new file mode 100644 index 000000000..ae07f7723 --- /dev/null +++ b/common/values/custom_map_value.cc @@ -0,0 +1,823 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/values/list_value_builder.h" +#include "common/values/map_value_builder.h" +#include "common/values/values.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::StructReflection; +using ::cel::well_known_types::ValueReflection; +using ::google::api::expr::runtime::CelList; +using ::google::api::expr::runtime::CelValue; + +absl::Status NoSuchKeyError(const Value& key) { + return absl::NotFoundError( + absl::StrCat("Key not found in map : ", key.DebugString())); +} + +absl::Status InvalidMapKeyTypeError(ValueKind kind) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid map key type: '", ValueKindToString(kind), "'")); +} + +class EmptyMapValue final : public common_internal::CompatMapValue { + public: + static const EmptyMapValue& Get() { + static const absl::NoDestructor empty; + return *empty; + } + + EmptyMapValue() = default; + + std::string DebugString() const override { return "{}"; } + + bool IsEmpty() const override { return true; } + + size_t Size() const override { return 0; } + + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + ListValue* absl_nonnull result) const override { + *result = ListValue(); + return absl::OkStatus(); + } + + absl::StatusOr NewIterator() const override { + return NewEmptyValueIterator(); + } + + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + json->Clear(); + return absl::OkStatus(); + } + + CustomMapValue Clone(google::protobuf::Arena* absl_nonnull) const override { + return CustomMapValue(); + } + + absl::optional operator[](CelValue key) const override { + return absl::nullopt; + } + + using CompatMapValue::Get; + absl::optional Get(google::protobuf::Arena* arena, + CelValue key) const override { + return absl::nullopt; + } + + absl::StatusOr Has(const CelValue& key) const override { return false; } + + int size() const override { return static_cast(Size()); } + + absl::StatusOr ListKeys() const override { + return common_internal::EmptyCompatListValue(); + } + + absl::StatusOr ListKeys(google::protobuf::Arena*) const override { + return ListKeys(); + } + + private: + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + return false; + } + + absl::StatusOr Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + return false; + } +}; + +} // namespace + +namespace common_internal { + +const CompatMapValue* absl_nonnull EmptyCompatMapValue() { + return &EmptyMapValue::Get(); +} + +} // namespace common_internal + +class CustomMapValueInterfaceIterator final : public ValueIterator { + public: + explicit CustomMapValueInterfaceIterator( + const CustomMapValueInterface* absl_nonnull interface) + : interface_(interface) {} + + bool HasNext() override { + if (keys_iterator_ == nullptr) { + return !interface_->IsEmpty(); + } + return keys_iterator_->HasNext(); + } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (keys_iterator_ == nullptr) { + if (interface_->IsEmpty()) { + return absl::FailedPreconditionError( + "ValueIterator::Next() called when " + "ValueIterator::HasNext() returns false"); + } + CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); + } + return keys_iterator_->Next(descriptor_pool, message_factory, arena, + result); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (keys_iterator_ == nullptr) { + if (interface_->IsEmpty()) { + return false; + } + CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); + } + + return keys_iterator_->Next1(descriptor_pool, message_factory, arena, + key_or_value); + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (keys_iterator_ == nullptr) { + if (interface_->IsEmpty()) { + return false; + } + CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); + } + + CEL_ASSIGN_OR_RETURN( + bool ok, + keys_iterator_->Next1(descriptor_pool, message_factory, arena, key)); + if (!ok) { + return false; + } + if (value != nullptr) { + CEL_ASSIGN_OR_RETURN(ok, interface_->Find(*key, descriptor_pool, + message_factory, arena, value)); + if (!ok) { + return absl::DataLossError( + "map iterator returned key that was not present in the map"); + } + } + return true; + } + + private: + // Projects the keys from the map, setting `keys_` and `keys_iterator_`. If + // this returns OK it is guaranteed that `keys_iterator_` is not null. + absl::Status ProjectKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(keys_iterator_ == nullptr); + + CEL_RETURN_IF_ERROR( + interface_->ListKeys(descriptor_pool, message_factory, arena, &keys_)); + CEL_ASSIGN_OR_RETURN(keys_iterator_, keys_.NewIterator()); + ABSL_CHECK(keys_iterator_->HasNext()); // Crash OK + return absl::OkStatus(); + } + + const CustomMapValueInterface* absl_nonnull const interface_; + ListValue keys_; + absl_nullable ValueIteratorPtr keys_iterator_; +}; + +namespace { + +class CustomMapValueDispatcherIterator final : public ValueIterator { + public: + explicit CustomMapValueDispatcherIterator( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content) + : dispatcher_(dispatcher), content_(content) {} + + bool HasNext() override { + if (keys_iterator_ == nullptr) { + if (dispatcher_->is_empty != nullptr) { + return !dispatcher_->is_empty(dispatcher_, content_); + } + return dispatcher_->size(dispatcher_, content_) != 0; + } + return keys_iterator_->HasNext(); + } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (keys_iterator_ == nullptr) { + if (dispatcher_->is_empty != nullptr + ? dispatcher_->is_empty(dispatcher_, content_) + : dispatcher_->size(dispatcher_, content_) == 0) { + return absl::FailedPreconditionError( + "ValueIterator::Next() called when " + "ValueIterator::HasNext() returns false"); + } + CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); + } + return keys_iterator_->Next(descriptor_pool, message_factory, arena, + result); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (keys_iterator_ == nullptr) { + if (dispatcher_->is_empty != nullptr + ? dispatcher_->is_empty(dispatcher_, content_) + : dispatcher_->size(dispatcher_, content_) == 0) { + return false; + } + CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); + } + + return keys_iterator_->Next1(descriptor_pool, message_factory, arena, + key_or_value); + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + ABSL_DCHECK(value != nullptr); + + if (keys_iterator_ == nullptr) { + if (dispatcher_->is_empty != nullptr + ? dispatcher_->is_empty(dispatcher_, content_) + : dispatcher_->size(dispatcher_, content_) == 0) { + return false; + } + CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); + } + + CEL_ASSIGN_OR_RETURN( + bool ok, + keys_iterator_->Next1(descriptor_pool, message_factory, arena, key)); + if (!ok) { + return false; + } + if (value != nullptr) { + CEL_ASSIGN_OR_RETURN( + ok, dispatcher_->find(dispatcher_, content_, *key, descriptor_pool, + message_factory, arena, value)); + if (!ok) { + return absl::DataLossError( + "map iterator returned key that was not present in the map"); + } + } + return true; + } + + private: + absl::Status ProjectKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(keys_iterator_ == nullptr); + + CEL_RETURN_IF_ERROR(dispatcher_->list_keys(dispatcher_, content_, + descriptor_pool, message_factory, + arena, &keys_)); + CEL_ASSIGN_OR_RETURN(keys_iterator_, keys_.NewIterator()); + ABSL_CHECK(keys_iterator_->HasNext()); // Crash OK + return absl::OkStatus(); + } + + const CustomMapValueDispatcher* absl_nonnull const dispatcher_; + const CustomMapValueContent content_; + ListValue keys_; + absl_nullable ValueIteratorPtr keys_iterator_; +}; + +} // namespace + +absl::Status CustomMapValueInterface::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + StructReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor_pool)); + const google::protobuf::Message* prototype = + message_factory->GetPrototype(reflection.GetDescriptor()); + if (prototype == nullptr) { + return absl::UnknownError( + absl::StrCat("failed to get message prototype: ", + reflection.GetDescriptor()->full_name())); + } + google::protobuf::Arena arena; + google::protobuf::Message* message = prototype->New(&arena); + CEL_RETURN_IF_ERROR( + ConvertToJsonObject(descriptor_pool, message_factory, message)); + if (!message->SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Struct"); + } + return absl::OkStatus(); +} + +absl::Status CustomMapValueInterface::ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + CEL_ASSIGN_OR_RETURN(auto iterator, NewIterator()); + while (iterator->HasNext()) { + Value key; + Value value; + CEL_RETURN_IF_ERROR( + iterator->Next(descriptor_pool, message_factory, arena, &key)); + CEL_ASSIGN_OR_RETURN( + bool found, Find(key, descriptor_pool, message_factory, arena, &value)); + if (!found) { + value = ErrorValue(NoSuchKeyError(key)); + } + CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +absl::StatusOr +CustomMapValueInterface::NewIterator() const { + return std::make_unique(this); +} + +absl::Status CustomMapValueInterface::Equal( + const MapValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return MapValueEqual(*this, other, descriptor_pool, message_factory, arena, + result); +} + +CustomMapValue::CustomMapValue() { + content_ = CustomMapValueContent::From(CustomMapValueInterface::Content{ + .interface = &EmptyMapValue::Get(), .arena = nullptr}); +} + +NativeTypeId CustomMapValue::GetTypeId() const { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetNativeTypeId(); + } + return dispatcher_->get_type_id(dispatcher_, content_); +} + +absl::string_view CustomMapValue::GetTypeName() const { return "map"; } + +std::string CustomMapValue::DebugString() const { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->DebugString(); + } + if (dispatcher_->debug_string != nullptr) { + return dispatcher_->debug_string(dispatcher_, content_); + } + return "map"; +} + +absl::Status CustomMapValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->SerializeTo(descriptor_pool, message_factory, + output); + } + if (dispatcher_->serialize_to != nullptr) { + return dispatcher_->serialize_to(dispatcher_, content_, descriptor_pool, + message_factory, output); + } + return absl::UnimplementedError( + absl::StrCat(GetTypeName(), " is unserializable")); +} + +absl::Status CustomMapValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + google::protobuf::Message* json_object = value_reflection.MutableStructValue(json); + + return ConvertToJsonObject(descriptor_pool, message_factory, json_object); +} + +absl::Status CustomMapValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->ConvertToJsonObject(descriptor_pool, + message_factory, json); + } + if (dispatcher_->convert_to_json_object != nullptr) { + return dispatcher_->convert_to_json_object( + dispatcher_, content_, descriptor_pool, message_factory, json); + } + return absl::UnimplementedError( + absl::StrCat(GetTypeName(), " is not convertable to JSON")); +} + +absl::Status CustomMapValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_map_value = other.AsMap(); other_map_value) { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Equal(*other_map_value, descriptor_pool, + message_factory, arena, result); + } + if (dispatcher_->equal != nullptr) { + return dispatcher_->equal(dispatcher_, content_, *other_map_value, + descriptor_pool, message_factory, arena, + result); + } + return common_internal::MapValueEqual(*this, *other_map_value, + descriptor_pool, message_factory, + arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +bool CustomMapValue::IsZeroValue() const { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->IsZeroValue(); + } + return dispatcher_->is_zero_value(dispatcher_, content_); +} + +CustomMapValue CustomMapValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + if (content.arena != arena) { + return content.interface->Clone(arena); + } + return *this; + } + return dispatcher_->clone(dispatcher_, content_, arena); +} + +bool CustomMapValue::IsEmpty() const { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->IsEmpty(); + } + if (dispatcher_->is_empty != nullptr) { + return dispatcher_->is_empty(dispatcher_, content_); + } + return dispatcher_->size(dispatcher_, content_) == 0; +} + +size_t CustomMapValue::Size() const { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Size(); + } + return dispatcher_->size(dispatcher_, content_); +} + +absl::Status CustomMapValue::Get( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + CEL_ASSIGN_OR_RETURN( + bool ok, Find(key, descriptor_pool, message_factory, arena, result)); + if (ABSL_PREDICT_FALSE(!ok)) { + switch (result->kind()) { + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + break; + default: + *result = ErrorValue(NoSuchKeyError(key)); + break; + } + } + return absl::OkStatus(); +} + +absl::StatusOr CustomMapValue::Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + switch (key.kind()) { + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + *result = key; + return false; + case ValueKind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kString: + break; + default: + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return false; + } + + bool ok; + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + CEL_ASSIGN_OR_RETURN( + ok, content.interface->Find(key, descriptor_pool, message_factory, + arena, result)); + } else { + CEL_ASSIGN_OR_RETURN( + ok, dispatcher_->find(dispatcher_, content_, key, descriptor_pool, + message_factory, arena, result)); + } + if (ok) { + return true; + } + *result = NullValue{}; + return false; +} + +absl::Status CustomMapValue::Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + switch (key.kind()) { + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + *result = key; + return absl::OkStatus(); + case ValueKind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kString: + break; + default: + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); + } + bool has; + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + CEL_ASSIGN_OR_RETURN(has, content.interface->Has(key, descriptor_pool, + message_factory, arena)); + } else { + CEL_ASSIGN_OR_RETURN( + has, dispatcher_->has(dispatcher_, content_, key, descriptor_pool, + message_factory, arena)); + } + *result = BoolValue(has); + return absl::OkStatus(); +} + +absl::Status CustomMapValue::ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->ListKeys(descriptor_pool, message_factory, arena, + result); + } + return dispatcher_->list_keys(dispatcher_, content_, descriptor_pool, + message_factory, arena, result); +} + +absl::Status CustomMapValue::ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->ForEach(callback, descriptor_pool, + message_factory, arena); + } + if (dispatcher_->for_each != nullptr) { + return dispatcher_->for_each(dispatcher_, content_, callback, + descriptor_pool, message_factory, arena); + } + absl_nonnull ValueIteratorPtr iterator; + if (dispatcher_->new_iterator != nullptr) { + CEL_ASSIGN_OR_RETURN(iterator, + dispatcher_->new_iterator(dispatcher_, content_)); + } else { + iterator = std::make_unique(dispatcher_, + content_); + } + while (iterator->HasNext()) { + Value key; + Value value; + CEL_RETURN_IF_ERROR( + iterator->Next(descriptor_pool, message_factory, arena, &key)); + CEL_ASSIGN_OR_RETURN( + bool found, + dispatcher_->find(dispatcher_, content_, key, descriptor_pool, + message_factory, arena, &value)); + if (!found) { + value = ErrorValue(NoSuchKeyError(key)); + } + CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +absl::StatusOr CustomMapValue::NewIterator() + const { + if (dispatcher_ == nullptr) { + CustomMapValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->NewIterator(); + } + if (dispatcher_->new_iterator != nullptr) { + return dispatcher_->new_iterator(dispatcher_, content_); + } + return std::make_unique(dispatcher_, + content_); +} + +} // namespace cel diff --git a/common/values/custom_map_value.h b/common/values/custom_map_value.h new file mode 100644 index 000000000..ca6e1e025 --- /dev/null +++ b/common/values/custom_map_value.h @@ -0,0 +1,469 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +// `CustomMapValue` represents values of the primitive `map` type. +// `CustomMapValueView` is a non-owning view of `CustomMapValue`. +// `CustomMapValueInterface` is the abstract base class of implementations. +// `CustomMapValue` and `CustomMapValueView` act as smart pointers to +// `CustomMapValueInterface`. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/function_ref.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/native_type.h" +#include "common/value_kind.h" +#include "common/values/custom_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class ListValue; +class CustomMapValueInterface; +class CustomMapValueInterfaceKeysIterator; +class CustomMapValue; +using CustomMapValueContent = CustomValueContent; + +struct CustomMapValueDispatcher { + using GetTypeId = + NativeTypeId (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content); + + using GetArena = google::protobuf::Arena* absl_nullable (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content); + + using DebugString = + std::string (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content); + + using SerializeTo = absl::Status (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output); + + using ConvertToJsonObject = absl::Status (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json); + + using Equal = absl::Status (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, const MapValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using IsZeroValue = + bool (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content); + + using IsEmpty = + bool (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content); + + using Size = + size_t (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content); + + using Find = absl::StatusOr (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using Has = absl::StatusOr (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + + using ListKeys = absl::Status (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result); + + using ForEach = absl::Status (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + absl::FunctionRef(const Value&, const Value&)> + callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + + using NewIterator = absl::StatusOr (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content); + + using Clone = CustomMapValue (*)( + const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, google::protobuf::Arena* absl_nonnull arena); + + absl_nonnull GetTypeId get_type_id; + + absl_nonnull GetArena get_arena; + + // If null, simply returns "map". + absl_nullable DebugString debug_string = nullptr; + + // If null, attempts to serialize results in an UNIMPLEMENTED error. + absl_nullable SerializeTo serialize_to = nullptr; + + // If null, attempts to convert to JSON results in an UNIMPLEMENTED error. + absl_nullable ConvertToJsonObject convert_to_json_object = nullptr; + + // If null, an nonoptimal fallback implementation for equality is used. + absl_nullable Equal equal = nullptr; + + absl_nonnull IsZeroValue is_zero_value; + + // If null, `size(...) == 0` is used. + absl_nullable IsEmpty is_empty = nullptr; + + absl_nonnull Size size; + + absl_nonnull Find find; + + absl_nonnull Has has; + + absl_nonnull ListKeys list_keys; + + // If null, a fallback implementation based on `list_keys` is used. + absl_nullable ForEach for_each = nullptr; + + // If null, a fallback implementation based on `list_keys` is used. + absl_nullable NewIterator new_iterator = nullptr; + + absl_nonnull Clone clone; +}; + +class CustomMapValueInterface { + public: + CustomMapValueInterface() = default; + CustomMapValueInterface(const CustomMapValueInterface&) = delete; + CustomMapValueInterface(CustomMapValueInterface&&) = delete; + + virtual ~CustomMapValueInterface() = default; + + CustomMapValueInterface& operator=(const CustomMapValueInterface&) = delete; + CustomMapValueInterface& operator=(CustomMapValueInterface&&) = delete; + + using ForEachCallback = + absl::FunctionRef(const Value&, const Value&)>; + + private: + friend class CustomMapValueInterfaceIterator; + friend class CustomMapValue; + friend absl::Status common_internal::MapValueEqual( + const CustomMapValueInterface& lhs, const MapValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + virtual std::string DebugString() const = 0; + + virtual absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + virtual absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const = 0; + + virtual absl::Status Equal( + const MapValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + + virtual bool IsZeroValue() const { return IsEmpty(); } + + // Returns `true` if this map contains no entries, `false` otherwise. + virtual bool IsEmpty() const { return Size() == 0; } + + // Returns the number of entries in this map. + virtual size_t Size() const = 0; + + // See the corresponding member function of `MapValue` for + // documentation. + virtual absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + ListValue* absl_nonnull result) const = 0; + + // See the corresponding member function of `MapValue` for + // documentation. + virtual absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + // By default, implementations do not guarantee any iteration order. Unless + // specified otherwise, assume the iteration order is random. + virtual absl::StatusOr NewIterator() const; + + virtual CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; + + virtual absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; + + virtual absl::StatusOr Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const = 0; + + virtual NativeTypeId GetNativeTypeId() const = 0; + + struct Content { + const CustomMapValueInterface* absl_nonnull interface; + google::protobuf::Arena* absl_nullable arena; + }; +}; + +// Creates a custom map value from a manual dispatch table `dispatcher` and +// opaque data `content` whose format is only know to functions in the manual +// dispatch table. The dispatch table should probably be valid for the lifetime +// of the process, but at a minimum must outlive all instances of the resulting +// value. +// +// IMPORTANT: This approach to implementing CustomMapValue should only be +// used when you know exactly what you are doing. When in doubt, just implement +// CustomMapValueInterface. +CustomMapValue UnsafeCustomMapValue(const CustomMapValueDispatcher* absl_nonnull + dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomMapValueContent content); + +class CustomMapValue final + : private common_internal::MapValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kMap; + + // Constructs a custom map value from an implementation of + // `CustomMapValueInterface` `interface` whose lifetime is tied to that of + // the arena `arena`. + CustomMapValue(const CustomMapValueInterface* absl_nonnull + interface ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(interface != nullptr); + ABSL_DCHECK(arena != nullptr); + content_ = CustomMapValueContent::From(CustomMapValueInterface::Content{ + .interface = interface, .arena = arena}); + } + + // By default, this creates an empty map whose type is `map(dyn, dyn)`. Unless + // you can help it, you should use a more specific typed map value. + CustomMapValue(); + CustomMapValue(const CustomMapValue&) = default; + CustomMapValue(CustomMapValue&&) = default; + CustomMapValue& operator=(const CustomMapValue&) = default; + CustomMapValue& operator=(CustomMapValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + NativeTypeId GetTypeId() const; + + absl::string_view GetTypeName() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Equal; + + bool IsZeroValue() const; + + CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + bool IsEmpty() const; + + size_t Size() const; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Get(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Get; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using MapValueMixin::Find; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Has(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Has; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; + using MapValueMixin::ListKeys; + + // See the corresponding type declaration of `MapValueInterface` for + // documentation. + using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::StatusOr NewIterator() const; + + const CustomMapValueDispatcher* absl_nullable dispatcher() const { + return dispatcher_; + } + + CustomMapValueContent content() const { + ABSL_DCHECK(dispatcher_ != nullptr); + return content_; + } + + const CustomMapValueInterface* absl_nullable interface() const { + if (dispatcher_ == nullptr) { + return content_.To().interface; + } + return nullptr; + } + + friend void swap(CustomMapValue& lhs, CustomMapValue& rhs) noexcept { + using std::swap; + swap(lhs.dispatcher_, rhs.dispatcher_); + swap(lhs.content_, rhs.content_); + } + + private: + friend class common_internal::ValueMixin; + friend class common_internal::MapValueMixin; + friend CustomMapValue UnsafeCustomMapValue( + const CustomMapValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomMapValueContent content); + + CustomMapValue(const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content) + : dispatcher_(dispatcher), content_(content) { + ABSL_DCHECK(dispatcher != nullptr); + ABSL_DCHECK(dispatcher->get_type_id != nullptr); + ABSL_DCHECK(dispatcher->get_arena != nullptr); + ABSL_DCHECK(dispatcher->is_zero_value != nullptr); + ABSL_DCHECK(dispatcher->size != nullptr); + ABSL_DCHECK(dispatcher->find != nullptr); + ABSL_DCHECK(dispatcher->has != nullptr); + ABSL_DCHECK(dispatcher->list_keys != nullptr); + ABSL_DCHECK(dispatcher->clone != nullptr); + } + + const CustomMapValueDispatcher* absl_nullable dispatcher_ = nullptr; + CustomMapValueContent content_ = CustomMapValueContent::Zero(); +}; + +inline std::ostream& operator<<(std::ostream& out, const CustomMapValue& type) { + return out << type.DebugString(); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const CustomMapValue& type) { + return type.GetTypeId(); + } +}; + +inline CustomMapValue UnsafeCustomMapValue( + const CustomMapValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomMapValueContent content) { + return CustomMapValue(dispatcher, content); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_VALUE_H_ diff --git a/common/values/custom_map_value_test.cc b/common/values/custom_map_value_test.cc new file mode 100644 index 000000000..8c3183cf8 --- /dev/null +++ b/common/values/custom_map_value_test.cc @@ -0,0 +1,642 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "common/values/list_value_builder.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::StringValueIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::IsNull; +using ::testing::Not; +using ::testing::NotNull; +using ::testing::Optional; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class CustomMapValueTest; + +struct CustomMapValueTestContent { + google::protobuf::Arena* absl_nonnull arena; +}; + +class CustomMapValueInterfaceTest final : public CustomMapValueInterface { + public: + std::string DebugString() const override { + return "{\"foo\": true, \"bar\": 1}"; + } + + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const override { + google::protobuf::Value json; + google::protobuf::ListValue* json_array = json.mutable_list_value(); + json_array->add_values()->set_bool_value(true); + json_array->add_values()->set_number_value(1.0); + if (!json.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Value"); + } + return absl::OkStatus(); + } + + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + google::protobuf::Struct json_object; + (*json_object.mutable_fields())["foo"].set_bool_value(true); + (*json_object.mutable_fields())["bar"].set_number_value(1.0); + absl::Cord serialized; + if (!json_object.SerializePartialToString(&serialized)) { + return absl::UnknownError("failed to serialize google.protobuf.Struct"); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError("failed to parse google.protobuf.Struct"); + } + return absl::OkStatus(); + } + + size_t Size() const override { return 2; } + + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + ListValue* absl_nonnull result) const override { + auto builder = common_internal::NewListValueBuilder(arena); + builder->Reserve(2); + CEL_RETURN_IF_ERROR(builder->Add(StringValue("foo"))); + CEL_RETURN_IF_ERROR(builder->Add(StringValue("bar"))); + *result = std::move(*builder).Build(); + return absl::OkStatus(); + } + + CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + return CustomMapValue( + (::new (arena->AllocateAligned(sizeof(CustomMapValueInterfaceTest), + alignof(CustomMapValueInterfaceTest))) + CustomMapValueInterfaceTest()), + arena); + } + + private: + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + if (auto string_key = key.AsString(); string_key) { + if (*string_key == "foo") { + *result = TrueValue(); + return true; + } + if (*string_key == "bar") { + *result = IntValue(1); + return true; + } + } + return false; + } + + absl::StatusOr Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + if (auto string_key = key.AsString(); string_key) { + if (*string_key == "foo") { + return true; + } + if (*string_key == "bar") { + return true; + } + } + return false; + } + + NativeTypeId GetNativeTypeId() const override { + return NativeTypeId::For(); + } +}; + +class CustomMapValueTest : public common_internal::ValueTest<> { + public: + CustomMapValue MakeInterface() { + return CustomMapValue( + (::new (arena()->AllocateAligned(sizeof(CustomMapValueInterfaceTest), + alignof(CustomMapValueInterfaceTest))) + CustomMapValueInterfaceTest()), + arena()); + } + + CustomMapValue MakeDispatcher() { + return UnsafeCustomMapValue( + &test_dispatcher_, CustomValueContent::From( + CustomMapValueTestContent{.arena = arena()})); + } + + protected: + CustomMapValueDispatcher test_dispatcher_ = { + .get_type_id = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content) -> NativeTypeId { + return NativeTypeId::For(); + }, + .get_arena = + [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content) -> google::protobuf::Arena* absl_nullable { + return content.To().arena; + }, + .debug_string = + [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content) -> std::string { + return "{\"foo\": true, \"bar\": 1}"; + }, + .serialize_to = + [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) + -> absl::Status { + google::protobuf::Value json; + google::protobuf::Struct* json_object = json.mutable_struct_value(); + (*json_object->mutable_fields())["foo"].set_bool_value(true); + (*json_object->mutable_fields())["bar"].set_number_value(1.0); + if (!json.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Value"); + } + return absl::OkStatus(); + }, + .convert_to_json_object = + [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) -> absl::Status { + { + google::protobuf::Struct json_object; + (*json_object.mutable_fields())["foo"].set_bool_value(true); + (*json_object.mutable_fields())["bar"].set_number_value(1.0); + absl::Cord serialized; + if (!json_object.SerializePartialToString(&serialized)) { + return absl::UnknownError( + "failed to serialize google.protobuf.Struct"); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError("failed to parse google.protobuf.Struct"); + } + return absl::OkStatus(); + } + }, + .is_zero_value = + [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content) -> bool { return false; }, + .size = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content) -> size_t { return 2; }, + .find = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) -> absl::StatusOr { + if (auto string_key = key.AsString(); string_key) { + if (*string_key == "foo") { + *result = TrueValue(); + return true; + } + if (*string_key == "bar") { + *result = IntValue(1); + return true; + } + } + return false; + }, + .has = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + if (auto string_key = key.AsString(); string_key) { + if (*string_key == "foo") { + return true; + } + if (*string_key == "bar") { + return true; + } + } + return false; + }, + .list_keys = + [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + ListValue* absl_nonnull result) -> absl::Status { + auto builder = common_internal::NewListValueBuilder(arena); + builder->Reserve(2); + CEL_RETURN_IF_ERROR(builder->Add(StringValue("foo"))); + CEL_RETURN_IF_ERROR(builder->Add(StringValue("bar"))); + *result = std::move(*builder).Build(); + return absl::OkStatus(); + }, + .clone = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, + CustomMapValueContent content, + google::protobuf::Arena* absl_nonnull arena) -> CustomMapValue { + return UnsafeCustomMapValue( + dispatcher, CustomValueContent::From( + CustomMapValueTestContent{.arena = arena})); + }, + }; +}; + +TEST_F(CustomMapValueTest, Kind) { + EXPECT_EQ(CustomMapValue::kind(), CustomMapValue::kKind); +} + +TEST_F(CustomMapValueTest, Dispatcher_GetTypeId) { + EXPECT_EQ(MakeDispatcher().GetTypeId(), + NativeTypeId::For()); +} + +TEST_F(CustomMapValueTest, Interface_GetTypeId) { + EXPECT_EQ(MakeInterface().GetTypeId(), + NativeTypeId::For()); +} + +TEST_F(CustomMapValueTest, Dispatcher_GetTypeName) { + EXPECT_EQ(MakeDispatcher().GetTypeName(), "map"); +} + +TEST_F(CustomMapValueTest, Interface_GetTypeName) { + EXPECT_EQ(MakeInterface().GetTypeName(), "map"); +} + +TEST_F(CustomMapValueTest, Dispatcher_DebugString) { + EXPECT_EQ(MakeDispatcher().DebugString(), "{\"foo\": true, \"bar\": 1}"); +} + +TEST_F(CustomMapValueTest, Interface_DebugString) { + EXPECT_EQ(MakeInterface().DebugString(), "{\"foo\": true, \"bar\": 1}"); +} + +TEST_F(CustomMapValueTest, Dispatcher_IsZeroValue) { + EXPECT_FALSE(MakeDispatcher().IsZeroValue()); +} + +TEST_F(CustomMapValueTest, Interface_IsZeroValue) { + EXPECT_FALSE(MakeInterface().IsZeroValue()); +} + +TEST_F(CustomMapValueTest, Dispatcher_SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(MakeDispatcher().SerializeTo(descriptor_pool(), message_factory(), + &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); +} + +TEST_F(CustomMapValueTest, Interface_SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(MakeInterface().SerializeTo(descriptor_pool(), message_factory(), + &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); +} + +TEST_F(CustomMapValueTest, Dispatcher_ConvertToJson) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeDispatcher().ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + struct_value: { + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + } + )pb")); +} + +TEST_F(CustomMapValueTest, Interface_ConvertToJson) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeInterface().ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + struct_value: { + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + } + )pb")); +} + +TEST_F(CustomMapValueTest, Dispatcher_ConvertToJsonObject) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeDispatcher().ConvertToJsonObject(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + )pb")); +} + +TEST_F(CustomMapValueTest, Interface_ConvertToJsonObject) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeInterface().ConvertToJsonObject(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + )pb")); +} + +TEST_F(CustomMapValueTest, Dispatcher_IsEmpty) { + EXPECT_FALSE(MakeDispatcher().IsEmpty()); +} + +TEST_F(CustomMapValueTest, Interface_IsEmpty) { + EXPECT_FALSE(MakeInterface().IsEmpty()); +} + +TEST_F(CustomMapValueTest, Dispatcher_Size) { + EXPECT_EQ(MakeDispatcher().Size(), 2); +} + +TEST_F(CustomMapValueTest, Interface_Size) { + EXPECT_EQ(MakeInterface().Size(), 2); +} + +TEST_F(CustomMapValueTest, Dispatcher_Get) { + CustomMapValue map = MakeDispatcher(); + ASSERT_THAT(map.Get(StringValue("foo"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(map.Get(StringValue("bar"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(IntValueIs(1))); + ASSERT_THAT( + map.Get(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); +} + +TEST_F(CustomMapValueTest, Interface_Get) { + CustomMapValue map = MakeInterface(); + ASSERT_THAT(map.Get(StringValue("foo"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(map.Get(StringValue("bar"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(IntValueIs(1))); + ASSERT_THAT( + map.Get(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); +} + +TEST_F(CustomMapValueTest, Dispatcher_Find) { + CustomMapValue map = MakeDispatcher(); + ASSERT_THAT(map.Find(StringValue("foo"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + ASSERT_THAT(map.Find(StringValue("bar"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Optional(IntValueIs(1)))); + ASSERT_THAT(map.Find(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomMapValueTest, Interface_Find) { + CustomMapValue map = MakeInterface(); + ASSERT_THAT(map.Find(StringValue("foo"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + ASSERT_THAT(map.Find(StringValue("bar"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Optional(IntValueIs(1)))); + ASSERT_THAT(map.Find(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomMapValueTest, Dispatcher_Has) { + CustomMapValue map = MakeDispatcher(); + ASSERT_THAT(map.Has(StringValue("foo"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(map.Has(StringValue("bar"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(map.Has(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(CustomMapValueTest, Interface_Has) { + CustomMapValue map = MakeInterface(); + ASSERT_THAT(map.Has(StringValue("foo"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(map.Has(StringValue("bar"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_THAT(map.Has(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(CustomMapValueTest, Dispatcher_ForEach) { + std::vector> entries; + EXPECT_THAT( + MakeDispatcher().ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{key, value}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, UnorderedElementsAre( + Pair(StringValueIs("foo"), BoolValueIs(true)), + Pair(StringValueIs("bar"), IntValueIs(1)))); +} + +TEST_F(CustomMapValueTest, Interface_ForEach) { + std::vector> entries; + EXPECT_THAT( + MakeInterface().ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{key, value}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, UnorderedElementsAre( + Pair(StringValueIs("foo"), BoolValueIs(true)), + Pair(StringValueIs("bar"), IntValueIs(1)))); +} + +TEST_F(CustomMapValueTest, Dispatcher_NewIterator) { + CustomMapValue map = MakeDispatcher(); + ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(StringValueIs("foo"))); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(StringValueIs("bar"))); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(CustomMapValueTest, Interface_NewIterator) { + CustomMapValue map = MakeInterface(); + ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(StringValueIs("foo"))); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(StringValueIs("bar"))); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(CustomMapValueTest, Dispatcher_NewIterator1) { + CustomMapValue map = MakeDispatcher(); + ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(StringValueIs("foo")))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(StringValueIs("bar")))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomMapValueTest, Interface_NewIterator1) { + CustomMapValue map = MakeInterface(); + ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(StringValueIs("foo")))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(StringValueIs("bar")))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomMapValueTest, Dispatcher_NewIterator2) { + CustomMapValue map = MakeDispatcher(); + ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); + EXPECT_THAT( + iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(StringValueIs("foo"), BoolValueIs(true))))); + EXPECT_THAT( + iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(StringValueIs("bar"), IntValueIs(1))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomMapValueTest, Interface_NewIterator2) { + CustomMapValue map = MakeInterface(); + ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); + EXPECT_THAT( + iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(StringValueIs("foo"), BoolValueIs(true))))); + EXPECT_THAT( + iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(StringValueIs("bar"), IntValueIs(1))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(CustomMapValueTest, Dispatcher) { + EXPECT_THAT(MakeDispatcher().dispatcher(), NotNull()); + EXPECT_THAT(MakeDispatcher().interface(), IsNull()); +} + +TEST_F(CustomMapValueTest, Interface) { + EXPECT_THAT(MakeInterface().dispatcher(), IsNull()); + EXPECT_THAT(MakeInterface().interface(), NotNull()); +} + +} // namespace +} // namespace cel diff --git a/common/values/custom_struct_value.cc b/common/values/custom_struct_value.cc new file mode 100644 index 000000000..0999cb80e --- /dev/null +++ b/common/values/custom_struct_value.cc @@ -0,0 +1,385 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/function_ref.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "common/values/values.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +} // namespace + +absl::Status CustomStructValueInterface::Equal( + const StructValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return common_internal::StructValueEqual(*this, other, descriptor_pool, + message_factory, arena, result); +} + +absl::Status CustomStructValueInterface::Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const { + return absl::UnimplementedError(absl::StrCat( + GetTypeName(), " does not implement field selection optimization")); +} + +NativeTypeId CustomStructValue::GetTypeId() const { + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + if (content.interface == nullptr) { + return NativeTypeId(); + } + return content.interface->GetNativeTypeId(); + } + return dispatcher_->get_type_id(dispatcher_, content_); +} + +StructType CustomStructValue::GetRuntimeType() const { + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetRuntimeType(); + } + if (dispatcher_->get_runtime_type != nullptr) { + return dispatcher_->get_runtime_type(dispatcher_, content_); + } + return common_internal::MakeBasicStructType(GetTypeName()); +} + +absl::string_view CustomStructValue::GetTypeName() const { + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetTypeName(); + } + return dispatcher_->get_type_name(dispatcher_, content_); +} + +std::string CustomStructValue::DebugString() const { + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->DebugString(); + } + if (dispatcher_->debug_string != nullptr) { + return dispatcher_->debug_string(dispatcher_, content_); + } + return std::string(GetTypeName()); +} + +absl::Status CustomStructValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->SerializeTo(descriptor_pool, message_factory, + output); + } + if (dispatcher_->serialize_to != nullptr) { + return dispatcher_->serialize_to(dispatcher_, content_, descriptor_pool, + message_factory, output); + } + return absl::UnimplementedError( + absl::StrCat(GetTypeName(), " is unserializable")); +} + +absl::Status CustomStructValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + ABSL_DCHECK(*this); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + google::protobuf::Message* json_object = value_reflection.MutableStructValue(json); + + return ConvertToJsonObject(descriptor_pool, message_factory, json_object); +} + +absl::Status CustomStructValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + if (ABSL_PREDICT_FALSE(content.interface == nullptr)) { + json->Clear(); + return absl::OkStatus(); + } + return content.interface->ConvertToJsonObject(descriptor_pool, + message_factory, json); + } + if (dispatcher_->convert_to_json_object != nullptr) { + return dispatcher_->convert_to_json_object( + dispatcher_, content_, descriptor_pool, message_factory, json); + } + return absl::UnimplementedError( + absl::StrCat(GetTypeName(), " is not convertable to JSON")); +} + +absl::Status CustomStructValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(*this); + + if (auto other_struct_value = other.AsStruct(); other_struct_value) { + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Equal(*other_struct_value, descriptor_pool, + message_factory, arena, result); + } + if (dispatcher_->equal != nullptr) { + return dispatcher_->equal(dispatcher_, content_, *other_struct_value, + descriptor_pool, message_factory, arena, + result); + } + return common_internal::StructValueEqual(*this, *other_struct_value, + descriptor_pool, message_factory, + arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +bool CustomStructValue::IsZeroValue() const { + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + if (content.interface == nullptr) { + return true; + } + return content.interface->IsZeroValue(); + } + return dispatcher_->is_zero_value(dispatcher_, content_); +} + +CustomStructValue CustomStructValue::Clone( + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + if (content.interface == nullptr) { + return *this; + } + if (content.arena != arena) { + return content.interface->Clone(arena); + } + return *this; + } + return dispatcher_->clone(dispatcher_, content_, arena); +} + +absl::Status CustomStructValue::GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetFieldByName(name, unboxing_options, + descriptor_pool, message_factory, + arena, result); + } + return dispatcher_->get_field_by_name(dispatcher_, content_, name, + unboxing_options, descriptor_pool, + message_factory, arena, result); +} + +absl::Status CustomStructValue::GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetFieldByNumber(number, unboxing_options, + descriptor_pool, message_factory, + arena, result); + } + if (dispatcher_->get_field_by_number != nullptr) { + return dispatcher_->get_field_by_number(dispatcher_, content_, number, + unboxing_options, descriptor_pool, + message_factory, arena, result); + } + return absl::UnimplementedError(absl::StrCat( + GetTypeName(), " does not implement access by field number")); +} + +absl::StatusOr CustomStructValue::HasFieldByName( + absl::string_view name) const { + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->HasFieldByName(name); + } + return dispatcher_->has_field_by_name(dispatcher_, content_, name); +} + +absl::StatusOr CustomStructValue::HasFieldByNumber(int64_t number) const { + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->HasFieldByNumber(number); + } + if (dispatcher_->has_field_by_number != nullptr) { + return dispatcher_->has_field_by_number(dispatcher_, content_, number); + } + return absl::UnimplementedError(absl::StrCat( + GetTypeName(), " does not implement access by field number")); +} + +absl::Status CustomStructValue::ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->ForEachField(callback, descriptor_pool, + message_factory, arena); + } + return dispatcher_->for_each_field(dispatcher_, content_, callback, + descriptor_pool, message_factory, arena); +} + +absl::Status CustomStructValue::Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const { + ABSL_DCHECK_GT(qualifiers.size(), 0); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(count != nullptr); + ABSL_DCHECK(*this); + + if (dispatcher_ == nullptr) { + CustomStructValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Qualify(qualifiers, presence_test, + descriptor_pool, message_factory, arena, + result, count); + } + if (dispatcher_->qualify != nullptr) { + return dispatcher_->qualify(dispatcher_, content_, qualifiers, + presence_test, descriptor_pool, message_factory, + arena, result, count); + } + return absl::UnimplementedError(absl::StrCat( + GetTypeName(), " does not implement field selection optimization")); +} + +} // namespace cel diff --git a/common/values/custom_struct_value.h b/common/values/custom_struct_value.h new file mode 100644 index 000000000..6ffd153f8 --- /dev/null +++ b/common/values/custom_struct_value.h @@ -0,0 +1,459 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_STRUCT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_STRUCT_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/function_ref.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_value.h" +#include "common/values/values.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class CustomStructValueInterface; +class CustomStructValue; +class Value; +struct CustomStructValueDispatcher; +using CustomStructValueContent = CustomValueContent; + +struct CustomStructValueDispatcher { + using GetTypeId = NativeTypeId (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content); + + using GetArena = google::protobuf::Arena* absl_nullable (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content); + + using GetTypeName = absl::string_view (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content); + + using DebugString = std::string (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content); + + using GetRuntimeType = + StructType (*)(const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content); + + using SerializeTo = absl::Status (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output); + + using ConvertToJsonObject = absl::Status (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json); + + using Equal = absl::Status (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, const StructValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using IsZeroValue = + bool (*)(const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content); + + using GetFieldByName = absl::Status (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, absl::string_view name, + ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using GetFieldByNumber = absl::Status (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, int64_t number, + ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using HasFieldByName = absl::StatusOr (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, absl::string_view name); + + using HasFieldByNumber = absl::StatusOr (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, int64_t number); + + using ForEachField = absl::Status (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + absl::FunctionRef(absl::string_view, const Value&)> + callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + + using Quality = absl::Status (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count); + + using Clone = CustomStructValue (*)( + const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, google::protobuf::Arena* absl_nonnull arena); + + absl_nonnull GetTypeId get_type_id; + + absl_nonnull GetArena get_arena; + + absl_nonnull GetTypeName get_type_name; + + absl_nullable DebugString debug_string = nullptr; + + absl_nullable GetRuntimeType get_runtime_type = nullptr; + + absl_nullable SerializeTo serialize_to = nullptr; + + absl_nullable ConvertToJsonObject convert_to_json_object = nullptr; + + absl_nullable Equal equal = nullptr; + + absl_nonnull IsZeroValue is_zero_value; + + absl_nonnull GetFieldByName get_field_by_name; + + absl_nullable GetFieldByNumber get_field_by_number = nullptr; + + absl_nonnull HasFieldByName has_field_by_name; + + absl_nullable HasFieldByNumber has_field_by_number = nullptr; + + absl_nonnull ForEachField for_each_field; + + absl_nullable Quality qualify = nullptr; + + absl_nonnull Clone clone; +}; + +class CustomStructValueInterface { + public: + CustomStructValueInterface() = default; + CustomStructValueInterface(const CustomStructValueInterface&) = delete; + CustomStructValueInterface(CustomStructValueInterface&&) = delete; + + virtual ~CustomStructValueInterface() = default; + + CustomStructValueInterface& operator=(const CustomStructValueInterface&) = + delete; + CustomStructValueInterface& operator=(CustomStructValueInterface&&) = delete; + + using ForEachFieldCallback = + absl::FunctionRef(absl::string_view, const Value&)>; + + private: + friend class CustomStructValue; + friend absl::Status common_internal::StructValueEqual( + const CustomStructValueInterface& lhs, const StructValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + virtual std::string DebugString() const = 0; + + virtual absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const = 0; + + virtual absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const = 0; + + virtual absl::string_view GetTypeName() const = 0; + + virtual StructType GetRuntimeType() const { + return common_internal::MakeBasicStructType(GetTypeName()); + } + + virtual absl::Status Equal( + const StructValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + + virtual bool IsZeroValue() const = 0; + + virtual absl::Status GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; + + virtual absl::Status GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; + + virtual absl::StatusOr HasFieldByName(absl::string_view name) const = 0; + + virtual absl::StatusOr HasFieldByNumber(int64_t number) const = 0; + + virtual absl::Status ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const = 0; + + virtual absl::Status Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const; + + virtual CustomStructValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; + + virtual NativeTypeId GetNativeTypeId() const = 0; + + struct Content { + const CustomStructValueInterface* absl_nonnull interface; + google::protobuf::Arena* absl_nonnull arena; + }; +}; + +// Creates a custom struct value from a manual dispatch table `dispatcher` and +// opaque data `content` whose format is only know to functions in the manual +// dispatch table. The dispatch table should probably be valid for the lifetime +// of the process, but at a minimum must outlive all instances of the resulting +// value. +// +// IMPORTANT: This approach to implementing CustomStructValues should only be +// used when you know exactly what you are doing. When in doubt, just implement +// CustomStructValueInterface. +CustomStructValue UnsafeCustomStructValue( + const CustomStructValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomStructValueContent content); + +class CustomStructValue final + : private common_internal::StructValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kStruct; + + // Constructs a custom struct value from an implementation of + // `CustomStructValueInterface` `interface` whose lifetime is tied to that of + // the arena `arena`. + CustomStructValue(const CustomStructValueInterface* absl_nonnull + interface ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(interface != nullptr); + ABSL_DCHECK(arena != nullptr); + content_ = + CustomStructValueContent::From(CustomStructValueInterface::Content{ + .interface = interface, .arena = arena}); + } + + CustomStructValue() = default; + CustomStructValue(const CustomStructValue&) = default; + CustomStructValue(CustomStructValue&&) = default; + CustomStructValue& operator=(const CustomStructValue&) = default; + CustomStructValue& operator=(CustomStructValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + NativeTypeId GetTypeId() const; + + StructType GetRuntimeType() const; + + absl::string_view GetTypeName() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using StructValueMixin::Equal; + + bool IsZeroValue() const; + + CustomStructValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByName; + + absl::Status GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByNumber; + + absl::StatusOr HasFieldByName(absl::string_view name) const; + + absl::StatusOr HasFieldByNumber(int64_t number) const; + + using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; + + absl::Status ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const; + using StructValueMixin::Qualify; + + const CustomStructValueDispatcher* absl_nullable dispatcher() const { + return dispatcher_; + } + + CustomStructValueContent content() const { + ABSL_DCHECK(dispatcher_ != nullptr); + return content_; + } + + const CustomStructValueInterface* absl_nullable interface() const { + if (dispatcher_ == nullptr) { + return content_.To().interface; + } + return nullptr; + } + + explicit operator bool() const { + if (dispatcher_ == nullptr) { + return content_.To().interface != + nullptr; + } + return true; + } + + friend void swap(CustomStructValue& lhs, CustomStructValue& rhs) noexcept { + using std::swap; + swap(lhs.dispatcher_, rhs.dispatcher_); + swap(lhs.content_, rhs.content_); + } + + private: + friend class common_internal::ValueMixin; + friend class common_internal::StructValueMixin; + friend CustomStructValue UnsafeCustomStructValue( + const CustomStructValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomStructValueContent content); + + // Constructs a custom struct value from a dispatcher and content. Only + // accessible from `UnsafeCustomStructValue`. + CustomStructValue(const CustomStructValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomStructValueContent content) + : dispatcher_(dispatcher), content_(content) { + ABSL_DCHECK(dispatcher != nullptr); + ABSL_DCHECK(dispatcher->get_type_id != nullptr); + ABSL_DCHECK(dispatcher->get_arena != nullptr); + ABSL_DCHECK(dispatcher->get_type_name != nullptr); + ABSL_DCHECK(dispatcher->is_zero_value != nullptr); + ABSL_DCHECK(dispatcher->get_field_by_name != nullptr); + ABSL_DCHECK(dispatcher->has_field_by_name != nullptr); + ABSL_DCHECK(dispatcher->for_each_field != nullptr); + ABSL_DCHECK(dispatcher->clone != nullptr); + } + + const CustomStructValueDispatcher* absl_nullable dispatcher_ = nullptr; + CustomStructValueContent content_ = CustomStructValueContent::Zero(); +}; + +inline std::ostream& operator<<(std::ostream& out, + const CustomStructValue& value) { + return out << value.DebugString(); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const CustomStructValue& type) { + return type.GetTypeId(); + } +}; + +inline CustomStructValue UnsafeCustomStructValue( + const CustomStructValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + CustomStructValueContent content) { + return CustomStructValue(dispatcher, content); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_STRUCT_VALUE_H_ diff --git a/common/values/custom_struct_value_test.cc b/common/values/custom_struct_value_test.cc new file mode 100644 index 000000000..32d867a4d --- /dev/null +++ b/common/values/custom_struct_value_test.cc @@ -0,0 +1,615 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "base/attribute.h" +#include "common/memory.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::IntValueIs; +using ::testing::IsEmpty; +using ::testing::IsNull; +using ::testing::Not; +using ::testing::NotNull; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class CustomStructValueTest; + +struct CustomStructValueTestContent { + google::protobuf::Arena* absl_nonnull arena; +}; + +class CustomStructValueInterfaceTest final : public CustomStructValueInterface { + public: + absl::string_view GetTypeName() const override { return "test.Interface"; } + + std::string DebugString() const override { + return std::string(GetTypeName()); + } + + bool IsZeroValue() const override { return false; } + + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const override { + google::protobuf::Value json; + google::protobuf::Struct* json_object = json.mutable_struct_value(); + (*json_object->mutable_fields())["foo"].set_bool_value(true); + (*json_object->mutable_fields())["bar"].set_number_value(1.0); + if (!json.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Value"); + } + return absl::OkStatus(); + } + + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + google::protobuf::Struct json_object; + (*json_object.mutable_fields())["foo"].set_bool_value(true); + (*json_object.mutable_fields())["bar"].set_number_value(1.0); + absl::Cord serialized; + if (!json_object.SerializePartialToString(&serialized)) { + return absl::UnknownError("failed to serialize google.protobuf.Struct"); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError("failed to parse google.protobuf.Struct"); + } + return absl::OkStatus(); + } + + absl::Status GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + if (name == "foo") { + *result = TrueValue(); + return absl::OkStatus(); + } + if (name == "bar") { + *result = IntValue(1); + return absl::OkStatus(); + } + return NoSuchFieldError(name).ToStatus(); + } + + absl::Status GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + if (number == 1) { + *result = TrueValue(); + return absl::OkStatus(); + } + if (number == 2) { + *result = IntValue(1); + return absl::OkStatus(); + } + return NoSuchFieldError(absl::StrCat(number)).ToStatus(); + } + + absl::StatusOr HasFieldByName(absl::string_view name) const override { + if (name == "foo") { + return true; + } + if (name == "bar") { + return true; + } + return NoSuchFieldError(name).ToStatus(); + } + + absl::StatusOr HasFieldByNumber(int64_t number) const override { + if (number == 1) { + return true; + } + if (number == 2) { + return true; + } + return NoSuchFieldError(absl::StrCat(number)).ToStatus(); + } + + absl::Status ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + CEL_ASSIGN_OR_RETURN(bool ok, callback("foo", TrueValue())); + if (!ok) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(ok, callback("bar", IntValue(1))); + return absl::OkStatus(); + } + + CustomStructValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + return CustomStructValue( + (::new (arena->AllocateAligned(sizeof(CustomStructValueInterfaceTest), + alignof(CustomStructValueInterfaceTest))) + CustomStructValueInterfaceTest()), + arena); + } + + private: + NativeTypeId GetNativeTypeId() const override { + return NativeTypeId::For(); + } +}; + +class CustomStructValueTest : public common_internal::ValueTest<> { + public: + CustomStructValue MakeInterface() { + return CustomStructValue((::new (arena()->AllocateAligned( + sizeof(CustomStructValueInterfaceTest), + alignof(CustomStructValueInterfaceTest))) + CustomStructValueInterfaceTest()), + arena()); + } + + CustomStructValue MakeDispatcher() { + return UnsafeCustomStructValue( + &test_dispatcher_, + CustomValueContent::From( + CustomStructValueTestContent{.arena = arena()})); + } + + protected: + CustomStructValueDispatcher test_dispatcher_ = { + .get_type_id = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content) -> NativeTypeId { + return NativeTypeId::For(); + }, + .get_arena = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content) -> google::protobuf::Arena* absl_nullable { + return content.To().arena; + }, + .get_type_name = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content) -> absl::string_view { + return "test.Dispatcher"; + }, + .debug_string = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content) -> std::string { + return "test.Dispatcher"; + }, + .get_runtime_type = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content) -> StructType { + return common_internal::MakeBasicStructType("test.Dispatcher"); + }, + .serialize_to = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) + -> absl::Status { + google::protobuf::Value json; + google::protobuf::Struct* json_object = json.mutable_struct_value(); + (*json_object->mutable_fields())["foo"].set_bool_value(true); + (*json_object->mutable_fields())["bar"].set_number_value(1.0); + if (!json.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Value"); + } + return absl::OkStatus(); + }, + .convert_to_json_object = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) -> absl::Status { + google::protobuf::Struct json_object; + (*json_object.mutable_fields())["foo"].set_bool_value(true); + (*json_object.mutable_fields())["bar"].set_number_value(1.0); + absl::Cord serialized; + if (!json_object.SerializePartialToString(&serialized)) { + return absl::UnknownError( + "failed to serialize google.protobuf.Struct"); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError("failed to parse google.protobuf.Struct"); + } + return absl::OkStatus(); + }, + .is_zero_value = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content) -> bool { return false; }, + .get_field_by_name = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, absl::string_view name, + ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) -> absl::Status { + if (name == "foo") { + *result = TrueValue(); + return absl::OkStatus(); + } + if (name == "bar") { + *result = IntValue(1); + return absl::OkStatus(); + } + return NoSuchFieldError(name).ToStatus(); + }, + .get_field_by_number = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, int64_t number, + ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) -> absl::Status { + if (number == 1) { + *result = TrueValue(); + return absl::OkStatus(); + } + if (number == 2) { + *result = IntValue(1); + return absl::OkStatus(); + } + return NoSuchFieldError(absl::StrCat(number)).ToStatus(); + }, + .has_field_by_name = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + absl::string_view name) -> absl::StatusOr { + if (name == "foo") { + return true; + } + if (name == "bar") { + return true; + } + return NoSuchFieldError(name).ToStatus(); + }, + .has_field_by_number = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + int64_t number) -> absl::StatusOr { + if (number == 1) { + return true; + } + if (number == 2) { + return true; + } + return NoSuchFieldError(absl::StrCat(number)).ToStatus(); + }, + .for_each_field = + [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + absl::FunctionRef(absl::string_view, + const Value&)> + callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::Status { + CEL_ASSIGN_OR_RETURN(bool ok, callback("foo", TrueValue())); + if (!ok) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(ok, callback("bar", IntValue(1))); + return absl::OkStatus(); + }, + .clone = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, + CustomStructValueContent content, + google::protobuf::Arena* absl_nonnull arena) -> CustomStructValue { + return UnsafeCustomStructValue( + dispatcher, CustomValueContent::From( + CustomStructValueTestContent{.arena = arena})); + }, + }; +}; + +TEST_F(CustomStructValueTest, Kind) { + EXPECT_EQ(CustomStructValue::kind(), CustomStructValue::kKind); +} + +TEST_F(CustomStructValueTest, Dispatcher_GetTypeId) { + EXPECT_EQ(MakeDispatcher().GetTypeId(), + NativeTypeId::For()); +} + +TEST_F(CustomStructValueTest, Interface_GetTypeId) { + EXPECT_EQ(MakeInterface().GetTypeId(), + NativeTypeId::For()); +} + +TEST_F(CustomStructValueTest, Dispatcher_GetTypeName) { + EXPECT_EQ(MakeDispatcher().GetTypeName(), "test.Dispatcher"); +} + +TEST_F(CustomStructValueTest, Interface_GetTypeName) { + EXPECT_EQ(MakeInterface().GetTypeName(), "test.Interface"); +} + +TEST_F(CustomStructValueTest, Dispatcher_DebugString) { + EXPECT_EQ(MakeDispatcher().DebugString(), "test.Dispatcher"); +} + +TEST_F(CustomStructValueTest, Interface_DebugString) { + EXPECT_EQ(MakeInterface().DebugString(), "test.Interface"); +} + +TEST_F(CustomStructValueTest, Dispatcher_GetRuntimeType) { + EXPECT_EQ(MakeDispatcher().GetRuntimeType(), + common_internal::MakeBasicStructType("test.Dispatcher")); +} + +TEST_F(CustomStructValueTest, Interface_GetRuntimeType) { + EXPECT_EQ(MakeInterface().GetRuntimeType(), + common_internal::MakeBasicStructType("test.Interface")); +} + +TEST_F(CustomStructValueTest, Dispatcher_IsZeroValue) { + EXPECT_FALSE(MakeDispatcher().IsZeroValue()); +} + +TEST_F(CustomStructValueTest, Interface_IsZeroValue) { + EXPECT_FALSE(MakeInterface().IsZeroValue()); +} + +TEST_F(CustomStructValueTest, Dispatcher_SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(MakeDispatcher().SerializeTo(descriptor_pool(), message_factory(), + &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); +} + +TEST_F(CustomStructValueTest, Interface_SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(MakeInterface().SerializeTo(descriptor_pool(), message_factory(), + &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); +} + +TEST_F(CustomStructValueTest, Dispatcher_ConvertToJson) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeDispatcher().ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + struct_value: { + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + } + )pb")); +} + +TEST_F(CustomStructValueTest, Interface_ConvertToJson) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeInterface().ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + struct_value: { + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + } + )pb")); +} + +TEST_F(CustomStructValueTest, Dispatcher_ConvertToJsonObject) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeDispatcher().ConvertToJsonObject(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + )pb")); +} + +TEST_F(CustomStructValueTest, Interface_ConvertToJsonObject) { + auto message = DynamicParseTextProto(); + EXPECT_THAT( + MakeInterface().ConvertToJsonObject(descriptor_pool(), message_factory(), + cel::to_address(message)), + IsOk()); + EXPECT_THAT(*message, EqualsTextProto(R"pb( + fields: { + key: "foo" + value: { bool_value: true } + } + fields: { + key: "bar" + value: { number_value: 1.0 } + } + )pb")); +} + +TEST_F(CustomStructValueTest, Dispatcher_GetFieldByName) { + EXPECT_THAT(MakeDispatcher().GetFieldByName("foo", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(MakeDispatcher().GetFieldByName("bar", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); +} + +TEST_F(CustomStructValueTest, Interface_GetFieldByName) { + EXPECT_THAT(MakeInterface().GetFieldByName("foo", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(MakeInterface().GetFieldByName("bar", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); +} + +TEST_F(CustomStructValueTest, Dispatcher_GetFieldByNumber) { + EXPECT_THAT(MakeDispatcher().GetFieldByNumber(1, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(MakeDispatcher().GetFieldByNumber(2, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); +} + +TEST_F(CustomStructValueTest, Interface_GetFieldByNumber) { + EXPECT_THAT(MakeInterface().GetFieldByNumber(1, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(MakeInterface().GetFieldByNumber(2, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); +} + +TEST_F(CustomStructValueTest, Dispatcher_HasFieldByName) { + EXPECT_THAT(MakeDispatcher().HasFieldByName("foo"), IsOkAndHolds(true)); + EXPECT_THAT(MakeDispatcher().HasFieldByName("bar"), IsOkAndHolds(true)); +} + +TEST_F(CustomStructValueTest, Interface_HasFieldByName) { + EXPECT_THAT(MakeInterface().HasFieldByName("foo"), IsOkAndHolds(true)); + EXPECT_THAT(MakeInterface().HasFieldByName("bar"), IsOkAndHolds(true)); +} + +TEST_F(CustomStructValueTest, Dispatcher_HasFieldByNumber) { + EXPECT_THAT(MakeDispatcher().HasFieldByNumber(1), IsOkAndHolds(true)); + EXPECT_THAT(MakeDispatcher().HasFieldByNumber(2), IsOkAndHolds(true)); +} + +TEST_F(CustomStructValueTest, Interface_HasFieldByNumber) { + EXPECT_THAT(MakeInterface().HasFieldByNumber(1), IsOkAndHolds(true)); + EXPECT_THAT(MakeInterface().HasFieldByNumber(2), IsOkAndHolds(true)); +} + +TEST_F(CustomStructValueTest, Default_Bool) { + EXPECT_FALSE(CustomStructValue()); +} + +TEST_F(CustomStructValueTest, Dispatcher_Bool) { + EXPECT_TRUE(MakeDispatcher()); +} + +TEST_F(CustomStructValueTest, Interface_Bool) { EXPECT_TRUE(MakeInterface()); } + +TEST_F(CustomStructValueTest, Dispatcher_ForEachField) { + std::vector> fields; + EXPECT_THAT(MakeDispatcher().ForEachField( + [&](absl::string_view name, + const Value& value) -> absl::StatusOr { + fields.push_back(std::pair{std::string(name), value}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(fields, UnorderedElementsAre(Pair("foo", BoolValueIs(true)), + Pair("bar", IntValueIs(1)))); +} + +TEST_F(CustomStructValueTest, Interface_ForEachField) { + std::vector> fields; + EXPECT_THAT(MakeInterface().ForEachField( + [&](absl::string_view name, + const Value& value) -> absl::StatusOr { + fields.push_back(std::pair{std::string(name), value}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(fields, UnorderedElementsAre(Pair("foo", BoolValueIs(true)), + Pair("bar", IntValueIs(1)))); +} + +TEST_F(CustomStructValueTest, Dispatcher_Qualify) { + EXPECT_THAT( + MakeDispatcher().Qualify({AttributeQualifier::OfString("foo")}, false, + descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST_F(CustomStructValueTest, Interface_Qualify) { + EXPECT_THAT( + MakeInterface().Qualify({AttributeQualifier::OfString("foo")}, false, + descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST_F(CustomStructValueTest, Dispatcher) { + EXPECT_THAT(MakeDispatcher().dispatcher(), NotNull()); + EXPECT_THAT(MakeDispatcher().interface(), IsNull()); +} + +TEST_F(CustomStructValueTest, Interface) { + EXPECT_THAT(MakeInterface().dispatcher(), IsNull()); + EXPECT_THAT(MakeInterface().interface(), NotNull()); +} + +} // namespace +} // namespace cel diff --git a/common/values/custom_value.h b/common/values/custom_value.h new file mode 100644 index 000000000..b549fe774 --- /dev/null +++ b/common/values/custom_value.h @@ -0,0 +1,84 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_CUSTOM_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_CUSTOM_VALUE_H_ + +#include +#include +#include +#include + +namespace cel { + +// CustomValueContent is an opaque 16-byte trivially copyable value. The format +// of the data stored within is unknown to everything except the the caller +// which creates it. Do not try to interpret it otherwise. +class CustomValueContent final { + public: + static CustomValueContent Zero() { + CustomValueContent content; + std::memset(&content, 0, sizeof(content)); + return content; + } + + template + static CustomValueContent From(T value) { + static_assert(std::is_trivially_copyable_v, + "T must be trivially copyable"); + static_assert(sizeof(T) <= 16, "sizeof(T) must be no greater than 16"); + + CustomValueContent content; + std::memcpy(content.raw_, std::addressof(value), sizeof(T)); + return content; + } + + template + static CustomValueContent From(const T (&array)[N]) { + static_assert(std::is_trivially_copyable_v, + "T must be trivially copyable"); + static_assert((sizeof(T) * N) <= 16, + "sizeof(T[N]) must be no greater than 16"); + + CustomValueContent content; + std::memcpy(content.raw_, array, sizeof(T) * N); + return content; + } + + template + T To() const { + static_assert(std::is_trivially_copyable_v, + "T must be trivially copyable"); + static_assert(sizeof(T) <= 16, "sizeof(T) must be no greater than 16"); + + T value; + std::memcpy(std::addressof(value), raw_, sizeof(T)); + return value; + } + + bool IsZero() const { + static const CustomValueContent kZero = Zero(); + return std::memcmp(raw_, kZero.raw_, sizeof(raw_)) == 0; + } + + private: + alignas(void*) std::byte raw_[16]; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_CUSTOM_VALUE_H_ diff --git a/common/values/double_value.cc b/common/values/double_value.cc new file mode 100644 index 000000000..c2299a2bb --- /dev/null +++ b/common/values/double_value.cc @@ -0,0 +1,137 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "common/value.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +std::string DoubleDebugString(double value) { + if (std::isfinite(value)) { + if (std::floor(value) != value) { + // The double is not representable as a whole number, so use + // absl::StrCat which will add decimal places. + return absl::StrCat(value); + } + // absl::StrCat historically would represent 0.0 as 0, and we want the + // decimal places so ZetaSQL correctly assumes the type as double + // instead of int64. + std::string stringified = absl::StrCat(value); + if (!absl::StrContains(stringified, '.')) { + absl::StrAppend(&stringified, ".0"); + } else { + // absl::StrCat has a decimal now? Use it directly. + } + return stringified; + } + if (std::isnan(value)) { + return "nan"; + } + if (std::signbit(value)) { + return "-infinity"; + } + return "+infinity"; +} + +} // namespace + +std::string DoubleValue::DebugString() const { + return DoubleDebugString(NativeValue()); +} + +absl::Status DoubleValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::DoubleValue message; + message.set_value(NativeValue()); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status DoubleValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.SetNumberValue(json, NativeValue()); + + return absl::OkStatus(); +} + +absl::Status DoubleValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsDouble(); other_value.has_value()) { + *result = BoolValue{NativeValue() == other_value->NativeValue()}; + return absl::OkStatus(); + } + if (auto other_value = other.AsInt(); other_value.has_value()) { + *result = + BoolValue{internal::Number::FromDouble(NativeValue()) == + internal::Number::FromInt64(other_value->NativeValue())}; + return absl::OkStatus(); + } + if (auto other_value = other.AsUint(); other_value.has_value()) { + *result = + BoolValue{internal::Number::FromDouble(NativeValue()) == + internal::Number::FromUint64(other_value->NativeValue())}; + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/double_value.h b/common/values/double_value.h new file mode 100644 index 000000000..dc24aee20 --- /dev/null +++ b/common/values/double_value.h @@ -0,0 +1,101 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_DOUBLE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_DOUBLE_VALUE_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class DoubleValue; + +class DoubleValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kDouble; + + explicit DoubleValue(double value) noexcept : value_(value) {} + + DoubleValue() = default; + DoubleValue(const DoubleValue&) = default; + DoubleValue(DoubleValue&&) = default; + DoubleValue& operator=(const DoubleValue&) = default; + DoubleValue& operator=(DoubleValue&&) = default; + + ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return DoubleType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return NativeValue() == 0.0; } + + double NativeValue() const { return static_cast(*this); } + + // NOLINTNEXTLINE(google-explicit-constructor) + operator double() const noexcept { return value_; } + + friend void swap(DoubleValue& lhs, DoubleValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + private: + friend class common_internal::ValueMixin; + + double value_ = 0.0; +}; + +inline std::ostream& operator<<(std::ostream& out, DoubleValue value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_DOUBLE_VALUE_H_ diff --git a/common/values/double_value_test.cc b/common/values/double_value_test.cc new file mode 100644 index 000000000..fc33a941b --- /dev/null +++ b/common/values/double_value_test.cc @@ -0,0 +1,96 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status_matchers.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; + +using DoubleValueTest = common_internal::ValueTest<>; + +TEST_F(DoubleValueTest, Kind) { + EXPECT_EQ(DoubleValue(1.0).kind(), DoubleValue::kKind); + EXPECT_EQ(Value(DoubleValue(1.0)).kind(), DoubleValue::kKind); +} + +TEST_F(DoubleValueTest, DebugString) { + { + std::ostringstream out; + out << DoubleValue(0.0); + EXPECT_EQ(out.str(), "0.0"); + } + { + std::ostringstream out; + out << DoubleValue(1.0); + EXPECT_EQ(out.str(), "1.0"); + } + { + std::ostringstream out; + out << DoubleValue(1.1); + EXPECT_EQ(out.str(), "1.1"); + } + { + std::ostringstream out; + out << DoubleValue(NAN); + EXPECT_EQ(out.str(), "nan"); + } + { + std::ostringstream out; + out << DoubleValue(INFINITY); + EXPECT_EQ(out.str(), "+infinity"); + } + { + std::ostringstream out; + out << DoubleValue(-INFINITY); + EXPECT_EQ(out.str(), "-infinity"); + } + { + std::ostringstream out; + out << Value(DoubleValue(0.0)); + EXPECT_EQ(out.str(), "0.0"); + } +} + +TEST_F(DoubleValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(DoubleValue(1.0).ConvertToJson(descriptor_pool(), + message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(number_value: 1)pb")); +} + +TEST_F(DoubleValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(DoubleValue(1.0)), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(DoubleValue(1.0))), + NativeTypeId::For()); +} + +TEST_F(DoubleValueTest, Equality) { + EXPECT_NE(DoubleValue(0.0), 1.0); + EXPECT_NE(1.0, DoubleValue(0.0)); + EXPECT_NE(DoubleValue(0.0), DoubleValue(1.0)); +} + +} // namespace +} // namespace cel diff --git a/common/values/duration_value.cc b/common/values/duration_value.cc new file mode 100644 index 000000000..a3b41e8ea --- /dev/null +++ b/common/values/duration_value.cc @@ -0,0 +1,103 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "google/protobuf/duration.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "internal/time.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::DurationReflection; +using ::cel::well_known_types::ValueReflection; + +std::string DurationDebugString(absl::Duration value) { + return internal::DebugStringDuration(value); +} + +} // namespace + +std::string DurationValue::DebugString() const { + return DurationDebugString(NativeValue()); +} + +absl::Status DurationValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::Duration message; + CEL_RETURN_IF_ERROR( + DurationReflection::SetFromAbslDuration(&message, NativeValue())); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status DurationValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.SetStringValueFromDuration(json, NativeValue()); + + return absl::OkStatus(); +} + +absl::Status DurationValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsDuration(); other_value.has_value()) { + *result = BoolValue{NativeValue() == other_value->NativeValue()}; + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/duration_value.h b/common/values/duration_value.h new file mode 100644 index 000000000..1b2468b60 --- /dev/null +++ b/common/values/duration_value.h @@ -0,0 +1,147 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_DURATION_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_DURATION_VALUE_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/utility/utility.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "internal/time.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class DurationValue; + +DurationValue UnsafeDurationValue(absl::Duration value); +absl::StatusOr SafeDurationValue(absl::Duration value); + +// `DurationValue` represents values of the primitive `duration` type. +class DurationValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kDuration; + + explicit DurationValue(absl::Duration value) noexcept + : DurationValue(absl::in_place, value) { + ABSL_DCHECK_OK(internal::ValidateDuration(value)); + } + + DurationValue() = default; + DurationValue(const DurationValue&) = default; + DurationValue(DurationValue&&) = default; + DurationValue& operator=(const DurationValue&) = default; + DurationValue& operator=(DurationValue&&) = default; + + ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return DurationType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return ToDuration() == absl::ZeroDuration(); } + + ABSL_DEPRECATED("Use ToDuration()") + absl::Duration NativeValue() const { + return static_cast(*this); + } + + ABSL_DEPRECATED("Use ToDuration()") + // NOLINTNEXTLINE(google-explicit-constructor) + operator absl::Duration() const noexcept { return value_; } + + absl::Duration ToDuration() const { return value_; } + + friend void swap(DurationValue& lhs, DurationValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + friend bool operator==(DurationValue lhs, DurationValue rhs) { + return lhs.value_ == rhs.value_; + } + + friend bool operator<(const DurationValue& lhs, const DurationValue& rhs) { + return lhs.value_ < rhs.value_; + } + + private: + friend class common_internal::ValueMixin; + friend DurationValue UnsafeDurationValue(absl::Duration value); + + DurationValue(absl::in_place_t, absl::Duration value) : value_(value) {} + + absl::Duration value_ = absl::ZeroDuration(); +}; + +inline DurationValue UnsafeDurationValue(absl::Duration value) { + return DurationValue(absl::in_place, value); +} + +inline absl::StatusOr SafeDurationValue(absl::Duration value) { + absl::Status status = internal::ValidateDuration(value); + if (!status.ok()) { + return status; + } + return UnsafeDurationValue(value); +} + +inline bool operator!=(DurationValue lhs, DurationValue rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, DurationValue value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_DURATION_VALUE_H_ diff --git a/common/values/duration_value_test.cc b/common/values/duration_value_test.cc new file mode 100644 index 000000000..29d9b0f9e --- /dev/null +++ b/common/values/duration_value_test.cc @@ -0,0 +1,92 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/time/time.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::IsEmpty; + +using DurationValueTest = common_internal::ValueTest<>; + +TEST_F(DurationValueTest, Kind) { + EXPECT_EQ(DurationValue().kind(), DurationValue::kKind); + EXPECT_EQ(Value(DurationValue(absl::Seconds(1))).kind(), + DurationValue::kKind); +} + +TEST_F(DurationValueTest, DebugString) { + { + std::ostringstream out; + out << DurationValue(absl::Seconds(1)); + EXPECT_EQ(out.str(), "1s"); + } + { + std::ostringstream out; + out << Value(DurationValue(absl::Seconds(1))); + EXPECT_EQ(out.str(), "1s"); + } +} + +TEST_F(DurationValueTest, SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(DurationValue().SerializeTo(descriptor_pool(), message_factory(), + &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), IsEmpty()); +} + +TEST_F(DurationValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(DurationValue().ConvertToJson(descriptor_pool(), + message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(string_value: "0s")pb")); +} + +TEST_F(DurationValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(DurationValue(absl::Seconds(1))), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(DurationValue(absl::Seconds(1)))), + NativeTypeId::For()); +} + +TEST_F(DurationValueTest, Equality) { + EXPECT_NE(DurationValue(absl::ZeroDuration()), absl::Seconds(1)); + EXPECT_NE(absl::Seconds(1), DurationValue(absl::ZeroDuration())); + EXPECT_NE(DurationValue(absl::ZeroDuration()), + DurationValue(absl::Seconds(1))); +} + +TEST_F(DurationValueTest, Comparison) { + EXPECT_LT(DurationValue(absl::ZeroDuration()), absl::Seconds(1)); + EXPECT_FALSE(DurationValue(absl::Seconds(1)) < + DurationValue(absl::Seconds(1))); + EXPECT_FALSE(DurationValue(absl::Seconds(2)) < + DurationValue(absl::Seconds(1))); +} + +} // namespace +} // namespace cel diff --git a/common/values/enum_value.h b/common/values/enum_value.h new file mode 100644 index 000000000..71f437e62 --- /dev/null +++ b/common/values/enum_value.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ + +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/meta/type_traits.h" +#include "google/protobuf/generated_enum_util.h" + +namespace cel::common_internal { + +template > +inline constexpr bool kIsWellKnownEnumType = + std::is_same::value; + +template > +inline constexpr bool kIsGeneratedEnum = google::protobuf::is_proto_enum::value; + +template +using EnableIfWellKnownEnum = std::enable_if_t< + kIsWellKnownEnumType && std::is_same, U>::value, R>; + +template +using EnableIfGeneratedEnum = std::enable_if_t< + absl::conjunction< + std::bool_constant>, + absl::negation>>>::value, + R>; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ diff --git a/common/values/error_value.cc b/common/values/error_value.cc new file mode 100644 index 000000000..8ea6554ec --- /dev/null +++ b/common/values/error_value.cc @@ -0,0 +1,194 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +std::string ErrorDebugString(const absl::Status& value) { + ABSL_DCHECK(!value.ok()) << "use of moved-from ErrorValue"; + return value.ToString(absl::StatusToStringMode::kWithEverything); +} + +const absl::Status& DefaultErrorValue() { + static const absl::NoDestructor value( + absl::UnknownError("unknown error")); + return *value; +} + +} // namespace + +ErrorValue::ErrorValue() : ErrorValue(DefaultErrorValue()) {} + +ErrorValue NoSuchFieldError(absl::string_view field) { + return ErrorValue(absl::NotFoundError( + absl::StrCat("no_such_field", field.empty() ? "" : " : ", field))); +} + +ErrorValue NoSuchKeyError(absl::string_view key) { + return ErrorValue( + absl::NotFoundError(absl::StrCat("Key not found in map : ", key))); +} + +ErrorValue NoSuchTypeError(absl::string_view type) { + return ErrorValue( + absl::NotFoundError(absl::StrCat("type not found: ", type))); +} + +ErrorValue DuplicateKeyError() { + return ErrorValue(absl::AlreadyExistsError("duplicate key in map")); +} + +ErrorValue TypeConversionError(absl::string_view from, absl::string_view to) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("type conversion error from '", from, "' to '", to, "'"))); +} + +ErrorValue TypeConversionError(const Type& from, const Type& to) { + return TypeConversionError(from.DebugString(), to.DebugString()); +} + +ErrorValue IndexOutOfBoundsError(size_t index) { + return ErrorValue( + absl::InvalidArgumentError(absl::StrCat("index out of bounds: ", index))); +} + +ErrorValue IndexOutOfBoundsError(ptrdiff_t index) { + return ErrorValue( + absl::InvalidArgumentError(absl::StrCat("index out of bounds: ", index))); +} + +bool IsNoSuchField(const ErrorValue& value) { + return absl::IsNotFound(value.NativeValue()) && + absl::StartsWith(value.NativeValue().message(), "no_such_field"); +} + +bool IsNoSuchKey(const ErrorValue& value) { + return absl::IsNotFound(value.NativeValue()) && + absl::StartsWith(value.NativeValue().message(), + "Key not found in map"); +} + +std::string ErrorValue::DebugString() const { + return ErrorDebugString(NativeValue()); +} + +absl::Status ErrorValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + ABSL_DCHECK(*this); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), " is unserializable")); +} + +absl::Status ErrorValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + ABSL_DCHECK(*this); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), " is not convertable to JSON")); +} + +absl::Status ErrorValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(*this); + + *result = FalseValue(); + return absl::OkStatus(); +} + +ErrorValue ErrorValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(*this); + + if (arena_ == nullptr || arena_ != arena) { + return ErrorValue(arena, + google::protobuf::Arena::Create(arena, ToStatus())); + } + return *this; +} + +absl::Status ErrorValue::ToStatus() const& { + ABSL_DCHECK(*this); + + if (arena_ == nullptr) { + return *std::launder( + reinterpret_cast(&status_.val[0])); + } + return *status_.ptr; +} + +absl::Status ErrorValue::ToStatus() && { + ABSL_DCHECK(*this); + + if (arena_ == nullptr) { + return std::move( + *std::launder(reinterpret_cast(&status_.val[0]))); + } + return *status_.ptr; +} + +ErrorValue::operator bool() const { + if (arena_ == nullptr) { + return !std::launder(reinterpret_cast(&status_.val[0])) + ->ok(); + } + return status_.ptr != nullptr && !status_.ptr->ok(); +} + +void swap(ErrorValue& lhs, ErrorValue& rhs) noexcept { + ErrorValue tmp(std::move(lhs)); + lhs = std::move(rhs); + rhs = std::move(tmp); +} + +} // namespace cel diff --git a/common/values/error_value.h b/common/values/error_value.h new file mode 100644 index 000000000..4e24c866b --- /dev/null +++ b/common/values/error_value.h @@ -0,0 +1,276 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/arena.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; + +// `ErrorValue` represents values of the `ErrorType`. +class ABSL_ATTRIBUTE_TRIVIAL_ABI ErrorValue final + : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kError; + + explicit ErrorValue(absl::Status value) : arena_(nullptr) { + ::new (static_cast(&status_.val[0])) absl::Status(std::move(value)); + ABSL_DCHECK(*this) << "ErrorValue requires a non-OK absl::Status"; + } + + // By default, this creates an UNKNOWN error. You should always create a more + // specific error value. + ErrorValue(); + + ErrorValue(const ErrorValue& other) { CopyConstruct(other); } + + ErrorValue(ErrorValue&& other) noexcept { MoveConstruct(other); } + + ~ErrorValue() { Destruct(); } + + ErrorValue& operator=(const ErrorValue& other) { + if (this != &other) { + Destruct(); + CopyConstruct(other); + } + return *this; + } + + ErrorValue& operator=(ErrorValue&& other) noexcept { + if (this != &other) { + Destruct(); + MoveConstruct(other); + } + return *this; + } + + static constexpr ValueKind kind() { return kKind; } + + static absl::string_view GetTypeName() { return ErrorType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return false; } + + ErrorValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status ToStatus() const&; + + absl::Status ToStatus() &&; + + ABSL_DEPRECATED("Use ToStatus()") + absl::Status NativeValue() const& { return ToStatus(); } + + ABSL_DEPRECATED("Use ToStatus()") + absl::Status NativeValue() && { return std::move(*this).ToStatus(); } + + friend void swap(ErrorValue& lhs, ErrorValue& rhs) noexcept; + + explicit operator bool() const; + + private: + friend class common_internal::ValueMixin; + friend struct ArenaTraits; + + ErrorValue(google::protobuf::Arena* absl_nonnull arena, + const absl::Status* absl_nonnull status) + : arena_(arena) { + status_.ptr = status; + } + + void CopyConstruct(const ErrorValue& other) { + arena_ = other.arena_; + if (arena_ == nullptr) { + ::new (static_cast(&status_.val[0])) absl::Status(*std::launder( + reinterpret_cast(&other.status_.val[0]))); + } else { + status_.ptr = other.status_.ptr; + } + } + + void MoveConstruct(ErrorValue& other) { + arena_ = other.arena_; + if (arena_ == nullptr) { + ::new (static_cast(&status_.val[0])) + absl::Status(std::move(*std::launder( + reinterpret_cast(&other.status_.val[0])))); + } else { + status_.ptr = other.status_.ptr; + } + } + + void Destruct() { + if (arena_ == nullptr) { + std::launder(reinterpret_cast(&status_.val[0]))->~Status(); + } + } + + google::protobuf::Arena* absl_nullable arena_; + union { + alignas(absl::Status) char val[sizeof(absl::Status)]; + const absl::Status* absl_nonnull ptr; + } status_; +}; + +ErrorValue NoSuchFieldError(absl::string_view field); + +ErrorValue NoSuchKeyError(absl::string_view key); + +ErrorValue NoSuchTypeError(absl::string_view type); + +ErrorValue DuplicateKeyError(); + +ErrorValue TypeConversionError(absl::string_view from, absl::string_view to); + +ErrorValue TypeConversionError(const Type& from, const Type& to); + +ErrorValue IndexOutOfBoundsError(size_t index); + +ErrorValue IndexOutOfBoundsError(ptrdiff_t index); + +// Catch other integrals and forward them to the above ones. This is needed to +// avoid ambiguous overload issues for smaller integral types like `int`. +template +std::enable_if_t, std::is_unsigned, + std::negation>>, + ErrorValue> +IndexOutOfBoundsError(T index) { + static_assert(sizeof(T) <= sizeof(size_t)); + return IndexOutOfBoundsError(static_cast(index)); +} +template +std::enable_if_t, std::is_signed, + std::negation>>, + ErrorValue> +IndexOutOfBoundsError(T index) { + static_assert(sizeof(T) <= sizeof(ptrdiff_t)); + return IndexOutOfBoundsError(static_cast(index)); +} + +inline std::ostream& operator<<(std::ostream& out, const ErrorValue& value) { + return out << value.DebugString(); +} + +bool IsNoSuchField(const ErrorValue& value); + +bool IsNoSuchKey(const ErrorValue& value); + +class ErrorValueReturn final { + public: + ErrorValueReturn() = default; + + ErrorValue operator()(absl::Status status) const { + return ErrorValue(std::move(status)); + } +}; + +namespace common_internal { + +struct ImplicitlyConvertibleStatus { + // NOLINTNEXTLINE(google-explicit-constructor) + operator absl::Status() const { return absl::OkStatus(); } + + template + // NOLINTNEXTLINE(google-explicit-constructor) + operator absl::StatusOr() const { + return T(); + } +}; + +} // namespace common_internal + +// For use with `RETURN_IF_ERROR(...).With(cel::ErrorValueAssign(&result))` and +// `ASSIGN_OR_RETURN(..., ..., _.With(cel::ErrorValueAssign(&result)))`. +// +// IMPORTANT: +// If the returning type is `absl::Status` the result will be +// `absl::OkStatus()`. If the returning type is `absl::StatusOr` the result +// will be `T()`. +class ErrorValueAssign final { + public: + ErrorValueAssign() = delete; + + explicit ErrorValueAssign(Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ErrorValueAssign(std::addressof(value)) {} + + explicit ErrorValueAssign( + Value* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(value) { + ABSL_DCHECK(value != nullptr); + } + + common_internal::ImplicitlyConvertibleStatus operator()( + absl::Status status) const; + + private: + Value* absl_nonnull value_; +}; + +template <> +struct ArenaTraits { + static bool trivially_destructible(const ErrorValue& value) { + return value.arena_ != nullptr; + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ diff --git a/common/values/error_value_test.cc b/common/values/error_value_test.cc new file mode 100644 index 000000000..343a93d19 --- /dev/null +++ b/common/values/error_value_test.cc @@ -0,0 +1,84 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/status/status.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::_; +using ::testing::IsEmpty; +using ::testing::Not; + +using ErrorValueTest = common_internal::ValueTest<>; + +TEST_F(ErrorValueTest, Default) { + ErrorValue value; + EXPECT_THAT(value.NativeValue(), StatusIs(absl::StatusCode::kUnknown)); +} + +TEST_F(ErrorValueTest, OkStatus) { + EXPECT_DEBUG_DEATH(static_cast(ErrorValue(absl::OkStatus())), _); +} + +TEST_F(ErrorValueTest, Kind) { + EXPECT_EQ(ErrorValue(absl::CancelledError()).kind(), ErrorValue::kKind); + EXPECT_EQ(Value(ErrorValue(absl::CancelledError())).kind(), + ErrorValue::kKind); +} + +TEST_F(ErrorValueTest, DebugString) { + { + std::ostringstream out; + out << ErrorValue(absl::CancelledError()); + EXPECT_THAT(out.str(), Not(IsEmpty())); + } + { + std::ostringstream out; + out << Value(ErrorValue(absl::CancelledError())); + EXPECT_THAT(out.str(), Not(IsEmpty())); + } +} + +TEST_F(ErrorValueTest, SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT( + ErrorValue().SerializeTo(descriptor_pool(), message_factory(), &output), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(ErrorValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT( + ErrorValue().ConvertToJson(descriptor_pool(), message_factory(), message), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(ErrorValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(ErrorValue(absl::CancelledError())), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(ErrorValue(absl::CancelledError()))), + NativeTypeId::For()); +} + +} // namespace +} // namespace cel diff --git a/common/values/int_value.cc b/common/values/int_value.cc new file mode 100644 index 000000000..0232bad19 --- /dev/null +++ b/common/values/int_value.cc @@ -0,0 +1,111 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "common/value.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +std::string IntDebugString(int64_t value) { return absl::StrCat(value); } + +} // namespace + +std::string IntValue::DebugString() const { + return IntDebugString(NativeValue()); +} + +absl::Status IntValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::Int64Value message; + message.set_value(NativeValue()); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status IntValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.SetNumberValue(json, NativeValue()); + + return absl::OkStatus(); +} + +absl::Status IntValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsInt(); other_value.has_value()) { + *result = BoolValue{NativeValue() == other_value->NativeValue()}; + return absl::OkStatus(); + } + if (auto other_value = other.AsDouble(); other_value.has_value()) { + *result = + BoolValue{internal::Number::FromInt64(NativeValue()) == + internal::Number::FromDouble(other_value->NativeValue())}; + return absl::OkStatus(); + } + if (auto other_value = other.AsUint(); other_value.has_value()) { + *result = + BoolValue{internal::Number::FromInt64(NativeValue()) == + internal::Number::FromUint64(other_value->NativeValue())}; + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/int_value.h b/common/values/int_value.h new file mode 100644 index 000000000..af0db7ee7 --- /dev/null +++ b/common/values/int_value.h @@ -0,0 +1,117 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_INT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_INT_VALUE_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class IntValue; + +// `IntValue` represents values of the primitive `int` type. +class IntValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kInt; + + explicit IntValue(int64_t value) noexcept : value_(value) {} + + IntValue() = default; + IntValue(const IntValue&) = default; + IntValue(IntValue&&) = default; + IntValue& operator=(const IntValue&) = default; + IntValue& operator=(IntValue&&) = default; + + ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return IntType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return NativeValue() == 0; } + + int64_t NativeValue() const { return static_cast(*this); } + + // NOLINTNEXTLINE(google-explicit-constructor) + operator int64_t() const noexcept { return value_; } + + friend void swap(IntValue& lhs, IntValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + private: + friend class common_internal::ValueMixin; + + int64_t value_ = 0; +}; + +template +H AbslHashValue(H state, IntValue value) { + return H::combine(std::move(state), value.NativeValue()); +} + +inline bool operator==(IntValue lhs, IntValue rhs) { + return lhs.NativeValue() == rhs.NativeValue(); +} + +inline bool operator!=(IntValue lhs, IntValue rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, IntValue value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_INT_VALUE_H_ diff --git a/common/values/int_value_test.cc b/common/values/int_value_test.cc new file mode 100644 index 000000000..0a3169606 --- /dev/null +++ b/common/values/int_value_test.cc @@ -0,0 +1,81 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/status_matchers.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; + +using IntValueTest = common_internal::ValueTest<>; + +TEST_F(IntValueTest, Kind) { + EXPECT_EQ(IntValue(1).kind(), IntValue::kKind); + EXPECT_EQ(Value(IntValue(1)).kind(), IntValue::kKind); +} + +TEST_F(IntValueTest, DebugString) { + { + std::ostringstream out; + out << IntValue(1); + EXPECT_EQ(out.str(), "1"); + } + { + std::ostringstream out; + out << Value(IntValue(1)); + EXPECT_EQ(out.str(), "1"); + } +} + +TEST_F(IntValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT( + IntValue(1).ConvertToJson(descriptor_pool(), message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(number_value: 1)pb")); +} + +TEST_F(IntValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(IntValue(1)), NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(IntValue(1))), + NativeTypeId::For()); +} + +TEST_F(IntValueTest, HashValue) { + EXPECT_EQ(absl::HashOf(IntValue(1)), absl::HashOf(int64_t{1})); +} + +TEST_F(IntValueTest, Equality) { + EXPECT_NE(IntValue(0), 1); + EXPECT_NE(1, IntValue(0)); + EXPECT_NE(IntValue(0), IntValue(1)); +} + +TEST_F(IntValueTest, LessThan) { + EXPECT_LT(IntValue(0), 1); + EXPECT_LT(0, IntValue(1)); + EXPECT_LT(IntValue(0), IntValue(1)); +} + +} // namespace +} // namespace cel diff --git a/common/values/legacy_list_value.cc b/common/values/legacy_list_value.cc new file mode 100644 index 000000000..93848ca44 --- /dev/null +++ b/common/values/legacy_list_value.cc @@ -0,0 +1,76 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/legacy_list_value.h" + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/values/list_value_builder.h" +#include "common/values/values.h" +#include "eval/public/cel_value.h" +#include "internal/casts.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::common_internal { + +absl::Status LegacyListValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (auto list_value = other.AsList(); list_value.has_value()) { + return ListValueEqual(*this, *list_value, descriptor_pool, message_factory, + arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +bool IsLegacyListValue(const Value& value) { + return value.variant_.Is(); +} + +LegacyListValue GetLegacyListValue(const Value& value) { + ABSL_DCHECK(IsLegacyListValue(value)); + return value.variant_.Get(); +} + +absl::optional AsLegacyListValue(const Value& value) { + if (IsLegacyListValue(value)) { + return GetLegacyListValue(value); + } + if (auto custom_list_value = value.AsCustomList(); custom_list_value) { + NativeTypeId native_type_id = custom_list_value->GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return LegacyListValue( + static_cast( + cel::internal::down_cast( + custom_list_value->interface()))); + } else if (native_type_id == NativeTypeId::For()) { + return LegacyListValue( + static_cast( + cel::internal::down_cast( + custom_list_value->interface()))); + } + } + return absl::nullopt; +} + +} // namespace cel::common_internal diff --git a/common/values/legacy_list_value.h b/common/values/legacy_list_value.h new file mode 100644 index 000000000..caffcbc25 --- /dev/null +++ b/common/values/legacy_list_value.h @@ -0,0 +1,167 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/values/list_value.h" +// IWYU pragma: friend "common/values/list_value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_LIST_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_LIST_VALUE_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/value_kind.h" +#include "common/values/custom_list_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { +class CelList; +} + +namespace cel { + +class Value; + +namespace common_internal { + +class LegacyListValue; + +class LegacyListValue final + : private common_internal::ListValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kList; + + explicit LegacyListValue( + const google::api::expr::runtime::CelList* absl_nullability_unknown impl) + : impl_(impl) {} + + // By default, this creates an empty list whose type is `list(dyn)`. Unless + // you can help it, you should use a more specific typed list value. + LegacyListValue() = default; + LegacyListValue(const LegacyListValue&) = default; + LegacyListValue(LegacyListValue&&) = default; + LegacyListValue& operator=(const LegacyListValue&) = default; + LegacyListValue& operator=(LegacyListValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return "list"; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonArray(). + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Equal; + + bool IsZeroValue() const { return IsEmpty(); } + + bool IsEmpty() const; + + size_t Size() const; + + // See ListValueInterface::Get for documentation. + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Get; + + using ForEachCallback = typename CustomListValueInterface::ForEachCallback; + + using ForEachWithIndexCallback = + typename CustomListValueInterface::ForEachWithIndexCallback; + + absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + using ListValueMixin::ForEach; + + absl::StatusOr NewIterator() const; + + absl::Status Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using ListValueMixin::Contains; + + const google::api::expr::runtime::CelList* absl_nullability_unknown cel_list() + const { + return impl_; + } + + friend void swap(LegacyListValue& lhs, LegacyListValue& rhs) noexcept { + using std::swap; + swap(lhs.impl_, rhs.impl_); + } + + private: + friend class common_internal::ValueMixin; + friend class common_internal::ListValueMixin; + + const google::api::expr::runtime::CelList* absl_nullability_unknown impl_ = + nullptr; +}; + +inline std::ostream& operator<<(std::ostream& out, + const LegacyListValue& type) { + return out << type.DebugString(); +} + +bool IsLegacyListValue(const Value& value); + +LegacyListValue GetLegacyListValue(const Value& value); + +absl::optional AsLegacyListValue(const Value& value); + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_LIST_VALUE_H_ diff --git a/common/values/legacy_map_value.cc b/common/values/legacy_map_value.cc new file mode 100644 index 000000000..1f370761e --- /dev/null +++ b/common/values/legacy_map_value.cc @@ -0,0 +1,76 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/legacy_map_value.h" + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/values/map_value_builder.h" +#include "common/values/values.h" +#include "eval/public/cel_value.h" +#include "internal/casts.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::common_internal { + +absl::Status LegacyMapValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (auto map_value = other.AsMap(); map_value.has_value()) { + return MapValueEqual(*this, *map_value, descriptor_pool, message_factory, + arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +bool IsLegacyMapValue(const Value& value) { + return value.variant_.Is(); +} + +LegacyMapValue GetLegacyMapValue(const Value& value) { + ABSL_DCHECK(IsLegacyMapValue(value)); + return value.variant_.Get(); +} + +absl::optional AsLegacyMapValue(const Value& value) { + if (IsLegacyMapValue(value)) { + return GetLegacyMapValue(value); + } + if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { + NativeTypeId native_type_id = NativeTypeId::Of(*custom_map_value); + if (native_type_id == NativeTypeId::For()) { + return LegacyMapValue( + static_cast( + cel::internal::down_cast( + custom_map_value->interface()))); + } else if (native_type_id == NativeTypeId::For()) { + return LegacyMapValue( + static_cast( + cel::internal::down_cast( + custom_map_value->interface()))); + } + } + return absl::nullopt; +} + +} // namespace cel::common_internal diff --git a/common/values/legacy_map_value.h b/common/values/legacy_map_value.h new file mode 100644 index 000000000..c83b7fc2f --- /dev/null +++ b/common/values/legacy_map_value.h @@ -0,0 +1,185 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/values/map_value.h" +// IWYU pragma: friend "common/values/map_value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_MAP_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_MAP_VALUE_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/value_kind.h" +#include "common/values/custom_map_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { +class CelMap; +} + +namespace cel { + +class Value; + +namespace common_internal { + +class LegacyMapValue; + +class LegacyMapValue final + : private common_internal::MapValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kMap; + + explicit LegacyMapValue( + const google::api::expr::runtime::CelMap* absl_nullability_unknown impl) + : impl_(impl) {} + + // By default, this creates an empty map whose type is `map(dyn, dyn)`. + // Unless you can help it, you should use a more specific typed map value. + LegacyMapValue() = default; + LegacyMapValue(const LegacyMapValue&) = default; + LegacyMapValue(LegacyMapValue&&) = default; + LegacyMapValue& operator=(const LegacyMapValue&) = default; + LegacyMapValue& operator=(LegacyMapValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return "map"; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Equal; + + bool IsZeroValue() const { return IsEmpty(); } + + bool IsEmpty() const; + + size_t Size() const; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Get(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Get; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using MapValueMixin::Find; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Has(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Has; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; + using MapValueMixin::ListKeys; + + // See the corresponding type declaration of `MapValue` for + // documentation. + using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr NewIterator() const; + + const google::api::expr::runtime::CelMap* absl_nonnull cel_map() const { + return impl_; + } + + friend void swap(LegacyMapValue& lhs, LegacyMapValue& rhs) noexcept { + using std::swap; + swap(lhs.impl_, rhs.impl_); + } + + private: + friend class common_internal::ValueMixin; + friend class common_internal::MapValueMixin; + + const google::api::expr::runtime::CelMap* absl_nullability_unknown impl_ = + nullptr; +}; + +inline std::ostream& operator<<(std::ostream& out, const LegacyMapValue& type) { + return out << type.DebugString(); +} + +bool IsLegacyMapValue(const Value& value); + +LegacyMapValue GetLegacyMapValue(const Value& value); + +absl::optional AsLegacyMapValue(const Value& value); + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_MAP_VALUE_H_ diff --git a/common/values/legacy_struct_value.cc b/common/values/legacy_struct_value.cc new file mode 100644 index 000000000..4a91c5d42 --- /dev/null +++ b/common/values/legacy_struct_value.cc @@ -0,0 +1,43 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/log/absl_check.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/value.h" +#include "google/protobuf/message.h" + +namespace cel::common_internal { + +StructType LegacyStructValue::GetRuntimeType() const { + return MessageType(message_ptr_->GetDescriptor()); +} + +bool IsLegacyStructValue(const Value& value) { + return value.variant_.Is(); +} + +LegacyStructValue GetLegacyStructValue(const Value& value) { + ABSL_DCHECK(IsLegacyStructValue(value)); + return value.variant_.Get(); +} + +absl::optional AsLegacyStructValue(const Value& value) { + if (IsLegacyStructValue(value)) { + return GetLegacyStructValue(value); + } + return absl::nullopt; +} + +} // namespace cel::common_internal diff --git a/common/values/legacy_struct_value.h b/common/values/legacy_struct_value.h new file mode 100644 index 000000000..ab5baed1e --- /dev/null +++ b/common/values/legacy_struct_value.h @@ -0,0 +1,183 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_STRUCT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_STRUCT_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_struct_value.h" +#include "common/values/values.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { +class LegacyTypeInfoApis; +} + +namespace cel { + +class Value; + +namespace common_internal { + +class LegacyStructValue; + +// `LegacyStructValue` is a wrapper around the old representation of protocol +// buffer messages in `google::api::expr::runtime::CelValue`. It only supports +// arena allocation. +class LegacyStructValue final + : private common_internal::StructValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kStruct; + + LegacyStructValue() = default; + + LegacyStructValue( + const google::protobuf::Message* absl_nullability_unknown message_ptr, + const google::api::expr::runtime:: + LegacyTypeInfoApis* absl_nullability_unknown legacy_type_info) + : message_ptr_(message_ptr), legacy_type_info_(legacy_type_info) {} + + LegacyStructValue(const LegacyStructValue&) = default; + LegacyStructValue& operator=(const LegacyStructValue&) = default; + + constexpr ValueKind kind() const { return kKind; } + + StructType GetRuntimeType() const; + + absl::string_view GetTypeName() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using StructValueMixin::Equal; + + bool IsZeroValue() const; + + absl::Status GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByName; + + absl::Status GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByNumber; + + absl::StatusOr HasFieldByName(absl::string_view name) const; + + absl::StatusOr HasFieldByNumber(int64_t number) const; + + using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; + + absl::Status ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const; + using StructValueMixin::Qualify; + + const google::protobuf::Message* absl_nullability_unknown message_ptr() const { + return message_ptr_; + } + + const google::api::expr::runtime::LegacyTypeInfoApis* absl_nullability_unknown + legacy_type_info() const { + return legacy_type_info_; + } + + friend void swap(LegacyStructValue& lhs, LegacyStructValue& rhs) noexcept { + using std::swap; + swap(lhs.message_ptr_, rhs.message_ptr_); + swap(lhs.legacy_type_info_, rhs.legacy_type_info_); + } + + private: + friend class common_internal::ValueMixin; + friend class common_internal::StructValueMixin; + + const google::protobuf::Message* absl_nullability_unknown message_ptr_ = nullptr; + const google::api::expr::runtime::LegacyTypeInfoApis* absl_nullability_unknown + legacy_type_info_ = nullptr; +}; + +inline std::ostream& operator<<(std::ostream& out, + const LegacyStructValue& value) { + return out << value.DebugString(); +} + +bool IsLegacyStructValue(const Value& value); + +LegacyStructValue GetLegacyStructValue(const Value& value); + +absl::optional AsLegacyStructValue(const Value& value); + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_STRUCT_VALUE_H_ diff --git a/common/values/list_value.cc b/common/values/list_value.cc new file mode 100644 index 000000000..35df98c40 --- /dev/null +++ b/common/values/list_value.cc @@ -0,0 +1,304 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/value.h" +#include "common/values/value_variant.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +NativeTypeId ListValue::GetTypeId() const { + return variant_.Visit([](const auto& alternative) -> NativeTypeId { + return NativeTypeId::Of(alternative); + }); +} + +std::string ListValue::DebugString() const { + return variant_.Visit([](const auto& alternative) -> std::string { + return alternative.DebugString(); + }); +} + +absl::Status ListValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.SerializeTo(descriptor_pool, message_factory, output); + }); +} + +absl::Status ListValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ConvertToJson(descriptor_pool, message_factory, json); + }); +} + +absl::Status ListValue::ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ConvertToJsonArray(descriptor_pool, message_factory, + json); + }); +} + +absl::Status ListValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Equal(other, descriptor_pool, message_factory, arena, + result); + }); +} + +bool ListValue::IsZeroValue() const { + return variant_.Visit([](const auto& alternative) -> bool { + return alternative.IsZeroValue(); + }); +} + +absl::StatusOr ListValue::IsEmpty() const { + return variant_.Visit([](const auto& alternative) -> absl::StatusOr { + return alternative.IsEmpty(); + }); +} + +absl::StatusOr ListValue::Size() const { + return variant_.Visit([](const auto& alternative) -> absl::StatusOr { + return alternative.Size(); + }); +} + +absl::Status ListValue::Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Get(index, descriptor_pool, message_factory, arena, + result); + }); +} + +absl::Status ListValue::ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ForEach(callback, descriptor_pool, message_factory, + arena); + }); +} + +absl::StatusOr ListValue::NewIterator() const { + return variant_.Visit([](const auto& alternative) + -> absl::StatusOr { + return alternative.NewIterator(); + }); +} + +absl::Status ListValue::Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Contains(other, descriptor_pool, message_factory, arena, + result); + }); +} + +namespace common_internal { + +absl::Status ListValueEqual( + const ListValue& lhs, const ListValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + CEL_ASSIGN_OR_RETURN(auto lhs_size, lhs.Size()); + CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); + if (lhs_size != rhs_size) { + *result = FalseValue(); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); + CEL_ASSIGN_OR_RETURN(auto rhs_iterator, rhs.NewIterator()); + Value lhs_element; + Value rhs_element; + for (size_t index = 0; index < lhs_size; ++index) { + ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK + ABSL_CHECK(rhs_iterator->HasNext()); // Crash OK + CEL_RETURN_IF_ERROR(lhs_iterator->Next(descriptor_pool, message_factory, + arena, &lhs_element)); + CEL_RETURN_IF_ERROR(rhs_iterator->Next(descriptor_pool, message_factory, + arena, &rhs_element)); + CEL_RETURN_IF_ERROR(lhs_element.Equal(rhs_element, descriptor_pool, + message_factory, arena, result)); + if (result->IsFalse()) { + return absl::OkStatus(); + } + } + ABSL_DCHECK(!lhs_iterator->HasNext()); + ABSL_DCHECK(!rhs_iterator->HasNext()); + *result = TrueValue(); + return absl::OkStatus(); +} + +absl::Status ListValueEqual( + const CustomListValueInterface& lhs, const ListValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + auto lhs_size = lhs.Size(); + CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); + if (lhs_size != rhs_size) { + *result = FalseValue(); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); + CEL_ASSIGN_OR_RETURN(auto rhs_iterator, rhs.NewIterator()); + Value lhs_element; + Value rhs_element; + for (size_t index = 0; index < lhs_size; ++index) { + ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK + ABSL_CHECK(rhs_iterator->HasNext()); // Crash OK + CEL_RETURN_IF_ERROR(lhs_iterator->Next(descriptor_pool, message_factory, + arena, &lhs_element)); + CEL_RETURN_IF_ERROR(rhs_iterator->Next(descriptor_pool, message_factory, + arena, &rhs_element)); + CEL_RETURN_IF_ERROR(lhs_element.Equal(rhs_element, descriptor_pool, + message_factory, arena, result)); + if (result->IsFalse()) { + return absl::OkStatus(); + } + } + ABSL_DCHECK(!lhs_iterator->HasNext()); + ABSL_DCHECK(!rhs_iterator->HasNext()); + *result = TrueValue(); + return absl::OkStatus(); +} + +} // namespace common_internal + +optional_ref ListValue::AsCustom() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional ListValue::AsCustom() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +const CustomListValue& ListValue::GetCustom() const& { + ABSL_DCHECK(IsCustom()); + + return variant_.Get(); +} + +CustomListValue ListValue::GetCustom() && { + ABSL_DCHECK(IsCustom()); + + return std::move(variant_).Get(); +} + +common_internal::ValueVariant ListValue::ToValueVariant() const& { + return variant_.Visit( + [](const auto& alternative) -> common_internal::ValueVariant { + return common_internal::ValueVariant(alternative); + }); +} + +common_internal::ValueVariant ListValue::ToValueVariant() && { + return std::move(variant_).Visit( + [](auto&& alternative) -> common_internal::ValueVariant { + // NOLINTNEXTLINE(bugprone-move-forwarding-reference) + return common_internal::ValueVariant(std::move(alternative)); + }); +} + +} // namespace cel diff --git a/common/values/list_value.h b/common/values/list_value.h new file mode 100644 index 000000000..516d16dcc --- /dev/null +++ b/common/values/list_value.h @@ -0,0 +1,284 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +// `ListValue` represents values of the primitive `list` type. +// `ListValueInterface` is the abstract base class of implementations. +// `ListValue` acts as a smart pointer to `ListValueInterface`. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/utility/utility.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/value_kind.h" +#include "common/values/custom_list_value.h" +#include "common/values/legacy_list_value.h" +#include "common/values/list_value_variant.h" +#include "common/values/parsed_json_list_value.h" +#include "common/values/parsed_repeated_field_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class ListValueInterface; +class ListValue; +class Value; + +class ListValue final : private common_internal::ListValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kList; + + // Move constructor for alternative struct values. + template < + typename T, + typename = std::enable_if_t< + common_internal::IsListValueAlternativeV>>> + // NOLINTNEXTLINE(google-explicit-constructor) + ListValue(T&& value) + : variant_(absl::in_place_type>, + std::forward(value)) {} + + ListValue() = default; + ListValue(const ListValue&) = default; + ListValue(ListValue&&) = default; + ListValue& operator=(const ListValue&) = default; + ListValue& operator=(ListValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + static absl::string_view GetTypeName() { return "list"; } + + NativeTypeId GetTypeId() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // Like ConvertToJson(), except `json` **MUST** be an instance of + // `google.protobuf.ListValue`. + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Equal; + + bool IsZeroValue() const; + + absl::StatusOr IsEmpty() const; + + absl::StatusOr Size() const; + + // See ListValueInterface::Get for documentation. + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Get; + + using ForEachCallback = typename CustomListValueInterface::ForEachCallback; + + using ForEachWithIndexCallback = + typename CustomListValueInterface::ForEachWithIndexCallback; + + absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + using ListValueMixin::ForEach; + + absl::StatusOr NewIterator() const; + + absl::Status Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using ListValueMixin::Contains; + + // Returns `true` if this value is an instance of a custom list value. + bool IsCustom() const { return variant_.Is(); } + + // Convenience method for use with template metaprogramming. See + // `IsParsed()`. + template + std::enable_if_t, bool> Is() const { + return IsCustom(); + } + + // Performs a checked cast from a value to a custom list value, + // returning a non-empty optional with either a value or reference to the + // custom list value. Otherwise an empty optional is returned. + optional_ref AsCustom() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsCustom(); + } + optional_ref AsCustom() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsCustom() &&; + absl::optional AsCustom() const&& { + return common_internal::AsOptional(AsCustom()); + } + + // Convenience method for use with template metaprogramming. See + // `AsCustom()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustom(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustom(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsCustom(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsCustom(); + } + + // Performs an unchecked cast from a value to a custom list value. In + // debug builds a best effort is made to crash. If `IsCustom()` would + // return false, calling this method is undefined behavior. + const CustomListValue& GetCustom() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetCustom(); + } + const CustomListValue& GetCustom() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + CustomListValue GetCustom() &&; + CustomListValue GetCustom() const&& { return GetCustom(); } + + // Convenience method for use with template metaprogramming. See + // `GetCustom()`. + template + std::enable_if_t, + const CustomListValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustom(); + } + template + std::enable_if_t, const CustomListValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustom(); + } + template + std::enable_if_t, CustomListValue> + Get() && { + return std::move(*this).GetCustom(); + } + template + std::enable_if_t, CustomListValue> Get() + const&& { + return std::move(*this).GetCustom(); + } + + friend void swap(ListValue& lhs, ListValue& rhs) noexcept { + using std::swap; + swap(lhs.variant_, rhs.variant_); + } + + private: + friend class Value; + friend class common_internal::ValueMixin; + friend class common_internal::ListValueMixin; + + common_internal::ValueVariant ToValueVariant() const&; + common_internal::ValueVariant ToValueVariant() &&; + + // Unlike many of the other derived values, `ListValue` is itself a composed + // type. This is to avoid making `ListValue` too big and by extension + // `Value` too big. Instead we store the derived `ListValue` values in + // `Value` and not `ListValue` itself. + common_internal::ListValueVariant variant_; +}; + +inline std::ostream& operator<<(std::ostream& out, const ListValue& value) { + return out << value.DebugString(); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const ListValue& value) { return value.GetTypeId(); } +}; + +class ListValueBuilder { + public: + virtual ~ListValueBuilder() = default; + + virtual absl::Status Add(Value value) = 0; + + virtual void UnsafeAdd(Value value) = 0; + + virtual bool IsEmpty() const { return Size() == 0; } + + virtual size_t Size() const = 0; + + virtual void Reserve(size_t capacity [[maybe_unused]]) {} + + virtual ListValue Build() && = 0; +}; + +using ListValueBuilderPtr = std::unique_ptr; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_H_ diff --git a/common/values/list_value_builder.h b/common/values/list_value_builder.h new file mode 100644 index 000000000..91cef066d --- /dev/null +++ b/common/values/list_value_builder.h @@ -0,0 +1,110 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_BUILDER_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/public/cel_value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +class ValueFactory; + +namespace common_internal { + +// Special implementation of list which is both a modern list and legacy list. +// Do not try this at home. This should only be implemented in +// `list_value_builder.cc`. +class CompatListValue : public CustomListValueInterface, + public google::api::expr::runtime::CelList { + private: + NativeTypeId GetNativeTypeId() const final { + return NativeTypeId::For(); + } +}; + +const CompatListValue* absl_nonnull EmptyCompatListValue(); + +absl::StatusOr MakeCompatListValue( + const CustomListValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + +// Extension of ParsedListValueInterface which is also mutable. Accessing this +// like a normal list before all elements are finished being appended is a bug. +// This is primarily used by the runtime to efficiently implement comprehensions +// which accumulate results into a list. +// +// IMPORTANT: This type is only meant to be utilized by the runtime. +class MutableListValue : public CustomListValueInterface { + public: + virtual absl::Status Append(Value value) const = 0; + + virtual void Reserve(size_t capacity) const {} + + private: + NativeTypeId GetNativeTypeId() const override { + return NativeTypeId::For(); + } +}; + +// Special implementation of list which is both a modern list, legacy list, and +// mutable. +// +// NOTE: We do not extend CompatListValue to avoid having to use virtual +// inheritance and `dynamic_cast`. +class MutableCompatListValue : public MutableListValue, + public google::api::expr::runtime::CelList { + private: + NativeTypeId GetNativeTypeId() const final { + return NativeTypeId::For(); + } +}; + +MutableListValue* absl_nonnull NewMutableListValue( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); + +bool IsMutableListValue(const Value& value); +bool IsMutableListValue(const ListValue& value); + +const MutableListValue* absl_nullable AsMutableListValue( + const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); +const MutableListValue* absl_nullable AsMutableListValue( + const ListValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + +const MutableListValue& GetMutableListValue( + const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); +const MutableListValue& GetMutableListValue( + const ListValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + +absl_nonnull cel::ListValueBuilderPtr NewListValueBuilder( + google::protobuf::Arena* absl_nonnull arena); + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_BUILDER_H_ diff --git a/common/values/list_value_test.cc b/common/values/list_value_test.cc new file mode 100644 index 000000000..321c05249 --- /dev/null +++ b/common/values/list_value_test.cc @@ -0,0 +1,170 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::ErrorValueIs; +using ::testing::ElementsAreArray; + +class ListValueTest : public common_internal::ValueTest<> { + public: + template + absl::StatusOr NewIntListValue(Args&&... args) { + auto builder = NewListValueBuilder(arena()); + (static_cast(builder->Add(std::forward(args))), ...); + return std::move(*builder).Build(); + } +}; + +TEST_F(ListValueTest, Default) { + ListValue value; + EXPECT_THAT(value.IsEmpty(), IsOkAndHolds(true)); + EXPECT_THAT(value.Size(), IsOkAndHolds(0)); + EXPECT_EQ(value.DebugString(), "[]"); +} + +TEST_F(ListValueTest, Kind) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + EXPECT_EQ(value.kind(), ListValue::kKind); + EXPECT_EQ(Value(value).kind(), ListValue::kKind); +} + +TEST_F(ListValueTest, DebugString) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + { + std::ostringstream out; + out << value; + EXPECT_EQ(out.str(), "[0, 1, 2]"); + } + { + std::ostringstream out; + out << Value(value); + EXPECT_EQ(out.str(), "[0, 1, 2]"); + } +} + +TEST_F(ListValueTest, IsEmpty) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + EXPECT_THAT(value.IsEmpty(), IsOkAndHolds(false)); +} + +TEST_F(ListValueTest, Size) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + EXPECT_THAT(value.Size(), IsOkAndHolds(3)); +} + +TEST_F(ListValueTest, Get) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + ASSERT_OK_AND_ASSIGN(auto element, value.Get(0, descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(element)); + ASSERT_EQ(Cast(element).NativeValue(), 0); + ASSERT_OK_AND_ASSIGN( + element, value.Get(1, descriptor_pool(), message_factory(), arena())); + ASSERT_TRUE(InstanceOf(element)); + ASSERT_EQ(Cast(element).NativeValue(), 1); + ASSERT_OK_AND_ASSIGN( + element, value.Get(2, descriptor_pool(), message_factory(), arena())); + ASSERT_TRUE(InstanceOf(element)); + ASSERT_EQ(Cast(element).NativeValue(), 2); + EXPECT_THAT( + value.Get(3, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); +} + +TEST_F(ListValueTest, ForEach) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + std::vector elements; + EXPECT_THAT(value.ForEach( + [&elements](const Value& element) { + elements.push_back(Cast(element).NativeValue()); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(elements, ElementsAreArray({0, 1, 2})); +} + +TEST_F(ListValueTest, Contains) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + ASSERT_OK_AND_ASSIGN(auto contained, + value.Contains(IntValue(2), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(contained)); + EXPECT_TRUE(Cast(contained).NativeValue()); + ASSERT_OK_AND_ASSIGN(contained, value.Contains(IntValue(3), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(contained)); + EXPECT_FALSE(Cast(contained).NativeValue()); +} + +TEST_F(ListValueTest, NewIterator) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + std::vector elements; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto element, + iterator->Next(descriptor_pool(), message_factory(), arena())); + ASSERT_TRUE(InstanceOf(element)); + elements.push_back(Cast(element).NativeValue()); + } + EXPECT_EQ(iterator->HasNext(), false); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(elements, ElementsAreArray({0, 1, 2})); +} + +TEST_F(ListValueTest, ConvertToJson) { + ASSERT_OK_AND_ASSIGN(auto value, + NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); + auto* message = NewArenaValueMessage(); + EXPECT_THAT( + value.ConvertToJson(descriptor_pool(), message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(list_value: { + values: { number_value: 0 } + values: { number_value: 1 } + values: { number_value: 2 } + })pb")); +} + +} // namespace +} // namespace cel diff --git a/common/values/list_value_variant.h b/common/values/list_value_variant.h new file mode 100644 index 000000000..660c002b4 --- /dev/null +++ b/common/values/list_value_variant.h @@ -0,0 +1,214 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_VARIANT_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_VARIANT_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/meta/type_traits.h" +#include "absl/utility/utility.h" +#include "common/values/custom_list_value.h" +#include "common/values/legacy_list_value.h" +#include "common/values/parsed_json_list_value.h" +#include "common/values/parsed_repeated_field_value.h" + +namespace cel::common_internal { + +enum class ListValueIndex : uint16_t { + kCustom = 0, + kParsedField, + kParsedJson, + kLegacy, +}; + +template +struct ListValueAlternative; + +template <> +struct ListValueAlternative { + static constexpr ListValueIndex kIndex = ListValueIndex::kCustom; +}; + +template <> +struct ListValueAlternative { + static constexpr ListValueIndex kIndex = ListValueIndex::kParsedField; +}; + +template <> +struct ListValueAlternative { + static constexpr ListValueIndex kIndex = ListValueIndex::kParsedJson; +}; + +template <> +struct ListValueAlternative { + static constexpr ListValueIndex kIndex = ListValueIndex::kLegacy; +}; + +template +struct IsListValueAlternative : std::false_type {}; + +template +struct IsListValueAlternative{})>> + : std::true_type {}; + +template +inline constexpr bool IsListValueAlternativeV = + IsListValueAlternative::value; + +inline constexpr size_t kListValueVariantAlign = 8; +inline constexpr size_t kListValueVariantSize = 24; + +// ListValueVariant is a subset of alternatives from the main ValueVariant that +// is only lists. It is not stored directly in ValueVariant. +class alignas(kListValueVariantAlign) ListValueVariant final { + public: + ListValueVariant() : ListValueVariant(absl::in_place_type) {} + + ListValueVariant(const ListValueVariant&) = default; + ListValueVariant(ListValueVariant&&) = default; + ListValueVariant& operator=(const ListValueVariant&) = default; + ListValueVariant& operator=(ListValueVariant&&) = default; + + template + explicit ListValueVariant(absl::in_place_type_t, Args&&... args) + : index_(ListValueAlternative::kIndex) { + static_assert(alignof(T) <= kListValueVariantAlign); + static_assert(sizeof(T) <= kListValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + ::new (static_cast(&raw_[0])) T(std::forward(args)...); + } + + template >>> + explicit ListValueVariant(T&& value) + : ListValueVariant(absl::in_place_type>, + std::forward(value)) {} + + template + void Assign(T&& value) { + using U = absl::remove_cvref_t; + + static_assert(alignof(U) <= kListValueVariantAlign); + static_assert(sizeof(U) <= kListValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + index_ = ListValueAlternative::kIndex; + ::new (static_cast(&raw_[0])) U(std::forward(value)); + } + + template + bool Is() const { + return index_ == ListValueAlternative::kIndex; + } + + template + T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + decltype(auto) Visit(Visitor&& visitor) const { + switch (index_) { + case ListValueIndex::kCustom: + return std::forward(visitor)(Get()); + case ListValueIndex::kParsedField: + return std::forward(visitor)(Get()); + case ListValueIndex::kParsedJson: + return std::forward(visitor)(Get()); + case ListValueIndex::kLegacy: + return std::forward(visitor)(Get()); + } + } + + friend void swap(ListValueVariant& lhs, ListValueVariant& rhs) noexcept { + using std::swap; + swap(lhs.index_, rhs.index_); + swap(lhs.raw_, rhs.raw_); + } + + private: + template + ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kListValueVariantAlign); + static_assert(sizeof(T) <= kListValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kListValueVariantAlign); + static_assert(sizeof(T) <= kListValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + ListValueIndex index_ = ListValueIndex::kCustom; + alignas(8) std::byte raw_[kListValueVariantSize]; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_VARIANT_H_ diff --git a/common/values/map_value.cc b/common/values/map_value.cc new file mode 100644 index 000000000..c8bf7b785 --- /dev/null +++ b/common/values/map_value.cc @@ -0,0 +1,378 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/values/value_variant.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +absl::Status InvalidMapKeyTypeError(ValueKind kind) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid map key type: '", ValueKindToString(kind), "'")); +} + +} // namespace + +NativeTypeId MapValue::GetTypeId() const { + return variant_.Visit([](const auto& alternative) -> NativeTypeId { + return NativeTypeId::Of(alternative); + }); +} + +std::string MapValue::DebugString() const { + return variant_.Visit([](const auto& alternative) -> std::string { + return alternative.DebugString(); + }); +} + +absl::Status MapValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.SerializeTo(descriptor_pool, message_factory, output); + }); +} + +absl::Status MapValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ConvertToJson(descriptor_pool, message_factory, json); + }); +} + +absl::Status MapValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }); +} + +absl::Status MapValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Equal(other, descriptor_pool, message_factory, arena, + result); + }); +} + +bool MapValue::IsZeroValue() const { + return variant_.Visit([](const auto& alternative) -> bool { + return alternative.IsZeroValue(); + }); +} + +absl::StatusOr MapValue::IsEmpty() const { + return variant_.Visit([](const auto& alternative) -> absl::StatusOr { + return alternative.IsEmpty(); + }); +} + +absl::StatusOr MapValue::Size() const { + return variant_.Visit([](const auto& alternative) -> absl::StatusOr { + return alternative.Size(); + }); +} + +absl::Status MapValue::Get( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Get(key, descriptor_pool, message_factory, arena, + result); + }); +} + +absl::StatusOr MapValue::Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::StatusOr { + return alternative.Find(key, descriptor_pool, message_factory, arena, + result); + }); +} + +absl::Status MapValue::Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Has(key, descriptor_pool, message_factory, arena, + result); + }); +} + +absl::Status MapValue::ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ListKeys(descriptor_pool, message_factory, arena, + result); + }); +} + +absl::Status MapValue::ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ForEach(callback, descriptor_pool, message_factory, + arena); + }); +} + +absl::StatusOr MapValue::NewIterator() const { + return variant_.Visit([](const auto& alternative) + -> absl::StatusOr { + return alternative.NewIterator(); + }); +} + +namespace common_internal { + +absl::Status MapValueEqual( + const MapValue& lhs, const MapValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + CEL_ASSIGN_OR_RETURN(auto lhs_size, lhs.Size()); + CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); + if (lhs_size != rhs_size) { + *result = FalseValue(); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); + Value lhs_key; + Value lhs_value; + Value rhs_value; + for (size_t index = 0; index < lhs_size; ++index) { + ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK + CEL_RETURN_IF_ERROR( + lhs_iterator->Next(descriptor_pool, message_factory, arena, &lhs_key)); + bool rhs_value_found; + CEL_ASSIGN_OR_RETURN( + rhs_value_found, + rhs.Find(lhs_key, descriptor_pool, message_factory, arena, &rhs_value)); + if (!rhs_value_found) { + *result = FalseValue(); + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( + lhs.Get(lhs_key, descriptor_pool, message_factory, arena, &lhs_value)); + CEL_RETURN_IF_ERROR(lhs_value.Equal(rhs_value, descriptor_pool, + message_factory, arena, result)); + if (result->IsFalse()) { + return absl::OkStatus(); + } + } + ABSL_DCHECK(!lhs_iterator->HasNext()); + *result = TrueValue(); + return absl::OkStatus(); +} + +absl::Status MapValueEqual( + const CustomMapValueInterface& lhs, const MapValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + auto lhs_size = lhs.Size(); + CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); + if (lhs_size != rhs_size) { + *result = FalseValue(); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); + Value lhs_key; + Value lhs_value; + Value rhs_value; + for (size_t index = 0; index < lhs_size; ++index) { + ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK + CEL_RETURN_IF_ERROR( + lhs_iterator->Next(descriptor_pool, message_factory, arena, &lhs_key)); + bool rhs_value_found; + CEL_ASSIGN_OR_RETURN( + rhs_value_found, + rhs.Find(lhs_key, descriptor_pool, message_factory, arena, &rhs_value)); + if (!rhs_value_found) { + *result = FalseValue(); + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( + CustomMapValue(&lhs, arena) + .Get(lhs_key, descriptor_pool, message_factory, arena, &lhs_value)); + CEL_RETURN_IF_ERROR(lhs_value.Equal(rhs_value, descriptor_pool, + message_factory, arena, result)); + if (result->IsFalse()) { + return absl::OkStatus(); + } + } + ABSL_DCHECK(!lhs_iterator->HasNext()); + *result = TrueValue(); + return absl::OkStatus(); +} + +} // namespace common_internal + +absl::Status CheckMapKey(const Value& key) { + switch (key.kind()) { + case ValueKind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kString: + return absl::OkStatus(); + case ValueKind::kError: + return key.GetError().NativeValue(); + default: + return InvalidMapKeyTypeError(key.kind()); + } +} + +optional_ref MapValue::AsCustom() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional MapValue::AsCustom() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +const CustomMapValue& MapValue::GetCustom() const& { + ABSL_DCHECK(IsCustom()); + + return variant_.Get(); +} + +CustomMapValue MapValue::GetCustom() && { + ABSL_DCHECK(IsCustom()); + + return std::move(variant_).Get(); +} + +common_internal::ValueVariant MapValue::ToValueVariant() const& { + return variant_.Visit( + [](const auto& alternative) -> common_internal::ValueVariant { + return common_internal::ValueVariant(alternative); + }); +} + +common_internal::ValueVariant MapValue::ToValueVariant() && { + return std::move(variant_).Visit( + [](auto&& alternative) -> common_internal::ValueVariant { + // NOLINTNEXTLINE(bugprone-move-forwarding-reference) + return common_internal::ValueVariant(std::move(alternative)); + }); +} + +} // namespace cel diff --git a/common/values/map_value.h b/common/values/map_value.h new file mode 100644 index 000000000..b6e69ea57 --- /dev/null +++ b/common/values/map_value.h @@ -0,0 +1,323 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +// `MapValue` represents values of the primitive `map` type. It provides a +// unified interface for accessing map contents, regardless of the underlying +// implementation (e.g., JSON, protobuf map field, or custom implementation). +// +// Public member functions: +// - `IsEmpty()` / `Size()`: Query map size. +// - `Get()` / `Find()` / `Has()`: Access entries by key. +// - `ListKeys()` / `NewIterator()` / `ForEach()`: Iterate over entries. +// - `ConvertToJson()` / `ConvertToJsonObject()`: JSON conversion. +// - `IsCustom()` / `AsCustom()` / `GetCustom()`: Access custom implementation. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/utility/utility.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/value_kind.h" +#include "common/values/custom_map_value.h" +#include "common/values/legacy_map_value.h" +#include "common/values/map_value_variant.h" +#include "common/values/parsed_json_map_value.h" +#include "common/values/parsed_map_field_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class MapValue; +class Value; + +absl::Status CheckMapKey(const Value& key); + +class MapValue final : private common_internal::MapValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kMap; + + // Move constructor for alternative struct values. + template >>> + // NOLINTNEXTLINE(google-explicit-constructor) + MapValue(T&& value) + : variant_(absl::in_place_type>, + std::forward(value)) {} + + MapValue() = default; + MapValue(const MapValue&) = default; + MapValue(MapValue&&) = default; + MapValue& operator=(const MapValue&) = default; + MapValue& operator=(MapValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + static absl::string_view GetTypeName() { return "map"; } + + NativeTypeId GetTypeId() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // Like ConvertToJson(), except `json` **MUST** be an instance of + // `google.protobuf.Struct`. + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Equal; + + bool IsZeroValue() const; + + absl::StatusOr IsEmpty() const; + + absl::StatusOr Size() const; + + // `Get` sets the value `result` to (via `result`) the value associated with + // `key`. If `key` is not found, `no such key` is set to `result`. If an error + // occurs (e.g., invalid key type), an `no such key` is returned. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. + absl::Status Get(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Get; + + // `Find` returns `true` if `key` is found in the map, and stores the + // associated value in `result`. If `key` is not found, `false` is returned + // and `result` is unchanged. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using MapValueMixin::Find; + + // `Has` returns `true` if `key` is found in the map, and stores the BoolValue + // result in `result`. In case of an error, the result is set to an + // ErrorValue. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. + absl::Status Has(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Has; + + // `ListKeys` returns a `ListValue` containing all keys in the map. + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; + using MapValueMixin::ListKeys; + + // `ForEachCallback` is the callback type for `ForEach`. + using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; + + // `ForEach` calls `callback` for each entry in the map. Iteration continues + // until all entries are visited or `callback` returns an error or `false`. + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + // `NewIterator` returns a new iterator for the map. + absl::StatusOr NewIterator() const; + + // Returns `true` if this value is an instance of a custom map value. + bool IsCustom() const { return variant_.Is(); } + + // Convenience method for use with template metaprogramming. See + // `IsCustom()`. + template + std::enable_if_t, bool> Is() const { + return IsCustom(); + } + + // Performs a checked cast from a value to a custom map value, + // returning a non-empty optional with either a value or reference to the + // custom map value. Otherwise an empty optional is returned. + optional_ref AsCustom() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsCustom(); + } + optional_ref AsCustom() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsCustom() &&; + absl::optional AsCustom() const&& { + return common_internal::AsOptional(AsCustom()); + } + + // Convenience method for use with template metaprogramming. See + // `AsCustom()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustom(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsCustom(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsCustom(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsCustom(); + } + + // Performs an unchecked cast from a value to a custom map value. In + // debug builds a best effort is made to crash. If `IsCustom()` would + // return false, calling this method is undefined behavior. + const CustomMapValue& GetCustom() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetCustom(); + } + const CustomMapValue& GetCustom() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + CustomMapValue GetCustom() &&; + CustomMapValue GetCustom() const&& { return GetCustom(); } + + // Convenience method for use with template metaprogramming. See + // `GetCustom()`. + template + std::enable_if_t, const CustomMapValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustom(); + } + template + std::enable_if_t, const CustomMapValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetCustom(); + } + template + std::enable_if_t, CustomMapValue> Get() && { + return std::move(*this).GetCustom(); + } + template + std::enable_if_t, CustomMapValue> Get() + const&& { + return std::move(*this).GetCustom(); + } + + friend void swap(MapValue& lhs, MapValue& rhs) noexcept { + using std::swap; + swap(lhs.variant_, rhs.variant_); + } + + private: + friend class Value; + friend class common_internal::ValueMixin; + friend class common_internal::MapValueMixin; + + common_internal::ValueVariant ToValueVariant() const&; + common_internal::ValueVariant ToValueVariant() &&; + + // Unlike many of the other derived values, `MapValue` is itself a composed + // type. This is to avoid making `MapValue` too big and by extension + // `Value` too big. Instead we store the derived `MapValue` values in + // `Value` and not `MapValue` itself. + common_internal::MapValueVariant variant_; +}; + +inline std::ostream& operator<<(std::ostream& out, const MapValue& value) { + return out << value.DebugString(); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const MapValue& value) { return value.GetTypeId(); } +}; + +class MapValueBuilder { + public: + virtual ~MapValueBuilder() = default; + + virtual absl::Status Put(Value key, Value value) = 0; + + virtual void UnsafePut(Value key, Value value) = 0; + + virtual bool IsEmpty() const { return Size() == 0; } + + virtual size_t Size() const = 0; + + virtual void Reserve(size_t capacity [[maybe_unused]]) {} + + virtual MapValue Build() && = 0; +}; + +using MapValueBuilderPtr = std::unique_ptr; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ diff --git a/common/values/map_value_builder.h b/common/values/map_value_builder.h new file mode 100644 index 000000000..a5a47eda9 --- /dev/null +++ b/common/values/map_value_builder.h @@ -0,0 +1,110 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_BUILDER_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/public/cel_value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +class ValueFactory; + +namespace common_internal { + +// Special implementation of map which is both a modern map and legacy map. Do +// not try this at home. This should only be implemented in +// `map_value_builder.cc`. +class CompatMapValue : public CustomMapValueInterface, + public google::api::expr::runtime::CelMap { + private: + NativeTypeId GetNativeTypeId() const final { + return NativeTypeId::For(); + } +}; + +const CompatMapValue* absl_nonnull EmptyCompatMapValue(); + +absl::StatusOr MakeCompatMapValue( + const CustomMapValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + +// Extension of ParsedMapValueInterface which is also mutable. Accessing this +// like a normal map before all entries are finished being inserted is a bug. +// This is primarily used by the runtime to efficiently implement comprehensions +// which accumulate results into a map. +// +// IMPORTANT: This type is only meant to be utilized by the runtime. +class MutableMapValue : public CustomMapValueInterface { + public: + virtual absl::Status Put(Value key, Value value) const = 0; + + virtual void Reserve(size_t capacity) const {} + + private: + NativeTypeId GetNativeTypeId() const override { + return NativeTypeId::For(); + } +}; + +// Special implementation of map which is both a modern map, legacy map, and +// mutable. +// +// NOTE: We do not extend CompatMapValue to avoid having to use virtual +// inheritance and `dynamic_cast`. +class MutableCompatMapValue : public MutableMapValue, + public google::api::expr::runtime::CelMap { + private: + NativeTypeId GetNativeTypeId() const final { + return NativeTypeId::For(); + } +}; + +MutableMapValue* absl_nonnull NewMutableMapValue( + google::protobuf::Arena* absl_nonnull arena); + +bool IsMutableMapValue(const Value& value); +bool IsMutableMapValue(const MapValue& value); + +const MutableMapValue* absl_nullable AsMutableMapValue( + const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); +const MutableMapValue* absl_nullable AsMutableMapValue( + const MapValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + +const MutableMapValue& GetMutableMapValue( + const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); +const MutableMapValue& GetMutableMapValue( + const MapValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + +absl_nonnull cel::MapValueBuilderPtr NewMapValueBuilder( + google::protobuf::Arena* absl_nonnull arena); + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_BUILDER_H_ diff --git a/common/values/map_value_test.cc b/common/values/map_value_test.cc new file mode 100644 index 000000000..f7d1c5197 --- /dev/null +++ b/common/values/map_value_test.cc @@ -0,0 +1,297 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::ErrorValueIs; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::UnorderedElementsAreArray; + +TEST(MapValue, CheckKey) { + EXPECT_THAT(CheckMapKey(BoolValue()), IsOk()); + EXPECT_THAT(CheckMapKey(IntValue()), IsOk()); + EXPECT_THAT(CheckMapKey(UintValue()), IsOk()); + EXPECT_THAT(CheckMapKey(StringValue()), IsOk()); + EXPECT_THAT(CheckMapKey(BytesValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +class MapValueTest : public common_internal::ValueTest<> { + public: + template + absl::StatusOr NewIntDoubleMapValue(Args&&... args) { + auto builder = NewMapValueBuilder(arena()); + (static_cast(builder->Put(std::forward(args).first, + std::forward(args).second)), + ...); + return std::move(*builder).Build(); + } + + template + absl::StatusOr NewJsonMapValue(Args&&... args) { + auto builder = NewMapValueBuilder(arena()); + (static_cast(builder->Put(std::forward(args).first, + std::forward(args).second)), + ...); + return std::move(*builder).Build(); + } +}; + +TEST_F(MapValueTest, Default) { + MapValue map_value; + EXPECT_THAT(map_value.IsEmpty(), IsOkAndHolds(true)); + EXPECT_THAT(map_value.Size(), IsOkAndHolds(0)); + EXPECT_EQ(map_value.DebugString(), "{}"); + ASSERT_OK_AND_ASSIGN( + auto list_value, + map_value.ListKeys(descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(list_value.IsEmpty(), IsOkAndHolds(true)); + EXPECT_THAT(list_value.Size(), IsOkAndHolds(0)); + EXPECT_EQ(list_value.DebugString(), "[]"); + ASSERT_OK_AND_ASSIGN(auto iterator, map_value.NewIterator()); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(MapValueTest, Kind) { + ASSERT_OK_AND_ASSIGN( + auto value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + EXPECT_EQ(value.kind(), MapValue::kKind); + EXPECT_EQ(Value(value).kind(), MapValue::kKind); +} + +TEST_F(MapValueTest, DebugString) { + ASSERT_OK_AND_ASSIGN( + auto value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + { + std::ostringstream out; + out << value; + EXPECT_THAT(out.str(), Not(IsEmpty())); + } + { + std::ostringstream out; + out << Value(value); + EXPECT_THAT(out.str(), Not(IsEmpty())); + } +} + +TEST_F(MapValueTest, IsEmpty) { + ASSERT_OK_AND_ASSIGN( + auto value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + EXPECT_THAT(value.IsEmpty(), IsOkAndHolds(false)); +} + +TEST_F(MapValueTest, Size) { + ASSERT_OK_AND_ASSIGN( + auto value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + EXPECT_THAT(value.Size(), IsOkAndHolds(3)); +} + +TEST_F(MapValueTest, Get) { + ASSERT_OK_AND_ASSIGN( + auto map_value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + ASSERT_OK_AND_ASSIGN(auto value, map_value.Get(IntValue(0), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(value)); + ASSERT_EQ(Cast(value).NativeValue(), 3.0); + ASSERT_OK_AND_ASSIGN(value, map_value.Get(IntValue(1), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(value)); + ASSERT_EQ(Cast(value).NativeValue(), 4.0); + ASSERT_OK_AND_ASSIGN(value, map_value.Get(IntValue(2), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(value)); + ASSERT_EQ(Cast(value).NativeValue(), 5.0); + EXPECT_THAT( + map_value.Get(IntValue(3), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); +} + +TEST_F(MapValueTest, Find) { + ASSERT_OK_AND_ASSIGN( + auto map_value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + absl::optional entry; + ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(0), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(entry); + ASSERT_TRUE(InstanceOf(*entry)); + ASSERT_EQ(Cast(*entry).NativeValue(), 3.0); + ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(1), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(entry); + ASSERT_TRUE(InstanceOf(*entry)); + ASSERT_EQ(Cast(*entry).NativeValue(), 4.0); + ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(2), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(entry); + ASSERT_TRUE(InstanceOf(*entry)); + ASSERT_EQ(Cast(*entry).NativeValue(), 5.0); + ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(3), descriptor_pool(), + message_factory(), arena())); + ASSERT_FALSE(entry); +} + +TEST_F(MapValueTest, Has) { + ASSERT_OK_AND_ASSIGN( + auto map_value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + ASSERT_OK_AND_ASSIGN(auto value, map_value.Has(IntValue(0), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(value)); + ASSERT_TRUE(Cast(value).NativeValue()); + ASSERT_OK_AND_ASSIGN(value, map_value.Has(IntValue(1), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(value)); + ASSERT_TRUE(Cast(value).NativeValue()); + ASSERT_OK_AND_ASSIGN(value, map_value.Has(IntValue(2), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(value)); + ASSERT_TRUE(Cast(value).NativeValue()); + ASSERT_OK_AND_ASSIGN(value, map_value.Has(IntValue(3), descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(InstanceOf(value)); + ASSERT_FALSE(Cast(value).NativeValue()); +} + +TEST_F(MapValueTest, ListKeys) { + ASSERT_OK_AND_ASSIGN( + auto map_value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + ASSERT_OK_AND_ASSIGN( + auto list_keys, + map_value.ListKeys(descriptor_pool(), message_factory(), arena())); + std::vector keys; + ASSERT_THAT(list_keys.ForEach( + [&keys](const Value& element) -> bool { + keys.push_back(Cast(element).NativeValue()); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(keys, UnorderedElementsAreArray({0, 1, 2})); +} + +TEST_F(MapValueTest, ForEach) { + ASSERT_OK_AND_ASSIGN( + auto value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + std::vector> entries; + EXPECT_THAT(value.ForEach( + [&entries](const Value& key, const Value& value) { + entries.push_back( + std::pair{Cast(key).NativeValue(), + Cast(value).NativeValue()}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAreArray( + {std::pair{0, 3.0}, std::pair{1, 4.0}, std::pair{2, 5.0}})); +} + +TEST_F(MapValueTest, NewIterator) { + ASSERT_OK_AND_ASSIGN( + auto value, + NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, + std::pair{IntValue(1), DoubleValue(4.0)}, + std::pair{IntValue(2), DoubleValue(5.0)})); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + std::vector keys; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto element, + iterator->Next(descriptor_pool(), message_factory(), arena())); + ASSERT_TRUE(InstanceOf(element)); + keys.push_back(Cast(element).NativeValue()); + } + EXPECT_EQ(iterator->HasNext(), false); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(keys, UnorderedElementsAreArray({0, 1, 2})); +} + +TEST_F(MapValueTest, ConvertToJson) { + ASSERT_OK_AND_ASSIGN( + auto value, + NewJsonMapValue(std::pair{StringValue("0"), DoubleValue(3.0)}, + std::pair{StringValue("1"), DoubleValue(4.0)}, + std::pair{StringValue("2"), DoubleValue(5.0)})); + auto* message = NewArenaValueMessage(); + EXPECT_THAT( + value.ConvertToJson(descriptor_pool(), message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(struct_value: { + fields: { + key: "0" + value: { number_value: 3 } + } + fields: { + key: "1" + value: { number_value: 4 } + } + fields: { + key: "2" + value: { number_value: 5 } + } + })pb")); +} + +} // namespace +} // namespace cel diff --git a/common/values/map_value_variant.h b/common/values/map_value_variant.h new file mode 100644 index 000000000..e7cf5b6b7 --- /dev/null +++ b/common/values/map_value_variant.h @@ -0,0 +1,212 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_VARIANT_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_VARIANT_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/meta/type_traits.h" +#include "absl/utility/utility.h" +#include "common/values/custom_map_value.h" +#include "common/values/legacy_map_value.h" +#include "common/values/parsed_json_map_value.h" +#include "common/values/parsed_map_field_value.h" + +namespace cel::common_internal { + +enum class MapValueIndex : uint16_t { + kCustom = 0, + kParsedField, + kParsedJson, + kLegacy, +}; + +template +struct MapValueAlternative; + +template <> +struct MapValueAlternative { + static constexpr MapValueIndex kIndex = MapValueIndex::kCustom; +}; + +template <> +struct MapValueAlternative { + static constexpr MapValueIndex kIndex = MapValueIndex::kParsedField; +}; + +template <> +struct MapValueAlternative { + static constexpr MapValueIndex kIndex = MapValueIndex::kParsedJson; +}; + +template <> +struct MapValueAlternative { + static constexpr MapValueIndex kIndex = MapValueIndex::kLegacy; +}; + +template +struct IsMapValueAlternative : std::false_type {}; + +template +struct IsMapValueAlternative{})>> + : std::true_type {}; + +template +inline constexpr bool IsMapValueAlternativeV = IsMapValueAlternative::value; + +inline constexpr size_t kMapValueVariantAlign = 8; +inline constexpr size_t kMapValueVariantSize = 24; + +// MapValueVariant is a subset of alternatives from the main ValueVariant that +// is only maps. It is not stored directly in ValueVariant. +class alignas(kMapValueVariantAlign) MapValueVariant final { + public: + MapValueVariant() : MapValueVariant(absl::in_place_type) {} + + MapValueVariant(const MapValueVariant&) = default; + MapValueVariant(MapValueVariant&&) = default; + MapValueVariant& operator=(const MapValueVariant&) = default; + MapValueVariant& operator=(MapValueVariant&&) = default; + + template + explicit MapValueVariant(absl::in_place_type_t, Args&&... args) + : index_(MapValueAlternative::kIndex) { + static_assert(alignof(T) <= kMapValueVariantAlign); + static_assert(sizeof(T) <= kMapValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + ::new (static_cast(&raw_[0])) T(std::forward(args)...); + } + + template >>> + explicit MapValueVariant(T&& value) + : MapValueVariant(absl::in_place_type>, + std::forward(value)) {} + + template + void Assign(T&& value) { + using U = absl::remove_cvref_t; + + static_assert(alignof(U) <= kMapValueVariantAlign); + static_assert(sizeof(U) <= kMapValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + index_ = MapValueAlternative::kIndex; + ::new (static_cast(&raw_[0])) U(std::forward(value)); + } + + template + bool Is() const { + return index_ == MapValueAlternative::kIndex; + } + + template + T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + decltype(auto) Visit(Visitor&& visitor) const { + switch (index_) { + case MapValueIndex::kCustom: + return std::forward(visitor)(Get()); + case MapValueIndex::kParsedField: + return std::forward(visitor)(Get()); + case MapValueIndex::kParsedJson: + return std::forward(visitor)(Get()); + case MapValueIndex::kLegacy: + return std::forward(visitor)(Get()); + } + } + + friend void swap(MapValueVariant& lhs, MapValueVariant& rhs) noexcept { + using std::swap; + swap(lhs.index_, rhs.index_); + swap(lhs.raw_, rhs.raw_); + } + + private: + template + ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kMapValueVariantAlign); + static_assert(sizeof(T) <= kMapValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kMapValueVariantAlign); + static_assert(sizeof(T) <= kMapValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + MapValueIndex index_ = MapValueIndex::kCustom; + alignas(8) std::byte raw_[kMapValueVariantSize]; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_VARIANT_H_ diff --git a/common/values/message_value.cc b/common/values/message_value.cc new file mode 100644 index 000000000..66dfd9511 --- /dev/null +++ b/common/values/message_value.cc @@ -0,0 +1,306 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/message_value.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "base/attribute.h" +#include "common/optional_ref.h" +#include "common/value.h" +#include "common/values/parsed_message_value.h" +#include "common/values/value_variant.h" +#include "common/values/values.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() const { + ABSL_CHECK(*this); // Crash OK + return absl::visit( + absl::Overload( + [](std::monostate) -> const google::protobuf::Descriptor* absl_nonnull { + ABSL_UNREACHABLE(); + }, + [](const ParsedMessageValue& alternative) + -> const google::protobuf::Descriptor* absl_nonnull { + return alternative.GetDescriptor(); + }), + variant_); +} + +std::string MessageValue::DebugString() const { + return absl::visit( + absl::Overload([](std::monostate) -> std::string { return "INVALID"; }, + [](const ParsedMessageValue& alternative) -> std::string { + return alternative.DebugString(); + }), + variant_); +} + +bool MessageValue::IsZeroValue() const { + ABSL_DCHECK(*this); + return absl::visit( + absl::Overload([](std::monostate) -> bool { return true; }, + [](const ParsedMessageValue& alternative) -> bool { + return alternative.IsZeroValue(); + }), + variant_); +} + +absl::Status MessageValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `ConvertToJson` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.SerializeTo(descriptor_pool, message_factory, + output); + }), + variant_); +} + +absl::Status MessageValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `ConvertToJson` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.ConvertToJson(descriptor_pool, message_factory, + json); + }), + variant_); +} + +absl::Status MessageValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `ConvertToJsonObject` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, + message_factory, json); + }), + variant_); +} + +absl::Status MessageValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `Equal` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.Equal(other, descriptor_pool, message_factory, + arena, result); + }), + variant_); +} + +absl::Status MessageValue::GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `GetFieldByName` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.GetFieldByName(name, unboxing_options, + descriptor_pool, message_factory, + arena, result); + }), + variant_); +} + +absl::Status MessageValue::GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `GetFieldByNumber` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.GetFieldByNumber(number, unboxing_options, + descriptor_pool, + message_factory, arena, result); + }), + variant_); +} + +absl::StatusOr MessageValue::HasFieldByName( + absl::string_view name) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::StatusOr { + return absl::InternalError( + "unexpected attempt to invoke `HasFieldByName` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::StatusOr { + return alternative.HasFieldByName(name); + }), + variant_); +} + +absl::StatusOr MessageValue::HasFieldByNumber(int64_t number) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::StatusOr { + return absl::InternalError( + "unexpected attempt to invoke `HasFieldByNumber` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::StatusOr { + return alternative.HasFieldByNumber(number); + }), + variant_); +} + +absl::Status MessageValue::ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `ForEachField` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.ForEachField(callback, descriptor_pool, + message_factory, arena); + }), + variant_); +} + +absl::Status MessageValue::Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const { + return absl::visit( + absl::Overload( + [](std::monostate) -> absl::Status { + return absl::InternalError( + "unexpected attempt to invoke `Qualify` on " + "an invalid `MessageValue`"); + }, + [&](const ParsedMessageValue& alternative) -> absl::Status { + return alternative.Qualify(qualifiers, presence_test, + descriptor_pool, message_factory, arena, + result, count); + }), + variant_); +} + +cel::optional_ref MessageValue::AsParsed() const& { + if (const auto* alternative = absl::get_if(&variant_); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional MessageValue::AsParsed() && { + if (auto* alternative = absl::get_if(&variant_); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +const ParsedMessageValue& MessageValue::GetParsed() const& { + ABSL_DCHECK(IsParsed()); + return absl::get(variant_); +} + +ParsedMessageValue MessageValue::GetParsed() && { + ABSL_DCHECK(IsParsed()); + return absl::get(std::move(variant_)); +} + +common_internal::ValueVariant MessageValue::ToValueVariant() const& { + return common_internal::ValueVariant(absl::get(variant_)); +} + +common_internal::ValueVariant MessageValue::ToValueVariant() && { + return common_internal::ValueVariant( + absl::get(std::move(variant_))); +} + +common_internal::StructValueVariant MessageValue::ToStructValueVariant() + const& { + return common_internal::StructValueVariant( + absl::get(variant_)); +} + +common_internal::StructValueVariant MessageValue::ToStructValueVariant() && { + return common_internal::StructValueVariant( + absl::get(std::move(variant_))); +} + +} // namespace cel diff --git a/common/values/message_value.h b/common/values/message_value.h new file mode 100644 index 000000000..480cdcc82 --- /dev/null +++ b/common/values/message_value.h @@ -0,0 +1,268 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MESSAGE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MESSAGE_VALUE_H_ + +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" +#include "base/attribute.h" +#include "common/arena.h" +#include "common/optional_ref.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_struct_value.h" +#include "common/values/parsed_message_value.h" +#include "common/values/values.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class StructValue; + +class MessageValue final + : private common_internal::StructValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kStruct; + + // NOLINTNEXTLINE(google-explicit-constructor) + MessageValue(const ParsedMessageValue& other) + : variant_(absl::in_place_type, other) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + MessageValue(ParsedMessageValue&& other) + : variant_(absl::in_place_type, std::move(other)) {} + + // Places the `MessageValue` into an unspecified state. Anything except + // assigning to `MessageValue` is undefined behavior. + MessageValue() = default; + MessageValue(const MessageValue&) = default; + MessageValue(MessageValue&&) = default; + MessageValue& operator=(const MessageValue&) = default; + MessageValue& operator=(MessageValue&&) = default; + + static ValueKind kind() { return kKind; } + + absl::string_view GetTypeName() const { return GetDescriptor()->full_name(); } + + MessageType GetRuntimeType() const { return MessageType(GetDescriptor()); } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const; + + bool IsZeroValue() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using StructValueMixin::Equal; + + absl::Status GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByName; + + absl::Status GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByNumber; + + absl::StatusOr HasFieldByName(absl::string_view name) const; + + absl::StatusOr HasFieldByNumber(int64_t number) const; + + using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; + + absl::Status ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const; + using StructValueMixin::Qualify; + + bool IsParsed() const { + return absl::holds_alternative(variant_); + } + + template + std::enable_if_t, bool> Is() const { + return IsParsed(); + } + + cel::optional_ref AsParsed() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsParsed(); + } + cel::optional_ref AsParsed() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsParsed() &&; + absl::optional AsParsed() const&& { + return common_internal::AsOptional(AsParsed()); + } + + template + std::enable_if_t, + cel::optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsed(); + } + template + std::enable_if_t, + cel::optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return IsParsed(); + } + template + std::enable_if_t, + absl::optional> + As() && ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::move(*this).AsParsed(); + } + template + std::enable_if_t, + absl::optional> + As() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::move(*this).AsParsed(); + } + + const ParsedMessageValue& GetParsed() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetParsed(); + } + const ParsedMessageValue& GetParsed() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ParsedMessageValue GetParsed() &&; + ParsedMessageValue GetParsed() const&& { return GetParsed(); } + + template + std::enable_if_t, + const ParsedMessageValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsed(); + } + template + std::enable_if_t, + const ParsedMessageValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsed(); + } + template + std::enable_if_t, ParsedMessageValue> + Get() && { + return std::move(*this).GetParsed(); + } + template + std::enable_if_t, ParsedMessageValue> + Get() const&& { + return std::move(*this).GetParsed(); + } + + explicit operator bool() const { + return !absl::holds_alternative(variant_); + } + + friend void swap(MessageValue& lhs, MessageValue& rhs) noexcept { + lhs.variant_.swap(rhs.variant_); + } + + private: + friend class Value; + friend class StructValue; + friend class common_internal::ValueMixin; + friend class common_internal::StructValueMixin; + friend struct ArenaTraits; + + common_internal::ValueVariant ToValueVariant() const&; + common_internal::ValueVariant ToValueVariant() &&; + + common_internal::StructValueVariant ToStructValueVariant() const&; + common_internal::StructValueVariant ToStructValueVariant() &&; + + absl::variant variant_; +}; + +inline std::ostream& operator<<(std::ostream& out, const MessageValue& value) { + return out << value.DebugString(); +} + +template <> +struct ArenaTraits { + static bool trivially_destructible(const MessageValue& value) { + return absl::visit( + [](const auto& alternative) -> bool { + return ArenaTraits<>::trivially_destructible(alternative); + }, + value.variant_); + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MESSAGE_VALUE_H_ diff --git a/common/values/message_value_test.cc b/common/values/message_value_test.cc new file mode 100644 index 000000000..2e3a8e711 --- /dev/null +++ b/common/values/message_value_test.cc @@ -0,0 +1,139 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/attribute.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::An; +using ::testing::Optional; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +using MessageValueTest = common_internal::ValueTest<>; + +TEST_F(MessageValueTest, Default) { + MessageValue value; + EXPECT_FALSE(value); + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), + StatusIs(absl::StatusCode::kInternal)); + Value scratch; + int count; + EXPECT_THAT( + value.Equal(NullValue(), descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.Equal(NullValue(), descriptor_pool(), message_factory(), + arena(), &scratch), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT( + value.GetFieldByName("", descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.GetFieldByName("", descriptor_pool(), message_factory(), + arena(), &scratch), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT( + value.GetFieldByNumber(0, descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.GetFieldByNumber(0, descriptor_pool(), message_factory(), + arena(), &scratch), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.HasFieldByName(""), StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.HasFieldByNumber(0), StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.ForEachField([](absl::string_view, const Value&) + -> absl::StatusOr { return true; }, + descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.Qualify({AttributeQualifier::OfString("foo")}, false, + descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(value.Qualify({AttributeQualifier::OfString("foo")}, false, + descriptor_pool(), message_factory(), arena(), + &scratch, &count), + StatusIs(absl::StatusCode::kInternal)); +} + +template +constexpr T& AsLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return t; +} + +template +constexpr const T& AsConstLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return t; +} + +template +constexpr T&& AsRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return static_cast(t); +} + +template +constexpr const T&& AsConstRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return static_cast(t); +} + +TEST_F(MessageValueTest, Parsed) { + MessageValue value(ParsedMessageValue( + DynamicParseTextProto(R"pb()pb"), arena())); + MessageValue other_value = value; + EXPECT_TRUE(value); + EXPECT_TRUE(value.Is()); + EXPECT_THAT(value.As(), + Optional(An())); + EXPECT_THAT(AsLValueRef(value).Get(), + An()); + EXPECT_THAT(AsConstLValueRef(value).Get(), + An()); + EXPECT_THAT(AsRValueRef(value).Get(), + An()); + EXPECT_THAT( + AsConstRValueRef(other_value).Get(), + An()); +} + +TEST_F(MessageValueTest, Kind) { + MessageValue value; + EXPECT_EQ(value.kind(), ParsedMessageValue::kKind); + EXPECT_EQ(value.kind(), ValueKind::kStruct); +} + +TEST_F(MessageValueTest, GetTypeName) { + MessageValue value(ParsedMessageValue( + DynamicParseTextProto(R"pb()pb"), arena())); + EXPECT_EQ(value.GetTypeName(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST_F(MessageValueTest, GetRuntimeType) { + MessageValue value(ParsedMessageValue( + DynamicParseTextProto(R"pb()pb"), arena())); + EXPECT_EQ(value.GetRuntimeType(), MessageType(value.GetDescriptor())); +} + +} // namespace +} // namespace cel diff --git a/common/values/mutable_list_value_test.cc b/common/values/mutable_list_value_test.cc new file mode 100644 index 000000000..c08d7091c --- /dev/null +++ b/common/values/mutable_list_value_test.cc @@ -0,0 +1,150 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "common/values/list_value_builder.h" +#include "internal/testing.h" + +namespace cel::common_internal { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::StringValueIs; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using MutableListValueTest = common_internal::ValueTest<>; + +TEST_F(MutableListValueTest, DebugString) { + auto* mutable_list_value = NewMutableListValue(arena()); + EXPECT_THAT(CustomListValue(mutable_list_value, arena()).DebugString(), "[]"); +} + +TEST_F(MutableListValueTest, IsEmpty) { + auto* mutable_list_value = NewMutableListValue(arena()); + mutable_list_value->Reserve(1); + EXPECT_TRUE(CustomListValue(mutable_list_value, arena()).IsEmpty()); + EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); + EXPECT_FALSE(CustomListValue(mutable_list_value, arena()).IsEmpty()); +} + +TEST_F(MutableListValueTest, Size) { + auto* mutable_list_value = NewMutableListValue(arena()); + mutable_list_value->Reserve(1); + EXPECT_THAT(CustomListValue(mutable_list_value, arena()).Size(), 0); + EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); + EXPECT_THAT(CustomListValue(mutable_list_value, arena()).Size(), 1); +} + +TEST_F(MutableListValueTest, ForEach) { + auto* mutable_list_value = NewMutableListValue(arena()); + mutable_list_value->Reserve(1); + std::vector> elements; + auto for_each_callback = [&](size_t index, + const Value& value) -> absl::StatusOr { + elements.push_back(std::pair{index, value}); + return true; + }; + EXPECT_THAT(CustomListValue(mutable_list_value, arena()) + .ForEach(for_each_callback, descriptor_pool(), + message_factory(), arena()), + IsOk()); + EXPECT_THAT(elements, IsEmpty()); + EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); + EXPECT_THAT(CustomListValue(mutable_list_value, arena()) + .ForEach(for_each_callback, descriptor_pool(), + message_factory(), arena()), + IsOk()); + EXPECT_THAT(elements, UnorderedElementsAre(Pair(0, StringValueIs("foo")))); +} + +TEST_F(MutableListValueTest, NewIterator) { + auto* mutable_list_value = NewMutableListValue(arena()); + mutable_list_value->Reserve(1); + ASSERT_OK_AND_ASSIGN( + auto iterator, + CustomListValue(mutable_list_value, arena()).NewIterator()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); + ASSERT_OK_AND_ASSIGN( + iterator, CustomListValue(mutable_list_value, arena()).NewIterator()); + EXPECT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(StringValueIs("foo"))); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(MutableListValueTest, Get) { + auto* mutable_list_value = NewMutableListValue(arena()); + mutable_list_value->Reserve(1); + Value value; + EXPECT_THAT( + CustomListValue(mutable_list_value, arena()) + .Get(0, descriptor_pool(), message_factory(), arena(), &value), + IsOk()); + EXPECT_THAT(value, + ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))); + EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); + EXPECT_THAT( + CustomListValue(mutable_list_value, arena()) + .Get(0, descriptor_pool(), message_factory(), arena(), &value), + IsOk()); + EXPECT_THAT(value, StringValueIs("foo")); +} + +TEST_F(MutableListValueTest, IsMutablListValue) { + auto* mutable_list_value = NewMutableListValue(arena()); + EXPECT_TRUE( + IsMutableListValue(Value(CustomListValue(mutable_list_value, arena())))); + EXPECT_TRUE(IsMutableListValue( + ListValue(CustomListValue(mutable_list_value, arena())))); +} + +TEST_F(MutableListValueTest, AsMutableListValue) { + auto* mutable_list_value = NewMutableListValue(arena()); + EXPECT_EQ( + AsMutableListValue(Value(CustomListValue(mutable_list_value, arena()))), + mutable_list_value); + EXPECT_EQ(AsMutableListValue( + ListValue(CustomListValue(mutable_list_value, arena()))), + mutable_list_value); +} + +TEST_F(MutableListValueTest, GetMutableListValue) { + auto* mutable_list_value = NewMutableListValue(arena()); + EXPECT_EQ( + &GetMutableListValue(Value(CustomListValue(mutable_list_value, arena()))), + mutable_list_value); + EXPECT_EQ(&GetMutableListValue( + ListValue(CustomListValue(mutable_list_value, arena()))), + mutable_list_value); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/values/mutable_map_value_test.cc b/common/values/mutable_map_value_test.cc new file mode 100644 index 000000000..2f08abe3f --- /dev/null +++ b/common/values/mutable_map_value_test.cc @@ -0,0 +1,179 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "common/values/map_value_builder.h" +#include "internal/testing.h" + +namespace cel::common_internal { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::ListValueElements; +using ::cel::test::ListValueIs; +using ::cel::test::StringValueIs; +using ::testing::IsEmpty; +using ::testing::IsFalse; +using ::testing::IsTrue; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using MutableMapValueTest = common_internal::ValueTest<>; + +TEST_F(MutableMapValueTest, DebugString) { + auto mutable_map_value = NewMutableMapValue(arena()); + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()).DebugString(), "{}"); +} + +TEST_F(MutableMapValueTest, IsEmpty) { + auto mutable_map_value = NewMutableMapValue(arena()); + mutable_map_value->Reserve(1); + EXPECT_TRUE(CustomMapValue(mutable_map_value, arena()).IsEmpty()); + EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); + EXPECT_FALSE(CustomMapValue(mutable_map_value, arena()).IsEmpty()); +} + +TEST_F(MutableMapValueTest, Size) { + auto mutable_map_value = NewMutableMapValue(arena()); + mutable_map_value->Reserve(1); + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()).Size(), 0); + EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()).Size(), 1); +} + +TEST_F(MutableMapValueTest, ListKeys) { + auto mutable_map_value = NewMutableMapValue(arena()); + mutable_map_value->Reserve(1); + ListValue keys; + EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); + EXPECT_THAT( + CustomMapValue(mutable_map_value, arena()) + .ListKeys(descriptor_pool(), message_factory(), arena(), &keys), + IsOk()); + EXPECT_THAT(keys, ListValueIs(ListValueElements( + UnorderedElementsAre(StringValueIs("foo")), + descriptor_pool(), message_factory(), arena()))); +} + +TEST_F(MutableMapValueTest, ForEach) { + auto mutable_map_value = NewMutableMapValue(arena()); + mutable_map_value->Reserve(1); + std::vector> entries; + auto for_each_callback = [&](const Value& key, + const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{key, value}); + return true; + }; + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) + .ForEach(for_each_callback, descriptor_pool(), + message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, IsEmpty()); + EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) + .ForEach(for_each_callback, descriptor_pool(), + message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(StringValueIs("foo"), IntValueIs(1)))); +} + +TEST_F(MutableMapValueTest, NewIterator) { + auto mutable_map_value = NewMutableMapValue(arena()); + mutable_map_value->Reserve(1); + ASSERT_OK_AND_ASSIGN( + auto iterator, CustomMapValue(mutable_map_value, arena()).NewIterator()); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); + ASSERT_OK_AND_ASSIGN( + iterator, CustomMapValue(mutable_map_value, arena()).NewIterator()); + EXPECT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(StringValueIs("foo"))); + EXPECT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(MutableMapValueTest, FindHas) { + auto* mutable_map_value = NewMutableMapValue(arena()); + mutable_map_value->Reserve(1); + Value value; + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) + .Find(StringValue("foo"), descriptor_pool(), + message_factory(), arena(), &value), + IsOkAndHolds(IsFalse())); + EXPECT_THAT(value, IsNullValue()); + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) + .Has(StringValue("foo"), descriptor_pool(), message_factory(), + arena(), &value), + IsOk()); + EXPECT_THAT(value, BoolValueIs(false)); + EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) + .Find(StringValue("foo"), descriptor_pool(), + message_factory(), arena(), &value), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(value, IntValueIs(1)); + EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) + .Has(StringValue("foo"), descriptor_pool(), message_factory(), + arena(), &value), + IsOk()); + EXPECT_THAT(value, BoolValueIs(true)); +} + +TEST_F(MutableMapValueTest, IsMutableMapValue) { + auto* mutable_map_value = NewMutableMapValue(arena()); + EXPECT_TRUE( + IsMutableMapValue(Value(CustomMapValue(mutable_map_value, arena())))); + EXPECT_TRUE( + IsMutableMapValue(MapValue(CustomMapValue(mutable_map_value, arena())))); +} + +TEST_F(MutableMapValueTest, AsMutableMapValue) { + auto* mutable_map_value = NewMutableMapValue(arena()); + EXPECT_EQ( + AsMutableMapValue(Value(CustomMapValue(mutable_map_value, arena()))), + mutable_map_value); + EXPECT_EQ( + AsMutableMapValue(MapValue(CustomMapValue(mutable_map_value, arena()))), + mutable_map_value); +} + +TEST_F(MutableMapValueTest, GetMutableMapValue) { + auto* mutable_map_value = NewMutableMapValue(arena()); + EXPECT_EQ( + &GetMutableMapValue(Value(CustomMapValue(mutable_map_value, arena()))), + mutable_map_value); + EXPECT_EQ( + &GetMutableMapValue(MapValue(CustomMapValue(mutable_map_value, arena()))), + mutable_map_value); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/values/null_value.cc b/common/values/null_value.cc new file mode 100644 index 000000000..bae6cb34c --- /dev/null +++ b/common/values/null_value.cc @@ -0,0 +1,78 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +using ::cel::well_known_types::ValueReflection; + +absl::Status NullValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::Value message; + message.set_null_value(google::protobuf::NULL_VALUE); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Value"); + } + return absl::OkStatus(); +} + +absl::Status NullValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.SetNullValue(json); + return absl::OkStatus(); +} + +absl::Status NullValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + *result = BoolValue(other.IsNull()); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/null_value.h b/common/values/null_value.h new file mode 100644 index 000000000..d4d05dba3 --- /dev/null +++ b/common/values/null_value.h @@ -0,0 +1,96 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_NULL_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_NULL_VALUE_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class NullValue; + +// `NullValue` represents the CEL `null` value. +class NullValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kNull; + + NullValue() = default; + NullValue(const NullValue&) = default; + NullValue(NullValue&&) = default; + NullValue& operator=(const NullValue&) = default; + NullValue& operator=(NullValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return NullType::kName; } + + std::string DebugString() const { return "null"; } + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return true; } + + friend void swap(NullValue&, NullValue&) noexcept {} + + private: + friend class common_internal::ValueMixin; +}; + +inline bool operator==(NullValue, NullValue) { return true; } + +inline bool operator!=(NullValue lhs, NullValue rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, const NullValue& value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_NULL_VALUE_H_ diff --git a/common/values/null_value_test.cc b/common/values/null_value_test.cc new file mode 100644 index 000000000..5f244c532 --- /dev/null +++ b/common/values/null_value_test.cc @@ -0,0 +1,82 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/casting.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::An; +using ::testing::Ne; + +using NullValueTest = common_internal::ValueTest<>; + +TEST_F(NullValueTest, Kind) { + EXPECT_EQ(NullValue().kind(), NullValue::kKind); + EXPECT_EQ(Value(NullValue()).kind(), NullValue::kKind); +} + +TEST_F(NullValueTest, DebugString) { + { + std::ostringstream out; + out << NullValue(); + EXPECT_EQ(out.str(), "null"); + } + { + std::ostringstream out; + out << Value(NullValue()); + EXPECT_EQ(out.str(), "null"); + } +} + +TEST_F(NullValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT( + NullValue().ConvertToJson(descriptor_pool(), message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(null_value: NULL_VALUE)pb")); +} + +TEST_F(NullValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(NullValue()), NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(NullValue())), + NativeTypeId::For()); +} + +TEST_F(NullValueTest, InstanceOf) { + EXPECT_TRUE(InstanceOf(NullValue())); + EXPECT_TRUE(InstanceOf(Value(NullValue()))); +} + +TEST_F(NullValueTest, Cast) { + EXPECT_THAT(Cast(NullValue()), An()); + EXPECT_THAT(Cast(Value(NullValue())), An()); +} + +TEST_F(NullValueTest, As) { + EXPECT_THAT(As(Value(NullValue())), Ne(absl::nullopt)); +} + +} // namespace +} // namespace cel diff --git a/common/values/opaque_value.cc b/common/values/opaque_value.cc new file mode 100644 index 000000000..235d268e7 --- /dev/null +++ b/common/values/opaque_value.cc @@ -0,0 +1,194 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/type.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +// Code below assumes OptionalValue has the same layout as OpaqueValue. +static_assert(std::is_base_of_v); +static_assert(sizeof(OpaqueValue) == sizeof(OptionalValue)); +static_assert(alignof(OpaqueValue) == alignof(OptionalValue)); + +OpaqueValue OpaqueValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { + OpaqueValueInterface::Content content = + content_.To(); + if (content.interface == nullptr) { + return *this; + } + if (content.arena != arena) { + return content.interface->Clone(arena); + } + return *this; + } + if (dispatcher_->get_arena(dispatcher_, content_) != arena) { + return dispatcher_->clone(dispatcher_, content_, arena); + } + return *this; +} + +OpaqueType OpaqueValue::GetRuntimeType() const { + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { + OpaqueValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetRuntimeType(); + } + return dispatcher_->get_runtime_type(dispatcher_, content_); +} + +absl::string_view OpaqueValue::GetTypeName() const { + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { + OpaqueValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->GetTypeName(); + } + return dispatcher_->get_type_name(dispatcher_, content_); +} + +std::string OpaqueValue::DebugString() const { + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { + OpaqueValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->DebugString(); + } + return dispatcher_->debug_string(dispatcher_, content_); +} + +// See Value::SerializeTo(). +absl::Status OpaqueValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), "is unserializable")); +} + +// See Value::ConvertToJson(). +absl::Status OpaqueValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), " is not convertable to JSON")); +} + +absl::Status OpaqueValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_opaque = other.AsOpaque(); other_opaque) { + if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { + OpaqueValueInterface::Content content = + content_.To(); + ABSL_DCHECK(content.interface != nullptr); + return content.interface->Equal(*other_opaque, descriptor_pool, + message_factory, arena, result); + } + return dispatcher_->equal(dispatcher_, content_, *other_opaque, + descriptor_pool, message_factory, arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +NativeTypeId OpaqueValue::GetTypeId() const { + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { + OpaqueValueInterface::Content content = + content_.To(); + if (content.interface == nullptr) { + return NativeTypeId(); + } + return content.interface->GetNativeTypeId(); + } + return dispatcher_->get_type_id(dispatcher_, content_); +} + +bool OpaqueValue::IsOptional() const { + return dispatcher_ != nullptr && + dispatcher_->get_type_id(dispatcher_, content_) == + NativeTypeId::For(); +} + +optional_ref OpaqueValue::AsOptional() const& { + if (IsOptional()) { + return *reinterpret_cast(this); + } + return absl::nullopt; +} + +absl::optional OpaqueValue::AsOptional() && { + if (IsOptional()) { + return std::move(*reinterpret_cast(this)); + } + return absl::nullopt; +} + +const OptionalValue& OpaqueValue::GetOptional() const& { + ABSL_DCHECK(IsOptional()) << *this; + return *reinterpret_cast(this); +} + +OptionalValue OpaqueValue::GetOptional() && { + ABSL_DCHECK(IsOptional()) << *this; + return std::move(*reinterpret_cast(this)); +} + +} // namespace cel diff --git a/common/values/opaque_value.h b/common/values/opaque_value.h new file mode 100644 index 000000000..57af78ae0 --- /dev/null +++ b/common/values/opaque_value.h @@ -0,0 +1,338 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" +// IWYU pragma: friend "common/values/optional_value.h" + +// `OpaqueValue` represents values of the `opaque` type. `OpaqueValueView` +// is a non-owning view of `OpaqueValue`. `OpaqueValueInterface` is the abstract +// base class of implementations. `OpaqueValue` and `OpaqueValueView` act as +// smart pointers to `OpaqueValueInterface`. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPAQUE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPAQUE_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class OpaqueValueInterface; +class OpaqueValueInterfaceIterator; +class OpaqueValue; + +using OpaqueValueContent = CustomValueContent; + +struct OpaqueValueDispatcher { + using GetTypeId = + NativeTypeId (*)(const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content); + + using GetArena = google::protobuf::Arena* absl_nullable (*)( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content); + + using GetTypeName = absl::string_view (*)( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content); + + using DebugString = + std::string (*)(const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content); + + using GetRuntimeType = + OpaqueType (*)(const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content); + + using Equal = absl::Status (*)( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, const OpaqueValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + + using Clone = OpaqueValue (*)( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena); + + absl_nonnull GetTypeId get_type_id; + + absl_nonnull GetArena get_arena; + + absl_nonnull GetTypeName get_type_name; + + absl_nonnull DebugString debug_string; + + absl_nonnull GetRuntimeType get_runtime_type; + + absl_nonnull Equal equal; + + absl_nonnull Clone clone; +}; + +class OpaqueValueInterface { + public: + OpaqueValueInterface() = default; + OpaqueValueInterface(const OpaqueValueInterface&) = delete; + OpaqueValueInterface(OpaqueValueInterface&&) = delete; + + virtual ~OpaqueValueInterface() = default; + + OpaqueValueInterface& operator=(const OpaqueValueInterface&) = delete; + OpaqueValueInterface& operator=(OpaqueValueInterface&&) = delete; + + private: + friend class OpaqueValue; + + virtual std::string DebugString() const = 0; + + virtual absl::string_view GetTypeName() const = 0; + + virtual OpaqueType GetRuntimeType() const = 0; + + virtual absl::Status Equal( + const OpaqueValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; + + virtual OpaqueValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; + + virtual NativeTypeId GetNativeTypeId() const = 0; + + struct Content { + const OpaqueValueInterface* absl_nonnull interface; + google::protobuf::Arena* absl_nonnull arena; + }; +}; + +// Creates an opaque value from a manual dispatch table `dispatcher` and +// opaque data `content` whose format is only know to functions in the manual +// dispatch table. The dispatch table should probably be valid for the lifetime +// of the process, but at a minimum must outlive all instances of the resulting +// value. +// +// IMPORTANT: This approach to implementing OpaqueValue should only be +// used when you know exactly what you are doing. When in doubt, just implement +// OpaqueValueInterface. +OpaqueValue UnsafeOpaqueValue(const OpaqueValueDispatcher* absl_nonnull + dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, + OpaqueValueContent content); + +class OpaqueValue : private common_internal::OpaqueValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kOpaque; + + // Constructs an opaque value from an implementation of + // `OpaqueValueInterface` `interface` whose lifetime is tied to that of + // the arena `arena`. + OpaqueValue(const OpaqueValueInterface* absl_nonnull + interface ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(interface != nullptr); + ABSL_DCHECK(arena != nullptr); + content_ = OpaqueValueContent::From( + OpaqueValueInterface::Content{.interface = interface, .arena = arena}); + } + + OpaqueValue() = default; + OpaqueValue(const OpaqueValue&) = default; + OpaqueValue(OpaqueValue&&) = default; + OpaqueValue& operator=(const OpaqueValue&) = default; + OpaqueValue& operator=(OpaqueValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + NativeTypeId GetTypeId() const; + + OpaqueType GetRuntimeType() const; + + absl::string_view GetTypeName() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using OpaqueValueMixin::Equal; + + bool IsZeroValue() const { return false; } + + OpaqueValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + // Returns `true` if this opaque value is an instance of an optional value. + bool IsOptional() const; + + // Convenience method for use with template metaprogramming. See + // `IsOptional()`. + template + std::enable_if_t, bool> Is() const { + return IsOptional(); + } + + // Performs a checked cast from an opaque value to an optional value, + // returning a non-empty optional with either a value or reference to the + // optional value. Otherwise an empty optional is returned. + optional_ref AsOptional() & + ABSL_ATTRIBUTE_LIFETIME_BOUND; + optional_ref AsOptional() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsOptional() &&; + absl::optional AsOptional() const&&; + + // Convenience method for use with template metaprogramming. See + // `AsOptional()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND; + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + template + std::enable_if_t, + absl::optional> + As() &&; + template + std::enable_if_t, + absl::optional> + As() const&&; + + // Performs an unchecked cast from an opaque value to an optional value. In + // debug builds a best effort is made to crash. If `IsOptional()` would return + // false, calling this method is undefined behavior. + const OptionalValue& GetOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND; + const OptionalValue& GetOptional() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + OptionalValue GetOptional() &&; + OptionalValue GetOptional() const&&; + + // Convenience method for use with template metaprogramming. See + // `Optional()`. + template + std::enable_if_t, const OptionalValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND; + template + std::enable_if_t, const OptionalValue&> Get() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + template + std::enable_if_t, OptionalValue> Get() &&; + template + std::enable_if_t, OptionalValue> Get() + const&&; + + const OpaqueValueDispatcher* absl_nullable dispatcher() const { + return dispatcher_; + } + + OpaqueValueContent content() const { + ABSL_DCHECK(dispatcher_ != nullptr); + return content_; + } + + const OpaqueValueInterface* absl_nullable interface() const { + if (dispatcher_ == nullptr) { + return content_.To().interface; + } + return nullptr; + } + + friend void swap(OpaqueValue& lhs, OpaqueValue& rhs) noexcept { + using std::swap; + swap(lhs.dispatcher_, rhs.dispatcher_); + swap(lhs.content_, rhs.content_); + } + + explicit operator bool() const { + if (dispatcher_ == nullptr) { + return content_.To().interface != nullptr; + } + return true; + } + + protected: + OpaqueValue(const OpaqueValueDispatcher* absl_nonnull dispatcher + ABSL_ATTRIBUTE_LIFETIME_BOUND, + OpaqueValueContent content) + : dispatcher_(dispatcher), content_(content) { + ABSL_DCHECK(dispatcher != nullptr); + ABSL_DCHECK(dispatcher->get_type_id != nullptr); + ABSL_DCHECK(dispatcher->get_type_name != nullptr); + ABSL_DCHECK(dispatcher->clone != nullptr); + } + + private: + friend class common_internal::ValueMixin; + friend class common_internal::OpaqueValueMixin; + friend OpaqueValue UnsafeOpaqueValue(const OpaqueValueDispatcher* absl_nonnull + dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, + OpaqueValueContent content); + + const OpaqueValueDispatcher* absl_nullable dispatcher_ = nullptr; + OpaqueValueContent content_ = OpaqueValueContent::Zero(); +}; + +inline std::ostream& operator<<(std::ostream& out, const OpaqueValue& type) { + return out << type.DebugString(); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const OpaqueValue& type) { return type.GetTypeId(); } +}; + +inline OpaqueValue UnsafeOpaqueValue(const OpaqueValueDispatcher* absl_nonnull + dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, + OpaqueValueContent content) { + return OpaqueValue(dispatcher, content); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPAQUE_VALUE_H_ diff --git a/common/values/optional_value.cc b/common/values/optional_value.cc new file mode 100644 index 000000000..7c214b9cb --- /dev/null +++ b/common/values/optional_value.cc @@ -0,0 +1,435 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/arena.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +struct OptionalValueDispatcher : public OpaqueValueDispatcher { + using HasValue = + bool (*)(const OptionalValueDispatcher* absl_nonnull dispatcher, + CustomValueContent content); + using Value = void (*)(const OptionalValueDispatcher* absl_nonnull dispatcher, + CustomValueContent content, + cel::Value* absl_nonnull result); + + absl_nonnull HasValue has_value; + + absl_nonnull Value value; +}; + +NativeTypeId OptionalValueGetTypeId(const OpaqueValueDispatcher* absl_nonnull, + OpaqueValueContent) { + return NativeTypeId::For(); +} + +absl::string_view OptionalValueGetTypeName( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { + return "optional_type"; +} + +OpaqueType OptionalValueGetRuntimeType( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { + return OptionalType(); +} + +std::string OptionalValueDebugString( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content) { + if (!static_cast(dispatcher) + ->has_value(static_cast(dispatcher), + content)) { + return "optional.none()"; + } + Value value; + static_cast(dispatcher) + ->value(static_cast(dispatcher), content, + &value); + return absl::StrCat("optional.of(", value.DebugString(), ")"); +} + +bool OptionalValueHasValue(const OptionalValueDispatcher* absl_nonnull, + OpaqueValueContent) { + return true; +} + +absl::Status OptionalValueEqual( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, const OpaqueValue& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + if (auto other_optional = other.AsOptional(); other_optional) { + const bool lhs_has_value = + static_cast(dispatcher) + ->has_value(static_cast(dispatcher), + content); + const bool rhs_has_value = other_optional->HasValue(); + if (lhs_has_value != rhs_has_value) { + *result = FalseValue(); + return absl::OkStatus(); + } + if (!lhs_has_value) { + *result = TrueValue(); + return absl::OkStatus(); + } + Value lhs_value; + Value rhs_value; + static_cast(dispatcher) + ->value(static_cast(dispatcher), + content, &lhs_value); + other_optional->Value(&rhs_value); + return lhs_value.Equal(rhs_value, descriptor_pool, message_factory, arena, + result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +google::protobuf::Arena* absl_nullable OptionalValueGetArenaNull( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { + return nullptr; +} + +OpaqueValue OptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + return common_internal::MakeOptionalValue(dispatcher, content); +} + +bool OptionalValueHasNoValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content) { + return false; +} + +void EmptyOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = + ErrorValue(absl::FailedPreconditionError("optional.none() dereference")); +} + +void NullOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = NullValue(); +} + +void BoolOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = BoolValue(content.To()); +} + +void IntOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = IntValue(content.To()); +} + +void UintOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UintValue(content.To()); +} + +void DoubleOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = DoubleValue(content.To()); +} + +void DurationOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeDurationValue(content.To()); +} + +void TimestampOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeTimestampValue(content.To()); +} + +ABSL_CONST_INIT const OptionalValueDispatcher + empty_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasNoValue, + &EmptyOptionalValueValue, +}; + +ABSL_CONST_INIT const OptionalValueDispatcher null_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasValue, + &NullOptionalValueValue, +}; + +ABSL_CONST_INIT const OptionalValueDispatcher bool_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasValue, + &BoolOptionalValueValue, +}; + +ABSL_CONST_INIT const OptionalValueDispatcher int_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasValue, + &IntOptionalValueValue, +}; + +ABSL_CONST_INIT const OptionalValueDispatcher uint_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasValue, + &UintOptionalValueValue, +}; + +ABSL_CONST_INIT const OptionalValueDispatcher + double_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasValue, + &DoubleOptionalValueValue, +}; + +ABSL_CONST_INIT const OptionalValueDispatcher + duration_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasValue, + &DurationOptionalValueValue, +}; + +ABSL_CONST_INIT const OptionalValueDispatcher + timestamp_optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &OptionalValueGetArenaNull, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &OptionalValueClone, + }, + &OptionalValueHasValue, + &TimestampOptionalValueValue, +}; + +struct OptionalValueContent { + const Value* absl_nonnull value; + google::protobuf::Arena* absl_nonnull arena; +}; + +google::protobuf::Arena* absl_nullable GenericOptionalValueGetArena( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent content) { + return content.To().arena; +} + +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena); + +void GenericOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = *content.To().value; +} + +ABSL_CONST_INIT const OptionalValueDispatcher optional_value_dispatcher = { + { + .get_type_id = &OptionalValueGetTypeId, + .get_arena = &GenericOptionalValueGetArena, + .get_type_name = &OptionalValueGetTypeName, + .debug_string = &OptionalValueDebugString, + .get_runtime_type = &OptionalValueGetRuntimeType, + .equal = &OptionalValueEqual, + .clone = &GenericOptionalValueClone, + }, + &OptionalValueHasValue, + &GenericOptionalValueValue, +}; + +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + + cel::Value* absl_nonnull result = + ::new (arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) + cel::Value(content.To().value->Clone(arena)); + if (!ArenaTraits<>::trivially_destructible(*result)) { + arena->OwnDestructor(result); + } + return common_internal::MakeOptionalValue( + &optional_value_dispatcher, OpaqueValueContent::From(OptionalValueContent{ + .value = result, .arena = arena})); +} + +} // namespace + +OptionalValue OptionalValue::Of(cel::Value value, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(value.kind() != ValueKind::kError && + value.kind() != ValueKind::kUnknown); + ABSL_DCHECK(arena != nullptr); + + // We can actually fit a lot more of the underlying values, avoiding arena + // allocations and destructors. For now, we just do scalars. + switch (value.kind()) { + case ValueKind::kNull: + return OptionalValue(&null_optional_value_dispatcher, + OpaqueValueContent::Zero()); + case ValueKind::kBool: + return OptionalValue( + &bool_optional_value_dispatcher, + OpaqueValueContent::From(absl::implicit_cast(value.GetBool()))); + case ValueKind::kInt: + return OptionalValue(&int_optional_value_dispatcher, + OpaqueValueContent::From( + absl::implicit_cast(value.GetInt()))); + case ValueKind::kUint: + return OptionalValue(&uint_optional_value_dispatcher, + OpaqueValueContent::From( + absl::implicit_cast(value.GetUint()))); + case ValueKind::kDouble: + return OptionalValue(&double_optional_value_dispatcher, + OpaqueValueContent::From( + absl::implicit_cast(value.GetDouble()))); + case ValueKind::kDuration: + return OptionalValue( + &duration_optional_value_dispatcher, + OpaqueValueContent::From(value.GetDuration().ToDuration())); + case ValueKind::kTimestamp: + return OptionalValue( + ×tamp_optional_value_dispatcher, + OpaqueValueContent::From(value.GetTimestamp().ToTime())); + default: { + cel::Value* absl_nonnull result = ::new ( + arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) + cel::Value(std::move(value)); + if (!ArenaTraits<>::trivially_destructible(*result)) { + arena->OwnDestructor(result); + } + return OptionalValue(&optional_value_dispatcher, + OpaqueValueContent::From(OptionalValueContent{ + .value = result, .arena = arena})); + } + } +} + +OptionalValue OptionalValue::None() { + return OptionalValue(&empty_optional_value_dispatcher, + OpaqueValueContent::Zero()); +} + +bool OptionalValue::HasValue() const { + return static_cast(OpaqueValue::dispatcher()) + ->has_value(static_cast( + OpaqueValue::dispatcher()), + OpaqueValue::content()); +} + +void OptionalValue::Value(cel::Value* absl_nonnull result) const { + ABSL_DCHECK(result != nullptr); + + static_cast(OpaqueValue::dispatcher()) + ->value(static_cast( + OpaqueValue::dispatcher()), + OpaqueValue::content(), result); +} + +cel::Value OptionalValue::Value() const { + cel::Value result; + Value(&result); + return result; +} + +} // namespace cel diff --git a/common/values/optional_value.h b/common/values/optional_value.h new file mode 100644 index 000000000..e52251881 --- /dev/null +++ b/common/values/optional_value.h @@ -0,0 +1,207 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +// `OptionalValue` represents values of the `optional_type` type. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPTIONAL_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPTIONAL_VALUE_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/types/optional.h" +#include "common/optional_ref.h" +#include "common/type.h" +#include "common/values/opaque_value.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class Value; +class OptionalValue; + +namespace common_internal { +OptionalValue MakeOptionalValue( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content); +} + +class OptionalValue final : public OpaqueValue { + public: + static OptionalValue None(); + + static OptionalValue Of(cel::Value value, google::protobuf::Arena* absl_nonnull arena); + + OptionalValue() : OptionalValue(None()) {} + OptionalValue(const OptionalValue&) = default; + OptionalValue(OptionalValue&&) = default; + OptionalValue& operator=(const OptionalValue&) = default; + OptionalValue& operator=(OptionalValue&&) = default; + + OptionalType GetRuntimeType() const { + return OpaqueValue::GetRuntimeType().GetOptional(); + } + + bool HasValue() const; + + void Value(cel::Value* absl_nonnull result) const; + + cel::Value Value() const; + + bool IsOptional() const = delete; + template + std::enable_if_t, bool> Is() const = delete; + optional_ref AsOptional() & = delete; + optional_ref AsOptional() const& = delete; + absl::optional AsOptional() && = delete; + absl::optional AsOptional() const&& = delete; + const OptionalValue& GetOptional() & = delete; + const OptionalValue& GetOptional() const& = delete; + OptionalValue GetOptional() && = delete; + OptionalValue GetOptional() const&& = delete; + template + std::enable_if_t, + optional_ref> + As() & = delete; + template + std::enable_if_t, + optional_ref> + As() const& = delete; + template + std::enable_if_t, + absl::optional> + As() && = delete; + template + std::enable_if_t, + absl::optional> + As() const&& = delete; + template + std::enable_if_t, + optional_ref> + Get() & = delete; + template + std::enable_if_t, + optional_ref> + Get() const& = delete; + template + std::enable_if_t, + absl::optional> + Get() && = delete; + template + std::enable_if_t, + absl::optional> + Get() const&& = delete; + + private: + friend OptionalValue common_internal::MakeOptionalValue( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content); + + OptionalValue(const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content) + : OpaqueValue(dispatcher, content) {} + + using OpaqueValue::content; + using OpaqueValue::dispatcher; + using OpaqueValue::interface; +}; + +inline optional_ref OpaqueValue::AsOptional() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsOptional(); +} + +inline absl::optional OpaqueValue::AsOptional() const&& { + return common_internal::AsOptional(AsOptional()); +} + +template + inline std::enable_if_t, + optional_ref> + OpaqueValue::As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsOptional(); +} + +template +inline std::enable_if_t, + optional_ref> +OpaqueValue::As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsOptional(); +} + +template +inline std::enable_if_t, + absl::optional> +OpaqueValue::As() && { + return std::move(*this).AsOptional(); +} + +template +inline std::enable_if_t, + absl::optional> +OpaqueValue::As() const&& { + return std::move(*this).AsOptional(); +} + +inline const OptionalValue& OpaqueValue::GetOptional() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetOptional(); +} + +inline OptionalValue OpaqueValue::GetOptional() const&& { + return GetOptional(); +} + +template + std::enable_if_t, const OptionalValue&> + OpaqueValue::Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetOptional(); +} + +template +std::enable_if_t, const OptionalValue&> +OpaqueValue::Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetOptional(); +} + +template +std::enable_if_t, OptionalValue> +OpaqueValue::Get() && { + return std::move(*this).GetOptional(); +} + +template +std::enable_if_t, OptionalValue> +OpaqueValue::Get() const&& { + return std::move(*this).GetOptional(); +} + +namespace common_internal { + +inline OptionalValue MakeOptionalValue( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content) { + return OptionalValue(dispatcher, content); +} + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPTIONAL_VALUE_H_ diff --git a/common/values/optional_value_test.cc b/common/values/optional_value_test.cc new file mode 100644 index 000000000..8b044a7f0 --- /dev/null +++ b/common/values/optional_value_test.cc @@ -0,0 +1,141 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/time/time.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::DurationValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::StringValueIs; +using ::cel::test::TimestampValueIs; +using ::cel::test::UintValueIs; + +class OptionalValueTest : public common_internal::ValueTest<> { + public: + OptionalValue OptionalNone() { return OptionalValue::None(); } + + OptionalValue OptionalOf(Value value) { + return OptionalValue::Of(std::move(value), arena()); + } +}; + +TEST_F(OptionalValueTest, Kind) { + EXPECT_EQ(OptionalValue::kind(), OptionalValue::kKind); +} + +TEST_F(OptionalValueTest, GetRuntimeType) { + EXPECT_EQ(OptionalValue().GetRuntimeType(), OptionalType()); + EXPECT_EQ(OpaqueValue(OptionalValue()).GetRuntimeType(), OptionalType()); +} + +TEST_F(OptionalValueTest, DebugString) { + EXPECT_EQ(OptionalValue().DebugString(), "optional.none()"); + EXPECT_EQ(OptionalOf(NullValue()).DebugString(), "optional.of(null)"); + EXPECT_EQ(OptionalOf(TrueValue()).DebugString(), "optional.of(true)"); + EXPECT_EQ(OptionalOf(IntValue(1)).DebugString(), "optional.of(1)"); + EXPECT_EQ(OptionalOf(UintValue(1u)).DebugString(), "optional.of(1u)"); + EXPECT_EQ(OptionalOf(DoubleValue(1.0)).DebugString(), "optional.of(1.0)"); + EXPECT_EQ(OptionalOf(DurationValue()).DebugString(), "optional.of(0)"); + EXPECT_EQ(OptionalOf(TimestampValue()).DebugString(), + "optional.of(1970-01-01T00:00:00Z)"); + EXPECT_EQ(OptionalOf(StringValue()).DebugString(), "optional.of(\"\")"); +} + +TEST_F(OptionalValueTest, SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(OptionalValue().SerializeTo(descriptor_pool(), message_factory(), + &output), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(OpaqueValue(OptionalValue()) + .SerializeTo(descriptor_pool(), message_factory(), &output), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(OptionalValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(OptionalValue().ConvertToJson(descriptor_pool(), + message_factory(), message), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(OpaqueValue(OptionalValue()) + .ConvertToJson(descriptor_pool(), message_factory(), message), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(OptionalValueTest, GetTypeId) { + EXPECT_EQ(OpaqueValue(OptionalValue()).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(NullValue())).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(TrueValue())).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(IntValue(1))).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(UintValue(1u))).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(DoubleValue(1.0))).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(DurationValue())).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(TimestampValue())).GetTypeId(), + NativeTypeId::For()); + EXPECT_EQ(OpaqueValue(OptionalOf(StringValue())).GetTypeId(), + NativeTypeId::For()); +} + +TEST_F(OptionalValueTest, HasValue) { + EXPECT_FALSE(OptionalValue().HasValue()); + EXPECT_TRUE(OptionalOf(NullValue()).HasValue()); + EXPECT_TRUE(OptionalOf(TrueValue()).HasValue()); + EXPECT_TRUE(OptionalOf(IntValue(1)).HasValue()); + EXPECT_TRUE(OptionalOf(UintValue(1u)).HasValue()); + EXPECT_TRUE(OptionalOf(DoubleValue(1.0)).HasValue()); + EXPECT_TRUE(OptionalOf(DurationValue()).HasValue()); + EXPECT_TRUE(OptionalOf(TimestampValue()).HasValue()); + EXPECT_TRUE(OptionalOf(StringValue()).HasValue()); +} + +TEST_F(OptionalValueTest, Value) { + EXPECT_THAT(OptionalValue().Value(), + ErrorValueIs(StatusIs(absl::StatusCode::kFailedPrecondition))); + EXPECT_THAT(OptionalOf(NullValue()).Value(), IsNullValue()); + EXPECT_THAT(OptionalOf(TrueValue()).Value(), BoolValueIs(true)); + EXPECT_THAT(OptionalOf(IntValue(1)).Value(), IntValueIs(1)); + EXPECT_THAT(OptionalOf(UintValue(1u)).Value(), UintValueIs(1u)); + EXPECT_THAT(OptionalOf(DoubleValue(1.0)).Value(), DoubleValueIs(1.0)); + EXPECT_THAT(OptionalOf(DurationValue()).Value(), + DurationValueIs(absl::ZeroDuration())); + EXPECT_THAT(OptionalOf(TimestampValue()).Value(), + TimestampValueIs(absl::UnixEpoch())); + EXPECT_THAT(OptionalOf(StringValue()).Value(), StringValueIs("")); +} + +} // namespace +} // namespace cel diff --git a/common/values/parsed_json_list_value.cc b/common/values/parsed_json_list_value.cc new file mode 100644 index 000000000..9acd23e3f --- /dev/null +++ b/common/values/parsed_json_list_value.cc @@ -0,0 +1,486 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/parsed_json_list_value.h" + +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/memory.h" +#include "common/value.h" +#include "common/values/parsed_json_value.h" +#include "common/values/values.h" +#include "internal/json.h" +#include "internal/message_equality.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +using ::cel::well_known_types::ValueReflection; + +namespace common_internal { + +absl::Status CheckWellKnownListValueMessage(const google::protobuf::Message& message) { + return internal::CheckJsonList(message); +} + +} // namespace common_internal + +std::string ParsedJsonListValue::DebugString() const { + if (value_ == nullptr) { + return "[]"; + } + return internal::JsonListDebugString(*value_); +} + +absl::Status ParsedJsonListValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + if (value_ == nullptr) { + return absl::OkStatus(); + } + + if (!value_->SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.ListValue"); + } + return absl::OkStatus(); +} + +absl::Status ParsedJsonListValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + auto* message = value_reflection.MutableListValue(json); + message->Clear(); + + if (value_ == nullptr) { + return absl::OkStatus(); + } + + if (value_->GetDescriptor() == message->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + message->CopyFrom(*value_); + } else { + // Equivalent descriptors but not identical. Must serialize and deserialize. + absl::Cord serialized; + if (!value_->SerializePartialToString(&serialized)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", value_->GetTypeName())); + } + if (!message->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", message->GetTypeName())); + } + } + return absl::OkStatus(); +} + +absl::Status ParsedJsonListValue::ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + + if (value_ == nullptr) { + json->Clear(); + return absl::OkStatus(); + } + + if (value_->GetDescriptor() == json->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + json->CopyFrom(*value_); + } else { + // Equivalent descriptors but not identical. Must serialize and deserialize. + absl::Cord serialized; + if (!value_->SerializePartialToString(&serialized)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", value_->GetTypeName())); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", json->GetTypeName())); + } + } + return absl::OkStatus(); +} + +absl::Status ParsedJsonListValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsParsedJsonList(); other_value) { + *result = BoolValue(*this == *other_value); + return absl::OkStatus(); + } + if (auto other_value = other.AsParsedRepeatedField(); other_value) { + if (value_ == nullptr) { + *result = BoolValue(other_value->IsEmpty()); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + auto equal, internal::MessageFieldEquals( + *value_, *other_value->message_, other_value->field_, + descriptor_pool, message_factory)); + *result = BoolValue(equal); + return absl::OkStatus(); + } + if (auto other_value = other.AsList(); other_value) { + return common_internal::ListValueEqual(ListValue(*this), *other_value, + descriptor_pool, message_factory, + arena, result); + } + *result = BoolValue(false); + return absl::OkStatus(); +} + +ParsedJsonListValue ParsedJsonListValue::Clone( + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + if (value_ == nullptr) { + return ParsedJsonListValue(); + } + if (arena_ == arena) { + return *this; + } + auto* cloned = value_->New(arena); + cloned->CopyFrom(*value_); + return ParsedJsonListValue(cloned, arena); +} + +size_t ParsedJsonListValue::Size() const { + if (value_ == nullptr) { + return 0; + } + return static_cast( + well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()) + .ValuesSize(*value_)); +} + +// See ListValueInterface::Get for documentation. +absl::Status ParsedJsonListValue::Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (value_ == nullptr) { + *result = IndexOutOfBoundsError(index); + return absl::OkStatus(); + } + const auto reflection = + well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()); + if (ABSL_PREDICT_FALSE(index >= + static_cast(reflection.ValuesSize(*value_)))) { + *result = IndexOutOfBoundsError(index); + return absl::OkStatus(); + } + *result = common_internal::ParsedJsonValue( + &reflection.Values(*value_, static_cast(index)), arena); + return absl::OkStatus(); +} + +absl::Status ParsedJsonListValue::ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + if (value_ == nullptr) { + return absl::OkStatus(); + } + Value scratch; + const auto reflection = + well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()); + const int size = reflection.ValuesSize(*value_); + for (int i = 0; i < size; ++i) { + scratch = + common_internal::ParsedJsonValue(&reflection.Values(*value_, i), arena); + CEL_ASSIGN_OR_RETURN(auto ok, callback(static_cast(i), scratch)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +namespace { + +class ParsedJsonListValueIterator final : public ValueIterator { + public: + explicit ParsedJsonListValueIterator( + const google::protobuf::Message* absl_nonnull message) + : message_(message), + reflection_(well_known_types::GetListValueReflectionOrDie( + message_->GetDescriptor())), + size_(reflection_.ValuesSize(*message_)) {} + + bool HasNext() override { return index_ < size_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "`ValueIterator::Next` called after `ValueIterator::HasNext` " + "returned false"); + } + *result = common_internal::ParsedJsonValue( + &reflection_.Values(*message_, index_), arena); + ++index_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (index_ >= size_) { + return false; + } + *key_or_value = common_internal::ParsedJsonValue( + &reflection_.Values(*message_, index_), arena); + ++index_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (index_ >= size_) { + return false; + } + if (value != nullptr) { + *value = common_internal::ParsedJsonValue( + &reflection_.Values(*message_, index_), arena); + } + *key = IntValue(index_); + ++index_; + return true; + } + + private: + const google::protobuf::Message* absl_nonnull const message_; + const well_known_types::ListValueReflection reflection_; + const int size_; + int index_ = 0; +}; + +} // namespace + +absl::StatusOr> +ParsedJsonListValue::NewIterator() const { + if (value_ == nullptr) { + return NewEmptyValueIterator(); + } + return std::make_unique(value_); +} + +namespace { + +absl::optional AsNumber(const Value& value) { + if (auto int_value = value.AsInt(); int_value) { + return internal::Number::FromInt64(*int_value); + } + if (auto uint_value = value.AsUint(); uint_value) { + return internal::Number::FromUint64(*uint_value); + } + if (auto double_value = value.AsDouble(); double_value) { + return internal::Number::FromDouble(*double_value); + } + return absl::nullopt; +} + +} // namespace + +absl::Status ParsedJsonListValue::Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (value_ == nullptr) { + *result = FalseValue(); + return absl::OkStatus(); + } + if (ABSL_PREDICT_FALSE(other.IsError() || other.IsUnknown())) { + *result = other; + return absl::OkStatus(); + } + // Other must be comparable to `null`, `double`, `string`, `list`, or `map`. + const auto reflection = + well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()); + if (reflection.ValuesSize(*value_) > 0) { + const auto value_reflection = well_known_types::GetValueReflectionOrDie( + reflection.GetValueDescriptor()); + if (other.IsNull()) { + for (const auto& element : reflection.Values(*value_)) { + const auto element_kind_case = value_reflection.GetKindCase(element); + if (element_kind_case == google::protobuf::Value::KIND_NOT_SET || + element_kind_case == google::protobuf::Value::kNullValue) { + *result = TrueValue(); + return absl::OkStatus(); + } + } + } else if (const auto other_value = other.AsBool(); other_value) { + for (const auto& element : reflection.Values(*value_)) { + if (value_reflection.GetKindCase(element) == + google::protobuf::Value::kBoolValue && + value_reflection.GetBoolValue(element) == *other_value) { + *result = TrueValue(); + return absl::OkStatus(); + } + } + } else if (const auto other_value = AsNumber(other); other_value) { + for (const auto& element : reflection.Values(*value_)) { + if (value_reflection.GetKindCase(element) == + google::protobuf::Value::kNumberValue && + internal::Number::FromDouble( + value_reflection.GetNumberValue(element)) == *other_value) { + *result = TrueValue(); + return absl::OkStatus(); + } + } + } else if (const auto other_value = other.AsString(); other_value) { + std::string scratch; + for (const auto& element : reflection.Values(*value_)) { + if (value_reflection.GetKindCase(element) == + google::protobuf::Value::kStringValue && + absl::visit( + [&](const auto& alternative) -> bool { + return *other_value == alternative; + }, + well_known_types::AsVariant( + value_reflection.GetStringValue(element, scratch)))) { + *result = TrueValue(); + return absl::OkStatus(); + } + } + } else if (const auto other_value = other.AsList(); other_value) { + for (const auto& element : reflection.Values(*value_)) { + if (value_reflection.GetKindCase(element) == + google::protobuf::Value::kListValue) { + CEL_RETURN_IF_ERROR(other_value->Equal( + ParsedJsonListValue(&value_reflection.GetListValue(element), + arena), + descriptor_pool, message_factory, arena, result)); + if (result->IsTrue()) { + return absl::OkStatus(); + } + } + } + } else if (const auto other_value = other.AsMap(); other_value) { + for (const auto& element : reflection.Values(*value_)) { + if (value_reflection.GetKindCase(element) == + google::protobuf::Value::kStructValue) { + CEL_RETURN_IF_ERROR(other_value->Equal( + ParsedJsonMapValue(&value_reflection.GetStructValue(element), + arena), + descriptor_pool, message_factory, arena, result)); + if (result->IsTrue()) { + return absl::OkStatus(); + } + } + } + } + } + *result = FalseValue(); + return absl::OkStatus(); +} + +bool operator==(const ParsedJsonListValue& lhs, + const ParsedJsonListValue& rhs) { + if (cel::to_address(lhs.value_) == cel::to_address(rhs.value_)) { + return true; + } + if (cel::to_address(lhs.value_) == nullptr) { + return rhs.IsEmpty(); + } + if (cel::to_address(rhs.value_) == nullptr) { + return lhs.IsEmpty(); + } + return internal::JsonListEquals(*lhs.value_, *rhs.value_); +} + +} // namespace cel diff --git a/common/values/parsed_json_list_value.h b/common/values/parsed_json_list_value.h new file mode 100644 index 000000000..d4f6c6e02 --- /dev/null +++ b/common/values/parsed_json_list_value.h @@ -0,0 +1,229 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_LIST_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_LIST_VALUE_H_ + +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_list_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class ValueIterator; +class ParsedRepeatedFieldValue; + +namespace common_internal { +absl::Status CheckWellKnownListValueMessage(const google::protobuf::Message& message); +} // namespace common_internal + +// ParsedJsonListValue is a ListValue backed by the google.protobuf.ListValue +// well known message type. +class ParsedJsonListValue final + : private common_internal::ListValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kList; + static constexpr absl::string_view kName = "google.protobuf.ListValue"; + + using element_type = const google::protobuf::Message; + + ParsedJsonListValue( + const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(value), arena_(arena) { + ABSL_DCHECK(value != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK_OK(CheckListValue(value_)); + ABSL_DCHECK_OK(CheckArena(value_, arena_)); + } + + // Constructs an empty `ParsedJsonListValue`. + ParsedJsonListValue() = default; + ParsedJsonListValue(const ParsedJsonListValue&) = default; + ParsedJsonListValue(ParsedJsonListValue&&) = default; + ParsedJsonListValue& operator=(const ParsedJsonListValue&) = default; + ParsedJsonListValue& operator=(ParsedJsonListValue&&) = default; + + static ValueKind kind() { return kKind; } + + static absl::string_view GetTypeName() { return kName; } + + static ListType GetRuntimeType() { return JsonListType(); } + + const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return *value_; + } + + const google::protobuf::Message* absl_nonnull operator->() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return value_; + } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonArray(). + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Equal; + + bool IsZeroValue() const { return IsEmpty(); } + + ParsedJsonListValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + bool IsEmpty() const { return Size() == 0; } + + size_t Size() const; + + // See ListValueInterface::Get for documentation. + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Get; + + using ForEachCallback = typename CustomListValueInterface::ForEachCallback; + + using ForEachWithIndexCallback = + typename CustomListValueInterface::ForEachWithIndexCallback; + + absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + using ListValueMixin::ForEach; + + absl::StatusOr NewIterator() const; + + absl::Status Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using ListValueMixin::Contains; + + explicit operator bool() const { return value_ != nullptr; } + + friend void swap(ParsedJsonListValue& lhs, + ParsedJsonListValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + swap(lhs.arena_, rhs.arena_); + } + + friend bool operator==(const ParsedJsonListValue& lhs, + const ParsedJsonListValue& rhs); + + private: + friend std::pointer_traits; + friend class ParsedRepeatedFieldValue; + friend class common_internal::ValueMixin; + friend class common_internal::ListValueMixin; + + static absl::Status CheckListValue( + const google::protobuf::Message* absl_nullable message) { + return message == nullptr + ? absl::OkStatus() + : common_internal::CheckWellKnownListValueMessage(*message); + } + + static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, + google::protobuf::Arena* absl_nonnull arena) { + if (message != nullptr && message->GetArena() != nullptr && + message->GetArena() != arena) { + return absl::InvalidArgumentError( + "message arena must be the same as arena"); + } + return absl::OkStatus(); + } + + const google::protobuf::Message* absl_nullable value_ = nullptr; + google::protobuf::Arena* absl_nullable arena_ = nullptr; +}; + +inline bool operator!=(const ParsedJsonListValue& lhs, + const ParsedJsonListValue& rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, + const ParsedJsonListValue& value) { + return out << value.DebugString(); +} + +} // namespace cel + +namespace std { + +template <> +struct pointer_traits { + using pointer = cel::ParsedJsonListValue; + using element_type = typename cel::ParsedJsonListValue::element_type; + using difference_type = ptrdiff_t; + + static element_type* to_address(const pointer& p) noexcept { + return cel::to_address(p.value_); + } +}; + +} // namespace std + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_LIST_VALUE_H_ diff --git a/common/values/parsed_json_list_value_test.cc b/common/values/parsed_json_list_value_test.cc new file mode 100644 index 000000000..017a24f9d --- /dev/null +++ b/common/values/parsed_json_list_value_test.cc @@ -0,0 +1,289 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::IsNullValue; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Optional; +using ::testing::Pair; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +using ParsedJsonListValueTest = common_internal::ValueTest<>; + +TEST_F(ParsedJsonListValueTest, Kind) { + EXPECT_EQ(ParsedJsonListValue::kind(), ParsedJsonListValue::kKind); + EXPECT_EQ(ParsedJsonListValue::kind(), ValueKind::kList); +} + +TEST_F(ParsedJsonListValueTest, GetTypeName) { + EXPECT_EQ(ParsedJsonListValue::GetTypeName(), ParsedJsonListValue::kName); + EXPECT_EQ(ParsedJsonListValue::GetTypeName(), "google.protobuf.ListValue"); +} + +TEST_F(ParsedJsonListValueTest, GetRuntimeType) { + EXPECT_EQ(ParsedJsonListValue::GetRuntimeType(), JsonListType()); +} + +TEST_F(ParsedJsonListValueTest, DebugString_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_EQ(valid_value.DebugString(), "[]"); +} + +TEST_F(ParsedJsonListValueTest, IsZeroValue_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_TRUE(valid_value.IsZeroValue()); +} + +TEST_F(ParsedJsonListValueTest, SerializeTo_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + google::protobuf::io::CordOutputStream output; + EXPECT_THAT( + valid_value.SerializeTo(descriptor_pool(), message_factory(), &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), IsEmpty()); +} + +TEST_F(ParsedJsonListValueTest, ConvertToJson_Dynamic) { + auto json = DynamicParseTextProto(R"pb()pb"); + ParsedJsonListValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_THAT(valid_value.ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(json)), + IsOk()); + EXPECT_THAT( + *json, EqualsTextProto(R"pb(list_value: {})pb")); +} + +TEST_F(ParsedJsonListValueTest, Equal_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_THAT(valid_value.Equal(BoolValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT( + valid_value.Equal( + ParsedJsonListValue( + DynamicParseTextProto(R"pb()pb"), + arena()), + descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Equal(ListValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ParsedJsonListValueTest, Empty_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_TRUE(valid_value.IsEmpty()); +} + +TEST_F(ParsedJsonListValueTest, Size_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_EQ(valid_value.Size(), 0); +} + +TEST_F(ParsedJsonListValueTest, Get_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto( + R"pb(values {} + values { bool_value: true })pb"), + arena()); + EXPECT_THAT(valid_value.Get(0, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(IsNullValue())); + EXPECT_THAT(valid_value.Get(1, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + valid_value.Get(2, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); +} + +TEST_F(ParsedJsonListValueTest, ForEach_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto( + R"pb(values {} + values { bool_value: true })pb"), + arena()); + { + std::vector values; + EXPECT_THAT(valid_value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(IsNullValue(), BoolValueIs(true))); + } + { + std::vector values; + EXPECT_THAT(valid_value.ForEach( + [&](size_t, const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(IsNullValue(), BoolValueIs(true))); + } +} + +TEST_F(ParsedJsonListValueTest, NewIterator_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto( + R"pb(values {} + values { bool_value: true })pb"), + arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(IsNullValue())); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(ParsedJsonListValueTest, NewIterator1) { + ParsedJsonListValue valid_value( + DynamicParseTextProto( + R"pb(values {} + values { bool_value: true })pb"), + arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(IsNullValue()))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedJsonListValueTest, NewIterator2) { + ParsedJsonListValue valid_value( + DynamicParseTextProto( + R"pb(values {} + values { bool_value: true })pb"), + arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(0), IsNullValue())))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(1), BoolValueIs(true))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedJsonListValueTest, Contains_Dynamic) { + ParsedJsonListValue valid_value( + DynamicParseTextProto( + R"pb(values {} + values { bool_value: true } + values { number_value: 1.0 } + values { string_value: "foo" } + values { list_value: {} } + values { struct_value: {} })pb"), + arena()); + EXPECT_THAT(valid_value.Contains(BytesValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(valid_value.Contains(NullValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Contains(BoolValue(false), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(valid_value.Contains(BoolValue(true), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Contains(DoubleValue(0.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(valid_value.Contains(DoubleValue(1.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Contains(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(valid_value.Contains(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Contains( + ParsedJsonListValue( + DynamicParseTextProto( + R"pb(values {} + values { bool_value: true } + values { number_value: 1.0 } + values { string_value: "foo" } + values { list_value: {} } + values { struct_value: {} })pb"), + arena()), + descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(valid_value.Contains(ListValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + valid_value.Contains( + ParsedJsonMapValue(DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: { bool_value: true } + })pb"), + arena()), + descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(valid_value.Contains(MapValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +} // namespace +} // namespace cel diff --git a/common/values/parsed_json_map_value.cc b/common/values/parsed_json_map_value.cc new file mode 100644 index 000000000..ec8c91a4f --- /dev/null +++ b/common/values/parsed_json_map_value.cc @@ -0,0 +1,439 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/parsed_json_map_value.h" + +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/allocator.h" +#include "common/memory.h" +#include "common/value.h" +#include "common/values/parsed_json_value.h" +#include "common/values/values.h" +#include "internal/json.h" +#include "internal/message_equality.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/map.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace cel { + +using ::cel::well_known_types::ValueReflection; + +namespace common_internal { + +absl::Status CheckWellKnownStructMessage(const google::protobuf::Message& message) { + return internal::CheckJsonMap(message); +} + +} // namespace common_internal + +std::string ParsedJsonMapValue::DebugString() const { + if (value_ == nullptr) { + return "{}"; + } + return internal::JsonMapDebugString(*value_); +} + +absl::Status ParsedJsonMapValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + if (value_ == nullptr) { + return absl::OkStatus(); + } + + if (!value_->SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + "failed to serialize message: google.protobuf.Struct"); + } + return absl::OkStatus(); +} + +absl::Status ParsedJsonMapValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + auto* message = value_reflection.MutableStructValue(json); + message->Clear(); + + if (value_ == nullptr) { + return absl::OkStatus(); + } + + if (value_->GetDescriptor() == message->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + message->CopyFrom(*value_); + } else { + // Equivalent descriptors but not identical. Must serialize and deserialize. + absl::Cord serialized; + if (!value_->SerializePartialToString(&serialized)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", value_->GetTypeName())); + } + if (!message->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", message->GetTypeName())); + } + } + return absl::OkStatus(); +} + +absl::Status ParsedJsonMapValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + if (value_ == nullptr) { + json->Clear(); + return absl::OkStatus(); + } + + if (value_->GetDescriptor() == json->GetDescriptor()) { + // We can directly use google::protobuf::Message::Copy(). + json->CopyFrom(*value_); + } else { + // Equivalent descriptors but not identical. Must serialize and deserialize. + absl::Cord serialized; + if (!value_->SerializePartialToString(&serialized)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", value_->GetTypeName())); + } + if (!json->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parsed message: ", json->GetTypeName())); + } + } + return absl::OkStatus(); +} + +absl::Status ParsedJsonMapValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (auto other_value = other.AsParsedJsonMap(); other_value) { + *result = BoolValue(*this == *other_value); + return absl::OkStatus(); + } + if (auto other_value = other.AsParsedMapField(); other_value) { + if (value_ == nullptr) { + *result = BoolValue(other_value->IsEmpty()); + return absl::OkStatus(); + } + ABSL_DCHECK(other_value->field_ != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + CEL_ASSIGN_OR_RETURN( + auto equal, internal::MessageFieldEquals( + *value_, *other_value->message_, other_value->field_, + descriptor_pool, message_factory)); + *result = BoolValue(equal); + return absl::OkStatus(); + } + if (auto other_value = other.AsMap(); other_value) { + return common_internal::MapValueEqual(MapValue(*this), *other_value, + descriptor_pool, message_factory, + arena, result); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +ParsedJsonMapValue ParsedJsonMapValue::Clone( + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + if (value_ == nullptr) { + return ParsedJsonMapValue(); + } + if (arena_ == arena) { + return *this; + } + auto* cloned = value_->New(arena); + cloned->CopyFrom(*value_); + return ParsedJsonMapValue(cloned, arena); +} + +size_t ParsedJsonMapValue::Size() const { + if (value_ == nullptr) { + return 0; + } + return static_cast( + well_known_types::GetStructReflectionOrDie(value_->GetDescriptor()) + .FieldsSize(*value_)); +} + +absl::Status ParsedJsonMapValue::Get( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + CEL_ASSIGN_OR_RETURN( + bool ok, Find(key, descriptor_pool, message_factory, arena, result)); + if (ABSL_PREDICT_FALSE(!ok) && !(result->IsError() || result->IsUnknown())) { + *result = NoSuchKeyError(key.DebugString()); + } + return absl::OkStatus(); +} + +absl::StatusOr ParsedJsonMapValue::Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (key.IsError() || key.IsUnknown()) { + *result = key; + return false; + } + if (value_ != nullptr) { + if (auto string_key = key.AsString(); string_key) { + if (ABSL_PREDICT_FALSE(value_ == nullptr)) { + *result = NullValue(); + return false; + } + std::string key_scratch; + if (const auto* value = + well_known_types::GetStructReflectionOrDie( + value_->GetDescriptor()) + .FindField(*value_, string_key->NativeString(key_scratch)); + value != nullptr) { + *result = common_internal::ParsedJsonValue(value, arena); + return true; + } + *result = NullValue(); + return false; + } + } + *result = NullValue(); + return false; +} + +absl::Status ParsedJsonMapValue::Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (key.IsError() || key.IsUnknown()) { + *result = key; + return absl::OkStatus(); + } + if (value_ != nullptr) { + if (auto string_key = key.AsString(); string_key) { + if (ABSL_PREDICT_FALSE(value_ == nullptr)) { + *result = FalseValue(); + return absl::OkStatus(); + } + std::string key_scratch; + if (const auto* value = + well_known_types::GetStructReflectionOrDie( + value_->GetDescriptor()) + .FindField(*value_, string_key->NativeString(key_scratch)); + value != nullptr) { + *result = TrueValue(); + } else { + *result = FalseValue(); + } + return absl::OkStatus(); + } + } + *result = FalseValue(); + return absl::OkStatus(); +} + +absl::Status ParsedJsonMapValue::ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { + if (value_ == nullptr) { + *result = ListValue(); + return absl::OkStatus(); + } + const auto reflection = + well_known_types::GetStructReflectionOrDie(value_->GetDescriptor()); + auto builder = NewListValueBuilder(arena); + builder->Reserve(static_cast(reflection.FieldsSize(*value_))); + auto keys_begin = reflection.BeginFields(*value_); + const auto keys_end = reflection.EndFields(*value_); + for (; keys_begin != keys_end; ++keys_begin) { + CEL_RETURN_IF_ERROR(builder->Add( + Value::WrapMapFieldKeyString(keys_begin.GetKey(), value_, arena))); + } + *result = std::move(*builder).Build(); + return absl::OkStatus(); +} + +absl::Status ParsedJsonMapValue::ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + if (value_ == nullptr) { + return absl::OkStatus(); + } + const auto reflection = + well_known_types::GetStructReflectionOrDie(value_->GetDescriptor()); + Value key_scratch; + Value value_scratch; + auto map_begin = reflection.BeginFields(*value_); + const auto map_end = reflection.EndFields(*value_); + for (; map_begin != map_end; ++map_begin) { + // We have to copy until `google::protobuf::MapKey` is just a view. + key_scratch = StringValue(arena, map_begin.GetKey().GetStringValue()); + value_scratch = common_internal::ParsedJsonValue( + &map_begin.GetValueRef().GetMessageValue(), arena); + CEL_ASSIGN_OR_RETURN(auto ok, callback(key_scratch, value_scratch)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +namespace { + +class ParsedJsonMapValueIterator final : public ValueIterator { + public: + explicit ParsedJsonMapValueIterator( + const google::protobuf::Message* absl_nonnull message) + : message_(message), + reflection_(well_known_types::GetStructReflectionOrDie( + message_->GetDescriptor())), + begin_(reflection_.BeginFields(*message_)), + end_(reflection_.EndFields(*message_)) {} + + bool HasNext() override { return begin_ != end_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (ABSL_PREDICT_FALSE(begin_ == end_)) { + return absl::FailedPreconditionError( + "`ValueIterator::Next` called after `ValueIterator::HasNext` " + "returned false"); + } + *result = Value::WrapMapFieldKeyString(begin_.GetKey(), message_, arena); + ++begin_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (begin_ == end_) { + return false; + } + *key_or_value = + Value::WrapMapFieldKeyString(begin_.GetKey(), message_, arena); + ++begin_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (begin_ == end_) { + return false; + } + *key = Value::WrapMapFieldKeyString(begin_.GetKey(), message_, arena); + if (value != nullptr) { + *value = common_internal::ParsedJsonValue( + &begin_.GetValueRef().GetMessageValue(), arena); + } + ++begin_; + return true; + } + + private: + const google::protobuf::Message* absl_nonnull const message_; + const well_known_types::StructReflection reflection_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; + std::string scratch_; +}; + +} // namespace + +absl::StatusOr> +ParsedJsonMapValue::NewIterator() const { + if (value_ == nullptr) { + return NewEmptyValueIterator(); + } + return std::make_unique(value_); +} + +bool operator==(const ParsedJsonMapValue& lhs, const ParsedJsonMapValue& rhs) { + if (cel::to_address(lhs.value_) == cel::to_address(rhs.value_)) { + return true; + } + if (cel::to_address(lhs.value_) == nullptr) { + return rhs.IsEmpty(); + } + if (cel::to_address(rhs.value_) == nullptr) { + return lhs.IsEmpty(); + } + return internal::JsonMapEquals(*lhs.value_, *rhs.value_); +} + +} // namespace cel diff --git a/common/values/parsed_json_map_value.h b/common/values/parsed_json_map_value.h new file mode 100644 index 000000000..ba8d3490d --- /dev/null +++ b/common/values/parsed_json_map_value.h @@ -0,0 +1,250 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_MAP_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_MAP_VALUE_H_ + +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_map_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class ListValue; +class ValueIterator; +class ParsedMapFieldValue; + +namespace common_internal { +absl::Status CheckWellKnownStructMessage(const google::protobuf::Message& message); +} // namespace common_internal + +// ParsedJsonMapValue is a MapValue backed by the google.protobuf.Struct +// well known message type. +class ParsedJsonMapValue final + : private common_internal::MapValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kMap; + static constexpr absl::string_view kName = "google.protobuf.Struct"; + + using element_type = const google::protobuf::Message; + + ParsedJsonMapValue( + const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(value), arena_(arena) { + ABSL_DCHECK(value != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK_OK(CheckStruct(value_)); + ABSL_DCHECK_OK(CheckArena(value_, arena_)); + } + + // Constructs an empty `ParsedJsonMapValue`. + ParsedJsonMapValue() = default; + ParsedJsonMapValue(const ParsedJsonMapValue&) = default; + ParsedJsonMapValue(ParsedJsonMapValue&&) = default; + ParsedJsonMapValue& operator=(const ParsedJsonMapValue&) = default; + ParsedJsonMapValue& operator=(ParsedJsonMapValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + static absl::string_view GetTypeName() { return kName; } + + static MapType GetRuntimeType() { return JsonMapType(); } + + const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return *value_; + } + + const google::protobuf::Message* absl_nonnull operator->() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return value_; + } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Equal; + + bool IsZeroValue() const { return IsEmpty(); } + + ParsedJsonMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + bool IsEmpty() const { return Size() == 0; } + + size_t Size() const; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Get(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Get; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using MapValueMixin::Find; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Has(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Has; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; + using MapValueMixin::ListKeys; + + // See the corresponding type declaration of `MapValue` for + // documentation. + using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr> NewIterator() + const; + + explicit operator bool() const { return value_ != nullptr; } + + friend void swap(ParsedJsonMapValue& lhs, ParsedJsonMapValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + swap(lhs.arena_, rhs.arena_); + } + + friend bool operator==(const ParsedJsonMapValue& lhs, + const ParsedJsonMapValue& rhs); + + private: + friend std::pointer_traits; + friend class ParsedMapFieldValue; + friend class common_internal::ValueMixin; + friend class common_internal::MapValueMixin; + + static absl::Status CheckStruct( + const google::protobuf::Message* absl_nullable message) { + return message == nullptr + ? absl::OkStatus() + : common_internal::CheckWellKnownStructMessage(*message); + } + + static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, + google::protobuf::Arena* absl_nonnull arena) { + if (message != nullptr && message->GetArena() != nullptr && + message->GetArena() != arena) { + return absl::InvalidArgumentError( + "message arena must be the same as arena"); + } + return absl::OkStatus(); + } + + const google::protobuf::Message* absl_nullable value_ = nullptr; + google::protobuf::Arena* absl_nullable arena_ = nullptr; +}; + +inline bool operator!=(const ParsedJsonMapValue& lhs, + const ParsedJsonMapValue& rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, + const ParsedJsonMapValue& value) { + return out << value.DebugString(); +} + +} // namespace cel + +namespace std { + +template <> +struct pointer_traits { + using pointer = cel::ParsedJsonMapValue; + using element_type = typename cel::ParsedJsonMapValue::element_type; + using difference_type = ptrdiff_t; + + static element_type* to_address(const pointer& p) noexcept { + return cel::to_address(p.value_); + } +}; + +} // namespace std + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_MAP_VALUE_H_ diff --git a/common/values/parsed_json_map_value_test.cc b/common/values/parsed_json_map_value_test.cc new file mode 100644 index 000000000..b65128076 --- /dev/null +++ b/common/values/parsed_json_map_value_test.cc @@ -0,0 +1,340 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::StringValueIs; +using ::testing::AnyOf; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Optional; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +using ParsedJsonMapValueTest = common_internal::ValueTest<>; + +TEST_F(ParsedJsonMapValueTest, Kind) { + EXPECT_EQ(ParsedJsonMapValue::kind(), ParsedJsonMapValue::kKind); + EXPECT_EQ(ParsedJsonMapValue::kind(), ValueKind::kMap); +} + +TEST_F(ParsedJsonMapValueTest, GetTypeName) { + EXPECT_EQ(ParsedJsonMapValue::GetTypeName(), ParsedJsonMapValue::kName); + EXPECT_EQ(ParsedJsonMapValue::GetTypeName(), "google.protobuf.Struct"); +} + +TEST_F(ParsedJsonMapValueTest, GetRuntimeType) { + EXPECT_EQ(ParsedJsonMapValue::GetRuntimeType(), JsonMapType()); +} + +TEST_F(ParsedJsonMapValueTest, DebugString_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_EQ(valid_value.DebugString(), "{}"); +} + +TEST_F(ParsedJsonMapValueTest, IsZeroValue_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_TRUE(valid_value.IsZeroValue()); +} + +TEST_F(ParsedJsonMapValueTest, SerializeTo_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + google::protobuf::io::CordOutputStream output; + EXPECT_THAT( + valid_value.SerializeTo(descriptor_pool(), message_factory(), &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), IsEmpty()); +} + +TEST_F(ParsedJsonMapValueTest, ConvertToJson_Dynamic) { + auto json = DynamicParseTextProto(R"pb()pb"); + ParsedJsonMapValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_THAT(valid_value.ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(json)), + IsOk()); + EXPECT_THAT(*json, EqualsTextProto( + R"pb(struct_value: {})pb")); +} + +TEST_F(ParsedJsonMapValueTest, Equal_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_THAT(valid_value.Equal(BoolValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT( + valid_value.Equal( + ParsedJsonMapValue( + DynamicParseTextProto(R"pb()pb"), + arena()), + descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Equal(MapValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ParsedJsonMapValueTest, Empty_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_TRUE(valid_value.IsEmpty()); +} + +TEST_F(ParsedJsonMapValueTest, Size_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto(R"pb()pb"), arena()); + EXPECT_EQ(valid_value.Size(), 0); +} + +TEST_F(ParsedJsonMapValueTest, Get_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + EXPECT_THAT( + valid_value.Get(BoolValue(), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); + EXPECT_THAT(valid_value.Get(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IsNullValue())); + EXPECT_THAT(valid_value.Get(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + valid_value.Get(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); +} + +TEST_F(ParsedJsonMapValueTest, Find_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + EXPECT_THAT(valid_value.Find(BoolValue(), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(valid_value.Find(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsNullValue()))); + EXPECT_THAT(valid_value.Find(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + EXPECT_THAT(valid_value.Find(StringValue("baz"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedJsonMapValueTest, Has_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + EXPECT_THAT(valid_value.Has(BoolValue(), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(valid_value.Has(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Has(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(valid_value.Has(StringValue("baz"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(ParsedJsonMapValueTest, ListKeys_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + ASSERT_OK_AND_ASSIGN( + auto keys, + valid_value.ListKeys(descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(keys.Size(), IsOkAndHolds(2)); + EXPECT_THAT(keys.DebugString(), + AnyOf("[\"foo\", \"bar\"]", "[\"bar\", \"foo\"]")); + EXPECT_THAT( + keys.Contains(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(keys.Contains(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(keys.Get(0, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); + EXPECT_THAT(keys.Get(1, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); +} + +TEST_F(ParsedJsonMapValueTest, ForEach_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + std::vector> entries; + EXPECT_THAT( + valid_value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, UnorderedElementsAre( + Pair(StringValueIs("foo"), IsNullValue()), + Pair(StringValueIs("bar"), BoolValueIs(true)))); +} + +TEST_F(ParsedJsonMapValueTest, NewIterator_Dynamic) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); + ASSERT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(ParsedJsonMapValueTest, NewIterator1) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds( + Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds( + Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedJsonMapValueTest, NewIterator2) { + ParsedJsonMapValue valid_value( + DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb"), + arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional( + AnyOf(Pair(StringValueIs("foo"), IsNullValue()), + Pair(StringValueIs("bar"), BoolValueIs(true)))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional( + AnyOf(Pair(StringValueIs("foo"), IsNullValue()), + Pair(StringValueIs("bar"), BoolValueIs(true)))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +} // namespace +} // namespace cel diff --git a/common/values/parsed_json_value.cc b/common/values/parsed_json_value.cc new file mode 100644 index 000000000..6b10bea40 --- /dev/null +++ b/common/values/parsed_json_value.cc @@ -0,0 +1,103 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/parsed_json_value.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "common/allocator.h" +#include "common/memory.h" +#include "common/value.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace cel::common_internal { + +namespace { + +using ::cel::well_known_types::AsVariant; +using ::cel::well_known_types::GetValueReflectionOrDie; + +google::protobuf::Arena* absl_nonnull MessageArenaOr( + const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull or_arena) { + google::protobuf::Arena* absl_nullable arena = message->GetArena(); + if (arena == nullptr) { + arena = or_arena; + } + return arena; +} + +} // namespace + +Value ParsedJsonValue(const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena) { + const auto reflection = GetValueReflectionOrDie(message->GetDescriptor()); + const auto kind_case = reflection.GetKindCase(*message); + switch (kind_case) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + return NullValue(); + case google::protobuf::Value::kBoolValue: + return BoolValue(reflection.GetBoolValue(*message)); + case google::protobuf::Value::kNumberValue: + return DoubleValue(reflection.GetNumberValue(*message)); + case google::protobuf::Value::kStringValue: { + std::string scratch; + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> StringValue { + if (string.empty()) { + return StringValue(); + } + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + return StringValue(arena, std::move(scratch)); + } else { + return StringValue( + Borrower::Arena(MessageArenaOr(message, arena)), string); + } + }, + [&](absl::Cord&& cord) -> StringValue { + if (cord.empty()) { + return StringValue(); + } + return StringValue(std::move(cord)); + }), + AsVariant(reflection.GetStringValue(*message, scratch))); + } + case google::protobuf::Value::kListValue: + return ParsedJsonListValue(&reflection.GetListValue(*message), + MessageArenaOr(message, arena)); + case google::protobuf::Value::kStructValue: + return ParsedJsonMapValue(&reflection.GetStructValue(*message), + MessageArenaOr(message, arena)); + default: + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("unexpected value kind case: ", kind_case))); + } +} + +} // namespace cel::common_internal diff --git a/common/values/parsed_json_value.h b/common/values/parsed_json_value.h new file mode 100644 index 000000000..e781b855e --- /dev/null +++ b/common/values/parsed_json_value.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_VALUE_H_ + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; + +namespace common_internal { + +// Adapts the given instance of the well known message type +// `google.protobuf.Value` to `cel::Value`. If the underlying value is a string +// and the string had to be copied, `allocator` will be used to create a new +// string value. This should be rare and unlikely. +Value ParsedJsonValue(const google::protobuf::Message* absl_nonnull message, + google::protobuf::Arena* absl_nonnull arena); + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_VALUE_H_ diff --git a/common/values/parsed_json_value_test.cc b/common/values/parsed_json_value_test.cc new file mode 100644 index 000000000..7a6fbf5d4 --- /dev/null +++ b/common/values/parsed_json_value_test.cc @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/parsed_json_value.h" + +#include "google/protobuf/struct.pb.h" +#include "absl/strings/string_view.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" + +namespace cel::common_internal { +namespace { + +using ::cel::test::BoolValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::ListValueElements; +using ::cel::test::ListValueIs; +using ::cel::test::MapValueElements; +using ::cel::test::MapValueIs; +using ::cel::test::StringValueIs; +using ::testing::ElementsAre; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +using ParsedJsonValueTest = common_internal::ValueTest<>; + +TEST_F(ParsedJsonValueTest, Null_Dynamic) { + EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( + R"pb(null_value: NULL_VALUE)pb"), + arena()), + IsNullValue()); + EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( + R"pb(null_value: NULL_VALUE)pb"), + arena()), + IsNullValue()); +} + +TEST_F(ParsedJsonValueTest, Bool_Dynamic) { + EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( + R"pb(bool_value: true)pb"), + arena()), + BoolValueIs(true)); +} + +TEST_F(ParsedJsonValueTest, Double_Dynamic) { + EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( + R"pb(number_value: 1.0)pb"), + arena()), + DoubleValueIs(1.0)); +} + +TEST_F(ParsedJsonValueTest, String_Dynamic) { + EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( + R"pb(string_value: "foo")pb"), + arena()), + StringValueIs("foo")); +} + +TEST_F(ParsedJsonValueTest, List_Dynamic) { + EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"), + arena()), + ListValueIs(ListValueElements( + ElementsAre(IsNullValue(), BoolValueIs(true)), + descriptor_pool(), message_factory(), arena()))); +} + +TEST_F(ParsedJsonValueTest, Map_Dynamic) { + EXPECT_THAT( + ParsedJsonValue(DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"), + arena()), + MapValueIs(MapValueElements( + UnorderedElementsAre(Pair(StringValueIs("foo"), IsNullValue()), + Pair(StringValueIs("bar"), BoolValueIs(true))), + descriptor_pool(), message_factory(), arena()))); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/values/parsed_map_field_value.cc b/common/values/parsed_map_field_value.cc new file mode 100644 index 000000000..47b737f82 --- /dev/null +++ b/common/values/parsed_map_field_value.cc @@ -0,0 +1,575 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/parsed_map_field_value.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/value.h" +#include "common/values/values.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/json.h" +#include "internal/message_equality.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +namespace cel { + +using ::cel::well_known_types::ValueReflection; + +std::string ParsedMapFieldValue::DebugString() const { + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return "INVALID"; + } + return "VALID"; +} + +absl::Status ParsedMapFieldValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return absl::OkStatus(); + } + // We have to convert to google.protobuf.Struct first. + google::protobuf::Value message; + CEL_RETURN_IF_ERROR(internal::MessageFieldToJson( + *message_, field_, descriptor_pool, message_factory, &message)); + if (!message.list_value().SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError("failed to serialize google.protobuf.Struct"); + } + return absl::OkStatus(); +} + +absl::Status ParsedMapFieldValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.MutableStructValue(json)->Clear(); + return absl::OkStatus(); + } + return internal::MessageFieldToJson(*message_, field_, descriptor_pool, + message_factory, json); +} + +absl::Status ParsedMapFieldValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + json->Clear(); + return absl::OkStatus(); + } + return internal::MessageFieldToJson(*message_, field_, descriptor_pool, + message_factory, json); +} + +absl::Status ParsedMapFieldValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (auto other_value = other.AsParsedMapField(); other_value) { + ABSL_DCHECK(field_ != nullptr); + ABSL_DCHECK(other_value->field_ != nullptr); + CEL_ASSIGN_OR_RETURN( + auto equal, internal::MessageFieldEquals( + *message_, field_, *other_value->message_, + other_value->field_, descriptor_pool, message_factory)); + *result = BoolValue(equal); + return absl::OkStatus(); + } + if (auto other_value = other.AsParsedJsonMap(); other_value) { + if (other_value->value_ == nullptr) { + *result = BoolValue(IsEmpty()); + return absl::OkStatus(); + } + ABSL_DCHECK(field_ != nullptr); + CEL_ASSIGN_OR_RETURN( + auto equal, + internal::MessageFieldEquals(*message_, field_, *other_value->value_, + descriptor_pool, message_factory)); + *result = BoolValue(equal); + return absl::OkStatus(); + } + if (auto other_value = other.AsMap(); other_value) { + return common_internal::MapValueEqual(MapValue(*this), *other_value, + descriptor_pool, message_factory, + arena, result); + } + *result = BoolValue(false); + return absl::OkStatus(); +} + +bool ParsedMapFieldValue::IsZeroValue() const { return IsEmpty(); } + +ParsedMapFieldValue ParsedMapFieldValue::Clone( + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return ParsedMapFieldValue(); + } + if (arena_ == arena) { + return *this; + } + auto field = message_->GetReflection()->GetRepeatedFieldRef( + *message_, field_); + auto* cloned = message_->New(arena); + auto cloned_field = + cloned->GetReflection()->GetMutableRepeatedFieldRef( + cloned, field_); + cloned_field.CopyFrom(field); + return ParsedMapFieldValue(cloned, field_, arena); +} + +bool ParsedMapFieldValue::IsEmpty() const { return Size() == 0; } + +size_t ParsedMapFieldValue::Size() const { + ABSL_DCHECK(*this); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return 0; + } + return static_cast(extensions::protobuf_internal::MapSize( + *GetReflection(), *message_, *field_)); +} + +namespace { + +absl::optional ValueAsInt32(const Value& value) { + if (auto int_value = value.AsInt(); + int_value && + int_value->NativeValue() >= std::numeric_limits::min() && + int_value->NativeValue() <= std::numeric_limits::max()) { + return static_cast(int_value->NativeValue()); + } else if (auto uint_value = value.AsUint(); + uint_value && + uint_value->NativeValue() <= std::numeric_limits::max()) { + return static_cast(uint_value->NativeValue()); + } else if (auto double_value = value.AsDouble(); + double_value && + static_cast(static_cast( + double_value->NativeValue())) == double_value->NativeValue()) { + return static_cast(double_value->NativeValue()); + } + return absl::nullopt; +} + +absl::optional ValueAsInt64(const Value& value) { + if (auto int_value = value.AsInt(); int_value) { + return int_value->NativeValue(); + } else if (auto uint_value = value.AsUint(); + uint_value && + uint_value->NativeValue() <= std::numeric_limits::max()) { + return static_cast(uint_value->NativeValue()); + } else if (auto double_value = value.AsDouble(); + double_value && + static_cast(static_cast( + double_value->NativeValue())) == double_value->NativeValue()) { + return static_cast(double_value->NativeValue()); + } + return absl::nullopt; +} + +absl::optional ValueAsUInt32(const Value& value) { + if (auto int_value = value.AsInt(); + int_value && int_value->NativeValue() >= 0 && + int_value->NativeValue() <= std::numeric_limits::max()) { + return static_cast(int_value->NativeValue()); + } else if (auto uint_value = value.AsUint(); + uint_value && uint_value->NativeValue() <= + std::numeric_limits::max()) { + return static_cast(uint_value->NativeValue()); + } else if (auto double_value = value.AsDouble(); + double_value && + static_cast(static_cast( + double_value->NativeValue())) == double_value->NativeValue()) { + return static_cast(double_value->NativeValue()); + } + return absl::nullopt; +} + +absl::optional ValueAsUInt64(const Value& value) { + if (auto int_value = value.AsInt(); + int_value && int_value->NativeValue() >= 0) { + return static_cast(int_value->NativeValue()); + } else if (auto uint_value = value.AsUint(); uint_value) { + return uint_value->NativeValue(); + } else if (auto double_value = value.AsDouble(); + double_value && + static_cast(static_cast( + double_value->NativeValue())) == double_value->NativeValue()) { + return static_cast(double_value->NativeValue()); + } + return absl::nullopt; +} + +bool ValueToProtoMapKey(const Value& key, + google::protobuf::FieldDescriptor::CppType cpp_type, + google::protobuf::MapKey* absl_nonnull proto_key, + std::string& proto_key_scratch) { + switch (cpp_type) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + if (auto bool_key = key.AsBool(); bool_key) { + proto_key->SetBoolValue(bool_key->NativeValue()); + return true; + } + return false; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + if (auto int_key = ValueAsInt32(key); int_key) { + proto_key->SetInt32Value(*int_key); + return true; + } + return false; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + if (auto int_key = ValueAsInt64(key); int_key) { + proto_key->SetInt64Value(*int_key); + return true; + } + return false; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + if (auto int_key = ValueAsUInt32(key); int_key) { + proto_key->SetUInt32Value(*int_key); + return true; + } + return false; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + if (auto int_key = ValueAsUInt64(key); int_key) { + proto_key->SetUInt64Value(*int_key); + return true; + } + return false; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (auto string_key = key.AsString(); string_key) { + proto_key_scratch = string_key->NativeString(); + proto_key->SetStringValue(proto_key_scratch); + return true; + } + return false; + } + default: + // protobuf map keys can only be bool, integrals, or string. + return false; + } +} + +} // namespace + +absl::Status ParsedMapFieldValue::Get( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + CEL_ASSIGN_OR_RETURN( + bool ok, Find(key, descriptor_pool, message_factory, arena, result)); + if (ABSL_PREDICT_FALSE(!ok) && !(result->IsError() || result->IsUnknown())) { + *result = ErrorValue(NoSuchKeyError(key.DebugString())); + } + return absl::OkStatus(); +} + +absl::StatusOr ParsedMapFieldValue::Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(*this); + ABSL_DCHECK(message_ != nullptr); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + *result = NullValue(); + return false; + } + if (key.IsError() || key.IsUnknown()) { + *result = key; + return false; + } + const google::protobuf::Descriptor* absl_nonnull entry_descriptor = + field_->message_type(); + const google::protobuf::FieldDescriptor* absl_nonnull key_field = + entry_descriptor->map_key(); + const google::protobuf::FieldDescriptor* absl_nonnull value_field = + entry_descriptor->map_value(); + std::string proto_key_scratch; + google::protobuf::MapKey proto_key; + if (!ValueToProtoMapKey(key, key_field->cpp_type(), &proto_key, + proto_key_scratch)) { + *result = NullValue(); + return false; + } + google::protobuf::MapValueConstRef proto_value; + if (!extensions::protobuf_internal::LookupMapValue( + *GetReflection(), *message_, *field_, proto_key, &proto_value)) { + *result = NullValue(); + return false; + } + if (arena_ == nullptr) { + *result = + Value::WrapMapFieldValueUnsafe(proto_value, message_, value_field, + descriptor_pool, message_factory, arena); + } else { + *result = Value::WrapMapFieldValue(proto_value, message_, value_field, + descriptor_pool, message_factory, arena); + } + return true; +} + +absl::Status ParsedMapFieldValue::Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(*this); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + *result = BoolValue(false); + return absl::OkStatus(); + } + const google::protobuf::FieldDescriptor* absl_nonnull key_field = + field_->message_type()->map_key(); + std::string proto_key_scratch; + google::protobuf::MapKey proto_key; + bool bool_result; + if (ValueToProtoMapKey(key, key_field->cpp_type(), &proto_key, + proto_key_scratch)) { + google::protobuf::MapValueConstRef proto_value; + bool_result = extensions::protobuf_internal::LookupMapValue( + *GetReflection(), *message_, *field_, proto_key, &proto_value); + } else { + bool_result = false; + } + *result = BoolValue(bool_result); + return absl::OkStatus(); +} + +absl::Status ParsedMapFieldValue::ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { + ABSL_DCHECK(*this); + if (field_ == nullptr) { + *result = ListValue(); + return absl::OkStatus(); + } + const auto* reflection = message_->GetReflection(); + if (reflection->FieldSize(*message_, field_) == 0) { + *result = ListValue(); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto key_accessor, + common_internal::MapFieldKeyAccessorFor( + field_->message_type()->map_key())); + auto builder = NewListValueBuilder(arena); + builder->Reserve(Size()); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); + for (; begin != end; ++begin) { + Value scratch; + (*key_accessor)(begin.GetKey(), message_, arena, &scratch); + CEL_RETURN_IF_ERROR(builder->Add(std::move(scratch))); + } + *result = std::move(*builder).Build(); + return absl::OkStatus(); +} + +absl::Status ParsedMapFieldValue::ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(*this); + if (field_ == nullptr) { + return absl::OkStatus(); + } + const auto* reflection = message_->GetReflection(); + if (reflection->FieldSize(*message_, field_) > 0) { + const auto* value_field = field_->message_type()->map_value(); + CEL_ASSIGN_OR_RETURN(auto key_accessor, + common_internal::MapFieldKeyAccessorFor( + field_->message_type()->map_key())); + CEL_ASSIGN_OR_RETURN( + auto value_accessor, + common_internal::MapFieldValueAccessorFor(value_field)); + auto begin = extensions::protobuf_internal::ConstMapBegin( + *reflection, *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); + Value key_scratch; + Value value_scratch; + for (; begin != end; ++begin) { + (*key_accessor)(begin.GetKey(), message_, arena, &key_scratch); + (*value_accessor)(begin.GetValueRef(), message_, value_field, + descriptor_pool, message_factory, arena, + &value_scratch); + CEL_ASSIGN_OR_RETURN(auto ok, callback(key_scratch, value_scratch)); + if (!ok) { + break; + } + } + } + return absl::OkStatus(); +} + +namespace { + +class ParsedMapFieldValueIterator final : public ValueIterator { + public: + ParsedMapFieldValueIterator( + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + absl_nonnull common_internal::MapFieldKeyAccessor key_accessor, + absl_nonnull common_internal::MapFieldValueAccessor value_accessor) + : message_(message), + value_field_(field->message_type()->map_value()), + key_accessor_(key_accessor), + value_accessor_(value_accessor), + begin_(extensions::protobuf_internal::ConstMapBegin( + *message_->GetReflection(), *message_, *field)), + end_(extensions::protobuf_internal::ConstMapEnd( + *message_->GetReflection(), *message_, *field)) {} + + bool HasNext() override { return begin_ != end_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (ABSL_PREDICT_FALSE(begin_ == end_)) { + return absl::FailedPreconditionError( + "ValueIterator::Next called after ValueIterator::HasNext returned " + "false"); + } + (*key_accessor_)(begin_.GetKey(), message_, arena, result); + ++begin_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (begin_ == end_) { + return false; + } + (*key_accessor_)(begin_.GetKey(), message_, arena, key_or_value); + ++begin_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (begin_ == end_) { + return false; + } + (*key_accessor_)(begin_.GetKey(), message_, arena, key); + if (value != nullptr) { + (*value_accessor_)(begin_.GetValueRef(), message_, value_field_, + descriptor_pool, message_factory, arena, value); + } + ++begin_; + return true; + } + + private: + const google::protobuf::Message* absl_nonnull const message_; + const google::protobuf::FieldDescriptor* absl_nonnull const value_field_; + const absl_nonnull common_internal::MapFieldKeyAccessor key_accessor_; + const absl_nonnull common_internal::MapFieldValueAccessor value_accessor_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; +}; + +} // namespace + +absl::StatusOr> +ParsedMapFieldValue::NewIterator() const { + ABSL_DCHECK(*this); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return NewEmptyValueIterator(); + } + CEL_ASSIGN_OR_RETURN(auto key_accessor, + common_internal::MapFieldKeyAccessorFor( + field_->message_type()->map_key())); + CEL_ASSIGN_OR_RETURN(auto value_accessor, + common_internal::MapFieldValueAccessorFor( + field_->message_type()->map_value())); + return std::make_unique( + message_, field_, key_accessor, value_accessor); +} + +const google::protobuf::Reflection* absl_nonnull ParsedMapFieldValue::GetReflection() + const { + return message_->GetReflection(); +} + +} // namespace cel diff --git a/common/values/parsed_map_field_value.h b/common/values/parsed_map_field_value.h new file mode 100644 index 000000000..21d686bfd --- /dev/null +++ b/common/values/parsed_map_field_value.h @@ -0,0 +1,242 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_FIELD_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_FIELD_VALUE_H_ + +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_map_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class ValueIterator; +class ListValue; +class ParsedJsonMapValue; + +// ParsedMapFieldValue is a MapValue over a map field of a parsed protocol +// buffer message. +class ParsedMapFieldValue final + : private common_internal::MapValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kMap; + static constexpr absl::string_view kName = "map"; + + ParsedMapFieldValue(const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::Arena* absl_nonnull arena) + : message_(message), field_(field), arena_(arena) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(field_->is_map()) + << field_->full_name() << " must be a map field"; + ABSL_DCHECK_OK(CheckArena(message_, arena_)); + } + + // Places the `ParsedMapFieldValue` into an invalid state. Anything + // except assigning to `ParsedMapFieldValue` is undefined behavior. + ParsedMapFieldValue() = default; + + ParsedMapFieldValue(const ParsedMapFieldValue&) = default; + ParsedMapFieldValue(ParsedMapFieldValue&&) = default; + ParsedMapFieldValue& operator=(const ParsedMapFieldValue&) = default; + ParsedMapFieldValue& operator=(ParsedMapFieldValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + static constexpr absl::string_view GetTypeName() { return kName; } + + static MapType GetRuntimeType() { return MapType(); } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Equal; + + bool IsZeroValue() const; + + ParsedMapFieldValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + bool IsEmpty() const; + + size_t Size() const; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Get(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Get; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using MapValueMixin::Find; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status Has(const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using MapValueMixin::Has; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; + using MapValueMixin::ListKeys; + + // See the corresponding type declaration of `MapValue` for + // documentation. + using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; + + // See the corresponding member function of `MapValue` for + // documentation. + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr> NewIterator() + const; + + const google::protobuf::Message& message() const { + ABSL_DCHECK(*this); + return *message_; + } + + const google::protobuf::FieldDescriptor* absl_nonnull field() const { + ABSL_DCHECK(*this); + return field_; + } + + // Returns `true` if `ParsedMapFieldValue` is in a valid state. + explicit operator bool() const { return field_ != nullptr; } + + friend void swap(ParsedMapFieldValue& lhs, + ParsedMapFieldValue& rhs) noexcept { + using std::swap; + swap(lhs.message_, rhs.message_); + swap(lhs.field_, rhs.field_); + swap(lhs.arena_, rhs.arena_); + } + + private: + friend class ParsedJsonMapValue; + friend class common_internal::ValueMixin; + friend class common_internal::MapValueMixin; + friend ParsedMapFieldValue UnsafeParsedMapFieldValue( + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field); + + ParsedMapFieldValue(const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field) + : message_(message), field_(field), arena_(message->GetArena()) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(field_->is_map()) + << field_->full_name() << " must be a map field"; + } + + static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, + google::protobuf::Arena* absl_nonnull arena) { + if (message != nullptr && message->GetArena() != nullptr && + message->GetArena() != arena) { + return absl::InvalidArgumentError( + "message arena must be the same as arena"); + } + return absl::OkStatus(); + } + + const google::protobuf::Reflection* absl_nonnull GetReflection() const; + + const google::protobuf::Message* absl_nullable message_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable field_ = nullptr; + google::protobuf::Arena* absl_nullable arena_ = nullptr; +}; + +// Creates a `ParsedMapFieldValue` without specifying a managing arena. +// The message must outlive the `ParsedMapFieldValue` or any value that +// might be derived from it. Prefer to use +// `cel::Value::WrapMapFieldValueUnsafe()`. +inline ParsedMapFieldValue UnsafeParsedMapFieldValue( + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field) { + return ParsedMapFieldValue(message, field); +} + +inline std::ostream& operator<<(std::ostream& out, + const ParsedMapFieldValue& value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_FIELD_VALUE_H_ diff --git a/common/values/parsed_map_field_value_test.cc b/common/values/parsed_map_field_value_test.cc new file mode 100644 index 000000000..271813f40 --- /dev/null +++ b/common/values/parsed_map_field_value_test.cc @@ -0,0 +1,571 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::BytesValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::DurationValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::StringValueIs; +using ::cel::test::UintValueIs; +using ::testing::_; +using ::testing::AnyOf; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Optional; +using ::testing::Pair; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +using ParsedMapFieldValueTest = common_internal::ValueTest<>; + +TEST_F(ParsedMapFieldValueTest, Field) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_TRUE(value); +} + +TEST_F(ParsedMapFieldValueTest, Kind) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_EQ(value.kind(), ParsedMapFieldValue::kKind); + EXPECT_EQ(value.kind(), ValueKind::kMap); +} + +TEST_F(ParsedMapFieldValueTest, GetTypeName) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_EQ(value.GetTypeName(), ParsedMapFieldValue::kName); + EXPECT_EQ(value.GetTypeName(), "map"); +} + +TEST_F(ParsedMapFieldValueTest, GetRuntimeType) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_EQ(value.GetRuntimeType(), MapType()); +} + +TEST_F(ParsedMapFieldValueTest, DebugString) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_THAT(value.DebugString(), _); +} + +TEST_F(ParsedMapFieldValueTest, IsZeroValue) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_TRUE(value.IsZeroValue()); +} + +TEST_F(ParsedMapFieldValueTest, SerializeTo) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), IsEmpty()); +} + +TEST_F(ParsedMapFieldValueTest, ConvertToJson) { + auto json = DynamicParseTextProto(R"pb()pb"); + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_THAT(value.ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(json)), + IsOk()); + EXPECT_THAT(*json, EqualsTextProto( + R"pb(struct_value: {})pb")); +} + +TEST_F(ParsedMapFieldValueTest, Equal_MapField) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_THAT( + value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT( + value.Equal( + ParsedMapFieldValue( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int32_int32"), arena()), + descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + value.Equal(MapValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ParsedMapFieldValueTest, Equal_JsonMap) { + ParsedMapFieldValue map_value( + DynamicParseTextProto( + R"pb(map_string_string { key: "foo" value: "bar" } + map_string_string { key: "bar" value: "foo" })pb"), + DynamicGetField("map_string_string"), arena()); + ParsedJsonMapValue json_value(DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value { string_value: "bar" } + } + fields { + key: "bar" + value { string_value: "foo" } + } + )pb"), + arena()); + EXPECT_THAT(map_value.Equal(json_value, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(json_value.Equal(map_value, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ParsedMapFieldValueTest, Empty) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_TRUE(value.IsEmpty()); +} + +TEST_F(ParsedMapFieldValueTest, Size) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("map_int64_int64"), arena()); + EXPECT_EQ(value.Size(), 0); +} + +TEST_F(ParsedMapFieldValueTest, Get) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + EXPECT_THAT( + value.Get(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); + EXPECT_THAT(value.Get(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Get(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + value.Get(StringValue("baz"), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); +} + +TEST_F(ParsedMapFieldValueTest, Find) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + EXPECT_THAT( + value.Find(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(value.Find(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(false)))); + EXPECT_THAT(value.Find(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + EXPECT_THAT(value.Find(StringValue("baz"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedMapFieldValueTest, Has) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + EXPECT_THAT( + value.Has(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Has(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(value.Has(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(value.Has(StringValue("baz"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(ParsedMapFieldValueTest, ListKeys) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + ASSERT_OK_AND_ASSIGN( + auto keys, value.ListKeys(descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(keys.Size(), IsOkAndHolds(2)); + EXPECT_THAT(keys.DebugString(), + AnyOf("[\"foo\", \"bar\"]", "[\"bar\", \"foo\"]")); + EXPECT_THAT( + keys.Contains(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(keys.Contains(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(keys.Get(0, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); + EXPECT_THAT(keys.Get(1, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_StringBool) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, UnorderedElementsAre( + Pair(StringValueIs("foo"), BoolValueIs(false)), + Pair(StringValueIs("bar"), BoolValueIs(true)))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_Int32Double) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_int32_double { key: 1 value: 2 } + map_int32_double { key: 2 value: 1 } + )pb"), + DynamicGetField("map_int32_double"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(IntValueIs(1), DoubleValueIs(2)), + Pair(IntValueIs(2), DoubleValueIs(1)))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_Int64Float) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_int64_float { key: 1 value: 2 } + map_int64_float { key: 2 value: 1 } + )pb"), + DynamicGetField("map_int64_float"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(IntValueIs(1), DoubleValueIs(2)), + Pair(IntValueIs(2), DoubleValueIs(1)))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_UInt32UInt64) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_uint32_uint64 { key: 1 value: 2 } + map_uint32_uint64 { key: 2 value: 1 } + )pb"), + DynamicGetField("map_uint32_uint64"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(UintValueIs(1), UintValueIs(2)), + Pair(UintValueIs(2), UintValueIs(1)))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_UInt64Int32) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_uint64_int32 { key: 1 value: 2 } + map_uint64_int32 { key: 2 value: 1 } + )pb"), + DynamicGetField("map_uint64_int32"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(UintValueIs(1), IntValueIs(2)), + Pair(UintValueIs(2), IntValueIs(1)))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_BoolUInt32) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_bool_uint32 { key: true value: 2 } + map_bool_uint32 { key: false value: 1 } + )pb"), + DynamicGetField("map_bool_uint32"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(BoolValueIs(true), UintValueIs(2)), + Pair(BoolValueIs(false), UintValueIs(1)))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_StringString) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_string { key: "foo" value: "bar" } + map_string_string { key: "bar" value: "foo" } + )pb"), + DynamicGetField("map_string_string"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, UnorderedElementsAre( + Pair(StringValueIs("foo"), StringValueIs("bar")), + Pair(StringValueIs("bar"), StringValueIs("foo")))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_StringDuration) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_duration { + key: "foo" + value: { seconds: 1 nanos: 1 } + } + map_string_duration { + key: "bar" + value: {} + } + )pb"), + DynamicGetField("map_string_duration"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT( + entries, + UnorderedElementsAre( + Pair(StringValueIs("foo"), + DurationValueIs(absl::Seconds(1) + absl::Nanoseconds(1))), + Pair(StringValueIs("bar"), DurationValueIs(absl::ZeroDuration())))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_StringBytes) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bytes { key: "foo" value: "bar" } + map_string_bytes { key: "bar" value: "foo" } + )pb"), + DynamicGetField("map_string_bytes"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, UnorderedElementsAre( + Pair(StringValueIs("foo"), BytesValueIs("bar")), + Pair(StringValueIs("bar"), BytesValueIs("foo")))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_StringEnum) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_enum { key: "foo" value: BAR } + map_string_enum { key: "bar" value: FOO } + )pb"), + DynamicGetField("map_string_enum"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(StringValueIs("foo"), IntValueIs(1)), + Pair(StringValueIs("bar"), IntValueIs(0)))); +} + +TEST_F(ParsedMapFieldValueTest, ForEach_StringNull) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_null_value { key: "foo" value: NULL_VALUE } + map_string_null_value { key: "bar" value: NULL_VALUE } + )pb"), + DynamicGetField("map_string_null_value"), arena()); + std::vector> entries; + EXPECT_THAT( + value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + entries.push_back(std::pair{std::move(key), std::move(value)}); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(entries, + UnorderedElementsAre(Pair(StringValueIs("foo"), IsNullValue()), + Pair(StringValueIs("bar"), IsNullValue()))); +} + +TEST_F(ParsedMapFieldValueTest, NewIterator) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); + ASSERT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(ParsedMapFieldValueTest, NewIterator1) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds( + Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds( + Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedMapFieldValueTest, NewIterator2) { + ParsedMapFieldValue value( + DynamicParseTextProto(R"pb( + map_string_bool { key: "foo" value: false } + map_string_bool { key: "bar" value: true } + )pb"), + DynamicGetField("map_string_bool"), arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional( + AnyOf(Pair(StringValueIs("foo"), BoolValueIs(false)), + Pair(StringValueIs("bar"), BoolValueIs(true)))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional( + AnyOf(Pair(StringValueIs("foo"), BoolValueIs(false)), + Pair(StringValueIs("bar"), BoolValueIs(true)))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +} // namespace +} // namespace cel diff --git a/common/values/parsed_message_value.cc b/common/values/parsed_message_value.cc new file mode 100644 index 000000000..8a2b8030d --- /dev/null +++ b/common/values/parsed_message_value.cc @@ -0,0 +1,411 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/parsed_message_value.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/empty.pb.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/memory.h" +#include "common/value.h" +#include "extensions/protobuf/internal/qualify.h" +#include "internal/empty_descriptors.h" +#include "internal/json.h" +#include "internal/message_equality.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +template +std::enable_if_t, + const google::protobuf::Message* absl_nonnull> +EmptyParsedMessageValue() { + return &T::default_instance(); +} + +template +std::enable_if_t< + std::conjunction_v, + std::negation>>, + const google::protobuf::Message* absl_nonnull> +EmptyParsedMessageValue() { + return internal::GetEmptyDefaultInstance(); +} + +} // namespace + +ParsedMessageValue::ParsedMessageValue() + : value_(EmptyParsedMessageValue()), + arena_(nullptr) {} + +bool ParsedMessageValue::IsZeroValue() const { + const auto* reflection = GetReflection(); + if (!reflection->GetUnknownFields(*value_).empty()) { + return false; + } + std::vector fields; + reflection->ListFields(*value_, &fields); + return fields.empty(); +} + +std::string ParsedMessageValue::DebugString() const { + return absl::StrCat(*value_); +} + +absl::Status ParsedMessageValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + if (!value_->SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", value_->GetTypeName())); + } + return absl::OkStatus(); +} + +absl::Status ParsedMessageValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + google::protobuf::Message* json_object = value_reflection.MutableStructValue(json); + + return internal::MessageToJson(*value_, descriptor_pool, message_factory, + json_object); +} + +absl::Status ParsedMessageValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + return internal::MessageToJson(*value_, descriptor_pool, message_factory, + json); +} + +absl::Status ParsedMessageValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_message = other.AsParsedMessage(); other_message) { + CEL_ASSIGN_OR_RETURN( + auto equal, internal::MessageEquals(*value_, **other_message, + descriptor_pool, message_factory)); + *result = BoolValue(equal); + return absl::OkStatus(); + } + if (auto other_struct = other.AsStruct(); other_struct) { + return common_internal::StructValueEqual(StructValue(*this), *other_struct, + descriptor_pool, message_factory, + arena, result); + } + *result = BoolValue(false); + return absl::OkStatus(); +} + +ParsedMessageValue ParsedMessageValue::Clone( + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + if (arena_ == arena) { + return *this; + } + auto* cloned = value_->New(arena); + cloned->CopyFrom(*value_); + return ParsedMessageValue(cloned, arena); +} + +absl::Status ParsedMessageValue::GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + const auto* descriptor = GetDescriptor(); + const auto* field = descriptor->FindFieldByName(name); + if (field == nullptr) { + field = descriptor->file()->pool()->FindExtensionByPrintableName(descriptor, + name); + if (field == nullptr) { + *result = NoSuchFieldError(name); + return absl::OkStatus(); + } + } + return GetField(field, unboxing_options, descriptor_pool, message_factory, + arena, result); +} + +absl::Status ParsedMessageValue::GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + const auto* descriptor = GetDescriptor(); + if (number < std::numeric_limits::min() || + number > std::numeric_limits::max()) { + *result = NoSuchFieldError(absl::StrCat(number)); + return absl::OkStatus(); + } + const auto* field = descriptor->FindFieldByNumber(static_cast(number)); + if (field == nullptr) { + *result = NoSuchFieldError(absl::StrCat(number)); + return absl::OkStatus(); + } + return GetField(field, unboxing_options, descriptor_pool, message_factory, + arena, result); +} + +absl::StatusOr ParsedMessageValue::HasFieldByName( + absl::string_view name) const { + const auto* descriptor = GetDescriptor(); + const auto* field = descriptor->FindFieldByName(name); + if (field == nullptr) { + field = descriptor->file()->pool()->FindExtensionByPrintableName(descriptor, + name); + if (field == nullptr) { + return NoSuchFieldError(name).NativeValue(); + } + } + return HasField(field); +} + +absl::StatusOr ParsedMessageValue::HasFieldByNumber( + int64_t number) const { + const auto* descriptor = GetDescriptor(); + if (number < std::numeric_limits::min() || + number > std::numeric_limits::max()) { + return NoSuchFieldError(absl::StrCat(number)).NativeValue(); + } + const auto* field = descriptor->FindFieldByNumber(static_cast(number)); + if (field == nullptr) { + return NoSuchFieldError(absl::StrCat(number)).NativeValue(); + } + return HasField(field); +} + +absl::Status ParsedMessageValue::ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + std::vector fields; + const auto* reflection = GetReflection(); + reflection->ListFields(*value_, &fields); + for (const auto* field : fields) { + auto value = Value::WrapField(value_, field, descriptor_pool, + message_factory, arena); + CEL_ASSIGN_OR_RETURN(auto ok, callback(field->name(), value)); + if (!ok) { + break; + } + } + return absl::OkStatus(); +} + +namespace { + +class ParsedMessageValueQualifyState final + : public extensions::protobuf_internal::ProtoQualifyState { + public: + ParsedMessageValueQualifyState( + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) + : ProtoQualifyState(message, message->GetDescriptor(), + message->GetReflection()), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), + arena_(arena) {} + + absl::optional& result() { return result_; } + + private: + void SetResultFromError(absl::Status status, cel::MemoryManagerRef) override { + result_ = ErrorValue(std::move(status)); + } + + void SetResultFromBool(bool value) override { result_ = BoolValue(value); } + + absl::Status SetResultFromField(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef) override { + result_ = Value::WrapField(unboxing_option, message, field, + descriptor_pool_, message_factory_, arena_); + return absl::OkStatus(); + } + + absl::Status SetResultFromRepeatedField(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field, + int index, + cel::MemoryManagerRef) override { + result_ = Value::WrapRepeatedField(index, message, field, descriptor_pool_, + message_factory_, arena_); + return absl::OkStatus(); + } + + absl::Status SetResultFromMapField(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field, + const google::protobuf::MapValueConstRef& value, + cel::MemoryManagerRef) override { + result_ = Value::WrapMapFieldValue(value, message, field, descriptor_pool_, + message_factory_, arena_); + return absl::OkStatus(); + } + + const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull const message_factory_; + google::protobuf::Arena* absl_nonnull const arena_; + absl::optional result_; +}; + +} // namespace + +absl::Status ParsedMessageValue::Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const { + ABSL_DCHECK(!qualifiers.empty()); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(count != nullptr); + + if (ABSL_PREDICT_FALSE(qualifiers.empty())) { + return absl::InvalidArgumentError("invalid select qualifier path."); + } + ParsedMessageValueQualifyState qualify_state(value_, descriptor_pool, + message_factory, arena); + for (int i = 0; i < qualifiers.size() - 1; i++) { + const auto& qualifier = qualifiers[i]; + CEL_RETURN_IF_ERROR(qualify_state.ApplySelectQualifier( + qualifier, MemoryManagerRef::Pooling(arena))); + if (qualify_state.result().has_value()) { + *result = std::move(qualify_state.result()).value(); + *count = result->Is() ? -1 : i + 1; + return absl::OkStatus(); + } + } + const auto& last_qualifier = qualifiers.back(); + if (presence_test) { + CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierHas( + last_qualifier, MemoryManagerRef::Pooling(arena))); + } else { + CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierGet( + last_qualifier, MemoryManagerRef::Pooling(arena))); + } + *result = std::move(qualify_state.result()).value(); + *count = -1; + return absl::OkStatus(); +} + +absl::Status ParsedMessageValue::GetField( + const google::protobuf::FieldDescriptor* absl_nonnull field, + ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (arena_ == nullptr) { + *result = Value::WrapFieldUnsafe(unboxing_options, value_, field, + descriptor_pool, message_factory, arena); + } else { + *result = Value::WrapField(unboxing_options, value_, field, descriptor_pool, + message_factory, arena); + } + return absl::OkStatus(); +} + +bool ParsedMessageValue::HasField( + const google::protobuf::FieldDescriptor* absl_nonnull field) const { + ABSL_DCHECK(field != nullptr); + + const auto* reflection = GetReflection(); + if (field->is_map() || field->is_repeated()) { + return reflection->FieldSize(*value_, field) > 0; + } + return reflection->HasField(*value_, field); +} + +} // namespace cel diff --git a/common/values/parsed_message_value.h b/common/values/parsed_message_value.h new file mode 100644 index 000000000..f3d1f7b40 --- /dev/null +++ b/common/values/parsed_message_value.h @@ -0,0 +1,251 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MESSAGE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MESSAGE_VALUE_H_ + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_struct_value.h" +#include "common/values/values.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class MessageValue; +class StructValue; +class Value; + +class ParsedMessageValue final + : private common_internal::StructValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kStruct; + + using element_type = const google::protobuf::Message; + + ParsedMessageValue( + const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(value), arena_(arena) { + ABSL_DCHECK(value != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(!value_ || !IsWellKnownMessageType(value_->GetDescriptor())) + << value_->GetTypeName() << " is a well known type"; + ABSL_DCHECK(!value_ || value_->GetReflection() != nullptr) + << value_->GetTypeName() << " is missing reflection"; + ABSL_DCHECK_OK(CheckArena(value_, arena_)); + } + + // Places the `ParsedMessageValue` into a special state where it is logically + // equivalent to the default instance of `google.protobuf.Empty`, however + // dereferencing via `operator*` or `operator->` is not allowed. + ParsedMessageValue(); + ParsedMessageValue(const ParsedMessageValue&) = default; + ParsedMessageValue(ParsedMessageValue&&) = default; + ParsedMessageValue& operator=(const ParsedMessageValue&) = default; + ParsedMessageValue& operator=(ParsedMessageValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + absl::string_view GetTypeName() const { return GetDescriptor()->full_name(); } + + MessageType GetRuntimeType() const { return MessageType(GetDescriptor()); } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + return (*this)->GetDescriptor(); + } + + const google::protobuf::Reflection* absl_nonnull GetReflection() const { + return (*this)->GetReflection(); + } + + const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *value_; + } + + const google::protobuf::Message* absl_nonnull operator->() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_; + } + + bool IsZeroValue() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonObject(). + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using StructValueMixin::Equal; + + ParsedMessageValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByName; + + absl::Status GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByNumber; + + absl::StatusOr HasFieldByName(absl::string_view name) const; + + absl::StatusOr HasFieldByNumber(int64_t number) const; + + using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; + + absl::Status ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const; + using StructValueMixin::Qualify; + + friend void swap(ParsedMessageValue& lhs, ParsedMessageValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + swap(lhs.arena_, rhs.arena_); + } + + private: + friend std::pointer_traits; + friend class StructValue; + friend class common_internal::ValueMixin; + friend class common_internal::StructValueMixin; + friend ParsedMessageValue UnsafeParsedMessageValue( + const google::protobuf::Message* absl_nonnull value); + + explicit ParsedMessageValue( + const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_(value), arena_(value->GetArena()) { + ABSL_DCHECK(value != nullptr); + ABSL_DCHECK(!value_ || !IsWellKnownMessageType(value_->GetDescriptor())) + << value_->GetTypeName() << " is a well known type"; + ABSL_DCHECK(!value_ || value_->GetReflection() != nullptr) + << value_->GetTypeName() << " is missing reflection"; + } + + static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, + google::protobuf::Arena* absl_nonnull arena) { + if (message != nullptr && message->GetArena() != nullptr && + message->GetArena() != arena) { + return absl::InvalidArgumentError( + "message arena must be the same as arena"); + } + return absl::OkStatus(); + } + + absl::Status GetField( + const google::protobuf::FieldDescriptor* absl_nonnull field, + ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + + bool HasField(const google::protobuf::FieldDescriptor* absl_nonnull field) const; + + const google::protobuf::Message* absl_nonnull value_; + // Arena that is attributed as owning the value. May be null to indicate that + // the value is managed externally. + google::protobuf::Arena* absl_nullable arena_; +}; + +inline std::ostream& operator<<(std::ostream& out, + const ParsedMessageValue& value) { + return out << value.DebugString(); +} + +// Creates a `ParsedMessageValue` without specifying a managing arena. +// The message must outlive the `ParsedMessageValue` or any value that might +// be derived from it. Prefer to use `cel::Value::WrapMessageUnsafe()`. +inline ParsedMessageValue UnsafeParsedMessageValue( + const google::protobuf::Message* absl_nonnull value) { + return ParsedMessageValue(value); +} + +} // namespace cel + +namespace std { + +template <> +struct pointer_traits { + using pointer = cel::ParsedMessageValue; + using element_type = typename cel::ParsedMessageValue::element_type; + using difference_type = ptrdiff_t; + + static element_type* to_address(const pointer& p) noexcept { + return cel::to_address(p.value_); + } +}; + +} // namespace std + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MESSAGE_VALUE_H_ diff --git a/common/values/parsed_message_value_test.cc b/common/values/parsed_message_value_test.cc new file mode 100644 index 000000000..7a84f82ba --- /dev/null +++ b/common/values/parsed_message_value_test.cc @@ -0,0 +1,112 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::cel::test::BoolValueIs; +using ::testing::_; +using ::testing::IsEmpty; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +using ParsedMessageValueTest = common_internal::ValueTest<>; + +TEST_F(ParsedMessageValueTest, Kind) { + ParsedMessageValue value = MakeParsedMessage(); + EXPECT_EQ(value.kind(), ParsedMessageValue::kKind); + EXPECT_EQ(value.kind(), ValueKind::kStruct); +} + +TEST_F(ParsedMessageValueTest, GetTypeName) { + ParsedMessageValue value = MakeParsedMessage(); + EXPECT_EQ(value.GetTypeName(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST_F(ParsedMessageValueTest, GetRuntimeType) { + ParsedMessageValue value = MakeParsedMessage(); + EXPECT_EQ(value.GetRuntimeType(), MessageType(value.GetDescriptor())); +} + +TEST_F(ParsedMessageValueTest, DebugString) { + ParsedMessageValue value = MakeParsedMessage(); + EXPECT_THAT(value.DebugString(), _); +} + +TEST_F(ParsedMessageValueTest, IsZeroValue) { + MessageValue value = MakeParsedMessage(); + EXPECT_TRUE(value.IsZeroValue()); +} + +TEST_F(ParsedMessageValueTest, SerializeTo) { + MessageValue value = MakeParsedMessage(); + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), IsEmpty()); +} + +TEST_F(ParsedMessageValueTest, ConvertToJson) { + MessageValue value = MakeParsedMessage(); + auto json = DynamicParseTextProto(R"pb()pb"); + EXPECT_THAT(value.ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(json)), + IsOk()); + EXPECT_THAT(*json, EqualsTextProto( + R"pb(struct_value: {})pb")); +} + +TEST_F(ParsedMessageValueTest, Equal) { + MessageValue value = MakeParsedMessage(); + EXPECT_THAT( + value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Equal(MakeParsedMessage(), + descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ParsedMessageValueTest, GetFieldByName) { + MessageValue value = MakeParsedMessage(); + EXPECT_THAT(value.GetFieldByName("single_bool", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(ParsedMessageValueTest, GetFieldByNumber) { + MessageValue value = MakeParsedMessage(); + EXPECT_THAT( + value.GetFieldByNumber(13, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +} // namespace +} // namespace cel diff --git a/common/values/parsed_repeated_field_value.cc b/common/values/parsed_repeated_field_value.cc new file mode 100644 index 000000000..b990d3965 --- /dev/null +++ b/common/values/parsed_repeated_field_value.cc @@ -0,0 +1,365 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/parsed_repeated_field_value.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "internal/json.h" +#include "internal/message_equality.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +using ::cel::well_known_types::ValueReflection; + +std::string ParsedRepeatedFieldValue::DebugString() const { + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return "INVALID"; + } + return "VALID"; +} + +absl::Status ParsedRepeatedFieldValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return absl::OkStatus(); + } + // We have to convert to google.protobuf.Struct first. + google::protobuf::Value message; + CEL_RETURN_IF_ERROR(internal::MessageFieldToJson( + *message_, field_, descriptor_pool, message_factory, &message)); + if (!message.list_value().SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError("failed to serialize google.protobuf.Struct"); + } + return absl::OkStatus(); +} + +absl::Status ParsedRepeatedFieldValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.MutableListValue(json)->Clear(); + return absl::OkStatus(); + } + return internal::MessageFieldToJson(*message_, field_, descriptor_pool, + message_factory, json); +} + +absl::Status ParsedRepeatedFieldValue::ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + ABSL_DCHECK(*this); + + json->Clear(); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return absl::OkStatus(); + } + return internal::MessageFieldToJson(*message_, field_, descriptor_pool, + message_factory, json); +} + +absl::Status ParsedRepeatedFieldValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + if (auto other_value = other.AsParsedRepeatedField(); other_value) { + ABSL_DCHECK(field_ != nullptr); + ABSL_DCHECK(other_value->field_ != nullptr); + CEL_ASSIGN_OR_RETURN( + auto equal, internal::MessageFieldEquals( + *message_, field_, *other_value->message_, + other_value->field_, descriptor_pool, message_factory)); + *result = BoolValue(equal); + return absl::OkStatus(); + } + if (auto other_value = other.AsParsedJsonList(); other_value) { + if (other_value->value_ == nullptr) { + *result = BoolValue(IsEmpty()); + return absl::OkStatus(); + } + ABSL_DCHECK(field_ != nullptr); + CEL_ASSIGN_OR_RETURN( + auto equal, + internal::MessageFieldEquals(*message_, field_, *other_value->value_, + descriptor_pool, message_factory)); + *result = BoolValue(equal); + return absl::OkStatus(); + } + if (auto other_value = other.AsList(); other_value) { + return common_internal::ListValueEqual(ListValue(*this), *other_value, + descriptor_pool, message_factory, + arena, result); + } + *result = BoolValue(false); + return absl::OkStatus(); +} + +bool ParsedRepeatedFieldValue::IsZeroValue() const { return IsEmpty(); } + +ParsedRepeatedFieldValue ParsedRepeatedFieldValue::Clone( + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(*this); + + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return ParsedRepeatedFieldValue(); + } + if (arena_ == arena) { + return *this; + } + auto field = message_->GetReflection()->GetRepeatedFieldRef( + *message_, field_); + auto* cloned_message = message_->New(arena); + auto cloned_field = + cloned_message->GetReflection() + ->GetMutableRepeatedFieldRef(cloned_message, field_); + cloned_field.CopyFrom(field); + return ParsedRepeatedFieldValue(cloned_message, field_, arena); +} + +bool ParsedRepeatedFieldValue::IsEmpty() const { return Size() == 0; } + +size_t ParsedRepeatedFieldValue::Size() const { + ABSL_DCHECK(*this); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return 0; + } + return static_cast(GetReflection()->FieldSize(*message_, field_)); +} + +// See ListValueInterface::Get for documentation. +absl::Status ParsedRepeatedFieldValue::Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(*this); + ABSL_DCHECK(message_ != nullptr); + + if (ABSL_PREDICT_FALSE(field_ == nullptr || + index >= std::numeric_limits::max() || + static_cast(index) >= + GetReflection()->FieldSize(*message_, field_))) { + *result = IndexOutOfBoundsError(index); + return absl::OkStatus(); + } + if (arena_ == nullptr) { + *result = Value::WrapRepeatedFieldUnsafe(static_cast(index), message_, + field_, descriptor_pool, + message_factory, arena); + } else { + *result = + Value::WrapRepeatedField(static_cast(index), message_, field_, + descriptor_pool, message_factory, arena); + } + return absl::OkStatus(); +} + +absl::Status ParsedRepeatedFieldValue::ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(*this); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return absl::OkStatus(); + } + const auto* reflection = message_->GetReflection(); + const int size = reflection->FieldSize(*message_, field_); + if (size > 0) { + CEL_ASSIGN_OR_RETURN(auto accessor, + common_internal::RepeatedFieldAccessorFor(field_)); + Value scratch; + for (int i = 0; i < size; ++i) { + (*accessor)(i, message_, field_, reflection, descriptor_pool, + message_factory, arena, &scratch); + CEL_ASSIGN_OR_RETURN(auto ok, callback(static_cast(i), scratch)); + if (!ok) { + break; + } + } + } + return absl::OkStatus(); +} + +namespace { + +class ParsedRepeatedFieldValueIterator final : public ValueIterator { + public: + ParsedRepeatedFieldValueIterator( + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + absl_nonnull common_internal::RepeatedFieldAccessor accessor) + : message_(message), + field_(field), + reflection_(message_->GetReflection()), + accessor_(accessor), + size_(reflection_->FieldSize(*message_, field_)) {} + + bool HasNext() override { return index_ < size_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "ValueIterator::Next called after ValueIterator::HasNext returned " + "false"); + } + (*accessor_)(index_, message_, field_, reflection_, descriptor_pool, + message_factory, arena, result); + ++index_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (index_ >= size_) { + return false; + } + (*accessor_)(index_, message_, field_, reflection_, descriptor_pool, + message_factory, arena, key_or_value); + ++index_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (index_ >= size_) { + return false; + } + if (value != nullptr) { + (*accessor_)(index_, message_, field_, reflection_, descriptor_pool, + message_factory, arena, value); + } + *key = IntValue(index_); + ++index_; + return true; + } + + private: + const google::protobuf::Message* absl_nonnull const message_; + const google::protobuf::FieldDescriptor* absl_nonnull const field_; + const google::protobuf::Reflection* absl_nonnull const reflection_; + const absl_nonnull common_internal::RepeatedFieldAccessor accessor_; + const int size_; + int index_ = 0; +}; + +} // namespace + +absl::StatusOr> +ParsedRepeatedFieldValue::NewIterator() const { + ABSL_DCHECK(*this); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + return NewEmptyValueIterator(); + } + CEL_ASSIGN_OR_RETURN(auto accessor, + common_internal::RepeatedFieldAccessorFor(field_)); + return std::make_unique(message_, field_, + accessor); +} + +absl::Status ParsedRepeatedFieldValue::Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(*this); + if (ABSL_PREDICT_FALSE(field_ == nullptr)) { + *result = FalseValue(); + return absl::OkStatus(); + } + const auto* reflection = message_->GetReflection(); + const int size = reflection->FieldSize(*message_, field_); + if (size > 0) { + CEL_ASSIGN_OR_RETURN(auto accessor, + common_internal::RepeatedFieldAccessorFor(field_)); + Value scratch; + for (int i = 0; i < size; ++i) { + (*accessor)(i, message_, field_, reflection, descriptor_pool, + message_factory, arena, &scratch); + CEL_RETURN_IF_ERROR(scratch.Equal(other, descriptor_pool, message_factory, + arena, result)); + if (result->IsTrue()) { + return absl::OkStatus(); + } + } + } + *result = FalseValue(); + return absl::OkStatus(); +} + +const google::protobuf::Reflection* absl_nonnull ParsedRepeatedFieldValue::GetReflection() + const { + return message_->GetReflection(); +} + +} // namespace cel diff --git a/common/values/parsed_repeated_field_value.h b/common/values/parsed_repeated_field_value.h new file mode 100644 index 000000000..e345c8ffa --- /dev/null +++ b/common/values/parsed_repeated_field_value.h @@ -0,0 +1,220 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_REPEATED_FIELD_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_REPEATED_FIELD_VALUE_H_ + +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_list_value.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class ValueIterator; +class ParsedJsonListValue; + +// ParsedRepeatedFieldValue is a ListValue over a repeated field of a parsed +// protocol buffer message. +class ParsedRepeatedFieldValue final + : private common_internal::ListValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kList; + static constexpr absl::string_view kName = "list"; + + ParsedRepeatedFieldValue(const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::Arena* absl_nonnull arena) + : message_(message), field_(field), arena_(arena) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(field_->is_repeated() && !field_->is_map()) + << field_->full_name() << " must be a repeated field"; + ABSL_DCHECK_OK(CheckArena(message_, arena_)); + } + + // Places the `ParsedRepeatedFieldValue` into an invalid state. Anything + // except assigning to `ParsedRepeatedFieldValue` is undefined behavior. + ParsedRepeatedFieldValue() = default; + + ParsedRepeatedFieldValue(const ParsedRepeatedFieldValue&) = default; + ParsedRepeatedFieldValue(ParsedRepeatedFieldValue&&) = default; + ParsedRepeatedFieldValue& operator=(const ParsedRepeatedFieldValue&) = + default; + ParsedRepeatedFieldValue& operator=(ParsedRepeatedFieldValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + static constexpr absl::string_view GetTypeName() { return kName; } + + static ListType GetRuntimeType() { return ListType(); } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // See Value::ConvertToJsonArray(). + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Equal; + + bool IsZeroValue() const; + + bool IsEmpty() const; + + ParsedRepeatedFieldValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + size_t Size() const; + + // See ListValueInterface::Get for documentation. + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ListValueMixin::Get; + + using ForEachCallback = typename CustomListValueInterface::ForEachCallback; + + using ForEachWithIndexCallback = + typename CustomListValueInterface::ForEachWithIndexCallback; + + absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + using ListValueMixin::ForEach; + + absl::StatusOr NewIterator() const; + + absl::Status Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using ListValueMixin::Contains; + + const google::protobuf::Message& message() const { + ABSL_DCHECK(*this); + return *message_; + } + + const google::protobuf::FieldDescriptor* absl_nonnull field() const { + ABSL_DCHECK(*this); + return field_; + } + + // Returns `true` if `ParsedRepeatedFieldValue` is in a valid state. + explicit operator bool() const { return field_ != nullptr; } + + friend void swap(ParsedRepeatedFieldValue& lhs, + ParsedRepeatedFieldValue& rhs) noexcept { + using std::swap; + swap(lhs.message_, rhs.message_); + swap(lhs.field_, rhs.field_); + swap(lhs.arena_, rhs.arena_); + } + + private: + friend class ParsedJsonListValue; + friend class common_internal::ValueMixin; + friend class common_internal::ListValueMixin; + friend ParsedRepeatedFieldValue UnsafeParsedRepeatedFieldValue( + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field); + + ParsedRepeatedFieldValue(const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field) + : message_(message), field_(field), arena_(message->GetArena()) { + ABSL_DCHECK(message != nullptr); + ABSL_DCHECK(field != nullptr); + ABSL_DCHECK(field_->is_repeated() && !field_->is_map()) + << field_->full_name() << " must be a repeated field"; + } + + static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, + google::protobuf::Arena* absl_nonnull arena) { + if (message != nullptr && message->GetArena() != nullptr && + message->GetArena() != arena) { + return absl::InvalidArgumentError( + "message arena must be the same as arena"); + } + return absl::OkStatus(); + } + + const google::protobuf::Reflection* absl_nonnull GetReflection() const; + + const google::protobuf::Message* absl_nullable message_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable field_ = nullptr; + google::protobuf::Arena* absl_nullable arena_ = nullptr; +}; + +inline std::ostream& operator<<(std::ostream& out, + const ParsedRepeatedFieldValue& value) { + return out << value.DebugString(); +} + +// Creates a `ParsedRepeatedFieldValue` without specifying a managing arena. +// The message must outlive the `ParsedRepeatedFieldValue` or any value that +// might be derived from it. Prefer to use +// `cel::Value::WrapRepeatedFieldUnsafe()`. +inline ParsedRepeatedFieldValue UnsafeParsedRepeatedFieldValue( + const google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field) { + return ParsedRepeatedFieldValue(message, field); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_REPEATED_FIELD_VALUE_H_ diff --git a/common/values/parsed_repeated_field_value_test.cc b/common/values/parsed_repeated_field_value_test.cc new file mode 100644 index 000000000..3155e7159 --- /dev/null +++ b/common/values/parsed_repeated_field_value_test.cc @@ -0,0 +1,450 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::BytesValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::DurationValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::UintValueIs; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Optional; +using ::testing::Pair; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +using ParsedRepeatedFieldValueTest = common_internal::ValueTest<>; + +TEST_F(ParsedRepeatedFieldValueTest, Field) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_TRUE(value); +} + +TEST_F(ParsedRepeatedFieldValueTest, Kind) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_EQ(value.kind(), ParsedRepeatedFieldValue::kKind); + EXPECT_EQ(value.kind(), ValueKind::kList); +} + +TEST_F(ParsedRepeatedFieldValueTest, GetTypeName) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_EQ(value.GetTypeName(), ParsedRepeatedFieldValue::kName); + EXPECT_EQ(value.GetTypeName(), "list"); +} + +TEST_F(ParsedRepeatedFieldValueTest, GetRuntimeType) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_EQ(value.GetRuntimeType(), ListType()); +} + +TEST_F(ParsedRepeatedFieldValueTest, DebugString) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_THAT(value.DebugString(), _); +} + +TEST_F(ParsedRepeatedFieldValueTest, IsZeroValue) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_TRUE(value.IsZeroValue()); +} + +TEST_F(ParsedRepeatedFieldValueTest, SerializeTo) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), + IsOk()); + EXPECT_THAT(std::move(output).Consume(), IsEmpty()); +} + +TEST_F(ParsedRepeatedFieldValueTest, ConvertToJson) { + auto json = DynamicParseTextProto(R"pb()pb"); + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_THAT(value.ConvertToJson(descriptor_pool(), message_factory(), + cel::to_address(json)), + IsOk()); + EXPECT_THAT( + *json, EqualsTextProto(R"pb(list_value: {})pb")); +} + +TEST_F(ParsedRepeatedFieldValueTest, Equal_RepeatedField) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_THAT( + value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT( + value.Equal( + ParsedRepeatedFieldValue( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()), + descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + value.Equal(ListValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ParsedRepeatedFieldValueTest, Equal_JsonList) { + ParsedRepeatedFieldValue repeated_value( + DynamicParseTextProto(R"pb(repeated_int64: 1 + repeated_int64: 0)pb"), + DynamicGetField("repeated_int64"), arena()); + ParsedJsonListValue json_value( + DynamicParseTextProto( + R"pb( + values { number_value: 1 } + values { number_value: 0 } + )pb"), + arena()); + EXPECT_THAT(repeated_value.Equal(json_value, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(json_value.Equal(repeated_value, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ParsedRepeatedFieldValueTest, Empty) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_TRUE(value.IsEmpty()); +} + +TEST_F(ParsedRepeatedFieldValueTest, Size) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb()pb"), + DynamicGetField("repeated_int64"), arena()); + EXPECT_EQ(value.Size(), 0); +} + +TEST_F(ParsedRepeatedFieldValueTest, Get) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_bool: false + repeated_bool: true)pb"), + DynamicGetField("repeated_bool"), arena()); + EXPECT_THAT(value.Get(0, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Get(1, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + value.Get(2, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Bool) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_bool: false + repeated_bool: true)pb"), + DynamicGetField("repeated_bool"), arena()); + { + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(BoolValueIs(false), BoolValueIs(true))); + } + { + std::vector values; + EXPECT_THAT(value.ForEach( + [&](size_t, const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(BoolValueIs(false), BoolValueIs(true))); + } +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Double) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_double: 1 + repeated_double: 0)pb"), + DynamicGetField("repeated_double"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(DoubleValueIs(1), DoubleValueIs(0))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Float) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_float: 1 + repeated_float: 0)pb"), + DynamicGetField("repeated_float"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(DoubleValueIs(1), DoubleValueIs(0))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_UInt64) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_uint64: 1 + repeated_uint64: 0)pb"), + DynamicGetField("repeated_uint64"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(UintValueIs(1), UintValueIs(0))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Int32) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_int32: 1 + repeated_int32: 0)pb"), + DynamicGetField("repeated_int32"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(IntValueIs(1), IntValueIs(0))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_UInt32) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_uint32: 1 + repeated_uint32: 0)pb"), + DynamicGetField("repeated_uint32"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(UintValueIs(1), UintValueIs(0))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Duration) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto( + R"pb(repeated_duration: { seconds: 1 nanos: 1 } + repeated_duration: {})pb"), + DynamicGetField("repeated_duration"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(DurationValueIs(absl::Seconds(1) + + absl::Nanoseconds(1)), + DurationValueIs(absl::ZeroDuration()))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Bytes) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto( + R"pb(repeated_bytes: "bar" repeated_bytes: "foo")pb"), + DynamicGetField("repeated_bytes"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(BytesValueIs("bar"), BytesValueIs("foo"))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Enum) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto( + R"pb(repeated_nested_enum: BAR repeated_nested_enum: FOO)pb"), + DynamicGetField("repeated_nested_enum"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(IntValueIs(1), IntValueIs(0))); +} + +TEST_F(ParsedRepeatedFieldValueTest, ForEach_Null) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_null_value: + NULL_VALUE + repeated_null_value: + NULL_VALUE)pb"), + DynamicGetField("repeated_null_value"), arena()); + std::vector values; + EXPECT_THAT(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + values.push_back(element); + return true; + }, + descriptor_pool(), message_factory(), arena()), + IsOk()); + EXPECT_THAT(values, ElementsAre(IsNullValue(), IsNullValue())); +} + +TEST_F(ParsedRepeatedFieldValueTest, NewIterator) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_bool: false + repeated_bool: true)pb"), + DynamicGetField("repeated_bool"), arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + ASSERT_TRUE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + ASSERT_FALSE(iterator->HasNext()); + EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(ParsedRepeatedFieldValueTest, NewIterator1) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_bool: false + repeated_bool: true)pb"), + DynamicGetField("repeated_bool"), arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(false)))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(BoolValueIs(true)))); + EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedRepeatedFieldValueTest, NewIterator2) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_bool: false + repeated_bool: true)pb"), + DynamicGetField("repeated_bool"), arena()); + ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(0), BoolValueIs(false))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Optional(Pair(IntValueIs(1), BoolValueIs(true))))); + EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ParsedRepeatedFieldValueTest, Contains) { + ParsedRepeatedFieldValue value( + DynamicParseTextProto(R"pb(repeated_bool: true)pb"), + DynamicGetField("repeated_bool"), arena()); + EXPECT_THAT(value.Contains(BytesValue(), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Contains(NullValue(), descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Contains(BoolValue(false), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Contains(BoolValue(true), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT(value.Contains(DoubleValue(0.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Contains(DoubleValue(1.0), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Contains(StringValue("bar"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT(value.Contains(StringValue("foo"), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); + EXPECT_THAT( + value.Contains(MapValue(), descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +} // namespace +} // namespace cel diff --git a/common/values/string_value.cc b/common/values/string_value.cc new file mode 100644 index 000000000..98912d32c --- /dev/null +++ b/common/values/string_value.cc @@ -0,0 +1,1519 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/cord.h" +#include "absl/strings/cord_buffer.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/internal/byte_string.h" +#include "common/internal/reference_count.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "internal/utf8.h" +#include "internal/well_known_types.h" +#include "runtime/internal/errors.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +template +std::string StringDebugString(const Bytes& value) { + return value.NativeValue(absl::Overload( + [](absl::string_view string) -> std::string { + return internal::FormatStringLiteral(string); + }, + [](const absl::Cord& cord) -> std::string { + if (auto flat = cord.TryFlat(); flat.has_value()) { + return internal::FormatStringLiteral(*flat); + } + return internal::FormatStringLiteral(static_cast(cord)); + })); +} + +} // namespace + +StringValue StringValue::Concat(const StringValue& lhs, const StringValue& rhs, + google::protobuf::Arena* absl_nonnull arena) { + return StringValue( + common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena)); +} + +std::string StringValue::DebugString() const { + return StringDebugString(*this); +} + +absl::Status StringValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::StringValue message; + message.set_value(NativeString()); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status StringValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + NativeValue( + [&](const auto& value) { value_reflection.SetStringValue(json, value); }); + + return absl::OkStatus(); +} + +absl::Status StringValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsString(); other_value.has_value()) { + *result = NativeValue([other_value](const auto& value) -> BoolValue { + return other_value->NativeValue( + [&value](const auto& other_value) -> BoolValue { + return BoolValue{value == other_value}; + }); + }); + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +size_t StringValue::Size() const { + return NativeValue([](const auto& alternative) -> size_t { + return internal::Utf8CodePointCount(alternative); + }); +} + +bool StringValue::IsEmpty() const { + return NativeValue( + [](const auto& alternative) -> bool { return alternative.empty(); }); +} + +bool StringValue::Equals(absl::string_view string) const { + return value_.Equals(string); +} + +bool StringValue::Equals(const absl::Cord& string) const { + return value_.Equals(string); +} + +bool StringValue::Equals(const StringValue& string) const { + return value_.Equals(string.value_); +} + +StringValue StringValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { + return StringValue(value_.Clone(arena)); +} + +int StringValue::Compare(absl::string_view string) const { + return value_.Compare(string); +} + +int StringValue::Compare(const absl::Cord& string) const { + return value_.Compare(string); +} + +int StringValue::Compare(const StringValue& string) const { + return value_.Compare(string.value_); +} + +bool StringValue::StartsWith(absl::string_view string) const { + return value_.StartsWith(string); +} + +bool StringValue::StartsWith(const absl::Cord& string) const { + return value_.StartsWith(string); +} + +bool StringValue::StartsWith(const StringValue& string) const { + return value_.StartsWith(string.value_); +} + +bool StringValue::EndsWith(absl::string_view string) const { + return value_.EndsWith(string); +} + +bool StringValue::EndsWith(const absl::Cord& string) const { + return value_.EndsWith(string); +} + +bool StringValue::EndsWith(const StringValue& string) const { + return value_.EndsWith(string.value_); +} + +bool StringValue::Contains(absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + return absl::StrContains(lhs, string); + }, + [&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); })); +} + +bool StringValue::Contains(const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + if (auto flat = string.TryFlat(); flat) { + return absl::StrContains(lhs, *flat); + } + // There is no nice way to do this. We cannot use std::search due to + // absl::Cord::CharIterator being an input iterator instead of a forward + // iterator. So just make an external cord with a noop releaser. We know + // the external cord will not outlive this function. + return absl::MakeCordFromExternal(lhs, []() {}).Contains(string); + }, + [&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); })); +} + +bool StringValue::Contains(const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [&](absl::string_view rhs) -> bool { return Contains(rhs); }, + [&](const absl::Cord& rhs) -> bool { return Contains(rhs); })); +} + +absl::optional StringValue::IndexOf(absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (absl::StartsWith(lhs, string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return absl::nullopt; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return absl::nullopt; + })); +} + +absl::optional StringValue::IndexOf(const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.substr(0, string.size()) == string) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return absl::nullopt; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return absl::nullopt; + })); +} + +absl::optional StringValue::IndexOf(const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [this](absl::string_view rhs) -> absl::optional { + return IndexOf(rhs); + }, + [this](const absl::Cord& rhs) -> absl::optional { + return IndexOf(rhs); + })); +} + +absl::optional StringValue::IndexOf(absl::string_view string, + int64_t pos) const { + if (pos < 0) { + pos = 0; + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && absl::StartsWith(lhs, string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return absl::nullopt; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && lhs.StartsWith(string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return absl::nullopt; + })); +} + +absl::optional StringValue::IndexOf(const absl::Cord& string, + int64_t pos) const { + if (pos < 0) { + pos = 0; + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && lhs.substr(0, string.size()) == string) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return absl::nullopt; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && lhs.StartsWith(string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return absl::nullopt; + })); +} + +absl::optional StringValue::IndexOf(const StringValue& string, + int64_t pos) const { + return string.value_.Visit(absl::Overload( + [this, pos](absl::string_view rhs) -> absl::optional { + return IndexOf(rhs, pos); + }, + [this, pos](const absl::Cord& rhs) -> absl::optional { + return IndexOf(rhs, pos); + })); +} + +absl::optional StringValue::LastIndexOf( + absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (absl::StartsWith(lhs, string)) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + })); +} + +absl::optional StringValue::LastIndexOf( + const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.substr(0, string.size()) == string) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + })); +} + +absl::optional StringValue::LastIndexOf( + const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [this](absl::string_view rhs) -> absl::optional { + return LastIndexOf(rhs); + }, + [this](const absl::Cord& rhs) -> absl::optional { + return LastIndexOf(rhs); + })); +} + +absl::optional StringValue::LastIndexOf(absl::string_view string, + int64_t pos) const { + if (pos < 0) { + return absl::nullopt; + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (absl::StartsWith(lhs, string)) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + })); +} + +absl::optional StringValue::LastIndexOf(const absl::Cord& string, + int64_t pos) const { + if (pos < 0) { + return absl::nullopt; + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.substr(0, string.size()) == string) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + }, + [&](absl::Cord lhs) -> absl::optional { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (last_index < 0) return absl::nullopt; + return last_index; + })); +} + +absl::optional StringValue::LastIndexOf(const StringValue& string, + int64_t pos) const { + return string.value_.Visit(absl::Overload( + [this, pos](absl::string_view rhs) -> absl::optional { + return LastIndexOf(rhs, pos); + }, + [this, pos](const absl::Cord& rhs) -> absl::optional { + return LastIndexOf(rhs, pos); + })); +} + +namespace { + +absl::StatusOr SubstringImpl(absl::string_view string, uint64_t start) { + size_t size_code_points = 0; + size_t size_code_units = 0; + while (!string.empty()) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(string); + if (size_code_points == start) { + return size_code_units; + } + string.remove_prefix(code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start) { + return size_code_units; + } + return absl::InvalidArgumentError( + ".substring(): is greater than .size()"); +} + +absl::StatusOr SubstringImpl(const absl::Cord& cord, + uint64_t start) { + absl::Cord::CharIterator char_begin = cord.char_begin(); + absl::Cord::CharIterator char_end = cord.char_end(); + size_t size_code_points = 0; + size_t size_code_units = 0; + while (char_begin != char_end) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(char_begin); + if (size_code_points == start) { + return cord.Subcord(size_code_units, std::numeric_limits::max()); + } + absl::Cord::Advance(&char_begin, code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start) { + return cord; + } + return absl::InvalidArgumentError( + ".substring(): is greater than .size()"); +} + +} // namespace + +Value StringValue::Substring(int64_t start) const { + if (start < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(): is less than 0")); + } + if (static_cast(start) > value_.size()) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()")); + } + if (start == 0) { + return *this; + } + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + absl::StatusOr status_or_index = + (SubstringImpl)(value_.GetSmall(), start); + if (!status_or_index.ok()) { + return ErrorValue(std::move(status_or_index).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = value_.rep_.small.size - *status_or_index; + std::memcpy(result.value_.rep_.small.data, + value_.rep_.small.data + *status_or_index, + result.value_.rep_.small.size); + result.value_.rep_.small.arena = value_.rep_.small.arena; + return result; + } + case common_internal::ByteStringKind::kMedium: { + absl::StatusOr status_or_index = + (SubstringImpl)(value_.GetMedium(), start); + if (!status_or_index.ok()) { + return ErrorValue(std::move(status_or_index).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; + result.value_.rep_.medium.size = + value_.rep_.medium.size - *status_or_index; + result.value_.rep_.medium.data = + value_.rep_.medium.data + *status_or_index; + result.value_.rep_.medium.owner = value_.rep_.medium.owner; + common_internal::StrongRef(result.value_.GetMediumReferenceCount()); + return result; + } + case common_internal::ByteStringKind::kLarge: { + absl::StatusOr status_or_cord = + (SubstringImpl)(value_.GetLarge(), start); + if (!status_or_cord.ok()) { + return ErrorValue(std::move(status_or_cord).status()); + } + return StringValue::Wrap(*std::move(status_or_cord)); + } + } +} + +namespace { + +absl::StatusOr> SubstringImpl( + absl::string_view string, uint64_t start, uint64_t end) { + size_t size_code_points = 0; + size_t size_code_units = 0; + size_t start_code_units; + while (!string.empty()) { + if (size_code_points == start) { + start_code_units = size_code_units; + } + if (size_code_points == end) { + return std::pair{start_code_units, size_code_units}; + } + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(string); + string.remove_prefix(code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start && start == end) { + return std::pair{size_code_units, size_code_units}; + } + return absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()"); +} + +absl::StatusOr SubstringImpl(const absl::Cord& cord, uint64_t start, + uint64_t end) { + absl::Cord::CharIterator char_begin = cord.char_begin(); + absl::Cord::CharIterator char_end = cord.char_end(); + size_t size_code_points = 0; + size_t size_code_units = 0; + size_t start_code_units; + while (char_begin != char_end) { + if (size_code_points == start) { + start_code_units = size_code_units; + } + if (size_code_points == end) { + return cord.Subcord(start_code_units, + size_code_points - start_code_units); + } + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(char_begin); + absl::Cord::Advance(&char_begin, code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start && start == end) { + return absl::Cord(); + } + return absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()"); +} + +} // namespace + +Value StringValue::Substring(int64_t start, int64_t end) const { + if (start < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): is less than 0")); + } + if (end < start) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): is less than ")); + } + if (static_cast(start) > value_.size() || + static_cast(end) > value_.size()) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()")); + } + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + absl::StatusOr> status_or_indices = + (SubstringImpl)(value_.GetSmall(), start, end); + if (!status_or_indices.ok()) { + return ErrorValue(std::move(status_or_indices).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = + (status_or_indices->second - status_or_indices->first); + std::memcpy(result.value_.rep_.small.data, + value_.rep_.small.data + status_or_indices->first, + result.value_.rep_.small.size); + result.value_.rep_.small.arena = value_.rep_.small.arena; + return result; + } + case common_internal::ByteStringKind::kMedium: { + absl::StatusOr> status_or_indices = + (SubstringImpl)(value_.GetMedium(), start, end); + if (!status_or_indices.ok()) { + return ErrorValue(std::move(status_or_indices).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; + result.value_.rep_.medium.size = + (status_or_indices->second - status_or_indices->first); + result.value_.rep_.medium.data = + value_.rep_.medium.data + status_or_indices->first; + result.value_.rep_.medium.owner = value_.rep_.medium.owner; + common_internal::StrongRef(result.value_.GetMediumReferenceCount()); + return result; + } + case common_internal::ByteStringKind::kLarge: { + absl::StatusOr status_or_cord = + (SubstringImpl)(value_.GetLarge(), start, end); + if (!status_or_cord.ok()) { + return ErrorValue(std::move(status_or_cord).status()); + } + return StringValue::Wrap(*std::move(status_or_cord)); + } + } +} + +namespace { + +bool LowerAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { + if (in.empty()) { + return false; + } + bool needs_conversion = false; + for (char c : in) { + if (absl::ascii_isupper(c)) { + needs_conversion = true; + break; + } + } + + if (!needs_conversion) { + return false; + } + + *out = absl::AsciiStrToLower(in); + return true; +} + +absl::Cord LowerAsciiImpl(const absl::Cord& in) { + if (in.empty()) { + return in; + } + size_t pos = 0; + bool needs_conversion = false; + for (char c : in.Chars()) { + if (absl::ascii_isupper(c)) { + needs_conversion = true; + break; + } + pos++; + } + if (!needs_conversion) { + return in; + } + absl::Cord out = in.Subcord(0, pos); + absl::Cord rest = in.Subcord(pos, in.size() - pos); + std::string suffix; + suffix.resize(rest.size()); + size_t current = 0; + for (char c : rest.Chars()) { + suffix[current++] = absl::ascii_tolower(c); + } + out.Append(std::move(suffix)); + return out; +} + +} // namespace + +StringValue StringValue::LowerAscii(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::string out; + if (!(LowerAsciiImpl)(value_.GetSmall(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kMedium: { + std::string out; + if (!(LowerAsciiImpl)(value_.GetMedium(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kLarge: + return StringValue::Wrap((LowerAsciiImpl)(value_.GetLarge())); + } +} + +namespace { + +bool UpperAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { + if (in.empty()) { + return false; + } + bool needs_conversion = false; + for (char c : in) { + if (absl::ascii_islower(c)) { + needs_conversion = true; + break; + } + } + + if (!needs_conversion) { + return false; + } + + *out = absl::AsciiStrToUpper(in); + return true; +} + +absl::Cord UpperAsciiImpl(const absl::Cord& in) { + if (in.empty()) { + return in; + } + size_t pos = 0; + bool needs_conversion = false; + for (char c : in.Chars()) { + if (absl::ascii_islower(c)) { + needs_conversion = true; + break; + } + pos++; + } + if (!needs_conversion) { + return in; + } + absl::Cord out = in.Subcord(0, pos); + absl::Cord rest = in.Subcord(pos, in.size() - pos); + std::string suffix; + suffix.resize(rest.size()); + size_t current = 0; + for (char c : rest.Chars()) { + suffix[current++] = absl::ascii_toupper(c); + } + out.Append(std::move(suffix)); + return out; +} + +} // namespace + +StringValue StringValue::UpperAscii(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::string out; + if (!(UpperAsciiImpl)(value_.GetSmall(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kMedium: { + std::string out; + if (!(UpperAsciiImpl)(value_.GetMedium(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kLarge: + return StringValue::Wrap((UpperAsciiImpl)(value_.GetLarge())); + } +} + +namespace { + +// Per CEL spec, checking for Unicode whitespace. +bool IsUnicodeWhitespace(char32_t c) { + if (c <= 0x0020) { + return c == 0x0020 || (c >= 0x0009 && c <= 0x000D); + } + if (c > 0x3000) return false; + if (c == 0x0085 || c == 0x00a0 || c == 0x1680) return true; + if (c >= 0x2000 && c <= 0x200a) return true; + return c == 0x2028 || c == 0x2029 || c == 0x202f || c == 0x205f || + c == 0x3000; +} + +std::pair TrimImpl(absl::string_view string) { + absl::string_view temp_string = string; + size_t left_trim_bytes = 0; + while (!temp_string.empty()) { + char32_t c; + size_t char_len = cel::internal::Utf8Decode(temp_string, &c); + if (!IsUnicodeWhitespace(c)) { + break; + } + temp_string.remove_prefix(char_len); + left_trim_bytes += char_len; + } + + if (left_trim_bytes == string.size()) { + return {left_trim_bytes, 0}; + } + + size_t last_non_ws_end_bytes = 0; + size_t current_pos_bytes = 0; + temp_string = string; + while (!temp_string.empty()) { + char32_t c; + size_t char_len = cel::internal::Utf8Decode(temp_string, &c); + if (!IsUnicodeWhitespace(c)) { + last_non_ws_end_bytes = current_pos_bytes + char_len; + } + current_pos_bytes += char_len; + temp_string.remove_prefix(char_len); + } + + return {left_trim_bytes, string.size() - last_non_ws_end_bytes}; +} + +absl::Cord TrimImpl(const absl::Cord& cord) { + size_t left_trim_bytes = 0; + { + absl::Cord::CharIterator begin = cord.char_begin(); + const absl::Cord::CharIterator end = cord.char_end(); + while (begin != end) { + char32_t c; + size_t char_len; + std::tie(c, char_len) = cel::internal::Utf8Decode(begin); + if (!IsUnicodeWhitespace(c)) { + break; + } + absl::Cord::Advance(&begin, char_len); + left_trim_bytes += char_len; + } + } + + if (left_trim_bytes == cord.size()) { + return absl::Cord(); + } + + absl::Cord ltrimmed = + cord.Subcord(left_trim_bytes, cord.size() - left_trim_bytes); + + size_t last_non_ws_end_bytes = 0; + size_t current_pos_bytes = 0; + { + absl::Cord::CharIterator begin = ltrimmed.char_begin(); + const absl::Cord::CharIterator end = ltrimmed.char_end(); + while (begin != end) { + char32_t c; + size_t char_len; + std::tie(c, char_len) = cel::internal::Utf8Decode(begin); + if (!IsUnicodeWhitespace(c)) { + last_non_ws_end_bytes = current_pos_bytes + char_len; + } + absl::Cord::Advance(&begin, char_len); + current_pos_bytes += char_len; + } + } + return ltrimmed.Subcord(0, last_non_ws_end_bytes); +} + +} // namespace + +StringValue StringValue::Trim() const { + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::pair trims = (TrimImpl)(value_.GetSmall()); + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = + value_.rep_.small.size - trims.first - trims.second; + std::memcpy(result.value_.rep_.small.data, + value_.rep_.small.data + trims.first, + result.value_.rep_.small.size); + result.value_.rep_.small.arena = value_.GetSmallArena(); + return result; + } + case common_internal::ByteStringKind::kMedium: { + std::pair trims = (TrimImpl)(value_.GetMedium()); + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; + result.value_.rep_.medium.size = + value_.rep_.medium.size - trims.first - trims.second; + result.value_.rep_.medium.data = value_.rep_.medium.data + trims.first; + result.value_.rep_.medium.owner = value_.rep_.medium.owner; + common_internal::StrongRef(result.value_.GetMediumReferenceCount()); + return result; + } + case common_internal::ByteStringKind::kLarge: { + return StringValue::Wrap((TrimImpl)(value_.GetLarge())); + } + } +} + +namespace { + +void AppendQuoteCodePoint(char32_t code_point, std::string& dst) { + switch (code_point) { + case '\a': + dst.append("\\a"); + break; + case '\b': + dst.append("\\b"); + break; + case '\f': + dst.append("\\f"); + break; + case '\n': + dst.append("\\n"); + break; + case '\r': + dst.append("\\r"); + break; + case '\t': + dst.append("\\t"); + break; + case '\v': + dst.append("\\v"); + break; + case '\\': + dst.append("\\\\"); + break; + case '\"': + dst.append("\\\""); + break; + default: + cel::internal::Utf8Encode(code_point, &dst); + break; + } +} + +} // namespace + +StringValue StringValue::Quote(google::protobuf::Arena* absl_nonnull arena) const { + return value_.Visit(absl::Overload( + [&](absl::string_view rep) -> StringValue { + std::string result; + result.push_back('\"'); + while (!rep.empty()) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(rep); + AppendQuoteCodePoint(code_point, result); + rep.remove_prefix(code_units); + } + result.push_back('\"'); + return StringValue::From(std::move(result), arena); + }, + [&](const absl::Cord& rep) -> StringValue { + absl::Cord::CharIterator begin = rep.char_begin(); + absl::Cord::CharIterator end = rep.char_end(); + std::string result; + result.push_back('\"'); + while (begin != end) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(begin); + AppendQuoteCodePoint(code_point, result); + absl::Cord::Advance(&begin, code_units); + } + result.push_back('\"'); + return StringValue::From(std::move(result), arena); + })); +} + +StringValue StringValue::Reverse(google::protobuf::Arena* absl_nonnull arena) const { + return value_.Visit(absl::Overload( + [arena](absl::string_view string) -> StringValue { + if (string.empty()) { + return StringValue(); + } + std::string reversed; + reversed.reserve(string.size()); + const char* ptr = string.data() + string.size(); + const char* begin = string.data(); + while (ptr > begin) { + const char* char_end = ptr; + --ptr; + // Back up to beginning of encoded UTF-8 code point. + while (ptr > begin && (*ptr & 0xC0) == 0x80) { + --ptr; + } + reversed.append(ptr, char_end - ptr); + } + return StringValue::From(std::move(reversed), arena); + }, + [arena](const absl::Cord& cord) -> StringValue { + if (cord.empty()) { + return StringValue(); + } + std::vector code_points; + absl::Cord::CharIterator char_begin = cord.char_begin(); + absl::Cord::CharIterator char_end = cord.char_end(); + while (char_begin != char_end) { + char32_t code_point; + size_t code_units = + cel::internal::Utf8Decode(char_begin, &code_point); + code_points.push_back(code_point); + absl::Cord::Advance(&char_begin, code_units); + } + std::string reversed; + reversed.reserve(cord.size()); + for (auto it = code_points.rbegin(); it != code_points.rend(); ++it) { + cel::internal::Utf8Encode(*it, &reversed); + } + return StringValue::From(std::move(reversed), arena); + })); +} + +absl::StatusOr StringValue::Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + Value result; + CEL_RETURN_IF_ERROR( + Join(list, descriptor_pool, message_factory, arena, &result)); + return result; +} + +absl::Status StringValue::Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + std::string joined; + + CEL_ASSIGN_OR_RETURN(auto iterator, list.NewIterator()); + + CEL_ASSIGN_OR_RETURN( + absl::optional element, + iterator->Next1(descriptor_pool, message_factory, arena)); + if (element) { + if (auto string_element = element->AsString(); string_element) { + string_element->AppendToString(&joined); + } else { + ABSL_DCHECK(!element->Is()); + *result = + ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); + return absl::OkStatus(); + } + while (true) { + CEL_ASSIGN_OR_RETURN( + element, iterator->Next1(descriptor_pool, message_factory, arena)); + if (!element) { + break; + } + AppendToString(&joined); + if (auto string_element = element->AsString(); string_element) { + string_element->AppendToString(&joined); + } else { + ABSL_DCHECK(!element->Is()); + *result = + ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); + return absl::OkStatus(); + } + } + } + + if (joined.size() > common_internal::kSmallByteStringCapacity) { + joined.shrink_to_fit(); + } + + *result = StringValue::From(std::move(joined), arena); + return absl::OkStatus(); +} + +absl::StatusOr StringValue::Split( + const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const { + Value result; + CEL_RETURN_IF_ERROR(Split(delimiter, limit, arena, &result)); + return result; +} + +absl::Status StringValue::Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + return Split(delimiter, -1, arena, result); +} + +absl::StatusOr StringValue::Split( + const StringValue& delimiter, google::protobuf::Arena* absl_nonnull arena) const { + Value result; + CEL_RETURN_IF_ERROR(Split(delimiter, -1, arena, &result)); + return result; +} + +absl::Status StringValue::Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (limit == 0) { + // Per spec, when limit is 0 return an empty list. + *result = ListValue(); + return absl::OkStatus(); + } + if (limit < 0) { + // Per spec, when limit is negative treat it as unlimited splits. + limit = std::numeric_limits::max(); + } + + std::vector> splits; + size_t pos = 0; + const size_t len = value_.size(); + + if (delimiter.IsEmpty()) { + value_.Visit(absl::Overload( + [&](absl::string_view s) { + while (pos < len && limit > 1) { + size_t char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); + splits.push_back({pos, pos + char_len}); + pos += char_len; + --limit; + } + }, + [&](const absl::Cord& s) { + while (pos < len && limit > 1) { + size_t char_len = cel::internal::Utf8Decode( + s.Subcord(pos, len - pos).char_begin(), nullptr); + splits.push_back({pos, pos + char_len}); + pos += char_len; + --limit; + } + })); + } else { + while (pos < len && limit > 1) { + absl::optional next = value_.Find(delimiter.value_, pos); + if (!next) { + break; + } + splits.push_back(std::pair{pos, *next}); + pos = *next + delimiter.value_.size(); + --limit; + ABSL_DCHECK_LE(pos, len); + } + } + + if (splits.empty() || !delimiter.IsEmpty() || pos < len) { + splits.push_back(std::pair{pos, len}); + } + + auto builder = NewListValueBuilder(arena); + builder->Reserve(splits.size()); + for (const std::pair& split : splits) { + builder->UnsafeAdd( + StringValue(value_.Substring(split.first, split.second))); + } + *result = std::move(*builder).Build(); + return absl::OkStatus(); +} + +absl::StatusOr StringValue::Replace( + const StringValue& needle, const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const { + Value result; + CEL_RETURN_IF_ERROR(Replace(needle, replacement, limit, arena, &result)); + return result; +} + +absl::Status StringValue::Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + return Replace(needle, replacement, -1, arena, result); +} + +absl::StatusOr StringValue::Replace( + const StringValue& needle, const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena) const { + Value result; + CEL_RETURN_IF_ERROR(Replace(needle, replacement, -1, arena, &result)); + return result; +} + +absl::Status StringValue::Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (limit == 0) { + // Per spec, when limit is 0 return the original string. + *result = *this; + return absl::OkStatus(); + } + if (limit < 0) { + // Per spec, when limit is negative treat it as unlimited replacements. + limit = std::numeric_limits::max(); + } + + size_t pos = 0; + const size_t len = value_.size(); + const size_t needle_len = needle.value_.size(); + std::string res_str; + + if (needle.IsEmpty()) { + value_.Visit(absl::Overload( + [&](absl::string_view s) { + while (pos < len && limit > 0) { + replacement.AppendToString(&res_str); + size_t char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); + value_.Substring(pos, pos + char_len).AppendToString(&res_str); + pos += char_len; + --limit; + } + }, + [&](const absl::Cord& s) { + while (pos < len && limit > 0) { + replacement.AppendToString(&res_str); + size_t char_len = cel::internal::Utf8Decode( + s.Subcord(pos, len - pos).char_begin(), nullptr); + value_.Substring(pos, pos + char_len).AppendToString(&res_str); + pos += char_len; + --limit; + } + })); + if (limit > 0) { + replacement.AppendToString(&res_str); + } + } else { + while (pos < len && limit > 0) { + absl::optional next = value_.Find(needle.value_, pos); + if (!next) { + break; + } + + value_.Substring(pos, *next).AppendToString(&res_str); + replacement.AppendToString(&res_str); + + pos = *next + needle_len; + --limit; + } + } + + if (pos < len) { + value_.Substring(pos, len).AppendToString(&res_str); + } + + if (res_str.size() > common_internal::kSmallByteStringCapacity) { + res_str.shrink_to_fit(); + } + + *result = StringValue::From(std::move(res_str), arena); + return absl::OkStatus(); +} + +Value StringValue::CharAt(int64_t pos) const { + if (pos < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".charAt(): is less than 0")); + } + return value_.Visit(absl::Overload( + [this, pos](absl::string_view rep) mutable -> Value { + while (!rep.empty()) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(rep); + if (pos == 0) { + StringValue result; + result.value_.rep_.header.kind = + common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = cel::internal::Utf8Encode( + code_point, result.value_.rep_.small.data); + result.value_.rep_.small.arena = value_.GetArena(); + return result; + } + rep.remove_prefix(code_units); + --pos; + } + // If we exit the loop, we iterated through all the code points in + // `rep`. `pos == 0` means we were looking for a character at index + // `size()`, which is defined to return an empty string. + if (pos == 0) { + return StringValue(); + } + return ErrorValue(absl::InvalidArgumentError( + ".charAt(): is greater than .size()")); + }, + [pos](const absl::Cord& rep) mutable -> Value { + absl::Cord::CharIterator begin = rep.char_begin(); + absl::Cord::CharIterator end = rep.char_end(); + while (begin != end) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(begin); + if (pos == 0) { + StringValue result; + result.value_.rep_.header.kind = + common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = cel::internal::Utf8Encode( + code_point, result.value_.rep_.small.data); + result.value_.rep_.small.arena = nullptr; + return result; + } + absl::Cord::Advance(&begin, code_units); + --pos; + } + // If we exit the loop, we iterated through all the code points in + // `rep`. `pos == 0` means we were looking for a character at index + // `size()`, which is defined to return an empty string. + if (pos == 0) { + return StringValue(); + } + return ErrorValue(absl::InvalidArgumentError( + ".charAt(): is greater than .size()")); + })); +} + +} // namespace cel diff --git a/common/values/string_value.h b/common/values/string_value.h new file mode 100644 index 000000000..8045e4b3f --- /dev/null +++ b/common/values/string_value.h @@ -0,0 +1,489 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRING_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRING_VALUE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/arena.h" +#include "common/internal/byte_string.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class ListValue; +class StringValue; + +namespace common_internal { +absl::string_view LegacyStringValue(const StringValue& value, bool stable, + google::protobuf::Arena* absl_nonnull arena); +} // namespace common_internal + +// `StringValue` represents values of the primitive `string` type. +class StringValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kString; + + static StringValue From(const char* absl_nullable value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + static StringValue From(absl::string_view value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + static StringValue From(const absl::Cord& value); + static StringValue From(std::string&& value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + static StringValue Wrap(absl::string_view value, + google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + static StringValue Wrap(absl::string_view value) = delete; + static StringValue Wrap(const absl::Cord& value); + static StringValue Wrap(std::string&& value) = delete; + static StringValue Wrap(std::string&& value, + google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) = delete; + + // Returns a StringValue that aliases the provided string. Caller must ensure + // the provided string outlives the use of the returned StringValue. + static StringValue WrapUnsafe(absl::string_view value); + + static StringValue Concat(const StringValue& lhs, const StringValue& rhs, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + + ABSL_DEPRECATED("Use From") + explicit StringValue(const char* absl_nullable value) : value_(value) {} + + ABSL_DEPRECATED("Use From") + explicit StringValue(absl::string_view value) : value_(value) {} + + ABSL_DEPRECATED("Use From") + explicit StringValue(const absl::Cord& value) : value_(value) {} + + ABSL_DEPRECATED("Use From") + explicit StringValue(std::string&& value) : value_(std::move(value)) {} + + ABSL_DEPRECATED("Use From") + StringValue(Allocator<> allocator, const char* absl_nullable value) + : value_(allocator, value) {} + + ABSL_DEPRECATED("Use From") + StringValue(Allocator<> allocator, absl::string_view value) + : value_(allocator, value) {} + + ABSL_DEPRECATED("Use From") + StringValue(Allocator<> allocator, const absl::Cord& value) + : value_(allocator, value) {} + + ABSL_DEPRECATED("Use From") + StringValue(Allocator<> allocator, std::string&& value) + : value_(allocator, std::move(value)) {} + + ABSL_DEPRECATED("Use Wrap") + StringValue(Borrower borrower, absl::string_view value) + : value_(borrower, value) {} + + ABSL_DEPRECATED("Use Wrap") + StringValue(Borrower borrower, const absl::Cord& value) + : value_(borrower, value) {} + + StringValue() = default; + StringValue(const StringValue&) = default; + StringValue(StringValue&&) = default; + StringValue& operator=(const StringValue&) = default; + StringValue& operator=(StringValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return StringType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + StringValue Clone(google::protobuf::Arena* absl_nonnull arena) const; + + bool IsZeroValue() const { + return NativeValue([](const auto& value) -> bool { return value.empty(); }); + } + + ABSL_DEPRECATED("Use ToString()") + std::string NativeString() const { return value_.ToString(); } + + ABSL_DEPRECATED("Use ToStringView()") + absl::string_view NativeString( + std::string& scratch + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_.ToStringView(&scratch); + } + + ABSL_DEPRECATED("Use ToCord()") + absl::Cord NativeCord() const { return value_.ToCord(); } + + template + ABSL_DEPRECATED("Use TryFlat()") + std::common_type_t< + std::invoke_result_t, + std::invoke_result_t> NativeValue(Visitor&& + visitor) + const { + return value_.Visit(std::forward(visitor)); + } + + void swap(StringValue& other) noexcept { + using std::swap; + swap(value_, other.value_); + } + + size_t Size() const; + + bool IsEmpty() const; + + bool Equals(absl::string_view string) const; + bool Equals(const absl::Cord& string) const; + bool Equals(const StringValue& string) const; + + int Compare(absl::string_view string) const; + int Compare(const absl::Cord& string) const; + int Compare(const StringValue& string) const; + + bool StartsWith(absl::string_view string) const; + bool StartsWith(const absl::Cord& string) const; + bool StartsWith(const StringValue& string) const; + + bool EndsWith(absl::string_view string) const; + bool EndsWith(const absl::Cord& string) const; + bool EndsWith(const StringValue& string) const; + + bool Contains(absl::string_view string) const; + bool Contains(const absl::Cord& string) const; + bool Contains(const StringValue& string) const; + + // Returns the 0-based index of the first occurrence of `string` in this + // string, or `absl::nullopt` if `string` is not found. + absl::optional IndexOf(absl::string_view string) const; + absl::optional IndexOf(const absl::Cord& string) const; + absl::optional IndexOf(const StringValue& string) const; + // Returns the 0-based index of the first occurrence of `string` in this + // string at or after `pos`, or `absl::nullopt` if `string` is not found. + absl::optional IndexOf(absl::string_view string, int64_t pos) const; + absl::optional IndexOf(const absl::Cord& string, int64_t pos) const; + absl::optional IndexOf(const StringValue& string, int64_t pos) const; + + // Returns the 0-based index of the last occurrence of `string` in this + // string, or `absl::nullopt` if `string` is not found. + absl::optional LastIndexOf(absl::string_view string) const; + absl::optional LastIndexOf(const absl::Cord& string) const; + absl::optional LastIndexOf(const StringValue& string) const; + // Returns the 0-based index of the last occurrence of `string` in this + // string at or before `pos`, or `absl::nullopt` if `string` is not found. + absl::optional LastIndexOf(absl::string_view string, + int64_t pos) const; + absl::optional LastIndexOf(const absl::Cord& string, + int64_t pos) const; + absl::optional LastIndexOf(const StringValue& string, + int64_t pos) const; + + Value Substring(int64_t start) const; + + Value Substring(int64_t start, int64_t end) const; + + // Returns a new `StringValue` with all lowercase ASCII characters + // converted to lowercase. + StringValue LowerAscii(google::protobuf::Arena* absl_nonnull arena) const; + + // Returns a new `StringValue` with all lowercase ASCII characters + // converted to uppercase. + StringValue UpperAscii(google::protobuf::Arena* absl_nonnull arena) const; + + StringValue Trim() const; + + // Returns a new `StringValue` with the string surrounded by double quotes. + StringValue Quote(google::protobuf::Arena* absl_nonnull arena) const; + + // Returns a new `StringValue` with the characters in reverse order. + StringValue Reverse(google::protobuf::Arena* absl_nonnull arena) const; + + // Joins the elements of `list` with this string using `separator` as the + // separator. + absl::Status Join(const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + // Splits this string on `delimiter`, returning a list of strings. If `limit` + // is provided and non-negative, the string is split into at most `limit` + // substrings. + absl::Status Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const; + absl::Status Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena) const; + + // Replaces occurrences of `needle` with `replacement`. If `limit` is provided + // and non-negative, only the first `limit` occurrences are replaced. + absl::Status Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const; + absl::Status Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena) const; + + // Returns the character at `pos` as a new `StringValue`. `pos` is a + // 0-based index based on Unicode code points. Returns `ErrorValue` if `pos` + // is out of range. + Value CharAt(int64_t pos) const; + + absl::optional TryFlat() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_.TryFlat(); + } + + std::string ToString() const { return value_.ToString(); } + + void CopyToString(std::string* absl_nonnull out) const { + value_.CopyToString(out); + } + + void AppendToString(std::string* absl_nonnull out) const { + value_.AppendToString(out); + } + + absl::Cord ToCord() const { return value_.ToCord(); } + + void CopyToCord(absl::Cord* absl_nonnull out) const { + value_.CopyToCord(out); + } + + void AppendToCord(absl::Cord* absl_nonnull out) const { + value_.AppendToCord(out); + } + + absl::string_view ToStringView( + std::string* absl_nonnull scratch + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_.ToStringView(scratch); + } + + template + friend H AbslHashValue(H state, const StringValue& string) { + return H::combine(std::move(state), string.value_); + } + + friend bool operator==(const StringValue& lhs, const StringValue& rhs) { + return lhs.value_ == rhs.value_; + } + + friend bool operator<(const StringValue& lhs, const StringValue& rhs) { + return lhs.value_ < rhs.value_; + } + + private: + friend class common_internal::ValueMixin; + friend absl::string_view common_internal::LegacyStringValue( + const StringValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena); + friend struct ArenaTraits; + + explicit StringValue(common_internal::ByteString value) noexcept + : value_(std::move(value)) {} + + common_internal::ByteString value_; +}; + +inline void swap(StringValue& lhs, StringValue& rhs) noexcept { lhs.swap(rhs); } + +inline bool operator==(const StringValue& lhs, absl::string_view rhs) { + return lhs.Equals(rhs); +} + +inline bool operator==(absl::string_view lhs, const StringValue& rhs) { + return rhs == lhs; +} + +inline bool operator==(const StringValue& lhs, const absl::Cord& rhs) { + return lhs.Equals(rhs); +} + +inline bool operator==(const absl::Cord& lhs, const StringValue& rhs) { + return rhs == lhs; +} + +inline bool operator!=(const StringValue& lhs, absl::string_view rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(absl::string_view lhs, const StringValue& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(const StringValue& lhs, const absl::Cord& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(const absl::Cord& lhs, const StringValue& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(const StringValue& lhs, const StringValue& rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator<(const StringValue& lhs, absl::string_view rhs) { + return lhs.Compare(rhs) < 0; +} + +inline bool operator<(absl::string_view lhs, const StringValue& rhs) { + return rhs.Compare(lhs) > 0; +} + +inline bool operator<(const StringValue& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs) < 0; +} + +inline bool operator<(const absl::Cord& lhs, const StringValue& rhs) { + return rhs.Compare(lhs) > 0; +} + +inline std::ostream& operator<<(std::ostream& out, const StringValue& value) { + return out << value.DebugString(); +} + +inline StringValue StringValue::From(const char* absl_nullable value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return From(absl::NullSafeStringView(value), arena); +} + +inline StringValue StringValue::From(absl::string_view value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(arena != nullptr); + + return StringValue(arena, value); +} + +inline StringValue StringValue::From(const absl::Cord& value) { + return StringValue(value); +} + +inline StringValue StringValue::From(std::string&& value, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(arena != nullptr); + + return StringValue(arena, std::move(value)); +} + +inline StringValue StringValue::Wrap(absl::string_view value, + google::protobuf::Arena* absl_nullable arena + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(arena != nullptr); + + return StringValue(Borrower::Arena(arena), value); +} + +inline StringValue StringValue::WrapUnsafe(absl::string_view value) { + return StringValue(common_internal::ByteString::FromExternal(value)); +} + +inline StringValue StringValue::Wrap(const absl::Cord& value) { + return StringValue(value); +} + +namespace common_internal { + +inline absl::string_view LegacyStringValue(const StringValue& value, + bool stable, + google::protobuf::Arena* absl_nonnull arena) { + return LegacyByteString(value.value_, stable, arena); +} + +} // namespace common_internal + +template <> +struct ArenaTraits { + using constructible = std::true_type; + + static bool trivially_destructible(const StringValue& value) { + return ArenaTraits<>::trivially_destructible(value.value_); + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRING_VALUE_H_ diff --git a/common/values/string_value_test.cc b/common/values/string_value_test.cc new file mode 100644 index 000000000..201724905 --- /dev/null +++ b/common/values/string_value_test.cc @@ -0,0 +1,494 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/cord.h" +#include "absl/strings/cord_test_helpers.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "common/values/int_value.h" +#include "internal/testing.h" +#include "runtime/internal/errors.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::Eq; +using ::testing::Optional; + +using StringValueTest = common_internal::ValueTest<>; + +TEST_F(StringValueTest, Kind) { + EXPECT_EQ(StringValue("foo").kind(), StringValue::kKind); + EXPECT_EQ(Value(StringValue(absl::Cord("foo"))).kind(), StringValue::kKind); +} + +TEST_F(StringValueTest, DebugString) { + { + std::ostringstream out; + out << StringValue("foo"); + EXPECT_EQ(out.str(), "\"foo\""); + } + { + std::ostringstream out; + out << StringValue(absl::MakeFragmentedCord({"f", "o", "o"})); + EXPECT_EQ(out.str(), "\"foo\""); + } + { + std::ostringstream out; + out << Value(StringValue(absl::Cord("foo"))); + EXPECT_EQ(out.str(), "\"foo\""); + } +} + +TEST_F(StringValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(StringValue("foo").ConvertToJson(descriptor_pool(), + message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(string_value: "foo")pb")); +} + +TEST_F(StringValueTest, NativeValue) { + std::string scratch; + EXPECT_EQ(StringValue("foo").NativeString(), "foo"); + EXPECT_EQ(StringValue("foo").NativeString(scratch), "foo"); + EXPECT_EQ(StringValue("foo").NativeCord(), "foo"); +} + +TEST_F(StringValueTest, TryFlat) { + EXPECT_THAT(StringValue("foo").TryFlat(), Optional(Eq("foo"))); + EXPECT_THAT( + StringValue(absl::MakeFragmentedCord({"Hello, World!", "World, Hello!"})) + .TryFlat(), + Eq(absl::nullopt)); +} + +TEST_F(StringValueTest, ToString) { + EXPECT_EQ(StringValue("foo").ToString(), "foo"); + EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToString(), + "foo"); +} + +TEST_F(StringValueTest, CopyToString) { + std::string out; + StringValue("foo").CopyToString(&out); + EXPECT_EQ(out, "foo"); + StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToString(&out); + EXPECT_EQ(out, "foo"); +} + +TEST_F(StringValueTest, AppendToString) { + std::string out; + StringValue("foo").AppendToString(&out); + EXPECT_EQ(out, "foo"); + StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToString(&out); + EXPECT_EQ(out, "foofoo"); +} + +TEST_F(StringValueTest, ToCord) { + EXPECT_EQ(StringValue("foo").ToCord(), "foo"); + EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToCord(), + "foo"); +} + +TEST_F(StringValueTest, CopyToCord) { + absl::Cord out; + StringValue("foo").CopyToCord(&out); + EXPECT_EQ(out, "foo"); + StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToCord(&out); + EXPECT_EQ(out, "foo"); +} + +TEST_F(StringValueTest, AppendToCord) { + absl::Cord out; + StringValue("foo").AppendToCord(&out); + EXPECT_EQ(out, "foo"); + StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToCord(&out); + EXPECT_EQ(out, "foofoo"); +} + +TEST_F(StringValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(StringValue("foo")), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(StringValue(absl::Cord("foo")))), + NativeTypeId::For()); +} + +TEST_F(StringValueTest, HashValue) { + EXPECT_EQ(absl::HashOf(StringValue("foo")), + absl::HashOf(absl::string_view("foo"))); + EXPECT_EQ(absl::HashOf(StringValue(absl::string_view("foo"))), + absl::HashOf(absl::string_view("foo"))); + EXPECT_EQ(absl::HashOf(StringValue(absl::Cord("foo"))), + absl::HashOf(absl::string_view("foo"))); +} + +TEST_F(StringValueTest, Equality) { + EXPECT_NE(StringValue("foo"), "bar"); + EXPECT_NE("bar", StringValue("foo")); + EXPECT_NE(StringValue("foo"), StringValue("bar")); + EXPECT_NE(StringValue("foo"), absl::Cord("bar")); + EXPECT_NE(absl::Cord("bar"), StringValue("foo")); +} + +TEST_F(StringValueTest, LessThan) { + EXPECT_LT(StringValue("bar"), "foo"); + EXPECT_LT("bar", StringValue("foo")); + EXPECT_LT(StringValue("bar"), StringValue("foo")); + EXPECT_LT(StringValue("bar"), absl::Cord("foo")); + EXPECT_LT(absl::Cord("bar"), StringValue("foo")); +} + +TEST_F(StringValueTest, StartsWith) { + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .StartsWith(StringValue("This string is large enough"))); + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .StartsWith(StringValue(absl::Cord("This string is large enough")))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .StartsWith(StringValue("This string is large enough"))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .StartsWith(StringValue(absl::Cord("This string is large enough")))); +} + +TEST_F(StringValueTest, EndsWith) { + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .EndsWith(StringValue("to not be stored inline!"))); + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .EndsWith(StringValue(absl::Cord("to not be stored inline!")))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .EndsWith(StringValue("to not be stored inline!"))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .EndsWith(StringValue(absl::Cord("to not be stored inline!")))); +} + +TEST_F(StringValueTest, Contains) { + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .Contains(StringValue("string is large enough"))); + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .Contains(StringValue(absl::Cord("string is large enough")))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .Contains(StringValue("string is large enough"))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .Contains(StringValue(absl::Cord("string is large enough")))); +} + +TEST_F(StringValueTest, IndexOf) { + StringValue big_string = + StringValue("This string is large enough to not be stored inline!"); + StringValue big_string_cord = StringValue( + absl::Cord("This string is large enough to not be stored inline!")); + StringValue small_string = StringValue("is"); + StringValue small_string_cord = StringValue(absl::Cord("is")); + + EXPECT_THAT(big_string.IndexOf(small_string), Optional(Eq(2))); + EXPECT_THAT(big_string.IndexOf(small_string_cord), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.IndexOf(small_string), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.IndexOf(small_string_cord), Optional(Eq(2))); + + EXPECT_THAT(big_string.IndexOf("is"), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.IndexOf("is"), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.IndexOf("not found"), Eq(absl::nullopt)); + + EXPECT_THAT(big_string.IndexOf(small_string, 4), Optional(Eq(12))); + EXPECT_THAT(big_string.IndexOf(small_string_cord, 4), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.IndexOf(small_string, 4), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.IndexOf(small_string_cord, 4), Optional(Eq(12))); + + EXPECT_THAT(big_string.IndexOf("is", 4), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.IndexOf("is", 4), Optional(Eq(12))); + + EXPECT_THAT(big_string.IndexOf(small_string, 13), Eq(absl::nullopt)); + EXPECT_THAT(big_string.IndexOf(small_string_cord, 13), Eq(absl::nullopt)); + EXPECT_THAT(big_string_cord.IndexOf(small_string, 13), Eq(absl::nullopt)); + EXPECT_THAT(big_string_cord.IndexOf(small_string_cord, 13), + Eq(absl::nullopt)); + + EXPECT_THAT(big_string.IndexOf(absl::Cord("is"), 4), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.IndexOf(absl::Cord("is"), 4), Optional(Eq(12))); + EXPECT_THAT(big_string.IndexOf(absl::Cord("is"), 13), Eq(absl::nullopt)); + EXPECT_THAT(big_string_cord.IndexOf(absl::Cord("is"), 13), Eq(absl::nullopt)); +} + +TEST_F(StringValueTest, LowerAscii) { + EXPECT_EQ(StringValue("UPPER lower").LowerAscii(arena()), "upper lower"); + EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).LowerAscii(arena()), + "upper lower"); + EXPECT_EQ(StringValue("upper lower").LowerAscii(arena()), "upper lower"); + EXPECT_EQ(StringValue(absl::Cord("upper lower")).LowerAscii(arena()), + "upper lower"); + EXPECT_EQ(StringValue("").LowerAscii(arena()), ""); + EXPECT_EQ(StringValue(absl::Cord("")).LowerAscii(arena()), ""); + const std::string kLongMixed = + "A long STRING with MiXeD case to test conversion to lower case!"; + const std::string kLongLower = + "a long string with mixed case to test conversion to lower case!"; + EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).LowerAscii(arena()), + kLongLower); + std::string very_long_mixed(10000, 'A'); + std::string very_long_lower(10000, 'a'); + EXPECT_EQ( + StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), + very_long_mixed.substr(5000)})) + .LowerAscii(arena()), + very_long_lower); + EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"hello", "WORLD"})) + .LowerAscii(arena()), + "helloworld"); +} + +TEST_F(StringValueTest, UpperAscii) { + EXPECT_EQ(StringValue("UPPER lower").UpperAscii(arena()), "UPPER LOWER"); + EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).UpperAscii(arena()), + "UPPER LOWER"); + EXPECT_EQ(StringValue("UPPER LOWER").UpperAscii(arena()), "UPPER LOWER"); + EXPECT_EQ(StringValue(absl::Cord("UPPER LOWER")).UpperAscii(arena()), + "UPPER LOWER"); + EXPECT_EQ(StringValue("").UpperAscii(arena()), ""); + EXPECT_EQ(StringValue(absl::Cord("")).UpperAscii(arena()), ""); + const std::string kLongMixed = + "A long STRING with MiXeD case to test conversion to UPPER case!"; + const std::string kLongUpper = + "A LONG STRING WITH MIXED CASE TO TEST CONVERSION TO UPPER CASE!"; + EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).UpperAscii(arena()), + kLongUpper); + std::string very_long_mixed(10000, 'a'); + std::string very_long_upper(10000, 'A'); + EXPECT_EQ( + StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), + very_long_mixed.substr(5000)})) + .UpperAscii(arena()), + very_long_upper); + EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"HELLO", "world"})) + .UpperAscii(arena()), + "HELLOWORLD"); +} + +TEST_F(StringValueTest, LastIndexOf) { + StringValue big_string = + StringValue("This string is large enough to not be stored inline!"); + StringValue big_string_cord = StringValue( + absl::Cord("This string is large enough to not be stored inline!")); + StringValue small_string = StringValue("is"); + StringValue small_string_cord = StringValue(absl::Cord("is")); + + EXPECT_THAT(big_string.LastIndexOf(small_string), Optional(Eq(12))); + EXPECT_THAT(big_string.LastIndexOf(small_string_cord), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord), Optional(Eq(12))); + + EXPECT_THAT(big_string.LastIndexOf("is"), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.LastIndexOf("is"), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.LastIndexOf("not found"), Eq(absl::nullopt)); + + EXPECT_THAT(big_string.LastIndexOf(small_string, 4), Optional(Eq(2))); + EXPECT_THAT(big_string.LastIndexOf(small_string_cord, 4), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string, 4), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord, 4), + Optional(Eq(2))); + + EXPECT_THAT(big_string.LastIndexOf("is", 4), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.LastIndexOf("is", 4), Optional(Eq(2))); + + EXPECT_THAT(big_string.LastIndexOf(small_string, 100), Optional(Eq(12))); + EXPECT_THAT(big_string.LastIndexOf(small_string_cord, 100), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string, 100), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord, 100), + Optional(Eq(12))); + EXPECT_THAT(big_string.LastIndexOf(absl::Cord("is"), 4), Optional(Eq(2))); + EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord("is"), 4), + Optional(Eq(2))); + EXPECT_THAT(big_string.LastIndexOf(absl::Cord("is"), 100), Optional(Eq(12))); + EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord("is"), 100), + Optional(Eq(12))); + EXPECT_THAT(big_string.LastIndexOf(absl::Cord(""), 100), Optional(Eq(52))); + EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord(""), 100), + Optional(Eq(52))); +} + +TEST_F(StringValueTest, Trim) { + using ::cel::test::StringValueIs; + StringValue unpadded = StringValue("no padding"); + StringValue front_padded = StringValue(" \t\r\nno padding"); + StringValue back_padded = StringValue("no padding \t\r\n"); + StringValue both_padded = StringValue(" \t\r\nno padding \t\r\n"); + StringValue whitespace = StringValue(" \t\r\n"); + StringValue empty = StringValue(""); + + EXPECT_THAT(unpadded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(front_padded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(back_padded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(both_padded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(whitespace.Trim(), StringValueIs("")); + EXPECT_THAT(empty.Trim(), StringValueIs("")); + + StringValue unpadded_cord = StringValue(absl::Cord("no padding")); + StringValue front_padded_cord = StringValue(absl::Cord(" \t\r\nno padding")); + StringValue back_padded_cord = StringValue(absl::Cord("no padding \t\r\n")); + StringValue both_padded_cord = + StringValue(absl::Cord(" \t\r\nno padding \t\r\n")); + StringValue whitespace_cord = StringValue(absl::Cord(" \t\r\n")); + StringValue empty_cord = StringValue(absl::Cord("")); + + EXPECT_THAT(unpadded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(front_padded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(back_padded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(both_padded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(whitespace_cord.Trim(), StringValueIs("")); + EXPECT_THAT(empty_cord.Trim(), StringValueIs("")); +} + +TEST_F(StringValueTest, CharAt) { + using ::cel::test::ErrorValueIs; + using ::cel::test::StringValueIs; + StringValue big_string = + StringValue("This string is large enough to not be stored inline!"); + StringValue big_string_cord = StringValue( + absl::Cord("This string is large enough to not be stored inline!")); + StringValue small_string = StringValue("abc"); + StringValue small_string_cord = StringValue(absl::Cord("abc")); + StringValue unicode_string = StringValue("aμc"); + StringValue unicode_string_cord = StringValue(absl::Cord("aμc")); + + EXPECT_THAT(big_string.CharAt(0), StringValueIs("T")); + EXPECT_THAT(big_string_cord.CharAt(0), StringValueIs("T")); + EXPECT_THAT(small_string.CharAt(1), StringValueIs("b")); + EXPECT_THAT(small_string_cord.CharAt(1), StringValueIs("b")); + EXPECT_THAT(unicode_string.CharAt(1), StringValueIs("μ")); + EXPECT_THAT(unicode_string_cord.CharAt(1), StringValueIs("μ")); + + EXPECT_THAT( + big_string.CharAt(100), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is greater than .size()"))); + EXPECT_THAT( + big_string_cord.CharAt(100), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is greater than .size()"))); + EXPECT_THAT(big_string.CharAt(-1), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is less than 0"))); + EXPECT_THAT(big_string_cord.CharAt(-1), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is less than 0"))); +} + +TEST_F(StringValueTest, Join) { + using ::cel::runtime_internal::CreateNoMatchingOverloadError; + using ::cel::test::ErrorValueIs; + using ::cel::test::StringValueIs; + + StringValue separator(","); + Value result; + + // Empty list. + auto list_builder0 = NewListValueBuilder(arena()); + auto list0 = std::move(*list_builder0).Build(); + EXPECT_THAT(separator.Join(list0, descriptor_pool(), message_factory(), + arena(), &result), + IsOk()); + EXPECT_THAT(result, StringValueIs("")); + + // Single element list. + auto list_builder1 = NewListValueBuilder(arena()); + ASSERT_THAT(list_builder1->Add(StringValue("foo")), IsOk()); + auto list1 = std::move(*list_builder1).Build(); + EXPECT_THAT(separator.Join(list1, descriptor_pool(), message_factory(), + arena(), &result), + IsOk()); + EXPECT_THAT(result, StringValueIs("foo")); + + // Multi element list. + auto list_builder2 = NewListValueBuilder(arena()); + ASSERT_THAT(list_builder2->Add(StringValue("foo")), IsOk()); + ASSERT_THAT(list_builder2->Add(StringValue("bar")), IsOk()); + ASSERT_THAT(list_builder2->Add(StringValue("baz")), IsOk()); + auto list2 = std::move(*list_builder2).Build(); + EXPECT_THAT(separator.Join(list2, descriptor_pool(), message_factory(), + arena(), &result), + IsOk()); + EXPECT_THAT(result, StringValueIs("foo,bar,baz")); + + // List with non-string. + auto list_builder3 = NewListValueBuilder(arena()); + ASSERT_THAT(list_builder3->Add(IntValue(1)), IsOk()); + auto list3 = std::move(*list_builder3).Build(); + EXPECT_THAT(separator.Join(list3, descriptor_pool(), message_factory(), + arena(), &result), + IsOk()); + EXPECT_THAT(result, ErrorValueIs(CreateNoMatchingOverloadError("join"))); + + // List with string and non-string. + auto list_builder4 = NewListValueBuilder(arena()); + ASSERT_THAT(list_builder4->Add(StringValue("foo")), IsOk()); + ASSERT_THAT(list_builder4->Add(IntValue(1)), IsOk()); + auto list4 = std::move(*list_builder4).Build(); + EXPECT_THAT(separator.Join(list4, descriptor_pool(), message_factory(), + arena(), &result), + IsOk()); + EXPECT_THAT(result, ErrorValueIs(CreateNoMatchingOverloadError("join"))); +} + +TEST_F(StringValueTest, Reverse) { + using ::cel::test::StringValueIs; + + EXPECT_THAT(StringValue().Reverse(arena()), StringValueIs("")); + EXPECT_THAT(StringValue("").Reverse(arena()), StringValueIs("")); + EXPECT_THAT(StringValue("hello").Reverse(arena()), StringValueIs("olleh")); + EXPECT_THAT(StringValue("aμc").Reverse(arena()), StringValueIs("cμa")); + EXPECT_THAT( + StringValue("This string is large enough to not be stored inline!") + .Reverse(arena()), + StringValueIs("!enilni derots eb ton ot hguone egral si gnirts sihT")); + EXPECT_THAT(StringValue(absl::Cord("hello")).Reverse(arena()), + StringValueIs("olleh")); + EXPECT_THAT(StringValue(absl::Cord("aμc")).Reverse(arena()), + StringValueIs("cμa")); + EXPECT_THAT( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .Reverse(arena()), + StringValueIs("!enilni derots eb ton ot hguone egral si gnirts sihT")); +} + +} // namespace +} // namespace cel diff --git a/common/values/struct_value.cc b/common/values/struct_value.cc new file mode 100644 index 000000000..10238a670 --- /dev/null +++ b/common/values/struct_value.cc @@ -0,0 +1,390 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/type.h" +#include "common/value.h" +#include "common/values/value_variant.h" +#include "internal/status_macros.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +StructType StructValue::GetRuntimeType() const { + return variant_.Visit([](const auto& alternative) -> StructType { + return alternative.GetRuntimeType(); + }); +} + +absl::string_view StructValue::GetTypeName() const { + return variant_.Visit([](const auto& alternative) -> absl::string_view { + return alternative.GetTypeName(); + }); +} + +NativeTypeId StructValue::GetTypeId() const { + return variant_.Visit([](const auto& alternative) -> NativeTypeId { + return NativeTypeId::Of(alternative); + }); +} + +std::string StructValue::DebugString() const { + return variant_.Visit([](const auto& alternative) -> std::string { + return alternative.DebugString(); + }); +} + +absl::Status StructValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.SerializeTo(descriptor_pool, message_factory, output); + }); +} + +absl::Status StructValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ConvertToJson(descriptor_pool, message_factory, json); + }); +} + +absl::Status StructValue::ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ConvertToJsonObject(descriptor_pool, message_factory, + json); + }); +} + +absl::Status StructValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Equal(other, descriptor_pool, message_factory, arena, + result); + }); +} + +bool StructValue::IsZeroValue() const { + return variant_.Visit([](const auto& alternative) -> bool { + return alternative.IsZeroValue(); + }); +} + +absl::StatusOr StructValue::HasFieldByName(absl::string_view name) const { + return variant_.Visit( + [name](const auto& alternative) -> absl::StatusOr { + return alternative.HasFieldByName(name); + }); +} + +absl::StatusOr StructValue::HasFieldByNumber(int64_t number) const { + return variant_.Visit( + [number](const auto& alternative) -> absl::StatusOr { + return alternative.HasFieldByNumber(number); + }); +} + +absl::Status StructValue::GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.GetFieldByName(name, unboxing_options, descriptor_pool, + message_factory, arena, result); + }); +} + +absl::Status StructValue::GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.GetFieldByNumber(number, unboxing_options, + descriptor_pool, message_factory, arena, + result); + }); +} + +absl::Status StructValue::ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.ForEachField(callback, descriptor_pool, message_factory, + arena); + }); +} + +absl::Status StructValue::Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const { + ABSL_DCHECK(!qualifiers.empty()); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + ABSL_DCHECK(count != nullptr); + + return variant_.Visit([&](const auto& alternative) -> absl::Status { + return alternative.Qualify(qualifiers, presence_test, descriptor_pool, + message_factory, arena, result, count); + }); +} + +namespace common_internal { + +absl::Status StructValueEqual( + const StructValue& lhs, const StructValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (lhs.GetTypeName() != rhs.GetTypeName()) { + *result = FalseValue(); + return absl::OkStatus(); + } + absl::flat_hash_map lhs_fields; + CEL_RETURN_IF_ERROR(lhs.ForEachField( + [&lhs_fields](absl::string_view name, + const Value& lhs_value) -> absl::StatusOr { + lhs_fields.insert_or_assign(std::string(name), Value(lhs_value)); + return true; + }, + descriptor_pool, message_factory, arena)); + bool equal = true; + size_t rhs_fields_count = 0; + CEL_RETURN_IF_ERROR(rhs.ForEachField( + [&](absl::string_view name, + const Value& rhs_value) -> absl::StatusOr { + auto lhs_field = lhs_fields.find(name); + if (lhs_field == lhs_fields.end()) { + equal = false; + return false; + } + CEL_RETURN_IF_ERROR(lhs_field->second.Equal( + rhs_value, descriptor_pool, message_factory, arena, result)); + if (result->IsFalse()) { + equal = false; + return false; + } + ++rhs_fields_count; + return true; + }, + descriptor_pool, message_factory, arena)); + if (!equal || rhs_fields_count != lhs_fields.size()) { + *result = FalseValue(); + return absl::OkStatus(); + } + *result = TrueValue(); + return absl::OkStatus(); +} + +absl::Status StructValueEqual( + const CustomStructValueInterface& lhs, const StructValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (lhs.GetTypeName() != rhs.GetTypeName()) { + *result = FalseValue(); + return absl::OkStatus(); + } + absl::flat_hash_map lhs_fields; + CEL_RETURN_IF_ERROR(lhs.ForEachField( + [&lhs_fields](absl::string_view name, + const Value& lhs_value) -> absl::StatusOr { + lhs_fields.insert_or_assign(std::string(name), Value(lhs_value)); + return true; + }, + descriptor_pool, message_factory, arena)); + bool equal = true; + size_t rhs_fields_count = 0; + CEL_RETURN_IF_ERROR(rhs.ForEachField( + [&](absl::string_view name, + const Value& rhs_value) -> absl::StatusOr { + auto lhs_field = lhs_fields.find(name); + if (lhs_field == lhs_fields.end()) { + equal = false; + return false; + } + CEL_RETURN_IF_ERROR(lhs_field->second.Equal( + rhs_value, descriptor_pool, message_factory, arena, result)); + if (result->IsFalse()) { + equal = false; + return false; + } + ++rhs_fields_count; + return true; + }, + descriptor_pool, message_factory, arena)); + if (!equal || rhs_fields_count != lhs_fields.size()) { + *result = FalseValue(); + return absl::OkStatus(); + } + *result = TrueValue(); + return absl::OkStatus(); +} + +} // namespace common_internal + +absl::optional StructValue::AsMessage() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional StructValue::AsMessage() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +optional_ref StructValue::AsParsedMessage() const& { + if (const auto* alternative = variant_.As(); + alternative != nullptr) { + return *alternative; + } + return absl::nullopt; +} + +absl::optional StructValue::AsParsedMessage() && { + if (auto* alternative = variant_.As(); + alternative != nullptr) { + return std::move(*alternative); + } + return absl::nullopt; +} + +MessageValue StructValue::GetMessage() const& { + ABSL_DCHECK(IsMessage()) << *this; + + return variant_.Get(); +} + +MessageValue StructValue::GetMessage() && { + ABSL_DCHECK(IsMessage()) << *this; + + return std::move(variant_).Get(); +} + +const ParsedMessageValue& StructValue::GetParsedMessage() const& { + ABSL_DCHECK(IsParsedMessage()) << *this; + + return variant_.Get(); +} + +ParsedMessageValue StructValue::GetParsedMessage() && { + ABSL_DCHECK(IsParsedMessage()) << *this; + + return std::move(variant_).Get(); +} + +common_internal::ValueVariant StructValue::ToValueVariant() const& { + return variant_.Visit( + [](const auto& alternative) -> common_internal::ValueVariant { + return common_internal::ValueVariant(alternative); + }); +} + +common_internal::ValueVariant StructValue::ToValueVariant() && { + return std::move(variant_).Visit( + [](auto&& alternative) -> common_internal::ValueVariant { + // NOLINTNEXTLINE(bugprone-move-forwarding-reference) + return common_internal::ValueVariant(std::move(alternative)); + }); +} + +} // namespace cel diff --git a/common/values/struct_value.h b/common/values/struct_value.h new file mode 100644 index 000000000..d096356c7 --- /dev/null +++ b/common/values/struct_value.h @@ -0,0 +1,373 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +// `StructValue` is the value representation of `StructType`. `StructValue` +// itself is a composed type of more specific runtime representations. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/utility/utility.h" +#include "base/attribute.h" +#include "common/native_type.h" +#include "common/optional_ref.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/custom_struct_value.h" +#include "common/values/legacy_struct_value.h" +#include "common/values/message_value.h" +#include "common/values/parsed_message_value.h" +#include "common/values/struct_value_variant.h" +#include "common/values/values.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class StructValue; +class Value; + +class StructValue final + : private common_internal::StructValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kStruct; + + template < + typename T, + typename = std::enable_if_t< + common_internal::IsStructValueAlternativeV>>> + // NOLINTNEXTLINE(google-explicit-constructor) + StructValue(T&& value) + : variant_(absl::in_place_type>, + std::forward(value)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + StructValue(const MessageValue& other) + : variant_(other.ToStructValueVariant()) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + StructValue(MessageValue&& other) + : variant_(std::move(other).ToStructValueVariant()) {} + + StructValue() = default; + StructValue(const StructValue&) = default; + StructValue(StructValue&& other) = default; + StructValue& operator=(const StructValue&) = default; + StructValue& operator=(StructValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + StructType GetRuntimeType() const; + + absl::string_view GetTypeName() const; + + NativeTypeId GetTypeId() const; + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + // Like ConvertToJson(), except `json` **MUST** be an instance of + // `google.protobuf.Struct`. + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using StructValueMixin::Equal; + + bool IsZeroValue() const; + + absl::Status GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByName; + + absl::Status GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + using StructValueMixin::GetFieldByNumber; + + absl::StatusOr HasFieldByName(absl::string_view name) const; + + absl::StatusOr HasFieldByNumber(int64_t number) const; + + using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; + + absl::Status ForEachField( + ForEachFieldCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, + int* absl_nonnull count) const; + using StructValueMixin::Qualify; + + // Returns `true` if this value is an instance of a message value. If `true` + // is returned, it is implied that `IsOpaque()` would also return true. + bool IsMessage() const { return IsParsedMessage(); } + + // Returns `true` if this value is an instance of a parsed message value. If + // `true` is returned, it is implied that `IsMessage()` would also return + // true. + bool IsParsedMessage() const { return variant_.Is(); } + + // Convenience method for use with template metaprogramming. See + // `IsMessage()`. + template + std::enable_if_t, bool> Is() const { + return IsMessage(); + } + + // Convenience method for use with template metaprogramming. See + // `IsParsedMessage()`. + template + std::enable_if_t, bool> Is() const { + return IsParsedMessage(); + } + + // Performs a checked cast from a value to a message value, + // returning a non-empty optional with either a value or reference to the + // message value. Otherwise an empty optional is returned. + absl::optional AsMessage() & { + return std::as_const(*this).AsMessage(); + } + absl::optional AsMessage() const&; + absl::optional AsMessage() &&; + absl::optional AsMessage() const&& { return AsMessage(); } + + // Performs a checked cast from a value to a parsed message value, + // returning a non-empty optional with either a value or reference to the + // parsed message value. Otherwise an empty optional is returned. + optional_ref AsParsedMessage() & + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).AsParsedMessage(); + } + optional_ref AsParsedMessage() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::optional AsParsedMessage() &&; + absl::optional AsParsedMessage() const&& { + return common_internal::AsOptional(AsParsedMessage()); + } + + // Convenience method for use with template metaprogramming. See + // `AsMessage()`. + template + std::enable_if_t, + absl::optional> + As() & { + return AsMessage(); + } + template + std::enable_if_t, + absl::optional> + As() const& { + return AsMessage(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsMessage(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsMessage(); + } + + // Convenience method for use with template metaprogramming. See + // `AsParsedMessage()`. + template + std::enable_if_t, + optional_ref> + As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedMessage(); + } + template + std::enable_if_t, + optional_ref> + As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AsParsedMessage(); + } + template + std::enable_if_t, + absl::optional> + As() && { + return std::move(*this).AsParsedMessage(); + } + template + std::enable_if_t, + absl::optional> + As() const&& { + return std::move(*this).AsParsedMessage(); + } + + // Performs an unchecked cast from a value to a message value. In + // debug builds a best effort is made to crash. If `IsMessage()` would return + // false, calling this method is undefined behavior. + MessageValue GetMessage() & { return std::as_const(*this).GetMessage(); } + MessageValue GetMessage() const&; + MessageValue GetMessage() &&; + MessageValue GetMessage() const&& { return GetMessage(); } + + // Performs an unchecked cast from a value to a parsed message value. In + // debug builds a best effort is made to crash. If `IsParsedMessage()` would + // return false, calling this method is undefined behavior. + const ParsedMessageValue& GetParsedMessage() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::as_const(*this).GetParsedMessage(); + } + const ParsedMessageValue& GetParsedMessage() + const& ABSL_ATTRIBUTE_LIFETIME_BOUND; + ParsedMessageValue GetParsedMessage() &&; + ParsedMessageValue GetParsedMessage() const&& { return GetParsedMessage(); } + + // Convenience method for use with template metaprogramming. See + // `GetMessage()`. + template + std::enable_if_t, MessageValue> Get() & { + return GetMessage(); + } + template + std::enable_if_t, MessageValue> Get() const& { + return GetMessage(); + } + template + std::enable_if_t, MessageValue> Get() && { + return std::move(*this).GetMessage(); + } + template + std::enable_if_t, MessageValue> Get() + const&& { + return std::move(*this).GetMessage(); + } + + // Convenience method for use with template metaprogramming. See + // `GetParsedMessage()`. + template + std::enable_if_t, + const ParsedMessageValue&> + Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedMessage(); + } + template + std::enable_if_t, + const ParsedMessageValue&> + Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetParsedMessage(); + } + template + std::enable_if_t, ParsedMessageValue> + Get() && { + return std::move(*this).GetParsedMessage(); + } + template + std::enable_if_t, ParsedMessageValue> + Get() const&& { + return std::move(*this).GetParsedMessage(); + } + + friend void swap(StructValue& lhs, StructValue& rhs) noexcept { + using std::swap; + swap(lhs.variant_, rhs.variant_); + } + + private: + friend class Value; + friend class common_internal::ValueMixin; + friend class common_internal::StructValueMixin; + + common_internal::ValueVariant ToValueVariant() const&; + common_internal::ValueVariant ToValueVariant() &&; + + // Unlike many of the other derived values, `StructValue` is itself a composed + // type. This is to avoid making `StructValue` too big and by extension + // `Value` too big. Instead we store the derived `StructValue` values in + // `Value` and not `StructValue` itself. + common_internal::StructValueVariant variant_; +}; + +inline std::ostream& operator<<(std::ostream& out, const StructValue& value) { + return out << value.DebugString(); +} + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const StructValue& value) { return value.GetTypeId(); } +}; + +class StructValueBuilder { + public: + virtual ~StructValueBuilder() = default; + + virtual absl::StatusOr> SetFieldByName( + absl::string_view name, Value value) = 0; + + virtual absl::StatusOr> SetFieldByNumber( + int64_t number, Value value) = 0; + + virtual absl::StatusOr Build() && = 0; +}; + +using StructValueBuilderPtr = std::unique_ptr; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_H_ diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc new file mode 100644 index 000000000..446b18421 --- /dev/null +++ b/common/values/struct_value_builder.cc @@ -0,0 +1,1552 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/struct_value_builder.h" + +#include +#include +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/any.h" +#include "common/memory.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/values/value_builder.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/message.h" + +// TODO(uncreated-issue/82): Improve test coverage for struct value builder + +// TODO(uncreated-issue/76): improve test coverage for JSON/Any + +namespace cel::common_internal { + +namespace { + +absl::StatusOr GetDescriptor( + const google::protobuf::Message& message) { + const auto* desc = message.GetDescriptor(); + if (ABSL_PREDICT_FALSE(desc == nullptr)) { + return absl::InvalidArgumentError( + absl::StrCat(message.GetTypeName(), " is missing descriptor")); + } + return desc; +} + +absl::StatusOr> ProtoMessageCopyUsingSerialization( + google::protobuf::MessageLite* to, const google::protobuf::MessageLite* from) { + ABSL_DCHECK_EQ(to->GetTypeName(), from->GetTypeName()); + absl::Cord serialized; + if (!from->SerializePartialToString(&serialized)) { + return absl::UnknownError( + absl::StrCat("failed to serialize `", from->GetTypeName(), "`")); + } + if (!to->ParsePartialFromString(serialized)) { + return absl::UnknownError( + absl::StrCat("failed to parse `", to->GetTypeName(), "`")); + } + return absl::nullopt; +} + +absl::StatusOr> ProtoMessageCopy( + google::protobuf::Message* absl_nonnull to_message, + const google::protobuf::Descriptor* absl_nonnull to_descriptor, + const google::protobuf::Message* absl_nonnull from_message) { + CEL_ASSIGN_OR_RETURN(const auto* from_descriptor, + GetDescriptor(*from_message)); + if (to_descriptor == from_descriptor) { + // Same. + to_message->CopyFrom(*from_message); + return absl::nullopt; + } + if (to_descriptor->full_name() == from_descriptor->full_name()) { + // Same type, different descriptors. + return ProtoMessageCopyUsingSerialization(to_message, from_message); + } + return TypeConversionError(from_descriptor->full_name(), + to_descriptor->full_name()); +} + +absl::StatusOr> ProtoMessageFromValueImpl( + const Value& value, const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory, + well_known_types::Reflection* absl_nonnull well_known_types, + google::protobuf::Message* absl_nonnull message) { + CEL_ASSIGN_OR_RETURN(const auto* to_desc, GetDescriptor(*message)); + switch (to_desc->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: { + if (auto double_value = value.AsDouble(); double_value) { + CEL_RETURN_IF_ERROR(well_known_types->FloatValue().Initialize( + message->GetDescriptor())); + well_known_types->FloatValue().SetValue( + message, static_cast(double_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { + if (auto double_value = value.AsDouble(); double_value) { + CEL_RETURN_IF_ERROR(well_known_types->DoubleValue().Initialize( + message->GetDescriptor())); + well_known_types->DoubleValue().SetValue(message, + double_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: { + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); + } + CEL_RETURN_IF_ERROR(well_known_types->Int32Value().Initialize( + message->GetDescriptor())); + well_known_types->Int32Value().SetValue( + message, static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: { + if (auto int_value = value.AsInt(); int_value) { + CEL_RETURN_IF_ERROR(well_known_types->Int64Value().Initialize( + message->GetDescriptor())); + well_known_types->Int64Value().SetValue(message, + int_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: { + if (auto uint_value = value.AsUint(); uint_value) { + if (uint_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); + } + CEL_RETURN_IF_ERROR(well_known_types->UInt32Value().Initialize( + message->GetDescriptor())); + well_known_types->UInt32Value().SetValue( + message, static_cast(uint_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: { + if (auto uint_value = value.AsUint(); uint_value) { + CEL_RETURN_IF_ERROR(well_known_types->UInt64Value().Initialize( + message->GetDescriptor())); + well_known_types->UInt64Value().SetValue(message, + uint_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: { + if (auto string_value = value.AsString(); string_value) { + CEL_RETURN_IF_ERROR(well_known_types->StringValue().Initialize( + message->GetDescriptor())); + well_known_types->StringValue().SetValue(message, + string_value->NativeCord()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: { + if (auto bytes_value = value.AsBytes(); bytes_value) { + CEL_RETURN_IF_ERROR(well_known_types->BytesValue().Initialize( + message->GetDescriptor())); + well_known_types->BytesValue().SetValue(message, + bytes_value->NativeCord()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: { + if (auto bool_value = value.AsBool(); bool_value) { + CEL_RETURN_IF_ERROR( + well_known_types->BoolValue().Initialize(message->GetDescriptor())); + well_known_types->BoolValue().SetValue(message, + bool_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: { + google::protobuf::io::CordOutputStream serialized; + CEL_RETURN_IF_ERROR(value.SerializeTo(pool, factory, &serialized)); + std::string type_url; + switch (value.kind()) { + case ValueKind::kNull: + type_url = MakeTypeUrl("google.protobuf.Value"); + break; + case ValueKind::kBool: + type_url = MakeTypeUrl("google.protobuf.BoolValue"); + break; + case ValueKind::kInt: + type_url = MakeTypeUrl("google.protobuf.Int64Value"); + break; + case ValueKind::kUint: + type_url = MakeTypeUrl("google.protobuf.UInt64Value"); + break; + case ValueKind::kDouble: + type_url = MakeTypeUrl("google.protobuf.DoubleValue"); + break; + case ValueKind::kBytes: + type_url = MakeTypeUrl("google.protobuf.BytesValue"); + break; + case ValueKind::kString: + type_url = MakeTypeUrl("google.protobuf.StringValue"); + break; + case ValueKind::kList: + type_url = MakeTypeUrl("google.protobuf.ListValue"); + break; + case ValueKind::kMap: + type_url = MakeTypeUrl("google.protobuf.Struct"); + break; + case ValueKind::kDuration: + type_url = MakeTypeUrl("google.protobuf.Duration"); + break; + case ValueKind::kTimestamp: + type_url = MakeTypeUrl("google.protobuf.Timestamp"); + break; + default: + type_url = MakeTypeUrl(value.GetTypeName()); + break; + } + CEL_RETURN_IF_ERROR( + well_known_types->Any().Initialize(message->GetDescriptor())); + well_known_types->Any().SetTypeUrl(message, type_url); + well_known_types->Any().SetValue(message, + std::move(serialized).Consume()); + return absl::nullopt; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: { + if (auto duration_value = value.AsDuration(); duration_value) { + CEL_RETURN_IF_ERROR( + well_known_types->Duration().Initialize(message->GetDescriptor())); + CEL_RETURN_IF_ERROR(well_known_types->Duration().SetFromAbslDuration( + message, duration_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + if (auto timestamp_value = value.AsTimestamp(); timestamp_value) { + CEL_RETURN_IF_ERROR( + well_known_types->Timestamp().Initialize(message->GetDescriptor())); + CEL_RETURN_IF_ERROR(well_known_types->Timestamp().SetFromAbslTime( + message, timestamp_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), to_desc->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: { + CEL_RETURN_IF_ERROR(value.ConvertToJson(pool, factory, message)); + return absl::nullopt; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: { + CEL_RETURN_IF_ERROR(value.ConvertToJsonArray(pool, factory, message)); + return absl::nullopt; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: { + CEL_RETURN_IF_ERROR(value.ConvertToJsonObject(pool, factory, message)); + return absl::nullopt; + } + default: + break; + } + + // Not a well known type. + + // Deal with legacy values. + if (auto legacy_value = common_internal::AsLegacyStructValue(value); + legacy_value) { + const auto* from_message = legacy_value->message_ptr(); + return ProtoMessageCopy(message, to_desc, from_message); + } + + // Deal with modern values. + if (auto parsed_message_value = value.AsParsedMessage(); + parsed_message_value) { + return ProtoMessageCopy(message, to_desc, + cel::to_address(*parsed_message_value)); + } + + return TypeConversionError(value.GetTypeName(), message->GetTypeName()); +} + +// Converts a value to a specific protocol buffer map key. +using ProtoMapKeyFromValueConverter = + absl::StatusOr> (*)(const Value&, + google::protobuf::MapKey&, + std::string&); + +absl::StatusOr> ProtoBoolMapKeyFromValueConverter( + const Value& value, google::protobuf::MapKey& key, std::string&) { + if (auto bool_value = value.AsBool(); bool_value) { + key.SetBoolValue(bool_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "bool"); +} + +absl::StatusOr> ProtoInt32MapKeyFromValueConverter( + const Value& value, google::protobuf::MapKey& key, std::string&) { + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); + } + key.SetInt32Value(static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); +} + +absl::StatusOr> ProtoInt64MapKeyFromValueConverter( + const Value& value, google::protobuf::MapKey& key, std::string&) { + if (auto int_value = value.AsInt(); int_value) { + key.SetInt64Value(int_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); +} + +absl::StatusOr> ProtoUInt32MapKeyFromValueConverter( + const Value& value, google::protobuf::MapKey& key, std::string&) { + if (auto uint_value = value.AsUint(); uint_value) { + if (uint_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); + } + key.SetUInt32Value(static_cast(uint_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); +} + +absl::StatusOr> ProtoUInt64MapKeyFromValueConverter( + const Value& value, google::protobuf::MapKey& key, std::string&) { + if (auto uint_value = value.AsUint(); uint_value) { + key.SetUInt64Value(uint_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); +} + +absl::StatusOr> ProtoStringMapKeyFromValueConverter( + const Value& value, google::protobuf::MapKey& key, std::string& key_string) { + if (auto string_value = value.AsString(); string_value) { + key_string = string_value->NativeString(); + key.SetStringValue(key_string); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "string"); +} + +// Gets the converter for converting from values to protocol buffer map key. +absl::StatusOr GetProtoMapKeyFromValueConverter( + google::protobuf::FieldDescriptor::CppType cpp_type) { + switch (cpp_type) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolMapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64MapKeyFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return ProtoStringMapKeyFromValueConverter; + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected protocol buffer map key type: ", + google::protobuf::FieldDescriptor::CppTypeName(cpp_type))); + } +} + +// Converts a value to a specific protocol buffer map value. +using ProtoMapValueFromValueConverter = + absl::StatusOr> (*)( + const Value&, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef&); + +absl::StatusOr> ProtoBoolMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto bool_value = value.AsBool(); bool_value) { + value_ref.SetBoolValue(bool_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "bool"); +} + +absl::StatusOr> ProtoInt32MapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); + } + value_ref.SetInt32Value(static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); +} + +absl::StatusOr> ProtoInt64MapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto int_value = value.AsInt(); int_value) { + value_ref.SetInt64Value(int_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); +} + +absl::StatusOr> +ProtoUInt32MapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto uint_value = value.AsUint(); uint_value) { + if (uint_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); + } + value_ref.SetUInt32Value(static_cast(uint_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); +} + +absl::StatusOr> +ProtoUInt64MapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto uint_value = value.AsUint(); uint_value) { + value_ref.SetUInt64Value(uint_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); +} + +absl::StatusOr> ProtoFloatMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto double_value = value.AsDouble(); double_value) { + value_ref.SetFloatValue(double_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "double"); +} + +absl::StatusOr> +ProtoDoubleMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto double_value = value.AsDouble(); double_value) { + value_ref.SetDoubleValue(double_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "double"); +} + +absl::StatusOr> ProtoBytesMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto bytes_value = value.AsBytes(); bytes_value) { + value_ref.SetStringValue(bytes_value->NativeString()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "bytes"); +} + +absl::StatusOr> +ProtoStringMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto string_value = value.AsString(); string_value) { + value_ref.SetStringValue(string_value->NativeString()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "string"); +} + +absl::StatusOr> ProtoNullMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (value.IsNull() || value.IsInt()) { + value_ref.SetEnumValue(0); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "google.protobuf.NullValue"); +} + +absl::StatusOr> ProtoEnumMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + google::protobuf::MapValueRef& value_ref) { + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); + } + value_ref.SetEnumValue(static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "enum"); +} + +absl::StatusOr> +ProtoMessageMapValueFromValueConverter( + const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory, + well_known_types::Reflection* absl_nonnull well_known_types, + google::protobuf::MapValueRef& value_ref) { + return ProtoMessageFromValueImpl(value, pool, factory, well_known_types, + value_ref.MutableMessageValue()); +} + +// Gets the converter for converting from values to protocol buffer map value. +absl::StatusOr +GetProtoMapValueFromValueConverter( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + ABSL_DCHECK(field->is_map()); + const auto* value_field = field->message_type()->map_value(); + switch (value_field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64MapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return ProtoFloatMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return ProtoDoubleMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + if (value_field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + return ProtoBytesMapValueFromValueConverter; + } + return ProtoStringMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + if (value_field->enum_type()->full_name() == + "google.protobuf.NullValue") { + return ProtoNullMapValueFromValueConverter; + } + return ProtoEnumMapValueFromValueConverter; + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return ProtoMessageMapValueFromValueConverter; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected protocol buffer map value type: ", + google::protobuf::FieldDescriptor::CppTypeName(value_field->cpp_type()))); + } +} + +using ProtoRepeatedFieldFromValueMutator = + absl::StatusOr> (*)( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull, google::protobuf::Message* absl_nonnull, + const google::protobuf::FieldDescriptor* absl_nonnull, const Value&); + +absl::StatusOr> +ProtoBoolRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto bool_value = value.AsBool(); bool_value) { + reflection->AddBool(message, field, bool_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "bool"); +} + +absl::StatusOr> +ProtoInt32RepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); + } + reflection->AddInt32(message, field, + static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); +} + +absl::StatusOr> +ProtoInt64RepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto int_value = value.AsInt(); int_value) { + reflection->AddInt64(message, field, int_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); +} + +absl::StatusOr> +ProtoUInt32RepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto uint_value = value.AsUint(); uint_value) { + if (uint_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); + } + reflection->AddUInt32(message, field, + static_cast(uint_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); +} + +absl::StatusOr> +ProtoUInt64RepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto uint_value = value.AsUint(); uint_value) { + reflection->AddUInt64(message, field, uint_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); +} + +absl::StatusOr> +ProtoFloatRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto double_value = value.AsDouble(); double_value) { + reflection->AddFloat(message, field, + static_cast(double_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "double"); +} + +absl::StatusOr> +ProtoDoubleRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto double_value = value.AsDouble(); double_value) { + reflection->AddDouble(message, field, double_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "double"); +} + +absl::StatusOr> +ProtoBytesRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto bytes_value = value.AsBytes(); bytes_value) { + reflection->AddString(message, field, bytes_value->NativeString()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "bytes"); +} + +absl::StatusOr> +ProtoStringRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (auto string_value = value.AsString(); string_value) { + reflection->AddString(message, field, string_value->NativeString()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "string"); +} + +absl::StatusOr> +ProtoNullRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + if (value.IsNull() || value.IsInt()) { + reflection->AddEnumValue(message, field, 0); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "null_type"); +} + +absl::StatusOr> +ProtoEnumRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + well_known_types::Reflection* absl_nonnull, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + const auto* enum_descriptor = field->enum_type(); + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return TypeConversionError(value.GetTypeName(), + enum_descriptor->full_name()); + } + reflection->AddEnumValue(message, field, + static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), enum_descriptor->full_name()); +} + +absl::StatusOr> +ProtoMessageRepeatedFieldFromValueMutator( + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory, + well_known_types::Reflection* absl_nonnull well_known_types, + const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + // If the value is null and the target repeated field is anything except + // google.protobuf.{Any,ListValue,Struct,Value}, it should be pruned. + if (value.IsNull()) { + const auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY) { + return absl::nullopt; + } + } + auto* element = reflection->AddMessage(message, field, factory); + auto result = ProtoMessageFromValueImpl(value, pool, factory, + well_known_types, element); + if (!result.ok() || result->has_value()) { + reflection->RemoveLast(message, field); + } + return result; +} + +absl::StatusOr +GetProtoRepeatedFieldFromValueMutator( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + ABSL_DCHECK(!field->is_map()); + ABSL_DCHECK(field->is_repeated()); + switch (field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return ProtoBoolRepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return ProtoInt32RepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return ProtoInt64RepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return ProtoUInt32RepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return ProtoUInt64RepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return ProtoFloatRepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return ProtoDoubleRepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + if (field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + return ProtoBytesRepeatedFieldFromValueMutator; + } + return ProtoStringRepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return ProtoNullRepeatedFieldFromValueMutator; + } + return ProtoEnumRepeatedFieldFromValueMutator; + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return ProtoMessageRepeatedFieldFromValueMutator; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected protocol buffer repeated field type: ", + google::protobuf::FieldDescriptor::CppTypeName(field->cpp_type()))); + } +} + +class MessageValueBuilderImpl { + public: + MessageValueBuilderImpl( + google::protobuf::Arena* absl_nullable arena, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull message) + : arena_(arena), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), + message_(message), + descriptor_(message_->GetDescriptor()), + reflection_(message_->GetReflection()) {} + + ~MessageValueBuilderImpl() { + if (arena_ == nullptr && message_ != nullptr) { + delete message_; + } + } + + absl::StatusOr> SetFieldByName( + absl::string_view name, Value value) { + const auto* field = descriptor_->FindFieldByName(name); + if (field == nullptr) { + field = descriptor_pool_->FindExtensionByPrintableName(descriptor_, name); + if (field == nullptr) { + return NoSuchFieldError(name); + } + } + return SetField(field, std::move(value)); + } + + absl::StatusOr> SetFieldByNumber(int64_t number, + Value value) { + if (number < std::numeric_limits::min() || + number > std::numeric_limits::max()) { + return NoSuchFieldError(absl::StrCat(number)); + } + const auto* field = + descriptor_->FindFieldByNumber(static_cast(number)); + if (field == nullptr) { + return NoSuchFieldError(absl::StrCat(number)); + } + return SetField(field, std::move(value)); + } + + absl::StatusOr Build() && { + return Value::WrapMessage(std::exchange(message_, nullptr), + descriptor_pool_, message_factory_, arena_); + } + + absl::StatusOr BuildStruct() && { + return ParsedMessageValue(std::exchange(message_, nullptr), arena_); + } + + private: + absl::StatusOr> SetMapField( + const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { + auto map_value = value.AsMap(); + if (!map_value) { + return TypeConversionError(value.GetTypeName(), "map"); + } + CEL_ASSIGN_OR_RETURN(auto key_converter, + GetProtoMapKeyFromValueConverter( + field->message_type()->map_key()->cpp_type())); + CEL_ASSIGN_OR_RETURN(auto value_converter, + GetProtoMapValueFromValueConverter(field)); + reflection_->ClearField(message_, field); + const auto* map_value_field = field->message_type()->map_value(); + absl::optional error_value; + // Don't replace this pattern with a status macro; nested macro invocations + // have the same __LINE__ on MSVC, causing CEL_ASSIGN_OR_RETURN invocations + // to conflict with each-other. + auto status = map_value->ForEach( + [this, field, key_converter, map_value_field, value_converter, + &error_value](const Value& entry_key, + const Value& entry_value) -> absl::StatusOr { + std::string proto_key_string; + google::protobuf::MapKey proto_key; + CEL_ASSIGN_OR_RETURN( + error_value, + (*key_converter)(entry_key, proto_key, proto_key_string)); + if (error_value) { + return false; + } + if (map_value_field->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + entry_value.IsNull()) { + auto well_known_type = + map_value_field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } + google::protobuf::MapValueRef proto_value; + extensions::protobuf_internal::InsertOrLookupMapValue( + *reflection_, message_, *field, proto_key, &proto_value); + CEL_ASSIGN_OR_RETURN( + error_value, + (*value_converter)(entry_value, map_value_field, descriptor_pool_, + message_factory_, &well_known_types_, + proto_value)); + if (error_value) { + return false; + } + return true; + }, + descriptor_pool_, message_factory_, arena_); + if (!status.ok()) { + return status; + } + return error_value; + } + + absl::StatusOr> SetRepeatedField( + const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { + auto list_value = value.AsList(); + if (!list_value) { + return TypeConversionError(value.GetTypeName(), "list").NativeValue(); + } + CEL_ASSIGN_OR_RETURN(auto accessor, + GetProtoRepeatedFieldFromValueMutator(field)); + reflection_->ClearField(message_, field); + absl::optional error_value; + CEL_RETURN_IF_ERROR(list_value->ForEach( + [this, field, accessor, + &error_value](const Value& element) -> absl::StatusOr { + if (field->message_type() != nullptr && element.IsNull()) { + auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } + CEL_ASSIGN_OR_RETURN(error_value, + (*accessor)(descriptor_pool_, message_factory_, + &well_known_types_, reflection_, + message_, field, element)); + return !error_value; + }, + descriptor_pool_, message_factory_, arena_)); + return error_value; + } + + absl::StatusOr> SetSingularField( + const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { + switch (field->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + if (auto bool_value = value.AsBool(); bool_value) { + reflection_->SetBool(message_, field, bool_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "bool"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < std::numeric_limits::min() || + int_value->NativeValue() > std::numeric_limits::max()) { + return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); + } + reflection_->SetInt32(message_, field, + static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + if (auto int_value = value.AsInt(); int_value) { + reflection_->SetInt64(message_, field, int_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "int"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + if (auto uint_value = value.AsUint(); uint_value) { + if (uint_value->NativeValue() > + std::numeric_limits::max()) { + return ErrorValue( + absl::OutOfRangeError("uint64 to uint32 overflow")); + } + reflection_->SetUInt32( + message_, field, + static_cast(uint_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + if (auto uint_value = value.AsUint(); uint_value) { + reflection_->SetUInt64(message_, field, uint_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "uint"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + if (auto double_value = value.AsDouble(); double_value) { + reflection_->SetFloat(message_, field, double_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "double"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + if (auto double_value = value.AsDouble(); double_value) { + reflection_->SetDouble(message_, field, double_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "double"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + if (auto bytes_value = value.AsBytes(); bytes_value) { + bytes_value->NativeValue(absl::Overload( + [this, field](absl::string_view string) { + reflection_->SetString(message_, field, std::string(string)); + }, + [this, field](const absl::Cord& cord) { + reflection_->SetString(message_, field, cord); + })); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "bytes"); + } + if (auto string_value = value.AsString(); string_value) { + string_value->NativeValue(absl::Overload( + [this, field](absl::string_view string) { + reflection_->SetString(message_, field, std::string(string)); + }, + [this, field](const absl::Cord& cord) { + reflection_->SetString(message_, field, cord); + })); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "string"); + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + if (value.IsNull() || value.IsInt()) { + reflection_->SetEnumValue(message_, field, 0); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), "null_type"); + } + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() >= std::numeric_limits::min() && + int_value->NativeValue() <= std::numeric_limits::max()) { + reflection_->SetEnumValue( + message_, field, static_cast(int_value->NativeValue())); + return absl::nullopt; + } + } + return TypeConversionError(value.GetTypeName(), + field->enum_type()->full_name()); + } + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + switch (field->message_type()->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto bool_value = value.AsBool(); bool_value) { + CEL_RETURN_IF_ERROR(well_known_types_.BoolValue().Initialize( + field->message_type())); + well_known_types_.BoolValue().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + bool_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto int_value = value.AsInt(); int_value) { + if (int_value->NativeValue() < + std::numeric_limits::min() || + int_value->NativeValue() > + std::numeric_limits::max()) { + return absl::OutOfRangeError("int64 to int32 overflow"); + } + CEL_RETURN_IF_ERROR(well_known_types_.Int32Value().Initialize( + field->message_type())); + well_known_types_.Int32Value().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + static_cast(int_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto int_value = value.AsInt(); int_value) { + CEL_RETURN_IF_ERROR(well_known_types_.Int64Value().Initialize( + field->message_type())); + well_known_types_.Int64Value().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + int_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto uint_value = value.AsUint(); uint_value) { + if (uint_value->NativeValue() > + std::numeric_limits::max()) { + return absl::OutOfRangeError("uint64 to uint32 overflow"); + } + CEL_RETURN_IF_ERROR(well_known_types_.UInt32Value().Initialize( + field->message_type())); + well_known_types_.UInt32Value().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + static_cast(uint_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto uint_value = value.AsUint(); uint_value) { + CEL_RETURN_IF_ERROR(well_known_types_.UInt64Value().Initialize( + field->message_type())); + well_known_types_.UInt64Value().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + uint_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto double_value = value.AsDouble(); double_value) { + CEL_RETURN_IF_ERROR(well_known_types_.FloatValue().Initialize( + field->message_type())); + well_known_types_.FloatValue().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + static_cast(double_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto double_value = value.AsDouble(); double_value) { + CEL_RETURN_IF_ERROR(well_known_types_.DoubleValue().Initialize( + field->message_type())); + well_known_types_.DoubleValue().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + double_value->NativeValue()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto bytes_value = value.AsBytes(); bytes_value) { + CEL_RETURN_IF_ERROR(well_known_types_.BytesValue().Initialize( + field->message_type())); + well_known_types_.BytesValue().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + bytes_value->NativeCord()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto string_value = value.AsString(); string_value) { + CEL_RETURN_IF_ERROR(well_known_types_.StringValue().Initialize( + field->message_type())); + well_known_types_.StringValue().SetValue( + reflection_->MutableMessage(message_, field, + message_factory_), + string_value->NativeCord()); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto duration_value = value.AsDuration(); duration_value) { + CEL_RETURN_IF_ERROR(well_known_types_.Duration().Initialize( + field->message_type())); + CEL_RETURN_IF_ERROR( + well_known_types_.Duration().SetFromAbslDuration( + reflection_->MutableMessage(message_, field, + message_factory_), + duration_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + if (auto timestamp_value = value.AsTimestamp(); timestamp_value) { + CEL_RETURN_IF_ERROR(well_known_types_.Timestamp().Initialize( + field->message_type())); + CEL_RETURN_IF_ERROR(well_known_types_.Timestamp().SetFromAbslTime( + reflection_->MutableMessage(message_, field, + message_factory_), + timestamp_value->NativeValue())); + return absl::nullopt; + } + return TypeConversionError(value.GetTypeName(), + field->message_type()->full_name()); + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: { + CEL_RETURN_IF_ERROR( + value.ConvertToJson(descriptor_pool_, message_factory_, + reflection_->MutableMessage( + message_, field, message_factory_))); + return absl::nullopt; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: { + CEL_RETURN_IF_ERROR(value.ConvertToJsonArray( + descriptor_pool_, message_factory_, + reflection_->MutableMessage(message_, field, + message_factory_))); + return absl::nullopt; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: { + CEL_RETURN_IF_ERROR(value.ConvertToJsonObject( + descriptor_pool_, message_factory_, + reflection_->MutableMessage(message_, field, + message_factory_))); + return absl::nullopt; + } + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: { + // Probably not correct, need to use the parent/common one. + google::protobuf::io::CordOutputStream serialized; + CEL_RETURN_IF_ERROR(value.SerializeTo( + descriptor_pool_, message_factory_, &serialized)); + std::string type_url; + switch (value.kind()) { + case ValueKind::kNull: + type_url = MakeTypeUrl("google.protobuf.Value"); + break; + case ValueKind::kBool: + type_url = MakeTypeUrl("google.protobuf.BoolValue"); + break; + case ValueKind::kInt: + type_url = MakeTypeUrl("google.protobuf.Int64Value"); + break; + case ValueKind::kUint: + type_url = MakeTypeUrl("google.protobuf.UInt64Value"); + break; + case ValueKind::kDouble: + type_url = MakeTypeUrl("google.protobuf.DoubleValue"); + break; + case ValueKind::kBytes: + type_url = MakeTypeUrl("google.protobuf.BytesValue"); + break; + case ValueKind::kString: + type_url = MakeTypeUrl("google.protobuf.StringValue"); + break; + case ValueKind::kList: + type_url = MakeTypeUrl("google.protobuf.ListValue"); + break; + case ValueKind::kMap: + type_url = MakeTypeUrl("google.protobuf.Struct"); + break; + case ValueKind::kDuration: + type_url = MakeTypeUrl("google.protobuf.Duration"); + break; + case ValueKind::kTimestamp: + type_url = MakeTypeUrl("google.protobuf.Timestamp"); + break; + default: + type_url = MakeTypeUrl(value.GetTypeName()); + break; + } + CEL_RETURN_IF_ERROR( + well_known_types_.Any().Initialize(field->message_type())); + well_known_types_.Any().SetTypeUrl( + reflection_->MutableMessage(message_, field, message_factory_), + type_url); + well_known_types_.Any().SetValue( + reflection_->MutableMessage(message_, field, message_factory_), + std::move(serialized).Consume()); + return absl::nullopt; + } + default: + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::nullopt; + } + break; + } + return ProtoMessageFromValueImpl( + value, descriptor_pool_, message_factory_, &well_known_types_, + reflection_->MutableMessage(message_, field, message_factory_)); + } + default: + return absl::InternalError( + absl::StrCat("unexpected protocol buffer message field type: ", + field->cpp_type_name())); + } + } + + absl::StatusOr> SetField( + const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { + if (field->is_map()) { + return SetMapField(field, std::move(value)); + } + if (field->is_repeated()) { + return SetRepeatedField(field, std::move(value)); + } + return SetSingularField(field, std::move(value)); + } + + google::protobuf::Arena* absl_nullable const arena_; + const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull const message_factory_; + google::protobuf::Message* absl_nullable message_; + const google::protobuf::Descriptor* absl_nonnull const descriptor_; + const google::protobuf::Reflection* absl_nonnull const reflection_; + well_known_types::Reflection well_known_types_; +}; + +class ValueBuilderImpl final : public ValueBuilder { + public: + ValueBuilderImpl(google::protobuf::Arena* absl_nullable arena, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull message) + : builder_(arena, descriptor_pool, message_factory, message) {} + + absl::StatusOr> SetFieldByName( + absl::string_view name, Value value) override { + return builder_.SetFieldByName(name, std::move(value)); + } + + absl::StatusOr> SetFieldByNumber( + int64_t number, Value value) override { + return builder_.SetFieldByNumber(number, std::move(value)); + } + + absl::StatusOr Build() && override { + return std::move(builder_).Build(); + } + + private: + MessageValueBuilderImpl builder_; +}; + +class StructValueBuilderImpl final : public StructValueBuilder { + public: + StructValueBuilderImpl( + google::protobuf::Arena* absl_nullable arena, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull message) + : builder_(arena, descriptor_pool, message_factory, message) {} + + absl::StatusOr> SetFieldByName( + absl::string_view name, Value value) override { + return builder_.SetFieldByName(name, std::move(value)); + } + + absl::StatusOr> SetFieldByNumber( + int64_t number, Value value) override { + return builder_.SetFieldByNumber(number, std::move(value)); + } + + absl::StatusOr Build() && override { + return std::move(builder_).BuildStruct(); + } + + private: + MessageValueBuilderImpl builder_; +}; + +} // namespace + +absl_nullable cel::ValueBuilderPtr NewValueBuilder( + Allocator<> allocator, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + absl::string_view name) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(name); + if (descriptor == nullptr) { + return nullptr; + } + const google::protobuf::Message* absl_nullable prototype = + message_factory->GetPrototype(descriptor); + ABSL_DCHECK(prototype != nullptr) + << "failed to get message prototype from factory, did you pass a dynamic " + "descriptor to the generated message factory? we consider this to be " + "a logic error and not a runtime error: " + << descriptor->full_name(); + if (ABSL_PREDICT_FALSE(prototype == nullptr)) { + return nullptr; + } + return std::make_unique(allocator.arena(), descriptor_pool, + message_factory, + prototype->New(allocator.arena())); +} + +absl_nullable cel::StructValueBuilderPtr NewStructValueBuilder( + Allocator<> allocator, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + absl::string_view name) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(name); + if (descriptor == nullptr) { + return nullptr; + } + const google::protobuf::Message* absl_nullable prototype = + message_factory->GetPrototype(descriptor); + ABSL_DCHECK(prototype != nullptr) + << "failed to get message prototype from factory, did you pass a dynamic " + "descriptor to the generated message factory? we consider this to be " + "a logic error and not a runtime error: " + << descriptor->full_name(); + if (ABSL_PREDICT_FALSE(prototype == nullptr)) { + return nullptr; + } + return std::make_unique( + allocator.arena(), descriptor_pool, message_factory, + prototype->New(allocator.arena())); +} + +} // namespace cel::common_internal diff --git a/common/values/struct_value_builder.h b/common/values/struct_value_builder.h new file mode 100644 index 000000000..ab4fdcd87 --- /dev/null +++ b/common/values/struct_value_builder.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ + +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "common/allocator.h" +#include "common/value.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::common_internal { + +absl_nullable cel::StructValueBuilderPtr NewStructValueBuilder( + Allocator<> allocator, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + absl::string_view name); + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ diff --git a/common/values/struct_value_test.cc b/common/values/struct_value_test.cc new file mode 100644 index 000000000..275acf70a --- /dev/null +++ b/common/values/struct_value_test.cc @@ -0,0 +1,144 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/base/attributes.h" +#include "common/value.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::cel::internal::DynamicParseTextProto; +using ::cel::internal::GetTestingDescriptorPool; +using ::cel::internal::GetTestingMessageFactory; +using ::testing::An; +using ::testing::Optional; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +TEST(StructValue, Is) { + EXPECT_TRUE(StructValue(ParsedMessageValue()).Is()); + EXPECT_TRUE(StructValue(ParsedMessageValue()).Is()); +} + +template +constexpr T& AsLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return t; +} + +template +constexpr const T& AsConstLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return t; +} + +template +constexpr T&& AsRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return static_cast(t); +} + +template +constexpr const T&& AsConstRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return static_cast(t); +} + +TEST(StructValue, As) { + google::protobuf::Arena arena; + + { + StructValue value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + StructValue other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstRValueRef(other_value).As(), + Optional(An())); + } + + { + StructValue value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + StructValue other_value = value; + EXPECT_THAT(AsLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsConstLValueRef(value).As(), + Optional(An())); + EXPECT_THAT(AsRValueRef(value).As(), + Optional(An())); + EXPECT_THAT( + AsConstRValueRef(other_value).As(), + Optional(An())); + } +} + +template +decltype(auto) DoGet(From&& from) { + return std::forward(from).template Get(); +} + +TEST(StructValue, Get) { + google::protobuf::Arena arena; + + { + StructValue value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + StructValue other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), + An()); + } + + { + StructValue value(ParsedMessageValue{ + DynamicParseTextProto(&arena, R"pb()pb", + GetTestingDescriptorPool(), + GetTestingMessageFactory()), + &arena}); + StructValue other_value = value; + EXPECT_THAT(DoGet(AsLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsConstLValueRef(value)), + An()); + EXPECT_THAT(DoGet(AsRValueRef(value)), + An()); + EXPECT_THAT( + DoGet(AsConstRValueRef(other_value)), + An()); + } +} + +} // namespace +} // namespace cel diff --git a/common/values/struct_value_variant.h b/common/values/struct_value_variant.h new file mode 100644 index 000000000..45a809b84 --- /dev/null +++ b/common/values/struct_value_variant.h @@ -0,0 +1,205 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_VARIANT_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_VARIANT_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/meta/type_traits.h" +#include "absl/utility/utility.h" +#include "common/values/custom_struct_value.h" +#include "common/values/legacy_struct_value.h" +#include "common/values/parsed_message_value.h" + +namespace cel::common_internal { + +enum class StructValueIndex : uint16_t { + kParsedMessage = 0, + kCustom, + kLegacy, +}; + +template +struct StructValueAlternative; + +template <> +struct StructValueAlternative { + static constexpr StructValueIndex kIndex = StructValueIndex::kCustom; +}; + +template <> +struct StructValueAlternative { + static constexpr StructValueIndex kIndex = StructValueIndex::kParsedMessage; +}; + +template <> +struct StructValueAlternative { + static constexpr StructValueIndex kIndex = StructValueIndex::kLegacy; +}; + +template +struct IsStructValueAlternative : std::false_type {}; + +template +struct IsStructValueAlternative< + T, std::void_t{})>> : std::true_type {}; + +template +inline constexpr bool IsStructValueAlternativeV = + IsStructValueAlternative::value; + +inline constexpr size_t kStructValueVariantAlign = 8; +inline constexpr size_t kStructValueVariantSize = 24; + +// StructValueVariant is a subset of alternatives from the main ValueVariant +// that is only structs. It is not stored directly in ValueVariant. +class alignas(kStructValueVariantAlign) StructValueVariant final { + public: + StructValueVariant() + : StructValueVariant(absl::in_place_type) {} + + StructValueVariant(const StructValueVariant&) = default; + StructValueVariant(StructValueVariant&&) = default; + StructValueVariant& operator=(const StructValueVariant&) = default; + StructValueVariant& operator=(StructValueVariant&&) = default; + + template + explicit StructValueVariant(absl::in_place_type_t, Args&&... args) + : index_(StructValueAlternative::kIndex) { + static_assert(alignof(T) <= kStructValueVariantAlign); + static_assert(sizeof(T) <= kStructValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + ::new (static_cast(&raw_[0])) T(std::forward(args)...); + } + + template >>> + explicit StructValueVariant(T&& value) + : StructValueVariant(absl::in_place_type>, + std::forward(value)) {} + + template + void Assign(T&& value) { + using U = absl::remove_cvref_t; + + static_assert(alignof(U) <= kStructValueVariantAlign); + static_assert(sizeof(U) <= kStructValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + index_ = StructValueAlternative::kIndex; + ::new (static_cast(&raw_[0])) U(std::forward(value)); + } + + template + bool Is() const { + return index_ == StructValueAlternative::kIndex; + } + + template + T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + decltype(auto) Visit(Visitor&& visitor) const { + switch (index_) { + case StructValueIndex::kCustom: + return std::forward(visitor)(Get()); + case StructValueIndex::kParsedMessage: + return std::forward(visitor)(Get()); + case StructValueIndex::kLegacy: + return std::forward(visitor)(Get()); + } + } + + friend void swap(StructValueVariant& lhs, StructValueVariant& rhs) noexcept { + using std::swap; + swap(lhs.index_, rhs.index_); + swap(lhs.raw_, rhs.raw_); + } + + private: + template + ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kStructValueVariantAlign); + static_assert(sizeof(T) <= kStructValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kStructValueVariantAlign); + static_assert(sizeof(T) <= kStructValueVariantSize); + static_assert(std::is_trivially_copyable_v); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + StructValueIndex index_ = StructValueIndex::kCustom; + alignas(8) std::byte raw_[kStructValueVariantSize]; +}; + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_VARIANT_H_ diff --git a/common/values/timestamp_value.cc b/common/values/timestamp_value.cc new file mode 100644 index 000000000..7d3a347e8 --- /dev/null +++ b/common/values/timestamp_value.cc @@ -0,0 +1,103 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "google/protobuf/timestamp.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "internal/time.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::TimestampReflection; +using ::cel::well_known_types::ValueReflection; + +std::string TimestampDebugString(absl::Time value) { + return internal::DebugStringTimestamp(value); +} + +} // namespace + +std::string TimestampValue::DebugString() const { + return TimestampDebugString(NativeValue()); +} + +absl::Status TimestampValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::Timestamp message; + CEL_RETURN_IF_ERROR( + TimestampReflection::SetFromAbslTime(&message, NativeValue())); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status TimestampValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.SetStringValueFromTimestamp(json, NativeValue()); + + return absl::OkStatus(); +} + +absl::Status TimestampValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsTimestamp(); other_value.has_value()) { + *result = BoolValue{NativeValue() == other_value->NativeValue()}; + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/timestamp_value.h b/common/values/timestamp_value.h new file mode 100644 index 000000000..acc202300 --- /dev/null +++ b/common/values/timestamp_value.h @@ -0,0 +1,146 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_TIMESTAMP_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_TIMESTAMP_VALUE_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/utility/utility.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "internal/time.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class TimestampValue; + +TimestampValue UnsafeTimestampValue(absl::Time value); +absl::StatusOr SafeTimestampValue(absl::Time value); + +// `TimestampValue` represents values of the primitive `timestamp` type. +class TimestampValue final + : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kTimestamp; + + explicit TimestampValue(absl::Time value) noexcept + : TimestampValue(absl::in_place, value) { + ABSL_DCHECK_OK(internal::ValidateTimestamp(value)); + } + + TimestampValue() = default; + TimestampValue(const TimestampValue&) = default; + TimestampValue(TimestampValue&&) = default; + TimestampValue& operator=(const TimestampValue&) = default; + TimestampValue& operator=(TimestampValue&&) = default; + + ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return TimestampType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return ToTime() == absl::UnixEpoch(); } + + ABSL_DEPRECATED("Use ToTime()") + absl::Time NativeValue() const { return static_cast(*this); } + + ABSL_DEPRECATED("Use ToTime()") + // NOLINTNEXTLINE(google-explicit-constructor) + operator absl::Time() const noexcept { return value_; } + + absl::Time ToTime() const { return value_; } + + friend void swap(TimestampValue& lhs, TimestampValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + friend bool operator==(TimestampValue lhs, TimestampValue rhs) { + return lhs.value_ == rhs.value_; + } + + friend bool operator<(const TimestampValue& lhs, const TimestampValue& rhs) { + return lhs.value_ < rhs.value_; + } + + private: + friend class common_internal::ValueMixin; + friend TimestampValue UnsafeTimestampValue(absl::Time value); + + TimestampValue(absl::in_place_t, absl::Time value) : value_(value) {} + + absl::Time value_ = absl::UnixEpoch(); +}; + +inline TimestampValue UnsafeTimestampValue(absl::Time value) { + return TimestampValue(absl::in_place, value); +} + +inline absl::StatusOr SafeTimestampValue(absl::Time value) { + absl::Status status = internal::ValidateTimestamp(value); + if (!status.ok()) { + return status; + } + return UnsafeTimestampValue(value); +} + +inline bool operator!=(TimestampValue lhs, TimestampValue rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, TimestampValue value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_TIMESTAMP_VALUE_H_ diff --git a/common/values/timestamp_value_test.cc b/common/values/timestamp_value_test.cc new file mode 100644 index 000000000..142e6511d --- /dev/null +++ b/common/values/timestamp_value_test.cc @@ -0,0 +1,87 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/status/status_matchers.h" +#include "absl/time/time.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; + +using TimestampValueTest = common_internal::ValueTest<>; + +TEST_F(TimestampValueTest, Kind) { + EXPECT_EQ(TimestampValue().kind(), TimestampValue::kKind); + EXPECT_EQ(Value(TimestampValue(absl::UnixEpoch() + absl::Seconds(1))).kind(), + TimestampValue::kKind); +} + +TEST_F(TimestampValueTest, DebugString) { + { + std::ostringstream out; + out << TimestampValue(absl::UnixEpoch() + absl::Seconds(1)); + EXPECT_EQ(out.str(), "1970-01-01T00:00:01Z"); + } + { + std::ostringstream out; + out << Value(TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); + EXPECT_EQ(out.str(), "1970-01-01T00:00:01Z"); + } +} + +TEST_F(TimestampValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(TimestampValue().ConvertToJson(descriptor_pool(), + message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto( + R"pb(string_value: "1970-01-01T00:00:00Z")pb")); +} + +TEST_F(TimestampValueTest, NativeTypeId) { + EXPECT_EQ( + NativeTypeId::Of(TimestampValue(absl::UnixEpoch() + absl::Seconds(1))), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of( + Value(TimestampValue(absl::UnixEpoch() + absl::Seconds(1)))), + NativeTypeId::For()); +} + +TEST_F(TimestampValueTest, Equality) { + EXPECT_NE(TimestampValue(absl::UnixEpoch()), + absl::UnixEpoch() + absl::Seconds(1)); + EXPECT_NE(absl::UnixEpoch() + absl::Seconds(1), + TimestampValue(absl::UnixEpoch())); + EXPECT_NE(TimestampValue(absl::UnixEpoch()), + TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); +} + +TEST_F(TimestampValueTest, Comparison) { + EXPECT_LT(TimestampValue(absl::UnixEpoch()), + TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); + EXPECT_FALSE(TimestampValue(absl::UnixEpoch() + absl::Seconds(1)) < + TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); + EXPECT_FALSE(TimestampValue(absl::UnixEpoch() + absl::Seconds(2)) < + TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); +} + +} // namespace +} // namespace cel diff --git a/common/values/type_value.cc b/common/values/type_value.cc new file mode 100644 index 000000000..add099d0a --- /dev/null +++ b/common/values/type_value.cc @@ -0,0 +1,72 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +absl::Status TypeValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), " is unserializable")); +} + +absl::Status TypeValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), " is not convertable to JSON")); +} + +absl::Status TypeValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsType(); other_value.has_value()) { + *result = BoolValue{NativeValue() == other_value->NativeValue()}; + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/type_value.h b/common/values/type_value.h new file mode 100644 index 000000000..cfc2056dd --- /dev/null +++ b/common/values/type_value.h @@ -0,0 +1,108 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_TYPE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_TYPE_VALUE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class TypeValue; + +// `TypeValue` represents values of the primitive `type` type. +class TypeValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kType; + + explicit TypeValue(Type value) : value_(value) {} + + TypeValue() = default; + TypeValue(const TypeValue&) = default; + TypeValue(TypeValue&&) = default; + TypeValue& operator=(const TypeValue&) = default; + TypeValue& operator=(TypeValue&&) = default; + + static constexpr ValueKind kind() { return kKind; } + + static absl::string_view GetTypeName() { return TypeType::kName; } + + std::string DebugString() const { return type().DebugString(); } + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return false; } + + ABSL_DEPRECATED(("Use type()")) + const Type& NativeValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return type(); + } + + const Type& type() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } + + absl::string_view name() const { return type().name(); } + + friend void swap(TypeValue& lhs, TypeValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + private: + friend class common_internal::ValueMixin; + + Type value_; +}; + +inline std::ostream& operator<<(std::ostream& out, const TypeValue& value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_TYPE_VALUE_H_ diff --git a/common/values/type_value_test.cc b/common/values/type_value_test.cc new file mode 100644 index 000000000..ef9ec1ad9 --- /dev/null +++ b/common/values/type_value_test.cc @@ -0,0 +1,72 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/status/status.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; + +using TypeValueTest = common_internal::ValueTest<>; + +TEST_F(TypeValueTest, Kind) { + EXPECT_EQ(TypeValue(AnyType()).kind(), TypeValue::kKind); + EXPECT_EQ(Value(TypeValue(AnyType())).kind(), TypeValue::kKind); +} + +TEST_F(TypeValueTest, DebugString) { + { + std::ostringstream out; + out << TypeValue(AnyType()); + EXPECT_EQ(out.str(), "google.protobuf.Any"); + } + { + std::ostringstream out; + out << Value(TypeValue(AnyType())); + EXPECT_EQ(out.str(), "google.protobuf.Any"); + } +} + +TEST_F(TypeValueTest, SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT(TypeValue(AnyType()).SerializeTo(descriptor_pool(), + message_factory(), &output), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(TypeValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(TypeValue(AnyType()).ConvertToJson(descriptor_pool(), + message_factory(), message), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(TypeValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(TypeValue(AnyType())), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(TypeValue(AnyType()))), + NativeTypeId::For()); +} + +} // namespace +} // namespace cel diff --git a/common/values/uint_value.cc b/common/values/uint_value.cc new file mode 100644 index 000000000..1c296fb39 --- /dev/null +++ b/common/values/uint_value.cc @@ -0,0 +1,110 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "common/value.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace { + +using ::cel::well_known_types::ValueReflection; + +std::string UintDebugString(int64_t value) { return absl::StrCat(value, "u"); } + +} // namespace + +std::string UintValue::DebugString() const { + return UintDebugString(NativeValue()); +} + +absl::Status UintValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + google::protobuf::UInt64Value message; + message.set_value(NativeValue()); + if (!message.SerializePartialToZeroCopyStream(output)) { + return absl::UnknownError( + absl::StrCat("failed to serialize message: ", message.GetTypeName())); + } + + return absl::OkStatus(); +} + +absl::Status UintValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection value_reflection; + CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); + value_reflection.SetNumberValue(json, NativeValue()); + + return absl::OkStatus(); +} + +absl::Status UintValue::Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (auto other_value = other.AsUint(); other_value.has_value()) { + *result = BoolValue{NativeValue() == other_value->NativeValue()}; + return absl::OkStatus(); + } + if (auto other_value = other.AsDouble(); other_value.has_value()) { + *result = + BoolValue{internal::Number::FromUint64(NativeValue()) == + internal::Number::FromDouble(other_value->NativeValue())}; + return absl::OkStatus(); + } + if (auto other_value = other.AsInt(); other_value.has_value()) { + *result = + BoolValue{internal::Number::FromUint64(NativeValue()) == + internal::Number::FromInt64(other_value->NativeValue())}; + return absl::OkStatus(); + } + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/uint_value.h b/common/values/uint_value.h new file mode 100644 index 000000000..f263bb7c9 --- /dev/null +++ b/common/values/uint_value.h @@ -0,0 +1,119 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_UINT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_UINT_VALUE_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class UintValue; + +// `UintValue` represents values of the primitive `uint` type. +class UintValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kUint; + + explicit UintValue(uint64_t value) noexcept : value_(value) {} + + UintValue() = default; + UintValue(const UintValue&) = default; + UintValue(UintValue&&) = default; + UintValue& operator=(const UintValue&) = default; + UintValue& operator=(UintValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return UintType::kName; } + + std::string DebugString() const; + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return NativeValue() == 0; } + + constexpr uint64_t NativeValue() const { + return static_cast(*this); + } + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr operator uint64_t() const noexcept { return value_; } + + friend void swap(UintValue& lhs, UintValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + } + + private: + friend class common_internal::ValueMixin; + + uint64_t value_ = 0; +}; + +template +H AbslHashValue(H state, UintValue value) { + return H::combine(std::move(state), value.NativeValue()); +} + +constexpr bool operator==(UintValue lhs, UintValue rhs) { + return lhs.NativeValue() == rhs.NativeValue(); +} + +constexpr bool operator!=(UintValue lhs, UintValue rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, UintValue value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_UINT_VALUE_H_ diff --git a/common/values/uint_value_test.cc b/common/values/uint_value_test.cc new file mode 100644 index 000000000..75552184d --- /dev/null +++ b/common/values/uint_value_test.cc @@ -0,0 +1,81 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/status_matchers.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; + +using UintValueTest = common_internal::ValueTest<>; + +TEST_F(UintValueTest, Kind) { + EXPECT_EQ(UintValue(1).kind(), UintValue::kKind); + EXPECT_EQ(Value(UintValue(1)).kind(), UintValue::kKind); +} + +TEST_F(UintValueTest, DebugString) { + { + std::ostringstream out; + out << UintValue(1); + EXPECT_EQ(out.str(), "1u"); + } + { + std::ostringstream out; + out << Value(UintValue(1)); + EXPECT_EQ(out.str(), "1u"); + } +} + +TEST_F(UintValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT( + UintValue(1).ConvertToJson(descriptor_pool(), message_factory(), message), + IsOk()); + EXPECT_THAT(*message, EqualsValueTextProto(R"pb(number_value: 1)pb")); +} + +TEST_F(UintValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(UintValue(1)), NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(UintValue(1))), + NativeTypeId::For()); +} + +TEST_F(UintValueTest, HashValue) { + EXPECT_EQ(absl::HashOf(UintValue(1)), absl::HashOf(uint64_t{1})); +} + +TEST_F(UintValueTest, Equality) { + EXPECT_NE(UintValue(0u), 1u); + EXPECT_NE(1u, UintValue(0u)); + EXPECT_NE(UintValue(0u), UintValue(1u)); +} + +TEST_F(UintValueTest, LessThan) { + EXPECT_LT(UintValue(0), 1); + EXPECT_LT(0, UintValue(1)); + EXPECT_LT(UintValue(0), UintValue(1)); +} + +} // namespace +} // namespace cel diff --git a/common/values/unknown_value.cc b/common/values/unknown_value.cc new file mode 100644 index 000000000..1cb8a7674 --- /dev/null +++ b/common/values/unknown_value.cc @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +absl::Status UnknownValue::SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(output != nullptr); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), " is unserializable")); +} + +absl::Status UnknownValue::ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + return absl::FailedPreconditionError( + absl::StrCat(GetTypeName(), " is not convertable to JSON")); +} + +absl::Status UnknownValue::Equal( + const Value&, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + *result = FalseValue(); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/common/values/unknown_value.h b/common/values/unknown_value.h new file mode 100644 index 000000000..9e8ddaae0 --- /dev/null +++ b/common/values/unknown_value.h @@ -0,0 +1,121 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_UNKNOWN_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_UNKNOWN_VALUE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/unknown.h" +#include "common/value_kind.h" +#include "common/values/values.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/message.h" + +namespace cel { + +class Value; +class UnknownValue; + +// `UnknownValue` represents values of the primitive `duration` type. +class UnknownValue final : private common_internal::ValueMixin { + public: + static constexpr ValueKind kKind = ValueKind::kUnknown; + + explicit UnknownValue(Unknown unknown) : unknown_(std::move(unknown)) {} + + UnknownValue() = default; + UnknownValue(const UnknownValue&) = default; + UnknownValue(UnknownValue&&) = default; + UnknownValue& operator=(const UnknownValue&) = default; + UnknownValue& operator=(UnknownValue&&) = default; + + constexpr ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return UnknownType::kName; } + + std::string DebugString() const { return ""; } + + // See Value::SerializeTo(). + absl::Status SerializeTo( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const; + + absl::Status Equal(const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + using ValueMixin::Equal; + + bool IsZeroValue() const { return false; } + + void swap(UnknownValue& other) noexcept { + using std::swap; + swap(unknown_, other.unknown_); + } + + const Unknown& NativeValue() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + return unknown_; + } + + Unknown NativeValue() && { + Unknown unknown = std::move(unknown_); + return unknown; + } + + const AttributeSet& attribute_set() const { + return unknown_.unknown_attributes(); + } + + const FunctionResultSet& function_result_set() const { + return unknown_.unknown_function_results(); + } + + private: + friend class common_internal::ValueMixin; + + Unknown unknown_; +}; + +inline void swap(UnknownValue& lhs, UnknownValue& rhs) noexcept { + lhs.swap(rhs); +} + +inline std::ostream& operator<<(std::ostream& out, const UnknownValue& value) { + return out << value.DebugString(); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_UNKNOWN_VALUE_H_ diff --git a/common/values/unknown_value_test.cc b/common/values/unknown_value_test.cc new file mode 100644 index 000000000..4618574b7 --- /dev/null +++ b/common/values/unknown_value_test.cc @@ -0,0 +1,71 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/status/status.h" +#include "common/native_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; + +using UnknownValueTest = common_internal::ValueTest<>; + +TEST_F(UnknownValueTest, Kind) { + EXPECT_EQ(UnknownValue().kind(), UnknownValue::kKind); + EXPECT_EQ(Value(UnknownValue()).kind(), UnknownValue::kKind); +} + +TEST_F(UnknownValueTest, DebugString) { + { + std::ostringstream out; + out << UnknownValue(); + EXPECT_EQ(out.str(), ""); + } + { + std::ostringstream out; + out << Value(UnknownValue()); + EXPECT_EQ(out.str(), ""); + } +} + +TEST_F(UnknownValueTest, SerializeTo) { + google::protobuf::io::CordOutputStream output; + EXPECT_THAT( + UnknownValue().SerializeTo(descriptor_pool(), message_factory(), &output), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(UnknownValueTest, ConvertToJson) { + auto* message = NewArenaValueMessage(); + EXPECT_THAT(UnknownValue().ConvertToJson(descriptor_pool(), message_factory(), + message), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(UnknownValueTest, NativeTypeId) { + EXPECT_EQ(NativeTypeId::Of(UnknownValue()), + NativeTypeId::For()); + EXPECT_EQ(NativeTypeId::Of(Value(UnknownValue())), + NativeTypeId::For()); +} + +} // namespace +} // namespace cel diff --git a/common/values/value_builder.cc b/common/values/value_builder.cc new file mode 100644 index 000000000..979837411 --- /dev/null +++ b/common/values/value_builder.cc @@ -0,0 +1,1432 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/allocator.h" +#include "common/arena.h" +#include "common/legacy_value.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/values/list_value_builder.h" +#include "common/values/map_value_builder.h" +#include "eval/public/cel_value.h" +#include "internal/casts.h" +#include "internal/manual.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace common_internal { + +namespace { + +using ::cel::well_known_types::ListValueReflection; +using ::cel::well_known_types::StructReflection; +using ::cel::well_known_types::ValueReflection; +using ::google::api::expr::runtime::CelValue; + +using ValueVector = std::vector>; + +absl::Status CheckListElement(const Value& value) { + if (auto error_value = value.AsError(); ABSL_PREDICT_FALSE(error_value)) { + return error_value->ToStatus(); + } + if (auto unknown_value = value.AsUnknown(); + ABSL_PREDICT_FALSE(unknown_value)) { + return absl::InvalidArgumentError("cannot add unknown value to list"); + } + return absl::OkStatus(); +} + +template +absl::Status ListValueToJsonArray( + const Vector& vector, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); + + ListValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); + + json->Clear(); + + if (vector.empty()) { + return absl::OkStatus(); + } + + for (const auto& element : vector) { + CEL_RETURN_IF_ERROR(element->ConvertToJson(descriptor_pool, message_factory, + reflection.AddValues(json))); + } + return absl::OkStatus(); +} + +template +absl::Status ListValueToJson( + const Vector& vector, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); + return ListValueToJsonArray(vector, descriptor_pool, message_factory, + reflection.MutableListValue(json)); +} + +class CompatListValueImplIterator final : public ValueIterator { + public: + explicit CompatListValueImplIterator(absl::Span elements) + : elements_(elements) {} + + bool HasNext() override { return index_ < elements_.size(); } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (ABSL_PREDICT_FALSE(index_ >= elements_.size())) { + return absl::FailedPreconditionError( + "ValueManager::Next called after ValueManager::HasNext returned " + "false"); + } + *result = elements_[index_++]; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (index_ >= elements_.size()) { + return false; + } + *key_or_value = elements_[index_]; + ++index_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (index_ >= elements_.size()) { + return false; + } + if (value != nullptr) { + *value = elements_[index_]; + } + *key = IntValue(index_++); + return true; + } + + private: + const absl::Span elements_; + size_t index_ = 0; +}; + +struct ValueFormatter { + void operator()(std::string* out, + const std::pair& value) const { + (*this)(out, value.first); + out->append(": "); + (*this)(out, value.second); + } + + void operator()(std::string* out, const Value& value) const { + out->append(value.DebugString()); + } +}; + +class ListValueBuilderImpl final : public ListValueBuilder { + public: + explicit ListValueBuilderImpl(google::protobuf::Arena* absl_nonnull arena) + : arena_(arena) { + elements_.Construct(arena); + } + + ~ListValueBuilderImpl() override { + if (!elements_trivially_destructible_) { + elements_.Destruct(); + } + } + + absl::Status Add(Value value) override { + CEL_RETURN_IF_ERROR(CheckListElement(value)); + UnsafeAdd(std::move(value)); + return absl::OkStatus(); + } + + void UnsafeAdd(Value value) override { + ABSL_DCHECK_OK(CheckListElement(value)); + elements_->emplace_back(std::move(value)); + if (elements_trivially_destructible_) { + elements_trivially_destructible_ = + ArenaTraits<>::trivially_destructible(elements_->back()); + } + } + + size_t Size() const override { return elements_->size(); } + + void Reserve(size_t capacity) override { elements_->reserve(capacity); } + + ListValue Build() && override; + + CustomListValue BuildCustom() &&; + + const CompatListValue* absl_nonnull BuildCompat() &&; + + const CompatListValue* absl_nonnull BuildCompatAt( + void* absl_nonnull address) &&; + + private: + google::protobuf::Arena* absl_nonnull const arena_; + internal::Manual elements_; + bool elements_trivially_destructible_ = true; +}; + +class CompatListValueImpl final : public CompatListValue { + public: + explicit CompatListValueImpl(ValueVector&& elements) + : elements_(std::move(elements)) {} + + std::string DebugString() const override { + return absl::StrCat("[", absl::StrJoin(elements_, ", ", ValueFormatter{}), + "]"); + } + + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + return ListValueToJsonArray(elements_, descriptor_pool, message_factory, + json); + } + + CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + ABSL_DCHECK(arena != nullptr); + + ListValueBuilderImpl builder(arena); + builder.Reserve(elements_.size()); + for (const auto& element : elements_) { + builder.UnsafeAdd(element.Clone(arena)); + } + return std::move(builder).BuildCustom(); + } + + size_t Size() const override { return elements_.size(); } + + absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + const size_t size = elements_.size(); + for (size_t i = 0; i < size; ++i) { + CEL_ASSIGN_OR_RETURN(auto ok, callback(i, elements_[i])); + if (!ok) { + break; + } + } + return absl::OkStatus(); + } + + absl::StatusOr NewIterator() const override { + return std::make_unique( + absl::MakeConstSpan(elements_)); + } + + CelValue operator[](int index) const override { + return Get(elements_.get_allocator().arena(), index); + } + + // Like `operator[](int)` above, but also accepts an arena. Prefer calling + // this variant if the arena is known. + CelValue Get(google::protobuf::Arena* arena, int index) const override { + if (arena == nullptr) { + arena = elements_.get_allocator().arena(); + } + if (ABSL_PREDICT_FALSE(index < 0 || index >= size())) { + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, IndexOutOfBoundsError(index).ToStatus())); + } + return common_internal::UnsafeLegacyValue( + elements_[index], + /*stable=*/true, + arena != nullptr ? arena : elements_.get_allocator().arena()); + } + + int size() const override { return static_cast(Size()); } + + protected: + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + if (index >= elements_.size()) { + *result = IndexOutOfBoundsError(index); + } else { + *result = elements_[index]; + } + return absl::OkStatus(); + } + + private: + const ValueVector elements_; +}; + +} // namespace + +} // namespace common_internal + +template <> +struct ArenaTraits { + using always_trivially_destructible = std::true_type; +}; + +namespace common_internal { + +namespace { + +ListValue ListValueBuilderImpl::Build() && { + if (elements_->empty()) { + return ListValue(); + } + return std::move(*this).BuildCustom(); +} + +CustomListValue ListValueBuilderImpl::BuildCustom() && { + if (elements_->empty()) { + return CustomListValue(EmptyCompatListValue(), arena_); + } + return CustomListValue(std::move(*this).BuildCompat(), arena_); +} + +const CompatListValue* absl_nonnull ListValueBuilderImpl::BuildCompat() && { + if (elements_->empty()) { + return EmptyCompatListValue(); + } + return std::move(*this).BuildCompatAt(arena_->AllocateAligned( + sizeof(CompatListValueImpl), alignof(CompatListValueImpl))); +} + +const CompatListValue* absl_nonnull ListValueBuilderImpl::BuildCompatAt( + void* absl_nonnull address) && { + CompatListValueImpl* absl_nonnull impl = + ::new (address) CompatListValueImpl(std::move(*elements_)); + if (!elements_trivially_destructible_) { + arena_->OwnDestructor(impl); + elements_trivially_destructible_ = true; + } + return impl; +} + +class MutableCompatListValueImpl final : public MutableCompatListValue { + public: + explicit MutableCompatListValueImpl(google::protobuf::Arena* absl_nonnull arena) + : elements_(arena) {} + + std::string DebugString() const override { + return absl::StrCat("[", absl::StrJoin(elements_, ", ", ValueFormatter{}), + "]"); + } + + absl::Status ConvertToJsonArray( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + return ListValueToJsonArray(elements_, descriptor_pool, message_factory, + json); + } + + CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + ABSL_DCHECK(arena != nullptr); + + ListValueBuilderImpl builder(arena); + builder.Reserve(elements_.size()); + for (const auto& element : elements_) { + builder.UnsafeAdd(element.Clone(arena)); + } + return std::move(builder).BuildCustom(); + } + + size_t Size() const override { return elements_.size(); } + + absl::Status ForEach( + ForEachWithIndexCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + const size_t size = elements_.size(); + for (size_t i = 0; i < size; ++i) { + CEL_ASSIGN_OR_RETURN(auto ok, callback(i, elements_[i])); + if (!ok) { + break; + } + } + return absl::OkStatus(); + } + + absl::StatusOr NewIterator() const override { + return std::make_unique( + absl::MakeConstSpan(elements_)); + } + + CelValue operator[](int index) const override { + return Get(elements_.get_allocator().arena(), index); + } + + // Like `operator[](int)` above, but also accepts an arena. Prefer calling + // this variant if the arena is known. + CelValue Get(google::protobuf::Arena* arena, int index) const override { + if (arena == nullptr) { + arena = elements_.get_allocator().arena(); + } + if (ABSL_PREDICT_FALSE(index < 0 || index >= size())) { + return CelValue::CreateError(google::protobuf::Arena::Create( + arena, IndexOutOfBoundsError(index).ToStatus())); + } + return common_internal::UnsafeLegacyValue( + elements_[index], /*stable=*/false, + arena != nullptr ? arena : elements_.get_allocator().arena()); + } + + int size() const override { return static_cast(Size()); } + + absl::Status Append(Value value) const override { + CEL_RETURN_IF_ERROR(CheckListElement(value)); + elements_.emplace_back(std::move(value)); + if (elements_trivially_destructible_) { + elements_trivially_destructible_ = + ArenaTraits<>::trivially_destructible(elements_.back()); + if (!elements_trivially_destructible_) { + elements_.get_allocator().arena()->OwnDestructor( + const_cast(this)); + } + } + return absl::OkStatus(); + } + + void Reserve(size_t capacity) const override { elements_.reserve(capacity); } + + protected: + absl::Status Get(size_t index, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + if (index >= elements_.size()) { + *result = IndexOutOfBoundsError(index); + } else { + *result = elements_[index]; + } + return absl::OkStatus(); + } + + private: + mutable ValueVector elements_; + mutable bool elements_trivially_destructible_ = true; +}; + +} // namespace + +} // namespace common_internal + +template <> +struct ArenaTraits { + using constructible = std::true_type; + + using always_trivially_destructible = std::true_type; +}; + +namespace common_internal { + +namespace {} // namespace + +absl::StatusOr MakeCompatListValue( + const CustomListValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + ListValueBuilderImpl builder(arena); + builder.Reserve(value.Size()); + + CEL_RETURN_IF_ERROR(value.ForEach( + [&](const Value& element) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(builder.Add(element)); + return true; + }, + descriptor_pool, message_factory, arena)); + + return std::move(builder).BuildCompat(); +} + +MutableListValue* absl_nonnull NewMutableListValue( + google::protobuf::Arena* absl_nonnull arena) { + return ::new (arena->AllocateAligned(sizeof(MutableCompatListValueImpl), + alignof(MutableCompatListValueImpl))) + MutableCompatListValueImpl(arena); +} + +bool IsMutableListValue(const Value& value) { + if (auto custom_list_value = value.AsCustomList(); custom_list_value) { + NativeTypeId native_type_id = custom_list_value->GetTypeId(); + if (native_type_id == NativeTypeId::For() || + native_type_id == NativeTypeId::For()) { + return true; + } + } + return false; +} + +bool IsMutableListValue(const ListValue& value) { + if (auto custom_list_value = value.AsCustom(); custom_list_value) { + NativeTypeId native_type_id = custom_list_value->GetTypeId(); + if (native_type_id == NativeTypeId::For() || + native_type_id == NativeTypeId::For()) { + return true; + } + } + return false; +} + +const MutableListValue* absl_nullable AsMutableListValue(const Value& value) { + if (auto custom_list_value = value.AsCustomList(); custom_list_value) { + NativeTypeId native_type_id = custom_list_value->GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_list_value->interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_list_value->interface()); + } + } + return nullptr; +} + +const MutableListValue* absl_nullable AsMutableListValue( + const ListValue& value) { + if (auto custom_list_value = value.AsCustom(); custom_list_value) { + NativeTypeId native_type_id = custom_list_value->GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_list_value->interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_list_value->interface()); + } + } + return nullptr; +} + +const MutableListValue& GetMutableListValue(const Value& value) { + ABSL_DCHECK(IsMutableListValue(value)) << value; + const auto& custom_list_value = value.GetCustomList(); + NativeTypeId native_type_id = custom_list_value.GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_list_value.interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_list_value.interface()); + } + ABSL_UNREACHABLE(); +} + +const MutableListValue& GetMutableListValue(const ListValue& value) { + ABSL_DCHECK(IsMutableListValue(value)) << value; + const auto& custom_list_value = value.GetCustom(); + NativeTypeId native_type_id = custom_list_value.GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_list_value.interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_list_value.interface()); + } + ABSL_UNREACHABLE(); +} + +absl_nonnull cel::ListValueBuilderPtr NewListValueBuilder( + google::protobuf::Arena* absl_nonnull arena) { + return std::make_unique(arena); +} + +} // namespace common_internal + +} // namespace cel + +namespace cel { + +namespace common_internal { + +namespace { + +using ::google::api::expr::runtime::CelList; +using ::google::api::expr::runtime::CelValue; + +absl::Status CheckMapValue(const Value& value) { + if (auto error_value = value.AsError(); ABSL_PREDICT_FALSE(error_value)) { + return error_value->ToStatus(); + } + if (auto unknown_value = value.AsUnknown(); + ABSL_PREDICT_FALSE(unknown_value)) { + return absl::InvalidArgumentError("cannot add unknown value to list"); + } + return absl::OkStatus(); +} + +size_t ValueHash(const Value& value) { + switch (value.kind()) { + case ValueKind::kBool: + return absl::HashOf(value.kind(), value.GetBool()); + case ValueKind::kInt: + return absl::HashOf(ValueKind::kInt, + absl::implicit_cast(value.GetInt())); + case ValueKind::kUint: + return absl::HashOf(ValueKind::kUint, + absl::implicit_cast(value.GetUint())); + case ValueKind::kString: + return absl::HashOf(value.kind(), value.GetString()); + default: + ABSL_UNREACHABLE(); + } +} + +size_t ValueHash(const CelValue& value) { + switch (value.type()) { + case CelValue::Type::kBool: + return absl::HashOf(ValueKind::kBool, value.BoolOrDie()); + case CelValue::Type::kInt: + return absl::HashOf(ValueKind::kInt, value.Int64OrDie()); + case CelValue::Type::kUint: + return absl::HashOf(ValueKind::kUint, value.Uint64OrDie()); + case CelValue::Type::kString: + return absl::HashOf(ValueKind::kString, value.StringOrDie().value()); + default: + ABSL_UNREACHABLE(); + } +} + +bool ValueEquals(const Value& lhs, const Value& rhs) { + switch (lhs.kind()) { + case ValueKind::kBool: + switch (rhs.kind()) { + case ValueKind::kBool: + return lhs.GetBool() == rhs.GetBool(); + case ValueKind::kInt: + return false; + case ValueKind::kUint: + return false; + case ValueKind::kString: + return false; + default: + ABSL_UNREACHABLE(); + } + case ValueKind::kInt: + switch (rhs.kind()) { + case ValueKind::kBool: + return false; + case ValueKind::kInt: + return lhs.GetInt() == rhs.GetInt(); + case ValueKind::kUint: + return false; + case ValueKind::kString: + return false; + default: + ABSL_UNREACHABLE(); + } + case ValueKind::kUint: + switch (rhs.kind()) { + case ValueKind::kBool: + return false; + case ValueKind::kInt: + return false; + case ValueKind::kUint: + return lhs.GetUint() == rhs.GetUint(); + case ValueKind::kString: + return false; + default: + ABSL_UNREACHABLE(); + } + case ValueKind::kString: + switch (rhs.kind()) { + case ValueKind::kBool: + return false; + case ValueKind::kInt: + return false; + case ValueKind::kUint: + return false; + case ValueKind::kString: + return lhs.GetString() == rhs.GetString(); + default: + ABSL_UNREACHABLE(); + } + default: + ABSL_UNREACHABLE(); + } +} + +bool CelValueEquals(const CelValue& lhs, const Value& rhs) { + switch (lhs.type()) { + case CelValue::Type::kBool: + switch (rhs.kind()) { + case ValueKind::kBool: + return BoolValue(lhs.BoolOrDie()) == rhs.GetBool(); + case ValueKind::kInt: + return false; + case ValueKind::kUint: + return false; + case ValueKind::kString: + return false; + default: + ABSL_UNREACHABLE(); + } + case CelValue::Type::kInt: + switch (rhs.kind()) { + case ValueKind::kBool: + return false; + case ValueKind::kInt: + return IntValue(lhs.Int64OrDie()) == rhs.GetInt(); + case ValueKind::kUint: + return false; + case ValueKind::kString: + return false; + default: + ABSL_UNREACHABLE(); + } + case CelValue::Type::kUint: + switch (rhs.kind()) { + case ValueKind::kBool: + return false; + case ValueKind::kInt: + return false; + case ValueKind::kUint: + return UintValue(lhs.Uint64OrDie()) == rhs.GetUint(); + case ValueKind::kString: + return false; + default: + ABSL_UNREACHABLE(); + } + case CelValue::Type::kString: + switch (rhs.kind()) { + case ValueKind::kBool: + return false; + case ValueKind::kInt: + return false; + case ValueKind::kUint: + return false; + case ValueKind::kString: + return rhs.GetString().Equals(lhs.StringOrDie().value()); + default: + ABSL_UNREACHABLE(); + } + default: + ABSL_UNREACHABLE(); + } +} + +absl::StatusOr ValueToJsonString(const Value& value) { + switch (value.kind()) { + case ValueKind::kString: + return value.GetString().NativeString(); + default: + return TypeConversionError(value.GetRuntimeType(), StringType()) + .ToStatus(); + } +} + +template +absl::Status MapValueToJsonObject( + const Map& map, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); + + StructReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); + + json->Clear(); + + if (map.empty()) { + return absl::OkStatus(); + } + + for (const auto& entry : map) { + CEL_ASSIGN_OR_RETURN(auto key, ValueToJsonString(entry.first)); + CEL_RETURN_IF_ERROR(entry.second.ConvertToJson( + descriptor_pool, message_factory, reflection.InsertField(json, key))); + } + return absl::OkStatus(); +} + +template +absl::Status MapValueToJson( + const Map& map, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(json != nullptr); + ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); + + ValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); + return MapValueToJsonObject(map, descriptor_pool, message_factory, + reflection.MutableStructValue(json)); +} + +struct ValueHasher { + using is_transparent = void; + + size_t operator()(const Value& value) const { return (ValueHash)(value); } + + size_t operator()(const CelValue& value) const { return (ValueHash)(value); } +}; + +struct ValueEqualer { + using is_transparent = void; + + bool operator()(const Value& lhs, const CelValue& rhs) const { + return (*this)(rhs, lhs); + } + + bool operator()(const CelValue& lhs, const Value& rhs) const { + return (CelValueEquals)(lhs, rhs); + } + + bool operator()(const Value& lhs, const Value& rhs) const { + return (ValueEquals)(lhs, rhs); + } +}; + +using ValueFlatHashMapAllocator = ArenaAllocator>; + +using ValueFlatHashMap = + absl::flat_hash_map; + +class CompatMapValueImplIterator final : public ValueIterator { + public: + explicit CompatMapValueImplIterator(const ValueFlatHashMap* absl_nonnull map) + : begin_(map->begin()), end_(map->end()) {} + + bool HasNext() override { return begin_ != end_; } + + absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) override { + if (ABSL_PREDICT_FALSE(begin_ == end_)) { + return absl::FailedPreconditionError( + "ValueManager::Next called after ValueManager::HasNext returned " + "false"); + } + *result = begin_->first; + ++begin_; + return absl::OkStatus(); + } + + absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull key_or_value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key_or_value != nullptr); + + if (begin_ == end_) { + return false; + } + *key_or_value = begin_->first; + ++begin_; + return true; + } + + absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, + Value* absl_nullable value) override { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(key != nullptr); + + if (begin_ == end_) { + return false; + } + *key = begin_->first; + if (value != nullptr) { + *value = begin_->second; + } + ++begin_; + return true; + } + + private: + typename ValueFlatHashMap::const_iterator begin_; + const typename ValueFlatHashMap::const_iterator end_; +}; + +class MapValueBuilderImpl final : public MapValueBuilder { + public: + explicit MapValueBuilderImpl(google::protobuf::Arena* absl_nonnull arena) + : arena_(arena) { + map_.Construct(arena_); + } + + ~MapValueBuilderImpl() override { + if (!entries_trivially_destructible_) { + map_.Destruct(); + } + } + + absl::Status Put(Value key, Value value) override { + CEL_RETURN_IF_ERROR(CheckMapKey(key)); + CEL_RETURN_IF_ERROR(CheckMapValue(value)); + if (auto it = map_->find(key); ABSL_PREDICT_FALSE(it != map_->end())) { + return DuplicateKeyError().ToStatus(); + } + UnsafePut(std::move(key), std::move(value)); + return absl::OkStatus(); + } + + void UnsafePut(Value key, Value value) override { + auto insertion = map_->insert({std::move(key), std::move(value)}); + ABSL_DCHECK(insertion.second); + if (entries_trivially_destructible_) { + entries_trivially_destructible_ = + ArenaTraits<>::trivially_destructible(insertion.first->first) && + ArenaTraits<>::trivially_destructible(insertion.first->second); + } + } + + size_t Size() const override { return map_->size(); } + + void Reserve(size_t capacity) override { map_->reserve(capacity); } + + MapValue Build() && override; + + CustomMapValue BuildCustom() &&; + + const CompatMapValue* absl_nonnull BuildCompat() &&; + + private: + google::protobuf::Arena* absl_nonnull const arena_; + internal::Manual map_; + bool entries_trivially_destructible_ = true; +}; + +class CompatMapValueImpl final : public CompatMapValue { + public: + explicit CompatMapValueImpl(ValueFlatHashMap&& map) : map_(std::move(map)) {} + + std::string DebugString() const override { + return absl::StrCat("{", absl::StrJoin(map_, ", ", ValueFormatter{}), "}"); + } + + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + return MapValueToJsonObject(map_, descriptor_pool, message_factory, json); + } + + CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + ABSL_DCHECK(arena != nullptr); + + MapValueBuilderImpl builder(arena); + builder.Reserve(map_.size()); + for (const auto& entry : map_) { + builder.UnsafePut(entry.first.Clone(arena), entry.second.Clone(arena)); + } + return std::move(builder).BuildCustom(); + } + + size_t Size() const override { return map_.size(); } + + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + ListValue* absl_nonnull result) const override { + *result = CustomListValue(ProjectKeys(), map_.get_allocator().arena()); + return absl::OkStatus(); + } + + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + for (const auto& entry : map_) { + CEL_ASSIGN_OR_RETURN(auto ok, callback(entry.first, entry.second)); + if (!ok) { + break; + } + } + return absl::OkStatus(); + } + + absl::StatusOr NewIterator() const override { + return std::make_unique(&map_); + } + + absl::optional operator[](CelValue key) const override { + return Get(map_.get_allocator().arena(), key); + } + + using CompatMapValue::Get; + absl::optional Get(google::protobuf::Arena* arena, + CelValue key) const override { + if (auto status = CelValue::CheckMapKeyType(key); !status.ok()) { + status.IgnoreError(); + return absl::nullopt; + } + if (auto it = map_.find(key); it != map_.end()) { + return common_internal::UnsafeLegacyValue( + it->second, /*stable=*/true, + arena != nullptr ? arena : map_.get_allocator().arena()); + } + return absl::nullopt; + } + + absl::StatusOr Has(const CelValue& key) const override { + // This check safeguards against issues with invalid key types such as NaN. + CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); + return map_.find(key) != map_.end(); + } + + int size() const override { return static_cast(Size()); } + + absl::StatusOr ListKeys() const override { + return ProjectKeys(); + } + + absl::StatusOr ListKeys(google::protobuf::Arena* arena) const override { + return ProjectKeys(); + } + + protected: + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + CEL_RETURN_IF_ERROR(CheckMapKey(key)); + if (auto it = map_.find(key); it != map_.end()) { + *result = it->second; + return true; + } + return false; + } + + absl::StatusOr Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + CEL_RETURN_IF_ERROR(CheckMapKey(key)); + return map_.find(key) != map_.end(); + } + + private: + const CompatListValue* absl_nonnull ProjectKeys() const { + absl::call_once(keys_once_, [this]() { + ListValueBuilderImpl builder(map_.get_allocator().arena()); + builder.Reserve(map_.size()); + + for (const auto& entry : map_) { + builder.UnsafeAdd(entry.first); + } + + std::move(builder).BuildCompatAt(&keys_[0]); + }); + return std::launder( + reinterpret_cast(&keys_[0])); + } + + const ValueFlatHashMap map_; + mutable absl::once_flag keys_once_; + alignas(CompatListValueImpl) mutable char keys_[sizeof(CompatListValueImpl)]; +}; + +MapValue MapValueBuilderImpl::Build() && { + if (map_->empty()) { + return MapValue(); + } + return std::move(*this).BuildCustom(); +} + +CustomMapValue MapValueBuilderImpl::BuildCustom() && { + if (map_->empty()) { + return CustomMapValue(EmptyCompatMapValue(), arena_); + } + return CustomMapValue(std::move(*this).BuildCompat(), arena_); +} + +const CompatMapValue* absl_nonnull MapValueBuilderImpl::BuildCompat() && { + if (map_->empty()) { + return EmptyCompatMapValue(); + } + CompatMapValueImpl* absl_nonnull impl = ::new (arena_->AllocateAligned( + sizeof(CompatMapValueImpl), alignof(CompatMapValueImpl))) + CompatMapValueImpl(std::move(*map_)); + if (!entries_trivially_destructible_) { + arena_->OwnDestructor(impl); + entries_trivially_destructible_ = true; + } + return impl; +} + +class TrivialMutableMapValueImpl final : public MutableCompatMapValue { + public: + explicit TrivialMutableMapValueImpl(google::protobuf::Arena* absl_nonnull arena) + : map_(arena) {} + + std::string DebugString() const override { + return absl::StrCat("{", absl::StrJoin(map_, ", ", ValueFormatter{}), "}"); + } + + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull json) const override { + return MapValueToJsonObject(map_, descriptor_pool, message_factory, json); + } + + CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + ABSL_DCHECK(arena != nullptr); + + MapValueBuilderImpl builder(arena); + builder.Reserve(map_.size()); + for (const auto& entry : map_) { + builder.UnsafePut(entry.first.Clone(arena), entry.second.Clone(arena)); + } + return std::move(builder).BuildCustom(); + } + + size_t Size() const override { return map_.size(); } + + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + ListValue* absl_nonnull result) const override { + *result = CustomListValue(ProjectKeys(), map_.get_allocator().arena()); + return absl::OkStatus(); + } + + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + for (const auto& entry : map_) { + CEL_ASSIGN_OR_RETURN(auto ok, callback(entry.first, entry.second)); + if (!ok) { + break; + } + } + return absl::OkStatus(); + } + + absl::StatusOr NewIterator() const override { + return std::make_unique(&map_); + } + + absl::optional operator[](CelValue key) const override { + return Get(map_.get_allocator().arena(), key); + } + + using MutableCompatMapValue::Get; + absl::optional Get(google::protobuf::Arena* arena, + CelValue key) const override { + if (auto status = CelValue::CheckMapKeyType(key); !status.ok()) { + status.IgnoreError(); + return absl::nullopt; + } + if (auto it = map_.find(key); it != map_.end()) { + return common_internal::UnsafeLegacyValue( + it->second, /*stable=*/false, + arena != nullptr ? arena : map_.get_allocator().arena()); + } + return absl::nullopt; + } + + absl::StatusOr Has(const CelValue& key) const override { + // This check safeguards against issues with invalid key types such as NaN. + CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); + return map_.find(key) != map_.end(); + } + + int size() const override { return static_cast(Size()); } + + absl::StatusOr ListKeys() const override { + return ProjectKeys(); + } + + absl::StatusOr ListKeys(google::protobuf::Arena* arena) const override { + return ProjectKeys(); + } + + absl::Status Put(Value key, Value value) const override { + CEL_RETURN_IF_ERROR(CheckMapKey(key)); + CEL_RETURN_IF_ERROR(CheckMapValue(value)); + if (auto it = map_.find(key); ABSL_PREDICT_FALSE(it != map_.end())) { + return DuplicateKeyError().ToStatus(); + } + auto insertion = map_.insert({std::move(key), std::move(value)}); + ABSL_DCHECK(insertion.second); + if (entries_trivially_destructible_) { + entries_trivially_destructible_ = + ArenaTraits<>::trivially_destructible(insertion.first->first) && + ArenaTraits<>::trivially_destructible(insertion.first->second); + if (!entries_trivially_destructible_) { + map_.get_allocator().arena()->OwnDestructor( + const_cast(this)); + } + } + return absl::OkStatus(); + } + + void Reserve(size_t capacity) const override { map_.reserve(capacity); } + + protected: + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + CEL_RETURN_IF_ERROR(CheckMapKey(key)); + if (auto it = map_.find(key); it != map_.end()) { + *result = it->second; + return true; + } + return false; + } + + absl::StatusOr Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + CEL_RETURN_IF_ERROR(CheckMapKey(key)); + return map_.find(key) != map_.end(); + } + + private: + const CompatListValue* absl_nonnull ProjectKeys() const { + absl::call_once(keys_once_, [this]() { + ListValueBuilderImpl builder(map_.get_allocator().arena()); + builder.Reserve(map_.size()); + + for (const auto& entry : map_) { + builder.UnsafeAdd(entry.first); + } + + std::move(builder).BuildCompatAt(&keys_[0]); + }); + return std::launder( + reinterpret_cast(&keys_[0])); + } + + mutable ValueFlatHashMap map_; + mutable bool entries_trivially_destructible_ = true; + mutable absl::once_flag keys_once_; + alignas(CompatListValueImpl) mutable char keys_[sizeof(CompatListValueImpl)]; +}; + +} // namespace + +absl::StatusOr MakeCompatMapValue( + const CustomMapValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + MapValueBuilderImpl builder(arena); + builder.Reserve(value.Size()); + + CEL_RETURN_IF_ERROR(value.ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(builder.Put(key, value)); + return true; + }, + descriptor_pool, message_factory, arena)); + + return std::move(builder).BuildCompat(); +} + +MutableMapValue* absl_nonnull NewMutableMapValue( + google::protobuf::Arena* absl_nonnull arena) { + return ::new (arena->AllocateAligned(sizeof(TrivialMutableMapValueImpl), + alignof(TrivialMutableMapValueImpl))) + TrivialMutableMapValueImpl(arena); +} + +bool IsMutableMapValue(const Value& value) { + if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { + NativeTypeId native_type_id = custom_map_value->GetTypeId(); + if (native_type_id == NativeTypeId::For() || + native_type_id == NativeTypeId::For()) { + return true; + } + } + return false; +} + +bool IsMutableMapValue(const MapValue& value) { + if (auto custom_map_value = value.AsCustom(); custom_map_value) { + NativeTypeId native_type_id = custom_map_value->GetTypeId(); + if (native_type_id == NativeTypeId::For() || + native_type_id == NativeTypeId::For()) { + return true; + } + } + return false; +} + +const MutableMapValue* absl_nullable AsMutableMapValue(const Value& value) { + if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { + NativeTypeId native_type_id = custom_map_value->GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_map_value->interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_map_value->interface()); + } + } + return nullptr; +} + +const MutableMapValue* absl_nullable AsMutableMapValue(const MapValue& value) { + if (auto custom_map_value = value.AsCustom(); custom_map_value) { + NativeTypeId native_type_id = custom_map_value->GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_map_value->interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + custom_map_value->interface()); + } + } + return nullptr; +} + +const MutableMapValue& GetMutableMapValue(const Value& value) { + ABSL_DCHECK(IsMutableMapValue(value)) << value; + const auto& custom_map_value = value.GetCustomMap(); + NativeTypeId native_type_id = custom_map_value.GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_map_value.interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_map_value.interface()); + } + ABSL_UNREACHABLE(); +} + +const MutableMapValue& GetMutableMapValue(const MapValue& value) { + ABSL_DCHECK(IsMutableMapValue(value)) << value; + const auto& custom_map_value = value.GetCustom(); + NativeTypeId native_type_id = custom_map_value.GetTypeId(); + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_map_value.interface()); + } + if (native_type_id == NativeTypeId::For()) { + return cel::internal::down_cast( + *custom_map_value.interface()); + } + ABSL_UNREACHABLE(); +} + +absl_nonnull cel::MapValueBuilderPtr NewMapValueBuilder( + google::protobuf::Arena* absl_nonnull arena) { + return std::make_unique(arena); +} + +} // namespace common_internal + +} // namespace cel diff --git a/common/values/value_builder.h b/common/values/value_builder.h new file mode 100644 index 000000000..685b13dd8 --- /dev/null +++ b/common/values/value_builder.h @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ + +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "common/allocator.h" +#include "common/value.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::common_internal { + +// Like NewStructValueBuilder, but deals with well known types. +absl_nullable cel::ValueBuilderPtr NewValueBuilder( + Allocator<> allocator, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + absl::string_view name); + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ diff --git a/common/values/value_variant.cc b/common/values/value_variant.cc new file mode 100644 index 000000000..1c287239c --- /dev/null +++ b/common/values/value_variant.cc @@ -0,0 +1,537 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/values/value_variant.h" + +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "common/values/bytes_value.h" +#include "common/values/error_value.h" +#include "common/values/string_value.h" +#include "common/values/unknown_value.h" +#include "common/values/values.h" + +namespace cel::common_internal { + +void ValueVariant::SlowCopyConstruct(const ValueVariant& other) noexcept { + ABSL_DCHECK((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial); + + switch (index_) { + case ValueIndex::kBytes: + ::new (static_cast(&raw_[0])) BytesValue(*other.At()); + break; + case ValueIndex::kString: + ::new (static_cast(&raw_[0])) + StringValue(*other.At()); + break; + case ValueIndex::kError: + ::new (static_cast(&raw_[0])) ErrorValue(*other.At()); + break; + case ValueIndex::kUnknown: + ::new (static_cast(&raw_[0])) + UnknownValue(*other.At()); + break; + default: + ABSL_UNREACHABLE(); + } +} + +void ValueVariant::SlowMoveConstruct(ValueVariant& other) noexcept { + ABSL_DCHECK((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial); + + switch (index_) { + case ValueIndex::kBytes: + ::new (static_cast(&raw_[0])) + BytesValue(std::move(*other.At())); + break; + case ValueIndex::kString: + ::new (static_cast(&raw_[0])) + StringValue(std::move(*other.At())); + break; + case ValueIndex::kError: + ::new (static_cast(&raw_[0])) + ErrorValue(std::move(*other.At())); + break; + case ValueIndex::kUnknown: + ::new (static_cast(&raw_[0])) + UnknownValue(std::move(*other.At())); + break; + default: + ABSL_UNREACHABLE(); + } +} + +void ValueVariant::SlowDestruct() noexcept { + ABSL_DCHECK((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial); + + switch (index_) { + case ValueIndex::kBytes: + At()->~BytesValue(); + break; + case ValueIndex::kString: + At()->~StringValue(); + break; + case ValueIndex::kError: + At()->~ErrorValue(); + break; + case ValueIndex::kUnknown: + At()->~UnknownValue(); + break; + default: + ABSL_UNREACHABLE(); + } +} + +void ValueVariant::SlowCopyAssign(const ValueVariant& other, bool trivial, + bool other_trivial) noexcept { + ABSL_DCHECK(!trivial || !other_trivial); + + if (trivial) { + switch (other.index_) { + case ValueIndex::kBytes: + ::new (static_cast(&raw_[0])) + BytesValue(*other.At()); + break; + case ValueIndex::kString: + ::new (static_cast(&raw_[0])) + StringValue(*other.At()); + break; + case ValueIndex::kError: + ::new (static_cast(&raw_[0])) + ErrorValue(*other.At()); + break; + case ValueIndex::kUnknown: + ::new (static_cast(&raw_[0])) + UnknownValue(*other.At()); + break; + default: + ABSL_UNREACHABLE(); + } + index_ = other.index_; + kind_ = other.kind_; + flags_ = other.flags_; + } else if (other_trivial) { + switch (index_) { + case ValueIndex::kBytes: + At()->~BytesValue(); + break; + case ValueIndex::kString: + At()->~StringValue(); + break; + case ValueIndex::kError: + At()->~ErrorValue(); + break; + case ValueIndex::kUnknown: + At()->~UnknownValue(); + break; + default: + ABSL_UNREACHABLE(); + } + FastCopyAssign(other); + } else { + switch (index_) { + case ValueIndex::kBytes: + switch (other.index_) { + case ValueIndex::kBytes: + *At() = *other.At(); + break; + case ValueIndex::kString: + At()->~BytesValue(); + ::new (static_cast(&raw_[0])) + StringValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kError: + At()->~BytesValue(); + ::new (static_cast(&raw_[0])) + ErrorValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kUnknown: + At()->~BytesValue(); + ::new (static_cast(&raw_[0])) + UnknownValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + default: + ABSL_UNREACHABLE(); + } + break; + case ValueIndex::kString: + switch (other.index_) { + case ValueIndex::kBytes: + At()->~StringValue(); + ::new (static_cast(&raw_[0])) + BytesValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kString: + *At() = *other.At(); + break; + case ValueIndex::kError: + At()->~StringValue(); + ::new (static_cast(&raw_[0])) + ErrorValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kUnknown: + At()->~StringValue(); + ::new (static_cast(&raw_[0])) + UnknownValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + default: + ABSL_UNREACHABLE(); + } + break; + case ValueIndex::kError: + switch (other.index_) { + case ValueIndex::kBytes: + At()->~ErrorValue(); + ::new (static_cast(&raw_[0])) + BytesValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kString: + At()->~ErrorValue(); + ::new (static_cast(&raw_[0])) + StringValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kError: + *At() = *other.At(); + break; + case ValueIndex::kUnknown: + At()->~ErrorValue(); + ::new (static_cast(&raw_[0])) + UnknownValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + default: + ABSL_UNREACHABLE(); + } + break; + case ValueIndex::kUnknown: + switch (other.index_) { + case ValueIndex::kBytes: + At()->~UnknownValue(); + ::new (static_cast(&raw_[0])) + BytesValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kString: + At()->~UnknownValue(); + ::new (static_cast(&raw_[0])) + StringValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kError: + At()->~UnknownValue(); + ::new (static_cast(&raw_[0])) + ErrorValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kUnknown: + At()->~UnknownValue(); + ::new (static_cast(&raw_[0])) + UnknownValue(*other.At()); + index_ = other.index_; + kind_ = other.kind_; + break; + default: + ABSL_UNREACHABLE(); + } + break; + default: + ABSL_UNREACHABLE(); + } + flags_ = other.flags_; + } +} + +void ValueVariant::SlowMoveAssign(ValueVariant& other, bool trivial, + bool other_trivial) noexcept { + ABSL_DCHECK(!trivial || !other_trivial); + + if (trivial) { + switch (other.index_) { + case ValueIndex::kBytes: + ::new (static_cast(&raw_[0])) + BytesValue(std::move(*other.At())); + break; + case ValueIndex::kString: + ::new (static_cast(&raw_[0])) + StringValue(std::move(*other.At())); + break; + case ValueIndex::kError: + ::new (static_cast(&raw_[0])) + ErrorValue(std::move(*other.At())); + break; + case ValueIndex::kUnknown: + ::new (static_cast(&raw_[0])) + UnknownValue(std::move(*other.At())); + break; + default: + ABSL_UNREACHABLE(); + } + index_ = other.index_; + kind_ = other.kind_; + flags_ = other.flags_; + } else if (other_trivial) { + switch (index_) { + case ValueIndex::kBytes: + At()->~BytesValue(); + break; + case ValueIndex::kString: + At()->~StringValue(); + break; + case ValueIndex::kError: + At()->~ErrorValue(); + break; + case ValueIndex::kUnknown: + At()->~UnknownValue(); + break; + default: + ABSL_UNREACHABLE(); + } + FastMoveAssign(other); + } else { + switch (index_) { + case ValueIndex::kBytes: + switch (other.index_) { + case ValueIndex::kBytes: + *At() = std::move(*other.At()); + break; + case ValueIndex::kString: + At()->~BytesValue(); + ::new (static_cast(&raw_[0])) + StringValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kError: + At()->~BytesValue(); + ::new (static_cast(&raw_[0])) + ErrorValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kUnknown: + At()->~BytesValue(); + ::new (static_cast(&raw_[0])) + UnknownValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + default: + ABSL_UNREACHABLE(); + } + break; + case ValueIndex::kString: + switch (other.index_) { + case ValueIndex::kBytes: + At()->~StringValue(); + ::new (static_cast(&raw_[0])) + BytesValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kString: + *At() = std::move(*other.At()); + break; + case ValueIndex::kError: + At()->~StringValue(); + ::new (static_cast(&raw_[0])) + ErrorValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kUnknown: + At()->~StringValue(); + ::new (static_cast(&raw_[0])) + UnknownValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + default: + ABSL_UNREACHABLE(); + } + break; + case ValueIndex::kError: + switch (other.index_) { + case ValueIndex::kBytes: + At()->~ErrorValue(); + ::new (static_cast(&raw_[0])) + BytesValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kString: + At()->~ErrorValue(); + ::new (static_cast(&raw_[0])) + StringValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kError: + *At() = std::move(*other.At()); + break; + case ValueIndex::kUnknown: + At()->~ErrorValue(); + ::new (static_cast(&raw_[0])) + UnknownValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + default: + ABSL_UNREACHABLE(); + } + break; + case ValueIndex::kUnknown: + switch (other.index_) { + case ValueIndex::kBytes: + At()->~UnknownValue(); + ::new (static_cast(&raw_[0])) + BytesValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kString: + At()->~UnknownValue(); + ::new (static_cast(&raw_[0])) + StringValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kError: + At()->~UnknownValue(); + ::new (static_cast(&raw_[0])) + ErrorValue(std::move(*other.At())); + index_ = other.index_; + kind_ = other.kind_; + break; + case ValueIndex::kUnknown: + *At() = std::move(*other.At()); + break; + default: + ABSL_UNREACHABLE(); + } + break; + default: + ABSL_UNREACHABLE(); + } + flags_ = other.flags_; + } +} + +void ValueVariant::SlowSwap(ValueVariant& lhs, ValueVariant& rhs, + bool lhs_trivial, bool rhs_trivial) noexcept { + using std::swap; + ABSL_DCHECK(!lhs_trivial || !rhs_trivial); + + if (lhs_trivial) { + alignas(ValueVariant) std::byte tmp[sizeof(ValueVariant)]; + // This is acceptable. We know that both are trivially copyable at runtime. + // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) + std::memcpy(tmp, std::addressof(lhs), sizeof(ValueVariant)); + switch (rhs.index_) { + case ValueIndex::kBytes: + ::new (static_cast(&lhs.raw_[0])) + BytesValue(*rhs.At()); + rhs.At()->~BytesValue(); + break; + case ValueIndex::kString: + ::new (static_cast(&lhs.raw_[0])) + StringValue(*rhs.At()); + rhs.At()->~StringValue(); + break; + case ValueIndex::kError: + ::new (static_cast(&lhs.raw_[0])) + ErrorValue(*rhs.At()); + rhs.At()->~ErrorValue(); + break; + case ValueIndex::kUnknown: + ::new (static_cast(&lhs.raw_[0])) + UnknownValue(*rhs.At()); + rhs.At()->~UnknownValue(); + break; + default: + ABSL_UNREACHABLE(); + } + lhs.index_ = rhs.index_; + lhs.kind_ = rhs.kind_; + lhs.flags_ = rhs.flags_; + // This is acceptable. We know that both are trivially copyable at runtime. + // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) + std::memcpy(std::addressof(rhs), tmp, sizeof(ValueVariant)); + } else if (rhs_trivial) { + alignas(ValueVariant) std::byte tmp[sizeof(ValueVariant)]; + // This is acceptable. We know that both are trivially copyable at runtime. + // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) + std::memcpy(tmp, std::addressof(rhs), sizeof(ValueVariant)); + switch (lhs.index_) { + case ValueIndex::kBytes: + ::new (static_cast(&rhs.raw_[0])) + BytesValue(*lhs.At()); + lhs.At()->~BytesValue(); + break; + case ValueIndex::kString: + ::new (static_cast(&rhs.raw_[0])) + StringValue(*lhs.At()); + lhs.At()->~StringValue(); + break; + case ValueIndex::kError: + ::new (static_cast(&rhs.raw_[0])) + ErrorValue(*lhs.At()); + lhs.At()->~ErrorValue(); + break; + case ValueIndex::kUnknown: + ::new (static_cast(&rhs.raw_[0])) + UnknownValue(*lhs.At()); + lhs.At()->~UnknownValue(); + break; + default: + ABSL_UNREACHABLE(); + } + rhs.index_ = lhs.index_; + rhs.kind_ = lhs.kind_; + rhs.flags_ = lhs.flags_; + // This is acceptable. We know that both are trivially copyable at runtime. + // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) + std::memcpy(std::addressof(lhs), tmp, sizeof(ValueVariant)); + } else { + ValueVariant tmp = std::move(lhs); + lhs = std::move(rhs); + rhs = std::move(tmp); + } +} + +} // namespace cel::common_internal diff --git a/common/values/value_variant.h b/common/values/value_variant.h new file mode 100644 index 000000000..b05511e3c --- /dev/null +++ b/common/values/value_variant.h @@ -0,0 +1,831 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_VARIANT_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_VARIANT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/meta/type_traits.h" +#include "absl/utility/utility.h" +#include "common/arena.h" +#include "common/value_kind.h" +#include "common/values/bool_value.h" +#include "common/values/bytes_value.h" +#include "common/values/custom_list_value.h" +#include "common/values/custom_map_value.h" +#include "common/values/custom_struct_value.h" +#include "common/values/double_value.h" +#include "common/values/duration_value.h" +#include "common/values/error_value.h" +#include "common/values/int_value.h" +#include "common/values/legacy_list_value.h" +#include "common/values/legacy_map_value.h" +#include "common/values/legacy_struct_value.h" +#include "common/values/list_value.h" +#include "common/values/map_value.h" +#include "common/values/null_value.h" +#include "common/values/opaque_value.h" +#include "common/values/parsed_json_list_value.h" +#include "common/values/parsed_json_map_value.h" +#include "common/values/parsed_map_field_value.h" +#include "common/values/parsed_message_value.h" +#include "common/values/parsed_repeated_field_value.h" +#include "common/values/string_value.h" +#include "common/values/timestamp_value.h" +#include "common/values/type_value.h" +#include "common/values/uint_value.h" +#include "common/values/unknown_value.h" +#include "common/values/values.h" + +namespace cel { + +class Value; + +namespace common_internal { + +// Used by ValueVariant to indicate the active alternative. +enum class ValueIndex : uint8_t { + kNull = 0, + kBool, + kInt, + kUint, + kDouble, + kDuration, + kTimestamp, + kType, + kLegacyList, + kParsedJsonList, + kParsedRepeatedField, + kCustomList, + kLegacyMap, + kParsedJsonMap, + kParsedMapField, + kCustomMap, + kLegacyStruct, + kParsedMessage, + kCustomStruct, + kOpaque, + + // Keep non-trivial alternatives together to aid in compiling optimizations. + kBytes, + kString, + kError, + kUnknown, +}; + +// Used by ValueVariant to indicate pre-computed behaviors. +enum class ValueFlags : uint32_t { + kNone = 0, + kNonTrivial = 1, +}; + +ABSL_ATTRIBUTE_ALWAYS_INLINE inline constexpr ValueFlags operator&( + ValueFlags lhs, ValueFlags rhs) { + return static_cast( + static_cast>(lhs) & + static_cast>(rhs)); +} + +// Traits specialized by each alternative. +// +// ValueIndex ValueAlternative::kIndex +// +// Indicates the alternative index corresponding to T. +// +// ValueKind ValueAlternative::kKind +// +// Indicatates the kind corresponding to T. +// +// bool ValueAlternative::kAlwaysTrivial +// +// True if T is trivially_copyable, false otherwise. +// +// ValueFlags ValueAlternative::Flags(const T* absl_nonnull ) +// +// Returns the flags for the corresponding instance of T. +template +struct ValueAlternative; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kNull; + static constexpr ValueKind kKind = NullValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const NullValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kBool; + static constexpr ValueKind kKind = BoolValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const BoolValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kInt; + static constexpr ValueKind kKind = IntValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const IntValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kUint; + static constexpr ValueKind kKind = UintValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const UintValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kDouble; + static constexpr ValueKind kKind = DoubleValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const DoubleValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kDuration; + static constexpr ValueKind kKind = DurationValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const DurationValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kTimestamp; + static constexpr ValueKind kKind = TimestampValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const TimestampValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kType; + static constexpr ValueKind kKind = TypeValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const TypeValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kLegacyList; + static constexpr ValueKind kKind = LegacyListValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const LegacyListValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kParsedJsonList; + static constexpr ValueKind kKind = ParsedJsonListValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const ParsedJsonListValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kParsedRepeatedField; + static constexpr ValueKind kKind = ParsedRepeatedFieldValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags( + const ParsedRepeatedFieldValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kCustomList; + static constexpr ValueKind kKind = CustomListValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const CustomListValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kLegacyMap; + static constexpr ValueKind kKind = LegacyMapValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const LegacyMapValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kParsedJsonMap; + static constexpr ValueKind kKind = ParsedJsonMapValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const ParsedJsonMapValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kParsedMapField; + static constexpr ValueKind kKind = ParsedMapFieldValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const ParsedMapFieldValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kCustomMap; + static constexpr ValueKind kKind = CustomMapValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const CustomMapValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kLegacyStruct; + static constexpr ValueKind kKind = LegacyStructValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const LegacyStructValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kParsedMessage; + static constexpr ValueKind kKind = ParsedMessageValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const ParsedMessageValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kCustomStruct; + static constexpr ValueKind kKind = CustomStructValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const CustomStructValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kOpaque; + static constexpr ValueKind kKind = OpaqueValue::kKind; + static constexpr bool kAlwaysTrivial = true; + + static constexpr ValueFlags Flags(const OpaqueValue* absl_nonnull) { + return ValueFlags::kNone; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kBytes; + static constexpr ValueKind kKind = BytesValue::kKind; + static constexpr bool kAlwaysTrivial = false; + + static ValueFlags Flags(const BytesValue* absl_nonnull alternative) { + return ArenaTraits::trivially_destructible(*alternative) + ? ValueFlags::kNone + : ValueFlags::kNonTrivial; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kString; + static constexpr ValueKind kKind = StringValue::kKind; + static constexpr bool kAlwaysTrivial = false; + + static ValueFlags Flags(const StringValue* absl_nonnull alternative) { + return ArenaTraits::trivially_destructible(*alternative) + ? ValueFlags::kNone + : ValueFlags::kNonTrivial; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kError; + static constexpr ValueKind kKind = ErrorValue::kKind; + static constexpr bool kAlwaysTrivial = false; + + static ValueFlags Flags(const ErrorValue* absl_nonnull alternative) { + return ArenaTraits::trivially_destructible(*alternative) + ? ValueFlags::kNone + : ValueFlags::kNonTrivial; + } +}; + +template <> +struct ValueAlternative { + static constexpr ValueIndex kIndex = ValueIndex::kUnknown; + static constexpr ValueKind kKind = UnknownValue::kKind; + static constexpr bool kAlwaysTrivial = false; + + static constexpr ValueFlags Flags(const UnknownValue* absl_nonnull) { + return ValueFlags::kNonTrivial; + } +}; + +template +struct IsValueAlternative : std::false_type {}; + +template +struct IsValueAlternative{})>> + : std::true_type {}; + +template +inline constexpr bool IsValueAlternativeV = IsValueAlternative::value; + +// Alignment and size of the storage inside ValueVariant, not for ValueVariant +// itself. +inline constexpr size_t kValueVariantAlign = 8; +inline constexpr size_t kValueVariantSize = 24; + +// Hand-rolled variant used by cel::Value which exhibits up to a 25% performance +// improvement compared to using std::variant. +// +// The implementation abuses the fact that most alternatives are trivially +// copyable and some are conditionally trivially copyable at runtime. For the +// fast path, we perform raw byte copying. For the slow path, we fallback to a +// non-inlined function. The compiler is typically smart enough to inline the +// fast path and emit efficient instructions for the raw byte copying (usually +// two instructions). It also uses switch for visiting, which most compilers can +// optimize better compared to a function pointer table (which libc++ currently +// uses and Clang currently does not optimize well). +class alignas(kValueVariantAlign) CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI + ValueVariant final { + public: + ValueVariant() = default; + + ValueVariant(const ValueVariant& other) noexcept + : index_(other.index_), kind_(other.kind_), flags_(other.flags_) { + if ((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone) { + std::memcpy(raw_, other.raw_, sizeof(raw_)); + } else { + SlowCopyConstruct(other); + } + } + + ValueVariant(ValueVariant&& other) noexcept + : index_(other.index_), kind_(other.kind_), flags_(other.flags_) { + if ((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone) { + std::memcpy(raw_, other.raw_, sizeof(raw_)); + } else { + SlowMoveConstruct(other); + } + } + + ~ValueVariant() { + if ((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial) { + SlowDestruct(); + } + } + + ValueVariant& operator=(const ValueVariant& other) noexcept { + if (this != &other) { + const bool trivial = + (flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; + const bool other_trivial = + (other.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; + if (trivial && other_trivial) { + FastCopyAssign(other); + } else { + SlowCopyAssign(other, trivial, other_trivial); + } + } + return *this; + } + + ValueVariant& operator=(ValueVariant&& other) noexcept { + if (this != &other) { + const bool trivial = + (flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; + const bool other_trivial = + (other.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; + if (trivial && other_trivial) { + FastMoveAssign(other); + } else { + SlowMoveAssign(other, trivial, other_trivial); + } + } + return *this; + } + + template + explicit ValueVariant(absl::in_place_type_t, Args&&... args) + : index_(ValueAlternative::kIndex), kind_(ValueAlternative::kKind) { + static_assert(alignof(T) <= kValueVariantAlign); + static_assert(sizeof(T) <= kValueVariantSize); + + flags_ = ValueAlternative::Flags(::new (static_cast(&raw_[0])) + T(std::forward(args)...)); + } + + template >>> + explicit ValueVariant(T&& value) + : ValueVariant(absl::in_place_type>, + std::forward(value)) {} + + ValueKind kind() const { return kind_; } + + template + void Assign(T&& value) { + using U = absl::remove_cvref_t; + + static_assert(alignof(U) <= kValueVariantAlign); + static_assert(sizeof(U) <= kValueVariantSize); + + if constexpr (ValueAlternative::kAlwaysTrivial) { + if ((flags_ & ValueFlags::kNonTrivial) != ValueFlags::kNone) { + SlowDestruct(); + } + index_ = ValueAlternative::kIndex; + kind_ = ValueAlternative::kKind; + flags_ = ValueAlternative::Flags(::new (static_cast(&raw_[0])) + U(std::forward(value))); + } else { + // U is not always trivial. See if the current active alternative is U. If + // it is, we can just do a simple assignment without having to destruct + // first. Otherwise fallback to destruct and construct. + if (index_ == ValueAlternative::kIndex) { + *At() = std::forward(value); + flags_ = ValueAlternative::Flags(At()); + } else { + if ((flags_ & ValueFlags::kNonTrivial) != ValueFlags::kNone) { + SlowDestruct(); + } + index_ = ValueAlternative::kIndex; + kind_ = ValueAlternative::kKind; + flags_ = ValueAlternative::Flags(::new (static_cast(&raw_[0])) + U(std::forward(value))); + } + } + } + + template + bool Is() const { + return index_ == ValueAlternative::kIndex; + } + + template + T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return *At(); + } + + template + T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Is()); + + return std::move(*At()); + } + + template + T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (Is()) { + return At(); + } + return nullptr; + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE decltype(auto) Visit(Visitor&& visitor) & { + return std::as_const(*this).Visit(std::forward(visitor)); + } + + template + decltype(auto) Visit(Visitor&& visitor) const& { + switch (index_) { + case ValueIndex::kNull: + return std::forward(visitor)(Get()); + case ValueIndex::kBool: + return std::forward(visitor)(Get()); + case ValueIndex::kInt: + return std::forward(visitor)(Get()); + case ValueIndex::kUint: + return std::forward(visitor)(Get()); + case ValueIndex::kDouble: + return std::forward(visitor)(Get()); + case ValueIndex::kDuration: + return std::forward(visitor)(Get()); + case ValueIndex::kTimestamp: + return std::forward(visitor)(Get()); + case ValueIndex::kType: + return std::forward(visitor)(Get()); + case ValueIndex::kLegacyList: + return std::forward(visitor)(Get()); + case ValueIndex::kParsedJsonList: + return std::forward(visitor)(Get()); + case ValueIndex::kParsedRepeatedField: + return std::forward(visitor)(Get()); + case ValueIndex::kCustomList: + return std::forward(visitor)(Get()); + case ValueIndex::kLegacyMap: + return std::forward(visitor)(Get()); + case ValueIndex::kParsedJsonMap: + return std::forward(visitor)(Get()); + case ValueIndex::kParsedMapField: + return std::forward(visitor)(Get()); + case ValueIndex::kCustomMap: + return std::forward(visitor)(Get()); + case ValueIndex::kLegacyStruct: + return std::forward(visitor)(Get()); + case ValueIndex::kParsedMessage: + return std::forward(visitor)(Get()); + case ValueIndex::kCustomStruct: + return std::forward(visitor)(Get()); + case ValueIndex::kOpaque: + return std::forward(visitor)(Get()); + case ValueIndex::kBytes: + return std::forward(visitor)(Get()); + case ValueIndex::kString: + return std::forward(visitor)(Get()); + case ValueIndex::kError: + return std::forward(visitor)(Get()); + case ValueIndex::kUnknown: + return std::forward(visitor)(Get()); + } + } + + template + decltype(auto) Visit(Visitor&& visitor) && { + switch (index_) { + case ValueIndex::kNull: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kBool: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kInt: + return std::forward(visitor)(std::move(*this).Get()); + case ValueIndex::kUint: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kDouble: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kDuration: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kTimestamp: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kType: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kLegacyList: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kParsedJsonList: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kParsedRepeatedField: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kCustomList: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kLegacyMap: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kParsedJsonMap: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kParsedMapField: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kCustomMap: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kLegacyStruct: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kParsedMessage: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kCustomStruct: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kOpaque: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kBytes: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kString: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kError: + return std::forward(visitor)( + std::move(*this).Get()); + case ValueIndex::kUnknown: + return std::forward(visitor)( + std::move(*this).Get()); + } + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE decltype(auto) Visit(Visitor&& visitor) const&& { + return Visit(std::forward(visitor)); + } + + friend void swap(ValueVariant& lhs, ValueVariant& rhs) noexcept { + if (&lhs != &rhs) { + const bool lhs_trivial = + (lhs.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; + const bool rhs_trivial = + (rhs.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; + if (lhs_trivial && rhs_trivial) { +// We validated the instances can be copied byte-wise at runtime, but compilers +// warn since this is not safe in the general case. +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#elif defined(__clang__) && __clang_major__ >= 20 +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnontrivial-memcall" +#endif + alignas(ValueVariant) std::byte tmp[sizeof(ValueVariant)]; + // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) + std::memcpy(tmp, std::addressof(lhs), sizeof(ValueVariant)); + // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) + std::memcpy(std::addressof(lhs), std::addressof(rhs), + sizeof(ValueVariant)); + // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) + std::memcpy(std::addressof(rhs), tmp, sizeof(ValueVariant)); +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#elif defined(__clang__) && __clang_major__ >= 20 +#pragma clang diagnostic pop +#endif + } else { + SlowSwap(lhs, rhs, lhs_trivial, rhs_trivial); + } + } + } + + private: + friend struct cel::ArenaTraits; + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kValueVariantAlign); + static_assert(sizeof(T) <= kValueVariantSize); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + static_assert(alignof(T) <= kValueVariantAlign); + static_assert(sizeof(T) <= kValueVariantSize); + + return std::launder(reinterpret_cast(&raw_[0])); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE void FastCopyAssign( + const ValueVariant& other) noexcept { + index_ = other.index_; + kind_ = other.kind_; + flags_ = other.flags_; + std::memcpy(raw_, other.raw_, sizeof(raw_)); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE void FastMoveAssign( + ValueVariant& other) noexcept { + FastCopyAssign(other); + } + + void SlowCopyConstruct(const ValueVariant& other) noexcept; + + void SlowMoveConstruct(ValueVariant& other) noexcept; + + void SlowDestruct() noexcept; + + void SlowCopyAssign(const ValueVariant& other, bool trivial, + bool other_trivial) noexcept; + + void SlowMoveAssign(ValueVariant& other, bool ntrivial, + bool other_trivial) noexcept; + + static void SlowSwap(ValueVariant& lhs, ValueVariant& rhs, bool lhs_trivial, + bool rhs_trivial) noexcept; + + ValueIndex index_ = ValueIndex::kNull; + ValueKind kind_ = ValueKind::kNull; + ValueFlags flags_ = ValueFlags::kNone; + alignas(kValueVariantAlign) std::byte raw_[kValueVariantSize]; +}; + +} // namespace common_internal + +template <> +struct ArenaTraits { + static bool trivially_destructible( + const common_internal::ValueVariant& value) { + return (value.flags_ & common_internal::ValueFlags::kNonTrivial) == + common_internal::ValueFlags::kNone; + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_VARIANT_H_ diff --git a/common/values/value_variant_test.cc b/common/values/value_variant_test.cc new file mode 100644 index 000000000..1fd3629aa --- /dev/null +++ b/common/values/value_variant_test.cc @@ -0,0 +1,126 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/strings/cord.h" +#include "common/value.h" +#include "internal/testing.h" + +namespace cel::common_internal { +namespace { + +template +class ValueVariantTest : public ::testing::Test {}; + +#define VALUE_VARIANT_TYPES(T) \ + std::pair, std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair + +using ValueVariantTypes = ::testing::Types< + VALUE_VARIANT_TYPES(NullValue), VALUE_VARIANT_TYPES(BoolValue), + VALUE_VARIANT_TYPES(IntValue), VALUE_VARIANT_TYPES(UintValue), + VALUE_VARIANT_TYPES(DoubleValue), VALUE_VARIANT_TYPES(DurationValue), + VALUE_VARIANT_TYPES(TimestampValue), VALUE_VARIANT_TYPES(TypeValue), + VALUE_VARIANT_TYPES(LegacyListValue), + VALUE_VARIANT_TYPES(ParsedJsonListValue), + VALUE_VARIANT_TYPES(ParsedRepeatedFieldValue), + VALUE_VARIANT_TYPES(CustomListValue), VALUE_VARIANT_TYPES(LegacyMapValue), + VALUE_VARIANT_TYPES(ParsedJsonMapValue), + VALUE_VARIANT_TYPES(ParsedMapFieldValue), + VALUE_VARIANT_TYPES(CustomMapValue), VALUE_VARIANT_TYPES(LegacyStructValue), + VALUE_VARIANT_TYPES(ParsedMessageValue), + VALUE_VARIANT_TYPES(CustomStructValue), VALUE_VARIANT_TYPES(OpaqueValue), + VALUE_VARIANT_TYPES(BytesValue), VALUE_VARIANT_TYPES(StringValue), + VALUE_VARIANT_TYPES(ErrorValue), VALUE_VARIANT_TYPES(UnknownValue)>; + +template +struct DefaultValue { + T operator()() const { return T(); } +}; + +template <> +struct DefaultValue { + BytesValue operator()() const { + return BytesValue( + absl::Cord("Some somewhat large string that is not storable inline!")); + } +}; + +template <> +struct DefaultValue { + StringValue operator()() const { + return StringValue( + absl::Cord("Some somewhat large string that is not storable inline!")); + } +}; + +#undef VALUE_VARIANT_TYPES + +TYPED_TEST_SUITE(ValueVariantTest, ValueVariantTypes); + +TYPED_TEST(ValueVariantTest, CopyAssign) { + using Left = typename TypeParam::first_type; + using Right = typename TypeParam::second_type; + + ValueVariant lhs(DefaultValue{}()); + ValueVariant rhs(DefaultValue{}()); + + EXPECT_TRUE(lhs.Is()); + + lhs = rhs; + + EXPECT_TRUE(lhs.Is()); + EXPECT_TRUE(rhs.Is()); +} + +TYPED_TEST(ValueVariantTest, MoveAssign) { + using Left = typename TypeParam::first_type; + using Right = typename TypeParam::second_type; + + ValueVariant lhs(DefaultValue{}()); + ValueVariant rhs(DefaultValue{}()); + + EXPECT_TRUE(lhs.Is()); + + lhs = std::move(rhs); + + EXPECT_TRUE(lhs.Is()); +} + +TYPED_TEST(ValueVariantTest, Swap) { + using Left = typename TypeParam::first_type; + using Right = typename TypeParam::second_type; + + ValueVariant lhs(DefaultValue{}()); + ValueVariant rhs(DefaultValue{}()); + + swap(lhs, rhs); + + EXPECT_TRUE(lhs.Is()); + EXPECT_TRUE(rhs.Is()); +} + +} // namespace +} // namespace cel::common_internal diff --git a/common/values/values.h b/common/values/values.h new file mode 100644 index 000000000..aaa6f8659 --- /dev/null +++ b/common/values/values.h @@ -0,0 +1,351 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUES_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUES_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +// absl::Cord is trivially relocatable IFF we are not using ASan or MSan. When +// using ASan or MSan absl::Cord will poison/unpoison its inline storage. +#if defined(ABSL_HAVE_ADDRESS_SANITIZER) || defined(ABSL_HAVE_MEMORY_SANITIZER) +#define CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI +#else +#define CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI ABSL_ATTRIBUTE_TRIVIAL_ABI +#endif + +namespace cel { + +class ValueInterface; +class ListValueInterface; +class StructValueInterface; + +class Value; +class BoolValue; +class BytesValue; +class DoubleValue; +class DurationValue; +class ABSL_ATTRIBUTE_TRIVIAL_ABI ErrorValue; +class IntValue; +class ListValue; +class MapValue; +class NullValue; +class OpaqueValue; +class OptionalValue; +class StringValue; +class StructValue; +class TimestampValue; +class TypeValue; +class UintValue; +class UnknownValue; +class ParsedMessageValue; +class ParsedMapFieldValue; +class ParsedRepeatedFieldValue; +class ParsedJsonListValue; +class ParsedJsonMapValue; + +class CustomListValue; +class CustomListValueInterface; + +class CustomMapValue; +class CustomMapValueInterface; + +class CustomStructValue; +class CustomStructValueInterface; + +class ValueIterator; +using ValueIteratorPtr = std::unique_ptr; + +class ValueIterator { + public: + virtual ~ValueIterator() = default; + + virtual bool HasNext() = 0; + + // Returns a view of the next value. If the underlying implementation cannot + // directly return a view of a value, the value will be stored in `scratch`, + // and the returned view will be that of `scratch`. + virtual absl::Status Next( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) = 0; + + absl::StatusOr Next( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + + // Next1 returns values for lists and keys for maps. + virtual absl::StatusOr Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value); + + absl::StatusOr> Next1( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); + + // Next2 returns indices (in ascending order) and values for lists and keys + // (in any order) and values for maps. + virtual absl::StatusOr Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nullable key, + Value* absl_nullable value) = 0; + + absl::StatusOr>> Next2( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); +}; + +namespace common_internal { + +class SharedByteString; +class SharedByteStringView; + +class LegacyListValue; + +class LegacyMapValue; + +class LegacyStructValue; + +class ListValueVariant; + +class MapValueVariant; + +class StructValueVariant; + +class CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI ValueVariant; + +ErrorValue GetDefaultErrorValue(); + +CustomListValue GetEmptyDynListValue(); + +CustomMapValue GetEmptyDynDynMapValue(); + +OptionalValue GetEmptyDynOptionalValue(); + +absl::Status ListValueEqual( + const ListValue& lhs, const ListValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + +absl::Status ListValueEqual( + const CustomListValueInterface& lhs, const ListValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + +absl::Status MapValueEqual( + const MapValue& lhs, const MapValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + +absl::Status MapValueEqual( + const CustomMapValueInterface& lhs, const MapValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + +absl::Status StructValueEqual( + const StructValue& lhs, const StructValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + +absl::Status StructValueEqual( + const CustomStructValueInterface& lhs, const StructValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); + +const SharedByteString& AsSharedByteString(const BytesValue& value); + +const SharedByteString& AsSharedByteString(const StringValue& value); + +using ListValueForEachCallback = + absl::FunctionRef(const Value&)>; +using ListValueForEach2Callback = + absl::FunctionRef(size_t, const Value&)>; + +template +class ValueMixin { + public: + absl::StatusOr Equal( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + friend Base; +}; + +template +class ListValueMixin : public ValueMixin { + public: + using ValueMixin::Equal; + + absl::StatusOr Get( + size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + using ForEachCallback = absl::FunctionRef(const Value&)>; + + absl::Status ForEach( + ForEachCallback callback, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + return static_cast(this)->ForEach( + [callback](size_t, const Value& value) -> absl::StatusOr { + return callback(value); + }, + descriptor_pool, message_factory, arena); + } + + absl::StatusOr Contains( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + friend Base; +}; + +template +class MapValueMixin : public ValueMixin { + public: + using ValueMixin::Equal; + + absl::StatusOr Get( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr> Find( + const Value& other, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + friend Base; +}; + +template +class StructValueMixin : public ValueMixin { + public: + using ValueMixin::Equal; + + absl::StatusOr GetFieldByName( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status GetFieldByName( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return static_cast(this)->GetFieldByName( + name, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, + message_factory, arena, result); + } + + absl::StatusOr GetFieldByName( + absl::string_view name, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr GetFieldByNumber( + int64_t number, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status GetFieldByNumber( + int64_t number, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + return static_cast(this)->GetFieldByNumber( + number, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, + message_factory, arena, result); + } + + absl::StatusOr GetFieldByNumber( + int64_t number, ProtoWrapperTypeOptions unboxing_options, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::StatusOr> Qualify( + absl::Span qualifiers, bool presence_test, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + friend Base; +}; + +template +class OpaqueValueMixin : public ValueMixin { + public: + using ValueMixin::Equal; + + friend Base; +}; + +} // namespace common_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUES_H_ diff --git a/compiler/BUILD b/compiler/BUILD new file mode 100644 index 000000000..d4a0ab4ac --- /dev/null +++ b/compiler/BUILD @@ -0,0 +1,181 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "compiler", + hdrs = ["compiler.h"], + deps = [ + "//checker:checker_options", + "//checker:type_checker", + "//checker:type_checker_builder", + "//checker:validation_result", + "//parser:options", + "//parser:parser_interface", + "//validator", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "compiler_factory", + srcs = ["compiler_factory.cc"], + hdrs = ["compiler_factory.h"], + deps = [ + ":compiler", + "//checker:type_check_issue", + "//checker:type_checker", + "//checker:type_checker_builder", + "//checker:type_checker_builder_factory", + "//checker:validation_result", + "//common:ast", + "//common:source", + "//internal:noop_delete", + "//internal:status_macros", + "//parser", + "//parser:parser_interface", + "//validator", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "compiler_factory_test", + srcs = ["compiler_factory_test.cc"], + deps = [ + ":compiler", + ":compiler_factory", + ":optional", + ":standard_library", + "//checker:optional", + "//checker:standard_library", + "//checker:type_check_issue", + "//checker:type_checker", + "//checker:validation_result", + "//common:decl", + "//common:source", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser:macro", + "//parser:parser_interface", + "//testutil:baseline_tests", + "//validator:timestamp_literal_validator", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "optional", + srcs = ["optional.cc"], + hdrs = ["optional.h"], + deps = [ + ":compiler", + "//checker:optional", + "//parser:macro", + "//parser:parser_interface", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "optional_test", + srcs = ["optional_test.cc"], + deps = [ + ":compiler", + ":compiler_factory", + ":optional", + ":standard_library", + "//checker:optional", + "//checker:standard_library", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:decl", + "//common:source", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//testutil:baseline_tests", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + ], +) + +cc_library( + name = "standard_library", + srcs = ["standard_library.cc"], + hdrs = ["standard_library.h"], + deps = [ + ":compiler", + "//checker:standard_library", + "//internal:status_macros", + "//parser:macro", + "//parser:parser_interface", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "compiler_library_subset_factory", + srcs = ["compiler_library_subset_factory.cc"], + hdrs = ["compiler_library_subset_factory.h"], + deps = [ + ":compiler", + "//checker:type_checker_subset_factory", + "//parser:parser_subset_factory", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "compiler_library_subset_factory_test", + srcs = ["compiler_library_subset_factory_test.cc"], + deps = [ + ":compiler", + ":compiler_factory", + ":compiler_library_subset_factory", + ":standard_library", + "//checker:validation_result", + "//common:standard_definitions", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/compiler/compiler.h b/compiler/compiler.h new file mode 100644 index 000000000..27237df60 --- /dev/null +++ b/compiler/compiler.h @@ -0,0 +1,166 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/checker_options.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "parser/options.h" +#include "parser/parser_interface.h" +#include "validator/validator.h" +#include "google/protobuf/arena.h" + +namespace cel { + +class Compiler; +class CompilerBuilder; + +// A CompilerLibrary represents a package of CEL configuration that can be +// added to a Compiler. +// +// It may contain either or both of a Parser configuration and a +// TypeChecker configuration. +struct CompilerLibrary { + // Optional identifier to avoid collisions re-adding the same library. + // If id is empty, it is not considered. + std::string id; + // Optional callback for configuring the parser. + ParserBuilderConfigurer configure_parser; + // Optional callback for configuring the type checker. + TypeCheckerBuilderConfigurer configure_checker; + + CompilerLibrary(std::string id, ParserBuilderConfigurer configure_parser, + TypeCheckerBuilderConfigurer configure_checker = nullptr) + : id(std::move(id)), + configure_parser(std::move(configure_parser)), + configure_checker(std::move(configure_checker)) {} + + CompilerLibrary(std::string id, + TypeCheckerBuilderConfigurer configure_checker) + : id(std::move(id)), + configure_parser(std::move(nullptr)), + configure_checker(std::move(configure_checker)) {} + + // Convenience conversion from the CheckerLibrary type. + // + // Note: if a related CompilerLibrary exists, prefer to use that to + // include expected parser configuration. + static CompilerLibrary FromCheckerLibrary(CheckerLibrary checker_library) { + return CompilerLibrary(std::move(checker_library.id), + /*configure_parser=*/nullptr, + std::move(checker_library.configure)); + } + + // For backwards compatibility. To be removed. + // NOLINTNEXTLINE(google-explicit-constructor) + CompilerLibrary(CheckerLibrary checker_library) + : id(std::move(checker_library.id)), + configure_parser(nullptr), + configure_checker(std::move(checker_library.configure)) {} +}; + +struct CompilerLibrarySubset { + // The id of the library to subset. Only one subset can be applied per + // library id. + // + // Must be non-empty. + std::string library_id; + ParserLibrarySubset::MacroPredicate should_include_macro; + TypeCheckerSubset::FunctionPredicate should_include_overload; + // TODO(uncreated-issue/71): to faithfully report the subset back, we need to track + // the default (include or exclude) behavior for each of the predicates. +}; + +// General options for configuring the underlying parser and checker. +struct CompilerOptions { + ParserOptions parser_options; + CheckerOptions checker_options; + // If true, parse errors will be adapted to issues where possible. + bool adapt_parser_errors = false; +}; + +// Interface for CEL CompilerBuilder objects. +// +// Builder implementations do not provide any synchronization themselves, +// but create thread-compatible Compiler instances. +class CompilerBuilder { + public: + virtual ~CompilerBuilder() = default; + + virtual absl::Status AddLibrary(CompilerLibrary library) = 0; + virtual absl::Status AddLibrarySubset(CompilerLibrarySubset subset) = 0; + + virtual TypeCheckerBuilder& GetCheckerBuilder() = 0; + virtual ParserBuilder& GetParserBuilder() = 0; + virtual Validator& GetValidator() = 0; + + virtual absl::StatusOr> Build() = 0; +}; + +// Interface for CEL Compiler objects. +// +// For CEL, compilation is the process of bundling the parse and type-check +// passes. +// +// Compiler instances should be thread-compatible. +class Compiler { + public: + virtual ~Compiler() = default; + + virtual absl::StatusOr Compile( + absl::string_view source, absl::string_view description, + google::protobuf::Arena* absl_nullable arena) const = 0; + + absl::StatusOr Compile(absl::string_view source) const { + return Compile(source, "", nullptr); + } + + absl::StatusOr Compile( + absl::string_view source, absl::string_view description) const { + return Compile(source, description, nullptr); + } + + // Accessor for the underlying type checker. + virtual const TypeChecker& GetTypeChecker() const = 0; + + // Accessor for the underlying parser. + virtual const Parser& GetParser() const = 0; + + // Accessor for the underlying validator. + virtual const Validator& GetValidator() const = 0; + + // Returns a builder initialized with the configuration of this compiler. + // + // The returned builder is a copy of the validated environment and may + // behave differently than the builder that created this compiler. + // + // The returned builder does not share state with the compiler and may be + // modified independently. + virtual std::unique_ptr ToBuilder() const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc new file mode 100644 index 000000000..ed22c5630 --- /dev/null +++ b/compiler/compiler_factory.cc @@ -0,0 +1,210 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/compiler_factory.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" +#include "validator/validator.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +namespace { + +class CompilerImpl : public Compiler { + public: + CompilerImpl(std::unique_ptr type_checker, + std::unique_ptr parser, + // Copy the validator in case builder is reused. + Validator validator, CompilerOptions options) + : type_checker_(std::move(type_checker)), + parser_(std::move(parser)), + validator_(std::move(validator)), + options_(options) {} + + absl::StatusOr Compile( + absl::string_view expression, absl::string_view description, + google::protobuf::Arena* arena) const override { + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(expression, std::string(description))); + std::vector parse_issues; + absl::StatusOr> ast = + parser_->Parse(*source, &parse_issues); + if (!ast.ok()) { + if (!options_.adapt_parser_errors || + ast.status().code() != absl::StatusCode::kInvalidArgument || + parse_issues.empty()) { + return ast.status(); + } + std::vector check_issues; + check_issues.reserve(parse_issues.size()); + for (const auto& issue : parse_issues) { + check_issues.push_back(TypeCheckIssue::CreateError( + issue.location(), std::string(issue.message()))); + } + ValidationResult result(std::move(check_issues)); + result.SetSource(std::move(source)); + return result; + } + CEL_ASSIGN_OR_RETURN(ValidationResult result, + type_checker_->Check(*std::move(ast), arena)); + + result.SetSource(std::move(source)); + if (!validator_.validations().empty()) { + validator_.UpdateValidationResult(result); + } + return result; + } + + std::unique_ptr ToBuilder() const override; + + const TypeChecker& GetTypeChecker() const override { return *type_checker_; } + const Parser& GetParser() const override { return *parser_; } + const Validator& GetValidator() const override { return validator_; } + + private: + std::unique_ptr type_checker_; + std::unique_ptr parser_; + Validator validator_; + CompilerOptions options_; +}; + +class CompilerBuilderImpl : public CompilerBuilder { + public: + CompilerBuilderImpl(std::unique_ptr type_checker_builder, + std::unique_ptr parser_builder, + Validator validator, CompilerOptions options) + : type_checker_builder_(std::move(type_checker_builder)), + parser_builder_(std::move(parser_builder)), + validator_(std::move(validator)), + options_(options) {} + + absl::Status AddLibrary(CompilerLibrary library) override { + if (!library.id.empty()) { + auto [it, inserted] = library_ids_.insert(library.id); + + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("library already exists: ", library.id)); + } + } + + if (library.configure_checker) { + CEL_RETURN_IF_ERROR(type_checker_builder_->AddLibrary({ + .id = library.id, + .configure = std::move(library.configure_checker), + })); + } + if (library.configure_parser) { + CEL_RETURN_IF_ERROR(parser_builder_->AddLibrary({ + .id = library.id, + .configure = std::move(library.configure_parser), + })); + } + return absl::OkStatus(); + } + + absl::Status AddLibrarySubset(CompilerLibrarySubset subset) override { + if (subset.library_id.empty()) { + return absl::InvalidArgumentError("library id must not be empty"); + } + std::string library_id = subset.library_id; + + auto [it, inserted] = subsets_.insert(library_id); + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("library subset already exists for: ", library_id)); + } + + if (subset.should_include_macro) { + CEL_RETURN_IF_ERROR(parser_builder_->AddLibrarySubset({ + library_id, + std::move(subset.should_include_macro), + })); + } + if (subset.should_include_overload) { + CEL_RETURN_IF_ERROR(type_checker_builder_->AddLibrarySubset( + {library_id, std::move(subset.should_include_overload)})); + } + return absl::OkStatus(); + } + + ParserBuilder& GetParserBuilder() override { return *parser_builder_; } + TypeCheckerBuilder& GetCheckerBuilder() override { + return *type_checker_builder_; + } + Validator& GetValidator() override { return validator_; } + + absl::StatusOr> Build() override { + CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); + CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); + return std::make_unique( + std::move(type_checker), std::move(parser), validator_, options_); + } + + private: + std::unique_ptr type_checker_builder_; + std::unique_ptr parser_builder_; + Validator validator_; + CompilerOptions options_; + + absl::flat_hash_set library_ids_; + absl::flat_hash_set subsets_; +}; + +std::unique_ptr CompilerImpl::ToBuilder() const { + return std::make_unique( + type_checker_->ToBuilder(), parser_->ToBuilder(), validator_, options_); +} + +} // namespace + +absl::StatusOr> NewCompilerBuilder( + std::shared_ptr descriptor_pool, + CompilerOptions options) { + if (descriptor_pool == nullptr) { + return absl::InvalidArgumentError("descriptor_pool must not be null"); + } + CEL_ASSIGN_OR_RETURN(auto type_checker_builder, + CreateTypeCheckerBuilder(std::move(descriptor_pool), + options.checker_options)); + auto parser_builder = NewParserBuilder(options.parser_options); + + return std::make_unique(std::move(type_checker_builder), + std::move(parser_builder), + Validator(), options); +} + +} // namespace cel diff --git a/compiler/compiler_factory.h b/compiler/compiler_factory.h new file mode 100644 index 000000000..03930b40d --- /dev/null +++ b/compiler/compiler_factory.h @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "internal/noop_delete.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Creates a new unconfigured CompilerBuilder for creating a new CEL Compiler +// instance. +// +// The builder is thread-hostile and intended to be configured by a single +// thread, but the created Compiler instances are thread-compatible (and +// effectively immutable). +// +// The descriptor pool must include the standard definitions for the protobuf +// well-known types: +// - google.protobuf.NullValue +// - google.protobuf.BoolValue +// - google.protobuf.Int32Value +// - google.protobuf.Int64Value +// - google.protobuf.UInt32Value +// - google.protobuf.UInt64Value +// - google.protobuf.FloatValue +// - google.protobuf.DoubleValue +// - google.protobuf.BytesValue +// - google.protobuf.StringValue +// - google.protobuf.Any +// - google.protobuf.Duration +// - google.protobuf.Timestamp +absl::StatusOr> NewCompilerBuilder( + std::shared_ptr descriptor_pool, + CompilerOptions options = {}); + +// Convenience overload for non-owning pointers (such as the generated pool). +// The descriptor pool must outlive the compiler builder and any compiler +// instances it builds. +inline absl::StatusOr> NewCompilerBuilder( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + CompilerOptions options = {}) { + return NewCompilerBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + std::move(options)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc new file mode 100644 index 000000000..035fd8aa6 --- /dev/null +++ b/compiler/compiler_factory_test.cc @@ -0,0 +1,431 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/compiler_factory.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/match.h" +#include "checker/optional.h" +#include "checker/standard_library.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro.h" +#include "parser/parser_interface.h" +#include "testutil/baseline_tests.h" +#include "validator/timestamp_literal_validator.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::FormatBaselineAst; +using ::testing::Contains; +using ::testing::HasSubstr; +using ::testing::Property; +using ::testing::Truly; + +TEST(CompilerFactoryTest, Works) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("['a', 'b', 'c'].exists(x, x in ['c', 'd', 'e']) && 10 " + "< (5 % 3 * 2 + 1 - 2)")); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_EQ(FormatBaselineAst(*result.GetAst()), + R"(_&&_( + __comprehension__( + // Variable + x, + // Target + [ + "a"~string, + "b"~string, + "c"~string + ]~list(string), + // Accumulator + @result, + // Init + false~bool, + // LoopCondition + @not_strictly_false( + !_( + @result~bool^@result + )~bool^logical_not + )~bool^not_strictly_false, + // LoopStep + _||_( + @result~bool^@result, + @in( + x~string^x, + [ + "c"~string, + "d"~string, + "e"~string + ]~list(string) + )~bool^in_list + )~bool^logical_or, + // Result + @result~bool^@result)~bool, + _<_( + 10~int, + _-_( + _+_( + _*_( + _%_( + 5~int, + 3~int + )~int^modulo_int64, + 2~int + )~int^multiply_int64, + 1~int + )~int^add_int64, + 2~int + )~int^subtract_int64 + )~bool^less_int64 +)~bool^logical_and)"); +} + +TEST(CompilerFactoryTest, ParserLibrary) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT( + builder->AddLibrary({"test", + [](ParserBuilder& builder) -> absl::Status { + builder.GetOptions().disable_standard_macros = + true; + return builder.AddMacro(cel::HasMacro()); + }}), + IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_THAT(compiler->Compile("has(a.b)"), IsOk()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("[].map(x, x)")); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'map'")))) + << result.GetIssues()[2].message(); +} + +TEST(CompilerFactoryTest, ParserOptions) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + builder->GetParserBuilder().GetOptions().enable_optional_syntax = true; + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_THAT(compiler->Compile("a.?b.orValue('foo')"), IsOk()); +} + +TEST(CompilerFactoryTest, GetParser) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + const cel::Parser& parser = compiler->GetParser(); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("Or(a, b)")); + ASSERT_OK_AND_ASSIGN(auto ast, parser.Parse(*source)); +} + +TEST(CompilerFactoryTest, GetTypeChecker) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + absl::Status s; + s.Update(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", BoolType()))); + + s.Update(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("b", BoolType()))); + + ASSERT_OK_AND_ASSIGN( + auto or_decl, + MakeFunctionDecl("Or", MakeOverloadDecl("Or_bool_bool", BoolType(), + BoolType(), BoolType()))); + s.Update(builder->GetCheckerBuilder().AddFunction(std::move(or_decl))); + + ASSERT_THAT(s, IsOk()); + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + const cel::Parser& parser = compiler->GetParser(); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("Or(a, b)")); + ASSERT_OK_AND_ASSIGN(auto ast, parser.Parse(*source)); + + const cel::TypeChecker& checker = compiler->GetTypeChecker(); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + checker.Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(CompilerFactoryTest, DisableStandardMacros) { + CompilerOptions options; + options.parser_options.disable_standard_macros = true; + + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool(), + options)); + // Add the type checker library, but not the parser library for CEL standard. + ASSERT_THAT(builder->AddLibrary(CompilerLibrary::FromCheckerLibrary( + StandardCheckerLibrary())), + IsOk()); + ASSERT_THAT(builder->GetParserBuilder().AddMacro(cel::ExistsMacro()), IsOk()); + + // a: map(dyn, dyn) + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a.b")); + + EXPECT_TRUE(result.IsValid()); + + // The has macro is disabled, so looks like a function call. + ASSERT_OK_AND_ASSIGN(result, compiler->Compile("has(a.b)")); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(Truly([](const TypeCheckIssue& issue) { + return absl::StrContains(issue.message(), + "undeclared reference to 'has'"); + }))); + + ASSERT_OK_AND_ASSIGN(result, compiler->Compile("a.exists(x, x == 'foo')")); + EXPECT_TRUE(result.IsValid()); +} + +TEST(CompilerFactoryTest, DisableStandardMacrosWithStdlib) { + CompilerOptions options; + options.parser_options.disable_standard_macros = true; + + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool(), + options)); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->GetParserBuilder().AddMacro(cel::ExistsMacro()), IsOk()); + + // a: map(dyn, dyn) + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a.b")); + + EXPECT_TRUE(result.IsValid()); + + // The has macro is disabled, so looks like a function call. + ASSERT_OK_AND_ASSIGN(result, compiler->Compile("has(a.b)")); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(Truly([](const TypeCheckIssue& issue) { + return absl::StrContains(issue.message(), + "undeclared reference to 'has'"); + }))); + + ASSERT_OK_AND_ASSIGN(result, compiler->Compile("a.exists(x, x == 'foo')")); + EXPECT_TRUE(result.IsValid()); +} + +TEST(CompilerFactoryTest, AddValidator) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + builder->GetValidator().AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("timestamp('invalid')")); + EXPECT_FALSE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(result, + compiler->Compile("timestamp('2024-01-01T00:00:00Z')")); + EXPECT_TRUE(result.IsValid()); +} + +TEST(CompilerFactoryTest, FailsIfLibraryAddedTwice) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), + StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("library already exists: stdlib"))); +} + +TEST(CompilerFactoryTest, FailsIfLibrarySubsetAddedTwice) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_THAT(builder->AddLibrarySubset({ + .library_id = "stdlib", + .should_include_macro = nullptr, + .should_include_overload = nullptr, + }), + IsOk()); + + ASSERT_THAT(builder->AddLibrarySubset({ + .library_id = "stdlib", + .should_include_macro = nullptr, + .should_include_overload = nullptr, + }), + StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("library subset already exists for: stdlib"))); +} + +TEST(CompilerFactoryTest, FailsIfLibrarySubsetHasNoId) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrarySubset({ + .library_id = "", + .should_include_macro = nullptr, + .should_include_overload = nullptr, + }), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("library id must not be empty"))); +} + +TEST(CompilerFactoryTest, FailsIfNullDescriptorPool) { + std::shared_ptr pool = + internal::GetSharedTestingDescriptorPool(); + pool.reset(); + ASSERT_THAT( + NewCompilerBuilder(std::move(pool)), + absl_testing::StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("descriptor_pool must not be null"))); +} + +TEST(CompilerFactoryTest, ToBuilderWorks) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + auto derived_builder = compiler->ToBuilder(); + + ASSERT_THAT(derived_builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto derived_compiler, derived_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + derived_compiler->Compile("has(a.b) && a.?b.orValue('foo') == 'foo'")); + EXPECT_TRUE(result.IsValid()); +} + +TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("[[1, 2, 3]][?0]", "", &arena)); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + auto it = result.GetResolvedTypeMap().find(ast->root_expr().id()); + ASSERT_TRUE(it != result.GetResolvedTypeMap().end()); + EXPECT_TRUE( + it->second.IsOptional() && + it->second.GetOptional().GetParameter().IsList() && + it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); +} + +TEST(CompilerFactoryTest, ReturnsIssuesFromParser) { + CompilerOptions opts; + opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a +")); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), testing::Not(testing::IsEmpty())); +} + +} // namespace +} // namespace cel diff --git a/compiler/compiler_library_subset_factory.cc b/compiler/compiler_library_subset_factory.cc new file mode 100644 index 000000000..8098ceb67 --- /dev/null +++ b/compiler/compiler_library_subset_factory.cc @@ -0,0 +1,91 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/compiler_library_subset_factory.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_checker_subset_factory.h" +#include "compiler/compiler.h" +#include "parser/parser_subset_factory.h" + +namespace cel { + +CompilerLibrarySubset MakeStdlibSubset( + absl::flat_hash_set macro_names, + absl::flat_hash_set function_overload_ids, + StdlibSubsetOptions options) { + CompilerLibrarySubset subset; + subset.library_id = "stdlib"; + switch (options.macro_list) { + case cel::StdlibSubsetOptions::ListKind::kInclude: + subset.should_include_macro = + IncludeMacrosByNamePredicate(std::move(macro_names)); + break; + case cel::StdlibSubsetOptions::ListKind::kExclude: + subset.should_include_macro = + ExcludeMacrosByNamePredicate(std::move(macro_names)); + break; + case cel::StdlibSubsetOptions::ListKind::kIgnore: + subset.should_include_macro = nullptr; + break; + } + + switch (options.function_list) { + case cel::StdlibSubsetOptions::ListKind::kInclude: + subset.should_include_overload = + IncludeOverloadsByIdPredicate(std::move(function_overload_ids)); + break; + case cel::StdlibSubsetOptions::ListKind::kExclude: + subset.should_include_overload = + ExcludeOverloadsByIdPredicate(std::move(function_overload_ids)); + break; + case cel::StdlibSubsetOptions::ListKind::kIgnore: + subset.should_include_overload = nullptr; + break; + } + + return subset; +} + +CompilerLibrarySubset MakeStdlibSubset( + absl::Span macro_names, + absl::Span function_overload_ids, + StdlibSubsetOptions options) { + return MakeStdlibSubset( + absl::flat_hash_set(macro_names.begin(), macro_names.end()), + absl::flat_hash_set(function_overload_ids.begin(), + function_overload_ids.end()), + options); +} + +CompilerLibrarySubset MakeStdlibSubsetByOverloadId( + absl::Span function_overload_ids, + StdlibSubsetOptions options) { + options.macro_list = StdlibSubsetOptions::ListKind::kIgnore; + return MakeStdlibSubset({}, function_overload_ids, options); +} + +CompilerLibrarySubset MakeStdlibSubsetByMacroName( + absl::Span macro_names, + StdlibSubsetOptions options) { + options.function_list = StdlibSubsetOptions::ListKind::kIgnore; + return MakeStdlibSubset(macro_names, {}, options); +} + +} // namespace cel diff --git a/compiler/compiler_library_subset_factory.h b/compiler/compiler_library_subset_factory.h new file mode 100644 index 000000000..982f4e18c --- /dev/null +++ b/compiler/compiler_library_subset_factory.h @@ -0,0 +1,80 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_LIBRARY_SUBSET_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_LIBRARY_SUBSET_FACTORY_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "compiler/compiler.h" + +namespace cel { + +struct StdlibSubsetOptions { + enum class ListKind { + // Include the given list of macros or functions, default to exclude. + kInclude, + // Exclude the given list of macros or functions, default to include. + kExclude, + // Ignore the given list of macros or functions. This is used to clarify + // intent of an empty list. + kIgnore + }; + ListKind macro_list = ListKind::kInclude; + ListKind function_list = ListKind::kInclude; +}; + +// Creates a subset of the CEL standard library. +// +// Example usage: +// // Include only the core boolean operators, and exists/all. +// // std::unique_ptr builder = ...; +// builder->AddLibrary(StandardCompilerLibrary()); +// // Add the subset. +// builder->AddLibrarySubset(MakeStdlibSubset( +// {"exists", "all"}, +// {"logical_and", "logical_or", "logical_not", "not_strictly_false", +// "equal", "inequal"}); +// +// // Exclude list concatenation and map macros. +// builder->AddLibrarySubset(MakeStdlibSubset( +// {"map"}, +// {"add_list"}, +// { .macro_list = StdlibSubsetOptions::ListKind::kExclude, +// .function_list = StdlibSubsetOptions::ListKind::kExclude +// })); +CompilerLibrarySubset MakeStdlibSubset( + absl::flat_hash_set macro_names, + absl::flat_hash_set function_overload_ids, + StdlibSubsetOptions options = {}); + +CompilerLibrarySubset MakeStdlibSubset( + absl::Span macro_names, + absl::Span function_overload_ids, + StdlibSubsetOptions options = {}); + +CompilerLibrarySubset MakeStdlibSubsetByOverloadId( + absl::Span function_overload_ids, + StdlibSubsetOptions options = {}); + +CompilerLibrarySubset MakeStdlibSubsetByMacroName( + absl::Span macro_names, + StdlibSubsetOptions options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_LIBRARY_SUBSET_FACTORY_H_ diff --git a/compiler/compiler_library_subset_factory_test.cc b/compiler/compiler_library_subset_factory_test.cc new file mode 100644 index 000000000..8a6a0ff5b --- /dev/null +++ b/compiler/compiler_library_subset_factory_test.cc @@ -0,0 +1,147 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/compiler_library_subset_factory.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/validation_result.h" +#include "common/standard_definitions.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +using ::absl_testing::IsOk; +using ::testing::Not; + +namespace cel { +namespace { + +MATCHER(IsValid, "") { + const absl::StatusOr& result = arg; + if (!result.ok()) { + (*result_listener) << "compilation failed: " << result.status(); + return false; + } + if (!result->GetIssues().empty()) { + (*result_listener) << "compilation issues: \n" << result->FormatError(); + } + return result->IsValid(); +} + +TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetInclude) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_THAT( + builder->AddLibrarySubset(MakeStdlibSubset( + {"exists", "all"}, + {StandardOverloadIds::kAnd, StandardOverloadIds::kOr, + StandardOverloadIds::kNot, StandardOverloadIds::kNotStrictlyFalse, + StandardOverloadIds::kEquals, StandardOverloadIds::kNotEquals})), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + + EXPECT_THAT( + compiler->Compile( + "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), + IsValid()); + EXPECT_THAT(compiler->Compile("1+2"), Not(IsValid())); + EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); +} + +TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetExclude) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_THAT(builder->AddLibrarySubset(MakeStdlibSubset( + absl::flat_hash_set({"map"}), {"add_list"}, + {.macro_list = StdlibSubsetOptions::ListKind::kExclude, + .function_list = StdlibSubsetOptions::ListKind::kExclude})), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + + EXPECT_THAT( + compiler->Compile( + "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), + IsValid()); + EXPECT_THAT(compiler->Compile("1+2"), IsValid()); + EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); + EXPECT_THAT(compiler->Compile("[2] + [1]"), Not(IsValid())); +} + +TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetByMacroName) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + absl::string_view kMacroNames[] = {"map"}; + ASSERT_THAT(builder->AddLibrarySubset(MakeStdlibSubsetByMacroName( + kMacroNames, + {.macro_list = StdlibSubsetOptions::ListKind::kExclude})), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + + EXPECT_THAT( + compiler->Compile( + "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), + IsValid()); + EXPECT_THAT(compiler->Compile("1+2"), IsValid()); + EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); + EXPECT_THAT(compiler->Compile("[2] + [1]"), IsValid()); +} + +TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetByOverloadId) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + absl::string_view kOverloadIds[] = {"add_list", "add_string"}; + ASSERT_THAT(builder->AddLibrarySubset(MakeStdlibSubsetByOverloadId( + kOverloadIds, + {// unused + .macro_list = StdlibSubsetOptions::ListKind::kInclude, + .function_list = StdlibSubsetOptions::ListKind::kExclude})), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + + EXPECT_THAT( + compiler->Compile( + "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), + IsValid()); + EXPECT_THAT(compiler->Compile("1+2"), IsValid()); + EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); + EXPECT_THAT(compiler->Compile("[2] + [1]"), Not(IsValid())); +} + +} // namespace +} // namespace cel diff --git a/compiler/optional.cc b/compiler/optional.cc new file mode 100644 index 000000000..077635bf3 --- /dev/null +++ b/compiler/optional.cc @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/optional.h" + +#include "absl/status/status.h" +#include "checker/optional.h" +#include "compiler/compiler.h" +#include "parser/macro.h" +#include "parser/parser_interface.h" + +namespace cel { + +CompilerLibrary OptionalCompilerLibrary(int version) { + CompilerLibrary library = + CompilerLibrary::FromCheckerLibrary(OptionalCheckerLibrary(version)); + + library.configure_parser = [version](ParserBuilder& builder) { + builder.GetOptions().enable_optional_syntax = true; + absl::Status status; + status.Update(builder.AddMacro(OptMapMacro())); + if (version == 0) { + return status; + } + status.Update(builder.AddMacro(OptFlatMapMacro())); + return status; + }; + + return library; +} + +} // namespace cel diff --git a/compiler/optional.h b/compiler/optional.h new file mode 100644 index 000000000..21e798339 --- /dev/null +++ b/compiler/optional.h @@ -0,0 +1,28 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ + +#include "checker/optional.h" +#include "compiler/compiler.h" + +namespace cel { + +// CompilerLibrary that enables support for CEL optional types. +CompilerLibrary OptionalCompilerLibrary( + int version = kOptionalExtensionLatestVersion); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ diff --git a/compiler/optional_test.cc b/compiler/optional_test.cc new file mode 100644 index 000000000..699c69f76 --- /dev/null +++ b/compiler/optional_test.cc @@ -0,0 +1,384 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "compiler/optional.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "checker/optional.h" +#include "checker/standard_library.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "testutil/baseline_tests.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::test::FormatBaselineAst; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string expr; + std::string expected_ast; +}; + +class OptionalTest : public testing::TestWithParam {}; + +std::string FormatIssues(const ValidationResult& result) { + const Source* source = result.GetSource(); + return absl::StrJoin( + result.GetIssues(), "\n", + [=](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend( + out, (source) ? issue.ToDisplayString(*source) : issue.message()); + }); +} + +TEST_P(OptionalTest, OptionalsEnabled) { + const TestCase& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable(MakeVariableDecl( + "msg", MessageType(TestAllTypes::descriptor()))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + absl::StatusOr maybe_result = + compiler->Compile(test_case.expr); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, std::move(maybe_result)); + ASSERT_TRUE(result.IsValid()) << FormatIssues(result); + EXPECT_EQ(FormatBaselineAst(*result.GetAst()), + absl::StripAsciiWhitespace(test_case.expected_ast)) + << test_case.expr; +} + +INSTANTIATE_TEST_SUITE_P( + OptionalTest, OptionalTest, + ::testing::Values( + TestCase{ + .expr = "msg.?single_int64", + .expected_ast = R"( +_?._( + msg~cel.expr.conformance.proto3.TestAllTypes^msg, + "single_int64" +)~optional_type(int)^select_optional_field)", + }, + TestCase{ + .expr = "optional.of('foo')", + .expected_ast = R"( +optional.of( + "foo"~string +)~optional_type(string)^optional_of)", + }, + TestCase{ + .expr = "optional.of('foo').optMap(x, x)", + .expected_ast = R"( +_?_:_( + optional.of( + "foo"~string + )~optional_type(string)^optional_of.hasValue()~bool^optional_hasValue, + optional.of( + __comprehension__( + // Variable + #unused, + // Target + []~list(dyn), + // Accumulator + x, + // Init + optional.of( + "foo"~string + )~optional_type(string)^optional_of.value()~string^optional_value, + // LoopCondition + false~bool, + // LoopStep + x~string^x, + // Result + x~string^x)~string + )~optional_type(string)^optional_of, + optional.none()~optional_type(string)^optional_none +)~optional_type(string)^conditional +)", + }, + TestCase{ + .expr = "optional.of('foo').optFlatMap(x, optional.of(x))", + .expected_ast = R"( +_?_:_( + optional.of( + "foo"~string + )~optional_type(string)^optional_of.hasValue()~bool^optional_hasValue, + __comprehension__( + // Variable + #unused, + // Target + []~list(dyn), + // Accumulator + x, + // Init + optional.of( + "foo"~string + )~optional_type(string)^optional_of.value()~string^optional_value, + // LoopCondition + false~bool, + // LoopStep + x~string^x, + // Result + optional.of( + x~string^x + )~optional_type(string)^optional_of)~optional_type(string), + optional.none()~optional_type(string)^optional_none +)~optional_type(string)^conditional +)", + }, + TestCase{ + .expr = "optional.ofNonZeroValue(1)", + .expected_ast = R"( +optional.ofNonZeroValue( + 1~int +)~optional_type(int)^optional_ofNonZeroValue +)", + }, + TestCase{ + .expr = "[0][?1]", + .expected_ast = R"( +_[?_]( + [ + 0~int + ]~list(int), + 1~int +)~optional_type(int)^list_optindex_optional_int +)", + }, + TestCase{ + .expr = "{0: 2}[?1]", + .expected_ast = R"( +_[?_]( + { + 0~int:2~int + }~map(int, int), + 1~int +)~optional_type(int)^map_optindex_optional_value +)", + }, + TestCase{ + .expr = "msg.?repeated_int64[1]", + .expected_ast = R"( +_[_]( + _?._( + msg~cel.expr.conformance.proto3.TestAllTypes^msg, + "repeated_int64" + )~optional_type(list(int))^select_optional_field, + 1~int +)~optional_type(int)^optional_list_index_int +)", + }, + TestCase{ + .expr = "msg.?map_int64_int64[1]", + .expected_ast = R"( +_[_]( + _?._( + msg~cel.expr.conformance.proto3.TestAllTypes^msg, + "map_int64_int64" + )~optional_type(map(int, int))^select_optional_field, + 1~int +)~optional_type(int)^optional_map_index_value +)", + }, + TestCase{ + .expr = "optional.of(1).or(optional.of(2))", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.or( + optional.of( + 2~int + )~optional_type(int)^optional_of +)~optional_type(int)^optional_or_optional)", + }, + TestCase{ + .expr = "optional.of(1).orValue(2)", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.orValue( + 2~int +)~int^optional_orValue_value +)", + }, + TestCase{ + .expr = "optional.of(1).value()", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.value()~int^optional_value +)", + }, + TestCase{ + .expr = "optional.of(1).hasValue()", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.hasValue()~bool^optional_hasValue +)", + })); + +TEST(OptionalTest, NotEnabled) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable(MakeVariableDecl( + "msg", MessageType(TestAllTypes::descriptor()))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("optional.of(1)")); + + EXPECT_THAT(FormatIssues(result), + HasSubstr("undeclared reference to 'optional'")); +} + +struct OptionalExtensionVersionTestCase { + std::string expr; + std::vector expected_supported_versions; +}; + +class OptionalExtensionVersionTest + : public ::testing::TestWithParam {}; + +TEST_P(OptionalExtensionVersionTest, OptionalExtensionVersions) { + const OptionalExtensionVersionTestCase& test_case = GetParam(); + for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; + ++version) { + CompilerLibrary compiler_library = OptionalCompilerLibrary(version); + + CompilerOptions compiler_options; + compiler_options.parser_options.enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), + compiler_options)); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(test_case.expr)); + if (absl::c_contains(test_case.expected_supported_versions, version)) { + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << "Expected no issues for expr: " << test_case.expr + << " at version: " << version << " but got: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference")))) + << "Expected undeclared reference for expr: " << test_case.expr + << " at version: " << version; + } + } +}; + +std::vector +CreateOptionalExtensionVersionParams() { + return { + OptionalExtensionVersionTestCase{ + .expr = "optional_type", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.of('foo').optMap(x, x)", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.of('foo')", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.ofNonZeroValue(1)", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.of('foo').value()", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.of('foo').hasValue()", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.of(1).or(optional.of(2))", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.of(1).orValue(2)", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "[1, 2, 3][?5]", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "dyn(1).?bar", + .expected_supported_versions = {0, 1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "optional.of('foo').optFlatMap(x, optional.of(x))", + .expected_supported_versions = {1, 2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "[1, 2, 3].first()", + .expected_supported_versions = {2}, + }, + OptionalExtensionVersionTestCase{ + .expr = "[1, 2, 3].last()", + .expected_supported_versions = {2}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(OptionalExtensionVersionTest, + OptionalExtensionVersionTest, + ValuesIn(CreateOptionalExtensionVersionParams())); + +} // namespace +} // namespace cel diff --git a/compiler/standard_library.cc b/compiler/standard_library.cc new file mode 100644 index 000000000..a178996ed --- /dev/null +++ b/compiler/standard_library.cc @@ -0,0 +1,49 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/standard_library.h" + +#include "absl/status/status.h" +#include "checker/standard_library.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/parser_interface.h" + +namespace cel { + +namespace { + +absl::Status AddStandardLibraryMacros(ParserBuilder& builder) { + // For consistency with the Parse free functions, follow the convenience + // option to disable all the standard macros. + if (builder.GetOptions().disable_standard_macros) { + return absl::OkStatus(); + } + for (const auto& macro : Macro::AllMacros()) { + CEL_RETURN_IF_ERROR(builder.AddMacro(macro)); + } + return absl::OkStatus(); +} + +} // namespace + +CompilerLibrary StandardCompilerLibrary() { + CompilerLibrary library = + CompilerLibrary::FromCheckerLibrary(StandardCheckerLibrary()); + library.configure_parser = AddStandardLibraryMacros; + return library; +} + +} // namespace cel diff --git a/compiler/standard_library.h b/compiler/standard_library.h new file mode 100644 index 000000000..c19029b12 --- /dev/null +++ b/compiler/standard_library.h @@ -0,0 +1,27 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_STANDARD_LIBRARY_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_STANDARD_LIBRARY_H_ + +#include "compiler/compiler.h" + +namespace cel { + +// Returns a CompilerLibrary containing all of the standard CEL declarations +// and macros. +CompilerLibrary StandardCompilerLibrary(); + +} // namespace cel +#endif // THIRD_PARTY_CEL_CPP_COMPILER_STANDARD_LIBRARY_H_ diff --git a/conformance/BUILD b/conformance/BUILD index 24db6ce08..35d554c7b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -1,16 +1,142 @@ -# Description -# Implementation of the conformance test server +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("//conformance:run.bzl", "gen_conformance_tests") package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) -# TODO(issues/77): Add support for the proto2, proto3 conformance tests. -ALL_TESTS = [ +cc_library( + name = "service", + testonly = True, + srcs = ["service.cc"], + hdrs = ["service.h"], + deps = [ + "//checker:optional", + "//checker:standard_library", + "//checker:type_checker_builder", + "//checker:type_checker_builder_factory", + "//common:ast", + "//common:ast_proto", + "//common:decl_proto_v1alpha1", + "//common:source", + "//common:value", + "//common/internal:value_conversion", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:transform_utility", + "//extensions:bindings_ext", + "//extensions:comprehensions_v2", + "//extensions:comprehensions_v2_functions", + "//extensions:comprehensions_v2_macros", + "//extensions:encoders", + "//extensions:math_ext", + "//extensions:math_ext_decls", + "//extensions:math_ext_macros", + "//extensions:proto_ext", + "//extensions:select_optimization", + "//extensions:strings", + "//extensions/protobuf:enum_adapter", + "//internal:status_macros", + "//parser", + "//parser:macro_registry", + "//parser:options", + "//parser:standard_macros", + "//runtime", + "//runtime:activation", + "//runtime:constant_folding", + "//runtime:optional_types", + "//runtime:reference_resolver", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "//testutil:test_macros", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_googleapis//google/rpc:code_cc_proto", + "@com_google_googleapis//google/rpc:status_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:empty_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + ], +) + +cc_library( + name = "run", + testonly = True, + srcs = ["run.cc"], + deps = [ + ":service", + ":utils", + "//internal:runfiles", + "//internal:testing_no_main", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:simple_cc_proto", + "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_googleapis//google/rpc:code_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//src/google/protobuf/io", + ], + alwayslink = True, +) + +cc_library( + name = "utils", + testonly = True, + hdrs = ["utils.h"], + deps = [ + "//internal:testing_no_main", + "@com_google_absl//absl/log:absl_check", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) + +_ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/basic.textproto", + "@com_google_cel_spec//tests/simple:testdata/bindings_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/comparisons.textproto", "@com_google_cel_spec//tests/simple:testdata/conversions.textproto", "@com_google_cel_spec//tests/simple:testdata/dynamic.textproto", + "@com_google_cel_spec//tests/simple:testdata/encoders_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/enums.textproto", "@com_google_cel_spec//tests/simple:testdata/fields.textproto", "@com_google_cel_spec//tests/simple:testdata/fp_math.textproto", @@ -18,124 +144,212 @@ ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/lists.textproto", "@com_google_cel_spec//tests/simple:testdata/logic.textproto", "@com_google_cel_spec//tests/simple:testdata/macros.textproto", + "@com_google_cel_spec//tests/simple:testdata/macros2.textproto", + "@com_google_cel_spec//tests/simple:testdata/math_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", + "@com_google_cel_spec//tests/simple:testdata/optionals.textproto", + "@com_google_cel_spec//tests/simple:testdata/parse.textproto", "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", + "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", + "@com_google_cel_spec//tests/simple:testdata/proto2_ext.textproto", + "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", "@com_google_cel_spec//tests/simple:testdata/string.textproto", + "@com_google_cel_spec//tests/simple:testdata/string_ext.textproto", + "@com_google_cel_spec//tests/simple:testdata/timestamps.textproto", "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", + "@com_google_cel_spec//tests/simple:testdata/wrappers.textproto", + "@com_google_cel_spec//tests/simple:testdata/block_ext.textproto", + "@com_google_cel_spec//tests/simple:testdata/type_deduction.textproto", ] -DASHBOARD_TESTS = [ - "@com_google_cel_spec//tests/simple:testdata/basic.textproto", - "@com_google_cel_spec//tests/simple:testdata/comparisons.textproto", - "@com_google_cel_spec//tests/simple:testdata/conversions.textproto", - "@com_google_cel_spec//tests/simple:testdata/dynamic.textproto", - "@com_google_cel_spec//tests/simple:testdata/enums.textproto", - "@com_google_cel_spec//tests/simple:testdata/fields.textproto", - "@com_google_cel_spec//tests/simple:testdata/fp_math.textproto", - "@com_google_cel_spec//tests/simple:testdata/integer_math.textproto", - "@com_google_cel_spec//tests/simple:testdata/lists.textproto", - "@com_google_cel_spec//tests/simple:testdata/logic.textproto", - "@com_google_cel_spec//tests/simple:testdata/macros.textproto", - "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", - "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", - "@com_google_cel_spec//tests/simple:testdata/string.textproto", - "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", +_TESTS_TO_SKIP = [ + # Tests which require spec changes. + # TODO(issues/93): Deprecate Duration.getMilliseconds. + "timestamps/duration_converters/get_milliseconds", + + # Broken test cases which should be supported. + # TODO(issues/112): Unbound functions result in empty eval response. + "basic/functions/unbound", + "basic/functions/unbound_is_runtime_error", + + # TODO(issues/97): Parse-only qualified variable lookup "x.y" with binding "x.y" or "y" within container "x" fails + "fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", + "namespace/qualified/self_eval_qualified_lookup", + "namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", + # TODO(issues/117): Integer overflow on enum assignments should error. + "enums/legacy_proto2/select_big,select_neg", + + # Skip until fixed. + "wrappers/field_mask/to_json", + "wrappers/empty/to_json", + "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", + "parse/receiver_function_names", + + # Future features for CEL 1.0 + # TODO(issues/119): Strong typing support for enums, specified but not implemented. + "enums/strong_proto2", + "enums/strong_proto3", + + # These depend on legacy US/ timezones. It's spotty if these are included with a normally + # configured timezone database. + "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", + "timestamps/timestamp_selectors_tz/getDayOfYear", + # These depend on using charconv (or equivalent) to format doubles with shortest possible + # precision to preserve value. Not available on older compilers where we just use absl::Format. + # We should probably update the spec to allow different formats that parse to the same value. + "conversions/string/double_hard", + + # Recent changes + "namespace/namespace_shadowing/basic", + "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] -cc_binary( - name = "server", - testonly = 1, - srcs = ["server.cc"], - deps = [ - "//eval/public:builtin_func_registrar", - "//eval/public:cel_expr_builder_factory", - "//eval/public:transform_utility", - "//eval/public/containers:container_backed_list_impl", - "//eval/public/containers:container_backed_map_impl", - "//internal:proto_util", - "//parser", - "@com_github_grpc_grpc//:grpc++", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:conformance_service_cc_grpc", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googleapis//google/rpc:code_cc_proto", - "@com_google_protobuf//:protobuf", +_TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP + +_TESTS_TO_SKIP_MODERN_DASHBOARD = [ + # Future features for CEL 1.0 + # TODO(issues/119): Strong typing support for enums, specified but not implemented. + "enums/strong_proto2", + "enums/strong_proto3", +] + +_TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ + # Legacy value does not support optional_type. + "optionals/optionals", + + # TODO(uncreated-issue/81): Fix null assignment to a field + "proto2/set_null/list_value", + "proto2/set_null/single_struct", + "proto3/set_null/list_value", + "proto3/set_null/single_struct", + + # no optional support for legacy types + "block_ext/basic/optional_list", + "block_ext/basic/optional_map", + "block_ext/basic/optional_map_chained", + "block_ext/basic/optional_message", +] + +_TESTS_TO_SKIP_CHECKED = [ + # block is a post-check optimization that inserts internal variables. The C++ type checker + # needs support for a proper optimizer for this to work. + # "block_ext", +] + +_TESTS_TO_SKIP_LEGACY_DASHBOARD = [ + # Future features for CEL 1.0 + # TODO(issues/119): Strong typing support for enums, specified but not implemented. + "enums/strong_proto2", + "enums/strong_proto3", + + # Legacy value does not support optional_type. + "optionals/optionals", +] + +# Generates a bunch of `cc_test` whose names follow the pattern +# `conformance_(...)_{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. +gen_conformance_tests( + name = "conformance_parse_only", + data = _ALL_TESTS, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN + ["type_deductions"], +) + +gen_conformance_tests( + name = "conformance_legacy_parse_only", + data = _ALL_TESTS, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY + ["type_deductions"], +) + +gen_conformance_tests( + name = "conformance_checked", + checked = True, + data = _ALL_TESTS, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_legacy_checked", + checked = True, + data = _ALL_TESTS, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, +) + +# select optimization is only supported for checked expressions. +gen_conformance_tests( + name = "conformance_legacy_select_opt", + checked = True, + data = _ALL_TESTS, + modern = False, + select_opt = True, + skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_select_opt", + checked = True, + data = _ALL_TESTS, + modern = True, + select_opt = True, + skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_legacy_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, +) + +# Generates a bunch of `cc_test` whose names follow the pattern +# `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. +gen_conformance_tests( + name = "conformance_dashboard_parse_only", + dashboard = True, + data = _ALL_TESTS, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN_DASHBOARD + ["type_deductions"], + tags = [ + "guitar", + "notap", ], ) -[ - sh_test( - name = "simple-" + driver, - srcs = [":" + driver], - args = [ - "$(location @com_google_cel_spec//tests/simple:simple_test)", - "--server=$(location :server)", - "--check_server=$(location @com_google_cel_go//server/main:cel_server)", - # TODO(issues/78): Missing bytes() conversion functions - "--skip_test=conversions/bytes", - # TODO(issues/79): Missing double() conversion functions - "--skip_test=conversions/double", - # TODO(issues/80): Missing dyn() conversion functions - "--skip_test=conversions/dyn/dyn_heterogeneous_list", - # TODO(issues/81): Conversion functions for int() which can be - # uncommented when the spec changes to truncation rather than - # rounding. - "--skip_test=conversions/int/double_nearest,double_nearest_neg,double_half_away_neg,double_half_away_pos", - # TODO(issues/82): Unexpected behavior when converting invalid bytes to string. - "--skip_test=conversions/string/bytes_invalid", - # TODO(issues/83): Missing type() conversion functions - "--skip_test=conversions/type", - # TODO(issues/84): Missing uint() conversion functions - "--skip_test=conversions/uint", - # Requires container support - "--skip_test=namespace/namespace/self_eval_container_lookup_unchecked", - "--skip_test=basic/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", - "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", - # Requires heteregenous equality spec clarification - "--skip_test=comparisons/eq_literal/eq_bytes", - "--skip_test=comparisons/ne_literal/not_ne_bytes", - "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", - "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", - "--skip_test=fields/in/singleton", - # Requires qualified bindings error message relaxation - "--skip_test=fields/qualified_identifier_resolution/ident_with_longest_prefix_check,int64_field_select_unsupported,list_field_select_unsupported,map_key_null,qualified_identifier_resolution_unchecked", - "--skip_test=integer_math/int64_math/int64_overflow_positive,int64_overflow_negative,uint64_overflow_positive,uint64_overflow_negative", - "--skip_test=string/size/one_unicode,unicode", - "--skip_test=string/bytes_concat/left_unit", - # TODO(issues/85): The exists one macro should not short-circuit false. - "--skip_test=macros/exists_one/list_no_shortcircuit", - # TODO(issues/86): Map macro may produce incorrect results on error. - "--skip_test=macros/map/list_error", - ] + ["$(location " + test + ")" for test in ALL_TESTS], - data = [ - ":server", - "@com_google_cel_go//server/main:cel_server", - "@com_google_cel_spec//tests/simple:simple_test", - ] + ALL_TESTS, - ) - for driver in [ - "test.sh", - "opt-test.sh", - ] -] +gen_conformance_tests( + name = "conformance_dashboard_checked", + checked = True, + dashboard = True, + data = _ALL_TESTS, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN_DASHBOARD, + tags = [ + "guitar", + "notap", + ], +) -sh_test( - name = "simple-dashboard-test.sh", - srcs = ["@com_google_cel_spec//tests:conftest-nofail.sh"], - args = [ - "$(location @com_google_cel_spec//tests/simple:simple_test)", - "--server=$(location :server)", - "--check_server=$(location @com_google_cel_go//server/main:cel_server)", - ] + ["$(location " + test + ")" for test in DASHBOARD_TESTS], - data = [ - ":server", - "@com_google_cel_go//server/main:cel_server", - "@com_google_cel_spec//tests/simple:simple_test", - ] + DASHBOARD_TESTS, - visibility = [ - "//:__subpackages__", - "//third_party/cel:__pkg__", +gen_conformance_tests( + name = "conformance_dashboard_legacy_parse_only", + dashboard = True, + data = _ALL_TESTS, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY_DASHBOARD + ["type_deductions"], + tags = [ + "guitar", + "notap", ], ) diff --git a/conformance/opt-test.sh b/conformance/opt-test.sh deleted file mode 100755 index 087fb576d..000000000 --- a/conformance/opt-test.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -export CEL_CPP_ENABLE_CONSTANT_FOLDING=true -exec "$@" diff --git a/conformance/policy/BUILD b/conformance/policy/BUILD new file mode 100644 index 000000000..29210e02d --- /dev/null +++ b/conformance/policy/BUILD @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load( + "//conformance/policy:policy_conformance_test.bzl", + "cel_policy_conformance_test", +) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "policy_conformance_test_lib", + testonly = True, + srcs = ["policy_conformance_test.cc"], + deps = [ + "//common:ast", + "//common:source", + "//common:value", + "//common/internal:value_conversion", + "//compiler", + "//env", + "//env:config", + "//env:env_runtime", + "//env:env_std_extensions", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//extensions/protobuf:bind_proto_to_activation", + "//extensions/protobuf:enum_adapter", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//internal:testing_no_main", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "//policy:cel_policy_validation_result", + "//policy:compiler", + "//policy:test_util", + "//policy:yaml_policy_parser", + "//runtime", + "//runtime:activation", + "//runtime:function_adapter", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cel_policy_conformance_test( + name = "policy_conformance_test", + example = "@cel_policy//conformance:testdata/nested_rule/policy.yaml", + skip_tests = [ + # TODO(b/506179116): Fix these. + # Need to add k8s custom yaml parser and mock runtime. + "k8s", + ], + test_files = [ + "@cel_policy//conformance:testdata", + ], +) diff --git a/conformance/policy/policy_conformance_test.bzl b/conformance/policy/policy_conformance_test.bzl new file mode 100644 index 000000000..0b4d1a4c6 --- /dev/null +++ b/conformance/policy/policy_conformance_test.bzl @@ -0,0 +1,46 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains build rules for generating policy conformance test targets. +""" + +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +def cel_policy_conformance_test(name, test_files, example, skip_tests = [], **kwargs): + """Generates a policy conformance test target. + + Args: + name: Name of the test target. + test_files: List of targets or files representing the test data. + example: A specific example file from test_files used for runfiles resolution. + skip_tests: List of test cases to skip. + testdata_dir: Path to testdata directory under runfiles. + **kwargs: Additional arguments passed to the underlying cc_test. + """ + args = ["--gunit_fail_if_no_test_linked"] + args.append("--testdata_example='$(rlocationpath {})'".format(example)) + + if skip_tests: + args.append("--skip_tests=" + ",".join(skip_tests)) + + cc_test( + name = name, + data = test_files + [example], + deps = [ + "//conformance/policy:policy_conformance_test_lib", + ], + args = args, + **kwargs + ) diff --git a/conformance/policy/policy_conformance_test.cc b/conformance/policy/policy_conformance_test.cc new file mode 100644 index 000000000..0d68f8abf --- /dev/null +++ b/conformance/policy/policy_conformance_test.cc @@ -0,0 +1,659 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +// NOLINTNEXTLINE(build/c++17) for OSS compatibility +#include + +#include "cel/expr/eval.pb.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/internal/value_conversion.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "extensions/protobuf/bind_proto_to_activation.h" +#include "extensions/protobuf/enum_adapter.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/compiler.h" +#include "policy/test_util.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +// Use a specific file to handle bazel runfiles resolution correctly. We find +// parent directory named 'testdata' to use as the root of the test cases. +ABSL_FLAG(std::string, testdata_example, "", + "Path to a specific example file."); +ABSL_FLAG(std::vector, skip_tests, {}, + "Comma-separated list of tests to skip."); + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::test::TestSuite; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; + +// Implementations for extension functions referenced in conformance tests. +cel::Value LocationCode(const cel::StringValue& ip, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* factory, google::protobuf::Arena* arena) { + std::string ip_str = ip.ToString(); + if (ip_str == "10.0.0.1") return cel::StringValue(arena, "us"); + if (ip_str == "10.0.0.2") return cel::StringValue(arena, "de"); + return cel::StringValue(arena, "ir"); +} + +// TODO(uncreated-issue/92): This should be migrated to use the testrunner utility +// after adding support for reading the yaml specification for envs/tests. +class InputEvaluator { + public: + static absl::StatusOr> Create( + const std::shared_ptr& pool) { + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + // Enable default extensions (optional, bindings) + cel::Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + env.SetConfig(config); + env_runtime.SetConfig(config); + + auto compiler_builder_or = env.NewCompilerBuilder(); + CEL_ASSIGN_OR_RETURN(auto compiler_builder, std::move(compiler_builder_or)); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + auto runtime_builder_or = env_runtime.CreateRuntimeBuilder(); + CEL_ASSIGN_OR_RETURN(auto runtime_builder, std::move(runtime_builder_or)); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + return absl::WrapUnique( + new InputEvaluator(std::move(compiler), std::move(runtime))); + } + + absl::StatusOr Evaluate( + absl::string_view expr_str, google::protobuf::Arena* arena, + google::protobuf::MessageFactory* message_factory) const { + CEL_ASSIGN_OR_RETURN(auto validation_result, compiler_->Compile(expr_str)); + if (!validation_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to compile input expr: ", expr_str)); + } + CEL_ASSIGN_OR_RETURN(auto ast, validation_result.ReleaseAst()); + CEL_ASSIGN_OR_RETURN( + auto program, + runtime_->CreateProgram(std::make_unique(std::move(*ast)))); + cel::Activation activation; + EvaluateOptions options; + options.message_factory = message_factory; + return program->Evaluate(arena, activation, options); + } + + private: + InputEvaluator(std::unique_ptr compiler, + std::unique_ptr runtime) + : compiler_(std::move(compiler)), runtime_(std::move(runtime)) {} + + std::unique_ptr compiler_; + std::unique_ptr runtime_; +}; + +absl::StatusOr EvaluateInputValue( + const cel::expr::conformance::test::InputValue& input_val, + const InputEvaluator& evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + if (input_val.has_expr()) { + return evaluator.Evaluate(input_val.expr(), arena, message_factory); + } + if (input_val.has_value()) { + return cel::test::FromExprValue(input_val.value(), descriptor_pool, + message_factory, arena); + } + return absl::InvalidArgumentError("Empty InputValue"); +} + +class CelValueMatcherImpl + : public testing::MatcherInterface { + public: + CelValueMatcherImpl(cel::Value expected_val, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) + : expected_val_(std::move(expected_val)), + pool_(pool), + message_factory_(message_factory), + arena_(arena) {} + + bool MatchAndExplain(const cel::Value& actual_val, + testing::MatchResultListener* listener) const override { + cel::Value actual = actual_val; + if (actual.IsOptional() && !expected_val_.IsOptional()) { + auto opt_val = actual.AsOptional(); + if (opt_val->HasValue()) { + actual = opt_val->Value(); + } + } + cel::Value eq_result; + auto eq_status = actual.Equal(expected_val_, pool_, message_factory_, + arena_, &eq_result); + if (!eq_status.ok()) { + *listener << "equality check failed with status: " << eq_status; + return false; + } + if (!eq_result.IsTrue()) { + *listener << "expected: " << expected_val_.DebugString() + << "\nactual: " << actual.DebugString(); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "is equal to " << expected_val_.DebugString(); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not equal to " << expected_val_.DebugString(); + } + + private: + cel::Value expected_val_; + const google::protobuf::DescriptorPool* pool_; + google::protobuf::MessageFactory* message_factory_; + google::protobuf::Arena* arena_; +}; + +absl::StatusOr> MakeExpectedValueMatcher( + const cel::expr::conformance::test::TestOutput& output, + const InputEvaluator& input_evaluator, const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + cel::Value expected_val; + if (output.has_result_expr()) { + CEL_ASSIGN_OR_RETURN( + expected_val, + input_evaluator.Evaluate(output.result_expr(), arena, message_factory)); + } else if (output.has_result_value()) { + CEL_ASSIGN_OR_RETURN(expected_val, + cel::test::FromExprValue(output.result_value(), pool, + message_factory, arena)); + } else { + return absl::InvalidArgumentError("Unsupported output kind"); + } + return testing::Matcher( + new CelValueMatcherImpl(expected_val, pool, message_factory, arena)); +} + +bool ShouldRunTest(absl::string_view test_name, + const std::vector& skip_tests) { + for (const std::string& skip : skip_tests) { + if (absl::StartsWith(test_name, skip)) { + return false; + } + } + return true; +} + +absl::Status PopulateActivation( + const cel::expr::conformance::test::TestCase& test, + const InputEvaluator& input_evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + absl::string_view context_msg_type_name, google::protobuf::Arena* arena, + Activation& activation) { + if (!test.has_input_context()) { + for (const auto& [var_name, input_val] : test.input()) { + CEL_ASSIGN_OR_RETURN( + auto val, + EvaluateInputValue(input_val, input_evaluator, descriptor_pool, + message_factory, arena)); + activation.InsertOrAssignValue(var_name, std::move(val)); + } + return absl::OkStatus(); + } + + const auto& input_context = test.input_context(); + const google::protobuf::Message* context_message = nullptr; + + if (input_context.has_context_message()) { + const google::protobuf::Any& any_msg = input_context.context_message(); + const google::protobuf::Descriptor* msg_descriptor = + descriptor_pool->FindMessageTypeByName(context_msg_type_name); + if (msg_descriptor == nullptr) { + return absl::NotFoundError(absl::StrCat( + "Failed to find message descriptor for: ", context_msg_type_name)); + } + const google::protobuf::Message* prototype = + message_factory->GetPrototype(msg_descriptor); + if (prototype == nullptr) { + return absl::NotFoundError( + absl::StrCat("Failed to get prototype for: ", context_msg_type_name)); + } + auto* buf = prototype->New(arena); + if (!any_msg.UnpackTo(buf)) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to unpack context message to ", context_msg_type_name)); + } + context_message = buf; + } else if (input_context.has_context_expr() && + !context_msg_type_name.empty()) { + CEL_ASSIGN_OR_RETURN(cel::Value evaluated_val, + input_evaluator.Evaluate(input_context.context_expr(), + arena, message_factory)); + + if (!evaluated_val.IsParsedMessage()) { + return absl::InvalidArgumentError( + absl::StrCat("Context expression did not evaluate to a message: ", + input_context.context_expr())); + } + if (evaluated_val.GetParsedMessage().GetDescriptor()->full_name() != + context_msg_type_name) { + return absl::InvalidArgumentError(absl::StrCat( + "Context expression evaluated to a message of type ", + evaluated_val.GetParsedMessage().GetDescriptor()->full_name(), + " which does not match the expected type ", context_msg_type_name)); + } + context_message = static_cast( + evaluated_val.GetParsedMessage().operator->()); + } + if (context_message == nullptr) { + return absl::InvalidArgumentError( + "Failed to resolve context message for test case"); + } + + return cel::extensions::BindProtoToActivation( + *context_message, + cel::extensions::BindProtoUnsetFieldBehavior::kBindDefaultValue, + descriptor_pool, message_factory, arena, &activation); +} + +class PolicyTestSuiteRunner { + public: + PolicyTestSuiteRunner(std::string suite_name, + std::unique_ptr compiler, + std::unique_ptr runtime, + std::shared_ptr policy_source, + CelPolicyValidationResult compile_result, + std::shared_ptr pool, + std::shared_ptr message_factory, + std::shared_ptr input_evaluator, + std::string context_msg_type_name, + bool expect_compile_fail = false) + : suite_name_(std::move(suite_name)), + compiler_(std::move(compiler)), + runtime_(std::move(runtime)), + policy_source_(std::move(policy_source)), + compile_result_(std::move(compile_result)), + pool_(std::move(pool)), + message_factory_(std::move(message_factory)), + input_evaluator_(std::move(input_evaluator)), + context_msg_type_name_(std::move(context_msg_type_name)), + expect_compile_fail_(expect_compile_fail) {} + + void RunTest(const cel::expr::conformance::test::TestCase& test, + absl::string_view full_test_name) { + const auto& output = test.output(); + + if (expect_compile_fail_) { + ASSERT_FALSE(compile_result_.IsValid()) + << "Expected compilation to fail in " << full_test_name; + ASSERT_TRUE(output.has_eval_error()) + << "Expected eval_error to be present in compile error test " + << full_test_name; + std::string err_msg = compile_result_.FormatIssues(); + for (const auto& expected_err : output.eval_error().errors()) { + EXPECT_THAT(err_msg, HasSubstr(expected_err.message())) + << "Did not find expected compile time error"; + } + return; + } + + // Compilation should have succeeded for evaluation tests + ASSERT_TRUE(compile_result_.IsValid()) + << "Compilation has validation errors in " << full_test_name << ": " + << compile_result_.FormatIssues(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime_->CreateProgram(std::make_unique( + *compile_result_.GetAst()))); + + // Parse Inputs and evaluate them + google::protobuf::Arena arena; + Activation activation; + ASSERT_THAT(PopulateActivation(test, *input_evaluator_, pool_.get(), + message_factory_.get(), + context_msg_type_name_, &arena, activation), + IsOk()); + + // Evaluate Policy + auto eval_result_or = program->Evaluate(&arena, activation); + ASSERT_THAT(eval_result_or.status(), IsOk()) + << "Evaluation failed in " << full_test_name; + cel::Value actual_val = *eval_result_or; + + ASSERT_OK_AND_ASSIGN( + auto matcher, + MakeExpectedValueMatcher(output, *input_evaluator_, pool_.get(), + message_factory_.get(), &arena)); + + // Apply matcher to the output of evaluation + EXPECT_THAT(actual_val, matcher) << "Test failed: " << full_test_name; + } + + private: + std::string suite_name_; + std::unique_ptr compiler_; + std::unique_ptr runtime_; + std::shared_ptr policy_source_; + CelPolicyValidationResult compile_result_; + std::shared_ptr pool_; + std::shared_ptr message_factory_; + std::shared_ptr input_evaluator_; + std::string context_msg_type_name_; + bool expect_compile_fail_; +}; + +class CelPolicyTest : public testing::Test { + public: + explicit CelPolicyTest(std::shared_ptr runner, + cel::expr::conformance::test::TestCase test_case, + std::string full_test_name, bool skip) + : runner_(std::move(runner)), + test_case_(std::move(test_case)), + full_test_name_(std::move(full_test_name)), + skip_(skip) {} + + void TestBody() override { + if (skip_) { + GTEST_SKIP() << "Skipping test: " << full_test_name_; + } + EXPECT_NO_FATAL_FAILURE(runner_->RunTest(test_case_, full_test_name_)); + } + + private: + std::shared_ptr runner_; + cel::expr::conformance::test::TestCase test_case_; + std::string full_test_name_; + bool skip_; +}; + + +absl::Status RegisterTestSuite( + const std::filesystem::path& dir_path, const std::string& suite_name, + const std::shared_ptr& input_evaluator, + const std::shared_ptr& pool, + const std::shared_ptr& message_factory, + const std::vector& skip_tests) { + // Check if the entire suite should be skipped (prefix match) + for (const auto& skip : skip_tests) { + if (suite_name == skip || + absl::StartsWith(suite_name, absl::StrCat(skip, "/"))) { + std::cout << "[ SKIPPED SUITE ] " << suite_name << std::endl; + return absl::OkStatus(); + } + } + + std::filesystem::path policy_path = dir_path / "policy.yaml"; + std::filesystem::path tests_path = dir_path / "tests.yaml"; + bool is_yaml = true; + if (!std::filesystem::exists(tests_path)) { + tests_path = dir_path / "tests.textproto"; + is_yaml = false; + } + std::filesystem::path config_path = dir_path / "config.yaml"; + + if (!std::filesystem::exists(policy_path) || + !std::filesystem::exists(tests_path)) { + // Not a valid test suite, assume it's a directory we don't care about. + return absl::OkStatus(); + } + + // Parse Environment Config + cel::Config config; + if (std::filesystem::exists(config_path)) { + std::string config_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(config_path.string(), &config_content)); + CEL_ASSIGN_OR_RETURN(config, cel::EnvConfigFromYaml(config_content)); + } + + // Enable default extensions (optional, bindings) in the config + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + + // Set up compiler & runtime environments + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + env.SetConfig(config); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler_builder, env.NewCompilerBuilder()); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + CEL_ASSIGN_OR_RETURN(auto runtime_builder, + env_runtime.CreateRuntimeBuilder()); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + // Register locationCode in runtime + CEL_RETURN_IF_ERROR( + (cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("locationCode", LocationCode, + runtime_builder.function_registry()))); + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + // Parse Policy + std::string policy_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(policy_path.string(), &policy_content)); + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(policy_content, "policy.yaml")); + auto policy_source = std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse policy.yaml in ", suite_name, + "\nIssues:\n", parse_result.FormattedIssues())); + } + const CelPolicy* policy = parse_result.GetPolicy(); + + // Compile Policy (unexpected non-ok status represents a bug) + CEL_ASSIGN_OR_RETURN(CelPolicyValidationResult compile_result, + CompilePolicy(*compiler, *policy)); + + std::string tests_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(tests_path.string(), &tests_content)); + TestSuite test_suite; + if (is_yaml) { + CEL_ASSIGN_OR_RETURN(test_suite, + cel::test::ParsePolicyTestSuiteYaml(tests_content)); + } else { + if (!google::protobuf::TextFormat::ParseFromString(tests_content, &test_suite)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse text proto in ", tests_path.string())); + } + } + + auto runner = std::make_shared( + suite_name, std::move(compiler), std::move(runtime), + std::move(policy_source), std::move(compile_result), pool, + message_factory, input_evaluator, config.GetContextType(), + /*expect_compile_fail=*/absl::StrContains(suite_name, "compile_errors")); + + for (const auto& section : test_suite.sections()) { + std::string section_name = section.name(); + for (const auto& test : section.tests()) { + std::string test_name = test.name(); + std::string full_test_name = + absl::StrCat(suite_name, "/", section_name, "/", test_name); + + bool skip = !ShouldRunTest(full_test_name, skip_tests); + + testing::RegisterTest( + suite_name.c_str(), + absl::StrCat(section_name, "/", test_name).c_str(), nullptr, + test_name.c_str(), __FILE__, __LINE__, + [runner, test, full_test_name, skip]() -> CelPolicyTest* { + return new CelPolicyTest(runner, test, full_test_name, skip); + }); + } + } + return absl::OkStatus(); +} + +void RegisterAllTests() { + // cel::google3-end + std::string testdata_example_flag = absl::GetFlag(FLAGS_testdata_example); + std::vector skip_tests = absl::GetFlag(FLAGS_skip_tests); + + std::string abs_testdata_example = + cel::internal::ResolveRunfilesPath(testdata_example_flag); + ABSL_CHECK(!abs_testdata_example.empty()) + << "Could not find testdata directory: " << testdata_example_flag; + + std::shared_ptr pool = + GetSharedTestingDescriptorPool(); + auto message_factory = + std::make_shared(pool.get()); + message_factory->SetDelegateToGeneratedFactory(true); + auto evaluator_or = InputEvaluator::Create(pool); + ABSL_CHECK_OK(evaluator_or.status()) << "Failed to create input evaluator"; + std::shared_ptr evaluator = std::move(evaluator_or.value()); + + std::filesystem::path testdata_path(abs_testdata_example); + ABSL_CHECK(std::filesystem::exists(testdata_path)) + << "Testdata path does not exist: " << testdata_path; + // walk up to find 'testdata' parent. A work around to portably + // get the expected directory from bazel. + while (!absl::EndsWith(testdata_path.string(), "testdata")) { + testdata_path = testdata_path.parent_path(); + ABSL_CHECK(testdata_path.string().size() > sizeof("testdata")) + << "could not resolve testdata directory"; + } + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(testdata_path)) { + if (!entry.is_directory()) { + continue; + } + std::filesystem::path dir_path = entry.path(); + // Check if this directory has policy.yaml and tests.yaml (or + // tests.textproto) + if (std::filesystem::exists(dir_path / "policy.yaml") && + (std::filesystem::exists(dir_path / "tests.yaml") || + std::filesystem::exists(dir_path / "tests.textproto"))) { + std::string suite_name = absl::StrReplaceAll( + std::filesystem::relative(dir_path, testdata_path).string(), + {{"\\", "/"}}); + + ABSL_CHECK_OK(RegisterTestSuite(dir_path, suite_name, evaluator, pool, + message_factory, skip_tests)); + } + } +} + +} // namespace +} // namespace cel + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + cel::RegisterAllTests(); + return RUN_ALL_TESTS(); +} diff --git a/conformance/run.bzl b/conformance/run.bzl new file mode 100644 index 000000000..8faeb6c16 --- /dev/null +++ b/conformance/run.bzl @@ -0,0 +1,132 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains build rules for generating the conformance test targets. +""" + +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +_TESTS_TO_SKIP_WINDOWS = [ + # These tests depend on configuring a timezone database which isn't available in our windows + # test environment. + "timestamps/timestamp_selectors_tz/getDate", + "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", + "timestamps/timestamp_selectors_tz/getDayOfMonth_name_neg", + "timestamps/timestamp_selectors_tz/getDayOfYear", + "timestamps/timestamp_selectors_tz/getMinutes", +] + +# Converts the list of tests to skip from the format used by the original Go test runner to a single +# flag value where each test is separated by a comma. It also performs expansion, for example +# `foo/bar,baz` becomes two entries which are `foo/bar` and `foo/baz`. +def _expand_tests_to_skip(tests_to_skip): + result = [] + for test_to_skip in tests_to_skip: + comma = test_to_skip.find(",") + if comma == -1: + result.append(test_to_skip) + continue + slash = test_to_skip.rfind("/", 0, comma) + if slash == -1: + slash = 0 + else: + slash = slash + 1 + for part in test_to_skip[slash:].split(","): + result.append(test_to_skip[0:slash] + part) + return result + +def _conformance_test_name(name, optimize, recursive): + return "_".join( + [ + name, + "optimized" if optimize else "unoptimized", + "recursive" if recursive else "iterative", + ], + ) + +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators): + args = [] + if modern: + args.append("--modern") + if optimize: + args.append("--opt") + if select_opt: + args.append("--select_optimization") + if recursive: + args.append("--recursive") + if skip_check: + args.append("--skip_check") + else: + args.append("--noskip_check") + if dashboard: + args.append("--dashboard") + if enable_variadic_logical_operators: + args.append("--enable_variadic_logical_operators") + return args + +def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard, enable_variadic_logical_operators): + cc_test( + name = _conformance_test_name(name, optimize, recursive), + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators) + ["$(rlocationpath {})".format(test) for test in data], + env = select( + { + "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, + "//conditions:default": {"CEL_SKIP_TESTS": ",".join(skip_tests)}, + }, + ), + data = data, + deps = ["//conformance:run"], + tags = tags, + ) + +def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = [], enable_variadic_logical_operators = False): + """Generates conformance tests. + + Args: + name: prefix for all tests + data: textproto targets describing conformance tests + modern: run using modern APIs + checked: whether to apply type checking + select_opt: enable select optimization + dashboard: enable dashboard mode + skip_tests: tests to skip in the format of the cel-spec test runner. See documentation + in github.com/google/cel-spec/tests/simple/simple_test.go + tags: tags added to the generated targets + enable_variadic_logical_operators: enable variadic logical operators + """ + skip_check = not checked + tests = [] + for optimize in (True, False): + for recursive in (True, False): + test_name = _conformance_test_name(name, optimize, recursive) + tests.append(test_name) + _conformance_test( + name, + data, + modern = modern, + optimize = optimize, + recursive = recursive, + select_opt = select_opt, + skip_check = skip_check, + skip_tests = _expand_tests_to_skip(skip_tests), + tags = tags, + dashboard = dashboard, + enable_variadic_logical_operators = enable_variadic_logical_operators, + ) + native.test_suite( + name = name, + tests = tests, + tags = tags, + ) diff --git a/conformance/run.cc b/conformance/run.cc new file mode 100644 index 000000000..1be16ba60 --- /dev/null +++ b/conformance/run.cc @@ -0,0 +1,300 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is a native C++ implementation of the original Go conformance test +// runner located at +// https://github.com/google/cel-spec/tree/master/tests/simple. It was ported to +// C++ to avoid having to pull in Go, gRPC, and others just to run C++ +// conformance tests; as well as integrating better with C++ testing +// infrastructure. + +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" +#include "cel/expr/eval.pb.h" +#include "google/api/expr/v1alpha1/checked.pb.h" // IWYU pragma: keep +#include "google/api/expr/v1alpha1/eval.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" // IWYU pragma: keep +#include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" +#include "google/rpc/code.pb.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/types/span.h" +#include "conformance/service.h" +#include "conformance/utils.h" +#include "internal/runfiles.h" +#include "internal/testing.h" +#include "cel/expr/conformance/test/simple.pb.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +ABSL_FLAG(bool, opt, false, "Enable optimizations (constant folding)"); +ABSL_FLAG( + bool, modern, false, + "Use modern cel::Value APIs implementation of the conformance service."); +ABSL_FLAG(bool, recursive, false, + "Enable recursive plans. Depth limited to slightly more than the " + "default nesting limit."); +ABSL_FLAG(std::vector, skip_tests, {}, "Tests to skip"); +ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures"); +ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions"); +ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); +ABSL_FLAG(bool, enable_variadic_logical_operators, false, + "Enable parsing logical AND & OR operators as a single flat variadic " + "call."); + +namespace { + +using cel::expr::conformance::test::SimpleTest; +using cel::expr::conformance::test::SimpleTestFile; +using google::api::expr::conformance::v1alpha1::CheckRequest; +using google::api::expr::conformance::v1alpha1::CheckResponse; +using google::api::expr::conformance::v1alpha1::EvalRequest; +using google::api::expr::conformance::v1alpha1::EvalResponse; +using google::api::expr::conformance::v1alpha1::ParseRequest; +using google::api::expr::conformance::v1alpha1::ParseResponse; +using ::testing::IsEmpty; + +google::rpc::Code ToGrpcCode(absl::StatusCode code) { + return static_cast(code); +} + +bool ShouldSkipTest(absl::Span tests_to_skip, + absl::string_view name) { + for (absl::string_view test_to_skip : tests_to_skip) { + auto consumed_name = name; + if (absl::ConsumePrefix(&consumed_name, test_to_skip) && + (consumed_name.empty() || absl::StartsWith(consumed_name, "/"))) { + return true; + } + } + return false; +} + +SimpleTest DefaultTestMatcherToTrueIfUnset(const SimpleTest& test) { + auto test_copy = test; + if (test_copy.result_matcher_case() == SimpleTest::RESULT_MATCHER_NOT_SET) { + test_copy.mutable_value()->set_bool_value(true); + } + return test_copy; +} + +class ConformanceTest : public testing::Test { + public: + explicit ConformanceTest( + std::shared_ptr service, + const SimpleTest& test, bool skip) + : service_(std::move(service)), + test_(DefaultTestMatcherToTrueIfUnset(test)), + skip_(skip) {} + + void TestBody() override { + if (skip_) { + GTEST_SKIP(); + } + ParseRequest parse_request; + parse_request.set_cel_source(test_.expr()); + parse_request.set_source_location(test_.name()); + parse_request.set_disable_macros(test_.disable_macros()); + ParseResponse parse_response; + service_->Parse(parse_request, parse_response); + ASSERT_THAT(parse_response.issues(), IsEmpty()); + + EvalRequest eval_request; + if (!test_.container().empty()) { + eval_request.set_container(test_.container()); + } + if (!test_.bindings().empty()) { + for (const auto& binding : test_.bindings()) { + absl::Cord serialized; + ABSL_CHECK(binding.second.SerializePartialToString(&serialized)); + ABSL_CHECK((*eval_request.mutable_bindings())[binding.first] + .ParsePartialFromString(serialized)); + } + } + + if (absl::GetFlag(FLAGS_skip_check) || test_.disable_check()) { + eval_request.set_allocated_parsed_expr( + parse_response.release_parsed_expr()); + } else { + CheckRequest check_request; + check_request.set_allocated_parsed_expr( + parse_response.release_parsed_expr()); + check_request.set_container(test_.container()); + for (const auto& type_env : test_.type_env()) { + absl::Cord serialized; + ABSL_CHECK(type_env.SerializePartialToString(&serialized)); + ABSL_CHECK( + check_request.add_type_env()->ParsePartialFromString(serialized)); + } + CheckResponse check_response; + service_->Check(check_request, check_response); + ASSERT_THAT(check_response.issues(), IsEmpty()) << absl::StrCat( + "unexpected type check issues for: '", test_.expr(), "'\n"); + eval_request.set_allocated_checked_expr( + check_response.release_checked_expr()); + } + + if (test_.check_only()) { + ASSERT_TRUE(test_.has_typed_result()) + << "test must specify a typed result if check_only is set"; + EXPECT_THAT(eval_request.checked_expr(), + cel_conformance::ResultTypeMatches( + test_.typed_result().deduced_type())); + return; + } + + EvalResponse eval_response; + if (auto status = service_->Eval(eval_request, eval_response); + !status.ok()) { + auto* issue = eval_response.add_issues(); + issue->set_message(status.message()); + issue->set_code(ToGrpcCode(status.code())); + } + ASSERT_TRUE(eval_response.has_result()) << eval_response; + switch (test_.result_matcher_case()) { + case SimpleTest::kValue: { + absl::Cord serialized; + ABSL_CHECK( + eval_response.result().SerializePartialToString(&serialized)); + cel::expr::ExprValue test_value; + ABSL_CHECK(test_value.ParsePartialFromString(serialized)); + EXPECT_THAT(test_value, + cel_conformance::MatchesConformanceValue(test_.value())); + break; + } + case SimpleTest::kTypedResult: { + ASSERT_TRUE(eval_request.has_checked_expr()) + << "expression was not type checked"; + absl::Cord serialized; + ABSL_CHECK( + eval_response.result().SerializePartialToString(&serialized)); + cel::expr::ExprValue test_value; + ABSL_CHECK(test_value.ParsePartialFromString(serialized)); + EXPECT_THAT(test_value, cel_conformance::MatchesConformanceValue( + test_.typed_result().result())); + EXPECT_THAT(eval_request.checked_expr(), + cel_conformance::ResultTypeMatches( + test_.typed_result().deduced_type())); + break; + } + case SimpleTest::kEvalError: + EXPECT_TRUE(eval_response.result().has_error()) + << eval_response.result(); + break; + default: + ADD_FAILURE() << "unexpected matcher kind: " + << test_.result_matcher_case(); + break; + } + } + + private: + const std::shared_ptr service_; + const SimpleTest test_; + const bool skip_; +}; + +absl::Status RegisterTestsFromFile( + const std::shared_ptr& + service, + absl::Span tests_to_skip, absl::string_view path) { + SimpleTestFile file; + { + std::ifstream in; + in.open(std::string(path), std::ios_base::in | std::ios_base::binary); + if (!in.is_open()) { + return absl::UnknownError(absl::StrCat("failed to open file: ", path)); + } + google::protobuf::io::IstreamInputStream stream(&in); + if (!google::protobuf::TextFormat::Parse(&stream, &file)) { + return absl::UnknownError(absl::StrCat("failed to parse file: ", path)); + } + } + for (const auto& section : file.section()) { + for (const auto& test : section.test()) { + const bool skip = ShouldSkipTest( + tests_to_skip, + absl::StrCat(file.name(), "/", section.name(), "/", test.name())); + testing::RegisterTest( + file.name().c_str(), + absl::StrCat(section.name(), "/", test.name()).c_str(), nullptr, + nullptr, __FILE__, __LINE__, [=]() -> ConformanceTest* { + return new ConformanceTest(service, test, skip); + }); + } + } + return absl::OkStatus(); +} + +// We could push this do be done per test or suite, but to avoid changing more +// than necessary we do it once to mimic the previous runner. +std::shared_ptr +NewConformanceServiceFromFlags() { + auto status_or_service = cel_conformance::NewConformanceService( + cel_conformance::ConformanceServiceOptions{ + .optimize = absl::GetFlag(FLAGS_opt), + .modern = absl::GetFlag(FLAGS_modern), + .recursive = absl::GetFlag(FLAGS_recursive), + .select_optimization = absl::GetFlag(FLAGS_select_optimization), + .enable_variadic_logical_operators = + absl::GetFlag(FLAGS_enable_variadic_logical_operators), + }); + ABSL_CHECK_OK(status_or_service); + return std::shared_ptr( + std::move(*status_or_service)); +} + +} // namespace + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + { + auto service = NewConformanceServiceFromFlags(); + auto tests_to_skip = absl::GetFlag(FLAGS_skip_tests); + if (const char* env_skip = std::getenv("CEL_SKIP_TESTS"); + env_skip != nullptr) { + for (absl::string_view test : + absl::StrSplit(env_skip, ',', absl::SkipEmpty())) { + tests_to_skip.push_back(std::string(test)); + } + } + for (int argi = 1; argi < argc; argi++) { + std::string path = cel::internal::ResolveRunfilesPath(argv[argi]); + ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, + absl::string_view(path))); + } + } + int exit_code = RUN_ALL_TESTS(); + if (absl::GetFlag(FLAGS_dashboard)) { + exit_code = EXIT_SUCCESS; + } + return exit_code; +} diff --git a/conformance/server.cc b/conformance/server.cc deleted file mode 100644 index cd3e22aee..000000000 --- a/conformance/server.cc +++ /dev/null @@ -1,162 +0,0 @@ -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/conformance_service.grpc.pb.h" -#include "google/api/expr/v1alpha1/eval.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/api/expr/v1alpha1/value.pb.h" -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/rpc/code.pb.h" -#include "grpcpp/grpcpp.h" -#include "absl/strings/str_split.h" -#include "eval/public/builtin_func_registrar.h" -#include "eval/public/cel_expr_builder_factory.h" -#include "eval/public/containers/container_backed_list_impl.h" -#include "eval/public/containers/container_backed_map_impl.h" -#include "eval/public/transform_utility.h" -#include "internal/proto_util.h" -#include "parser/parser.h" -#include "absl/status/statusor.h" - - -using ::grpc::Status; -using ::grpc::StatusCode; -using ::google::protobuf::Arena; - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -class ConformanceServiceImpl final - : public v1alpha1::ConformanceService::Service { - public: - ConformanceServiceImpl(std::unique_ptr builder) - : builder_(std::move(builder)) {} - Status Parse(grpc::ServerContext* context, - const v1alpha1::ParseRequest* request, - v1alpha1::ParseResponse* response) override { - if (request->cel_source().empty()) { - return Status(StatusCode::INVALID_ARGUMENT, "No source code."); - } - auto parse_status = parser::Parse(request->cel_source(), ""); - if (!parse_status.ok()) { - auto issue = response->add_issues(); - *issue->mutable_message() = std::string(parse_status.status().message()); - issue->set_code(google::rpc::Code::INVALID_ARGUMENT); - } else { - google::api::expr::v1alpha1::ParsedExpr out; - (out).MergeFrom(parse_status.value()); - response->mutable_parsed_expr()->CopyFrom(out); - } - return Status::OK; - } - Status Check(grpc::ServerContext* context, - const v1alpha1::CheckRequest* request, - v1alpha1::CheckResponse* response) override { - return Status(StatusCode::UNIMPLEMENTED, "Check is not supported"); - } - Status Eval(grpc::ServerContext* context, - const v1alpha1::EvalRequest* request, - v1alpha1::EvalResponse* response) override { - const v1alpha1::Expr* expr = nullptr; - if (request->has_parsed_expr()) { - expr = &request->parsed_expr().expr(); - } else if (request->has_checked_expr()) { - expr = &request->checked_expr().expr(); - } - - Arena arena; - google::api::expr::v1alpha1::SourceInfo source_info; - google::api::expr::v1alpha1::Expr out; - (out).MergeFrom(*expr); - auto cel_expression_status = builder_->CreateExpression(&out, &source_info); - - if (!cel_expression_status.ok()) { - return Status(StatusCode::INTERNAL, - std::string(cel_expression_status.status().message())); - } - - auto cel_expression = std::move(cel_expression_status.value()); - Activation activation; - - for (const auto& pair : request->bindings()) { - auto* import_value = - Arena::CreateMessage(&arena); - (*import_value).MergeFrom(pair.second.value()); - auto import_status = ValueToCelValue(*import_value, &arena); - if (!import_status.ok()) { - return Status(StatusCode::INTERNAL, import_status.status().ToString()); - } - activation.InsertValue(pair.first, import_status.value()); - } - - auto eval_status = cel_expression->Evaluate(activation, &arena); - if (!eval_status.ok()) { - return Status(StatusCode::INTERNAL, - std::string(eval_status.status().message())); - } - - CelValue result = eval_status.value(); - if (result.IsError()) { - *response->mutable_result() - ->mutable_error() - ->add_errors() - ->mutable_message() = std::string(result.ErrorOrDie()->message()); - } else { - google::api::expr::v1alpha1::Value export_value; - auto export_status = CelValueToValue(result, &export_value); - if (!export_status.ok()) { - return Status(StatusCode::INTERNAL, export_status.ToString()); - } - auto* result_value = response->mutable_result()->mutable_value(); - (*result_value).MergeFrom(export_value); - } - return Status::OK; - } - - private: - std::unique_ptr builder_; -}; - -int RunServer(std::string server_address) { - google::protobuf::Arena arena; - InterpreterOptions options; - - const char* enable_constant_folding = - getenv("CEL_CPP_ENABLE_CONSTANT_FOLDING"); - if (enable_constant_folding != nullptr) { - options.constant_folding = true; - options.constant_arena = &arena; - } - - std::unique_ptr builder = - CreateCelExpressionBuilder(options); - auto register_status = RegisterBuiltinFunctions(builder->GetRegistry()); - if (!register_status.ok()) { - return 1; - } - - ConformanceServiceImpl service(std::move(builder)); - grpc::ServerBuilder grpc_builder; - int port; - grpc_builder.AddListeningPort(server_address, - grpc::InsecureServerCredentials(), &port); - grpc_builder.RegisterService(&service); - std::unique_ptr server(grpc_builder.BuildAndStart()); - std::cout << "Listening on 127.0.0.1:" << port << std::endl; - fflush(stdout); - server->Wait(); - return 0; -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google - -int main(int argc, char** argv) { - std::string server_address = "127.0.0.1:0"; - return google::api::expr::runtime::RunServer(server_address); -} diff --git a/conformance/service.cc b/conformance/service.cc new file mode 100644 index 000000000..d81200cad --- /dev/null +++ b/conformance/service.cc @@ -0,0 +1,684 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "conformance/service.h" + +#include +#include +#include +#include + +#include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/eval.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/api/expr/v1alpha1/value.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/empty.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/rpc/code.pb.h" +#include "google/rpc/status.pb.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "checker/optional.h" +#include "checker/standard_library.h" +#include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/decl_proto_v1alpha1.h" +#include "common/internal/value_conversion.h" +#include "common/source.h" +#include "common/value.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/transform_utility.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2.h" +#include "extensions/comprehensions_v2_functions.h" +#include "extensions/comprehensions_v2_macros.h" +#include "extensions/encoders.h" +#include "extensions/math_ext.h" +#include "extensions/math_ext_decls.h" +#include "extensions/math_ext_macros.h" +#include "extensions/proto_ext.h" +#include "extensions/protobuf/enum_adapter.h" +#include "extensions/select_optimization.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/standard_macros.h" +#include "runtime/activation.h" +#include "runtime/constant_folding.h" +#include "runtime/optional_types.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "testutil/test_macros.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +using ::cel::CreateStandardRuntimeBuilder; +using ::cel::Runtime; +using ::cel::RuntimeOptions; +using ::cel::extensions::RegisterProtobufEnum; +using ::cel::test::ConvertWireCompatProto; +using ::cel::test::FromExprValue; +using ::cel::test::ToExprValue; + +using ::google::protobuf::Arena; + +namespace google::api::expr::runtime { + +namespace { + +google::rpc::Code ToGrpcCode(absl::StatusCode code) { + return static_cast(code); +} + +using ConformanceServiceInterface = + ::cel_conformance::ConformanceServiceInterface; + +// Return a normalized raw expr for evaluation. +cel::expr::Expr ExtractExpr( + const conformance::v1alpha1::EvalRequest& request) { + const v1alpha1::Expr* expr = nullptr; + + // For now, discard type-check information if any. + if (request.has_parsed_expr()) { + expr = &request.parsed_expr().expr(); + } else if (request.has_checked_expr()) { + expr = &request.checked_expr().expr(); + } + cel::expr::Expr out; + if (expr != nullptr) { + ABSL_CHECK(ConvertWireCompatProto(*expr, &out)); // Crash OK + } + return out; +} + +absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, + conformance::v1alpha1::ParseResponse& response, + bool enable_optional_syntax, + bool enable_variadic_logical_operators) { + if (request.cel_source().empty()) { + return absl::InvalidArgumentError("no source code"); + } + cel::ParserOptions options; + options.enable_optional_syntax = enable_optional_syntax; + options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = enable_variadic_logical_operators; + cel::MacroRegistry macros; + CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); + CEL_RETURN_IF_ERROR( + cel::extensions::RegisterComprehensionsV2Macros(macros, options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); + CEL_RETURN_IF_ERROR(cel::test::RegisterTestMacros(macros)); + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), + request.source_location())); + CEL_ASSIGN_OR_RETURN(auto parsed_expr, + parser::Parse(*source, macros, options)); + ABSL_CHECK( // Crash OK + ConvertWireCompatProto(parsed_expr, response.mutable_parsed_expr())); + return absl::OkStatus(); +} + +absl::Status CheckImpl(google::protobuf::Arena* arena, + const conformance::v1alpha1::CheckRequest& request, + conformance::v1alpha1::CheckResponse& response) { + cel::expr::ParsedExpr parsed_expr; + + ABSL_CHECK(ConvertWireCompatProto(request.parsed_expr(), // Crash OK + &parsed_expr)); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, + cel::CreateAstFromParsedExpr(parsed_expr)); + + absl::string_view location = parsed_expr.source_info().location(); + std::unique_ptr source; + if (absl::StartsWith(location, "Source: ")) { + location = absl::StripPrefix(location, "Source: "); + CEL_ASSIGN_OR_RETURN(source, cel::NewSource(location)); + } + + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::CreateTypeCheckerBuilder(google::protobuf::DescriptorPool::generated_pool())); + + if (!request.no_std_env()) { + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::MathCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::EncodersCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::ComprehensionsV2CheckerLibrary())); + } + + for (const auto& decl : request.type_env()) { + const auto& name = decl.name(); + if (decl.has_function()) { + CEL_ASSIGN_OR_RETURN( + auto fn_decl, cel::FunctionDeclFromV1Alpha1Proto( + name, decl.function(), + google::protobuf::DescriptorPool::generated_pool(), arena)); + CEL_RETURN_IF_ERROR(builder->AddFunction(std::move(fn_decl))); + } else if (decl.has_ident()) { + CEL_ASSIGN_OR_RETURN( + auto var_decl, cel::VariableDeclFromV1Alpha1Proto( + name, decl.ident(), + google::protobuf::DescriptorPool::generated_pool(), arena)); + CEL_RETURN_IF_ERROR(builder->AddVariable(std::move(var_decl))); + } + } + builder->set_container(request.container()); + + CEL_ASSIGN_OR_RETURN(auto checker, std::move(*builder).Build()); + + CEL_ASSIGN_OR_RETURN(auto validation_result, checker->Check(std::move(ast))); + + for (const auto& checker_issue : validation_result.GetIssues()) { + auto* issue = response.add_issues(); + issue->set_code(ToGrpcCode(absl::StatusCode::kInvalidArgument)); + if (source) { + issue->set_message(checker_issue.ToDisplayString(*source)); + } else { + issue->set_message(checker_issue.message()); + } + } + + const cel::Ast* checked_ast = validation_result.GetAst(); + if (!validation_result.IsValid() || checked_ast == nullptr) { + return absl::OkStatus(); + } + cel::expr::CheckedExpr pb_checked_ast; + CEL_RETURN_IF_ERROR( + cel::AstToCheckedExpr(*validation_result.GetAst(), &pb_checked_ast)); + ABSL_CHECK(ConvertWireCompatProto(pb_checked_ast, // Crash OK + response.mutable_checked_expr())); + return absl::OkStatus(); +} + +class LegacyConformanceServiceImpl : public ConformanceServiceInterface { + public: + static absl::StatusOr> Create( + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { + static auto* constant_arena = new Arena(); + + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto3::TestAllTypes>(); + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto2::TestAllTypes>(); + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto3::NestedTestAllTypes>(); + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto2::NestedTestAllTypes>(); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::int32_ext); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::nested_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::test_all_types_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::nested_enum_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::repeated_test_all_types); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + int64_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + message_scoped_nested_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + nested_enum_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + message_scoped_repeated_test_all_types); + + InterpreterOptions options; + options.enable_qualified_type_identifiers = true; + options.enable_timestamp_duration_overflow_errors = true; + options.enable_heterogeneous_equality = true; + options.enable_empty_wrapper_null_unboxing = true; + options.enable_qualified_identifier_rewrites = true; + options.fail_on_warnings = false; + + if (optimize) { + std::cerr << "Enabling optimizations" << std::endl; + options.constant_folding = true; + options.constant_arena = constant_arena; + } + + if (select_optimization) { + std::cerr << "Enabling select optimizations" << std::endl; + options.enable_select_optimization = true; + } + + if (recursive) { + options.max_recursion_depth = 48; + } + + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + auto type_registry = builder->GetTypeRegistry(); + type_registry->Register( + cel::expr::conformance::proto2::GlobalEnum_descriptor()); + type_registry->Register( + cel::expr::conformance::proto3::GlobalEnum_descriptor()); + type_registry->Register( + cel::expr::conformance::proto2::TestAllTypes::NestedEnum_descriptor()); + type_registry->Register( + cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor()); + CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( + builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( + builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterStringsFunctions( + builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( + builder->GetRegistry(), options)); + + return absl::WrapUnique(new LegacyConformanceServiceImpl( + std::move(builder), enable_variadic_logical_operators)); + } + + void Parse(const conformance::v1alpha1::ParseRequest& request, + conformance::v1alpha1::ParseResponse& response) override { + auto status = + LegacyParse(request, response, /*enable_optional_syntax=*/false, + enable_variadic_logical_operators_); + if (!status.ok()) { + auto* issue = response.add_issues(); + issue->set_code(ToGrpcCode(status.code())); + issue->set_message(status.message()); + } + } + + void Check(const conformance::v1alpha1::CheckRequest& request, + conformance::v1alpha1::CheckResponse& response) override { + google::protobuf::Arena arena; + auto status = CheckImpl(&arena, request, response); + if (!status.ok()) { + auto* issue = response.add_issues(); + issue->set_code(ToGrpcCode(status.code())); + issue->set_message(status.message()); + } + } + + absl::Status Eval(const conformance::v1alpha1::EvalRequest& request, + conformance::v1alpha1::EvalResponse& response) override { + Arena arena; + cel::expr::SourceInfo source_info; + cel::expr::Expr expr = ExtractExpr(request); + builder_->set_container(request.container()); + absl::StatusOr> cel_expression_status = + absl::InternalError( + "no expression provided in ConformanceService::Eval"); + + if (request.has_parsed_expr()) { + cel::expr::ParsedExpr parsed_expr; + if (!ConvertWireCompatProto(request.parsed_expr(), &parsed_expr)) { + return absl::InternalError( + "failed to convert versioned ParsedExpr to unversioned"); + } + cel_expression_status = builder_->CreateExpression( + &parsed_expr.expr(), &parsed_expr.source_info()); + } else if (request.has_checked_expr()) { + cel::expr::CheckedExpr checked_expr; + if (!ConvertWireCompatProto(request.checked_expr(), &checked_expr)) { + return absl::InternalError( + "failed to convert versioned CheckedExpr to unversioned"); + } + cel_expression_status = builder_->CreateExpression(&checked_expr); + } + + if (!cel_expression_status.ok()) { + return absl::InternalError(cel_expression_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); + } + + auto cel_expression = std::move(cel_expression_status.value()); + Activation activation; + + for (const auto& pair : request.bindings()) { + auto* import_value = Arena::Create(&arena); + ABSL_CHECK(ConvertWireCompatProto(pair.second.value(), // Crash OK + import_value)); + auto import_status = ValueToCelValue(*import_value, &arena); + if (!import_status.ok()) { + return absl::InternalError(import_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); + } + activation.InsertValue(pair.first, import_status.value()); + } + + auto eval_status = cel_expression->Evaluate(activation, &arena); + if (!eval_status.ok()) { + *response.mutable_result() + ->mutable_error() + ->add_errors() + ->mutable_message() = eval_status.status().ToString( + absl::StatusToStringMode::kWithEverything); + return absl::OkStatus(); + } + + CelValue result = eval_status.value(); + if (result.IsError()) { + *response.mutable_result() + ->mutable_error() + ->add_errors() + ->mutable_message() = std::string(result.ErrorOrDie()->ToString( + absl::StatusToStringMode::kWithEverything)); + } else { + cel::expr::Value export_value; + auto export_status = CelValueToValue(result, &export_value); + if (!export_status.ok()) { + return absl::InternalError( + export_status.ToString(absl::StatusToStringMode::kWithEverything)); + } + auto* result_value = response.mutable_result()->mutable_value(); + ABSL_CHECK( // Crash OK + ConvertWireCompatProto(export_value, result_value)); + } + return absl::OkStatus(); + } + + private: + LegacyConformanceServiceImpl(std::unique_ptr builder, + bool enable_variadic_logical_operators) + : builder_(std::move(builder)), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} + + std::unique_ptr builder_; + bool enable_variadic_logical_operators_; +}; + +class ModernConformanceServiceImpl : public ConformanceServiceInterface { + public: + static absl::StatusOr> Create( + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto3::TestAllTypes>(); + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto2::TestAllTypes>(); + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto3::NestedTestAllTypes>(); + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto2::NestedTestAllTypes>(); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::int32_ext); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::nested_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::test_all_types_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::nested_enum_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::repeated_test_all_types); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + int64_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + message_scoped_nested_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + nested_enum_ext); + google::protobuf::LinkExtensionReflection( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + message_scoped_repeated_test_all_types); + + RuntimeOptions options; + options.enable_qualified_type_identifiers = true; + options.enable_timestamp_duration_overflow_errors = true; + options.enable_heterogeneous_equality = true; + options.enable_empty_wrapper_null_unboxing = true; + // Planning warnings are expected in conformance tests, but the test expects + // failure to happen at evaluation time so we ignore them. + options.fail_on_warnings = false; + if (recursive) { + options.max_recursion_depth = 48; + } + + return absl::WrapUnique( + new ModernConformanceServiceImpl(options, optimize, select_optimization, + enable_variadic_logical_operators)); + } + + absl::StatusOr> Setup( + absl::string_view container) { + RuntimeOptions options(options_); + options.container = std::string(container); + CEL_ASSIGN_OR_RETURN( + auto builder, CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + if (enable_optimizations_) { + CEL_RETURN_IF_ERROR(cel::extensions::EnableConstantFolding( + builder, google::protobuf::MessageFactory::generated_factory())); + } + CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( + builder, cel::ReferenceResolverEnabled::kAlways)); + if (enable_select_optimization_) { + CEL_RETURN_IF_ERROR(cel::extensions::EnableSelectOptimization(builder)); + } + + auto& type_registry = builder.type_registry(); + // Use linked pbs in the generated descriptor pool. + CEL_RETURN_IF_ERROR(RegisterProtobufEnum( + type_registry, + cel::expr::conformance::proto2::GlobalEnum_descriptor())); + CEL_RETURN_IF_ERROR(RegisterProtobufEnum( + type_registry, + cel::expr::conformance::proto3::GlobalEnum_descriptor())); + CEL_RETURN_IF_ERROR(RegisterProtobufEnum( + type_registry, + cel::expr::conformance::proto2::TestAllTypes::NestedEnum_descriptor())); + CEL_RETURN_IF_ERROR(RegisterProtobufEnum( + type_registry, + cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor())); + + CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( + builder.function_registry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::EnableOptionalTypes(builder)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( + builder.function_registry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterStringsFunctions( + builder.function_registry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( + builder.function_registry(), options)); + + return std::move(builder).Build(); + } + + void Parse(const conformance::v1alpha1::ParseRequest& request, + conformance::v1alpha1::ParseResponse& response) override { + auto status = + LegacyParse(request, response, /*enable_optional_syntax=*/true, + enable_variadic_logical_operators_); + if (!status.ok()) { + auto* issue = response.add_issues(); + issue->set_code(ToGrpcCode(status.code())); + issue->set_message(status.message()); + } + } + + void Check(const conformance::v1alpha1::CheckRequest& request, + conformance::v1alpha1::CheckResponse& response) override { + google::protobuf::Arena arena; + auto status = CheckImpl(&arena, request, response); + if (!status.ok()) { + auto* issue = response.add_issues(); + issue->set_code(ToGrpcCode(status.code())); + issue->set_message(status.message()); + } + } + + absl::Status Eval(const conformance::v1alpha1::EvalRequest& request, + conformance::v1alpha1::EvalResponse& response) override { + google::protobuf::Arena arena; + + auto runtime_status = Setup(request.container()); + if (!runtime_status.ok()) { + return absl::InternalError(runtime_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); + } + std::unique_ptr runtime = + std::move(runtime_status).value(); + + auto program_status = Plan(*runtime, request); + if (!program_status.ok()) { + return absl::InternalError(program_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); + } + std::unique_ptr program = + std::move(program_status).value(); + cel::Activation activation; + + for (const auto& pair : request.bindings()) { + cel::expr::Value import_value; + ABSL_CHECK(ConvertWireCompatProto(pair.second.value(), // Crash OK + &import_value)); + auto import_status = + FromExprValue(import_value, runtime->GetDescriptorPool(), + runtime->GetMessageFactory(), &arena); + if (!import_status.ok()) { + return absl::InternalError(import_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); + } + + activation.InsertOrAssignValue(pair.first, + std::move(import_status).value()); + } + + auto eval_status = program->Evaluate(&arena, activation); + if (!eval_status.ok()) { + *response.mutable_result() + ->mutable_error() + ->add_errors() + ->mutable_message() = eval_status.status().ToString( + absl::StatusToStringMode::kWithEverything); + return absl::OkStatus(); + } + + cel::Value result = eval_status.value(); + if (result->Is()) { + const absl::Status& error = result.GetError().NativeValue(); + *response.mutable_result() + ->mutable_error() + ->add_errors() + ->mutable_message() = std::string( + error.ToString(absl::StatusToStringMode::kWithEverything)); + } else { + auto export_status = ToExprValue(result, runtime->GetDescriptorPool(), + runtime->GetMessageFactory(), &arena); + if (!export_status.ok()) { + return absl::InternalError(export_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); + } + auto* result_value = response.mutable_result()->mutable_value(); + ABSL_CHECK( // Crash OK + ConvertWireCompatProto(*export_status, result_value)); + } + return absl::OkStatus(); + } + + private: + ModernConformanceServiceImpl(const RuntimeOptions& options, + bool enable_optimizations, + bool enable_select_optimization, + bool enable_variadic_logical_operators) + : options_(options), + enable_optimizations_(enable_optimizations), + enable_select_optimization_(enable_select_optimization), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} + + static absl::StatusOr> Plan( + const cel::Runtime& runtime, + const conformance::v1alpha1::EvalRequest& request) { + std::unique_ptr ast; + if (request.has_parsed_expr()) { + cel::expr::ParsedExpr unversioned; + ABSL_CHECK(ConvertWireCompatProto(request.parsed_expr(), // Crash OK + &unversioned)); + + CEL_ASSIGN_OR_RETURN( + ast, cel::CreateAstFromParsedExpr(std::move(unversioned))); + + } else if (request.has_checked_expr()) { + cel::expr::CheckedExpr unversioned; + ABSL_CHECK(ConvertWireCompatProto(request.checked_expr(), // Crash OK + &unversioned)); + CEL_ASSIGN_OR_RETURN( + ast, cel::CreateAstFromCheckedExpr(std::move(unversioned))); + } + if (ast == nullptr) { + return absl::InternalError("no expression provided"); + } + + return runtime.CreateTraceableProgram(std::move(ast)); + } + + RuntimeOptions options_; + bool enable_optimizations_; + bool enable_select_optimization_; + bool enable_variadic_logical_operators_; +}; + +} // namespace + +} // namespace google::api::expr::runtime + +namespace cel_conformance { + +absl::StatusOr> +NewConformanceService(const ConformanceServiceOptions& options) { + if (options.modern) { + return google::api::expr::runtime::ModernConformanceServiceImpl::Create( + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); + } else { + return google::api::expr::runtime::LegacyConformanceServiceImpl::Create( + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); + } +} + +} // namespace cel_conformance diff --git a/conformance/service.h b/conformance/service.h new file mode 100644 index 000000000..8eb97296e --- /dev/null +++ b/conformance/service.h @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CONFORMANCE_SERVICE_H_ +#define THIRD_PARTY_CEL_CPP_CONFORMANCE_SERVICE_H_ + +#include + +#include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace cel_conformance { + +class ConformanceServiceInterface { + public: + virtual ~ConformanceServiceInterface() = default; + + virtual void Parse( + const google::api::expr::conformance::v1alpha1::ParseRequest& request, + google::api::expr::conformance::v1alpha1::ParseResponse& response) = 0; + + virtual void Check( + const google::api::expr::conformance::v1alpha1::CheckRequest& request, + google::api::expr::conformance::v1alpha1::CheckResponse& response) = 0; + + virtual absl::Status Eval( + const google::api::expr::conformance::v1alpha1::EvalRequest& request, + google::api::expr::conformance::v1alpha1::EvalResponse& response) = 0; +}; + +struct ConformanceServiceOptions { + bool optimize; + bool modern; + bool arena; + bool recursive; + bool select_optimization; + bool enable_variadic_logical_operators = false; +}; + +absl::StatusOr> +NewConformanceService(const ConformanceServiceOptions&); + +} // namespace cel_conformance + +#endif // THIRD_PARTY_CEL_CPP_CONFORMANCE_SERVICE_H_ diff --git a/conformance/test.sh b/conformance/test.sh deleted file mode 100755 index 9d8ad4d3c..000000000 --- a/conformance/test.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -exec "$@" diff --git a/conformance/utils.h b/conformance/utils.h new file mode 100644 index 000000000..e01114125 --- /dev/null +++ b/conformance/utils.h @@ -0,0 +1,118 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ +#define THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/eval.pb.h" +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/value.pb.h" +#include "absl/log/absl_check.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/field_comparator.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel_conformance { + +inline std::string DescribeMessage(const google::protobuf::Message& message) { + std::string string; + ABSL_CHECK(google::protobuf::TextFormat::PrintToString(message, &string)); + if (string.empty()) { + string = "\"\"\n"; + } + return string; +} + +MATCHER_P(MatchesConformanceValue, expected, "") { + static auto* kFieldComparator = []() { + auto* field_comparator = new google::protobuf::util::DefaultFieldComparator(); + field_comparator->set_treat_nan_as_equal(true); + return field_comparator; + }(); + static auto* kDifferencer = []() { + auto* differencer = new google::protobuf::util::MessageDifferencer(); + differencer->set_message_field_comparison( + google::protobuf::util::MessageDifferencer::EQUIVALENT); + differencer->set_field_comparator(kFieldComparator); + const auto* descriptor = cel::expr::MapValue::descriptor(); + const auto* entries_field = descriptor->FindFieldByName("entries"); + const auto* key_field = + entries_field->message_type()->FindFieldByName("key"); + differencer->TreatAsMap(entries_field, key_field); + return differencer; + }(); + + const cel::expr::ExprValue& got = arg; + const cel::expr::Value& want = expected; + + cel::expr::ExprValue test_value; + (*test_value.mutable_value()) = want; + + if (kDifferencer->Compare(got, test_value)) { + return true; + } + (*result_listener) << "got: " << DescribeMessage(got); + (*result_listener) << "\n"; + (*result_listener) << "wanted: " << DescribeMessage(test_value); + return false; +} + +MATCHER_P(ResultTypeMatches, expected, "") { + static auto* kDifferencer = []() { + auto* differencer = new google::protobuf::util::MessageDifferencer(); + differencer->set_message_field_comparison( + google::protobuf::util::MessageDifferencer::EQUIVALENT); + return differencer; + }(); + + const cel::expr::Type& want = expected; + const google::api::expr::v1alpha1::CheckedExpr& checked_expr = arg; + + int64_t root_id = checked_expr.expr().id(); + auto it = checked_expr.type_map().find(root_id); + + if (it == checked_expr.type_map().end()) { + (*result_listener) << "type map does not contain root id: " << root_id; + return false; + } + + auto got_versioned = it->second; + std::string serialized; + cel::expr::Type got; + if (!got_versioned.SerializeToString(&serialized) || + !got.ParseFromString(serialized)) { + (*result_listener) << "type cannot be converted from versioned type: " + << DescribeMessage(got_versioned); + return false; + } + + if (kDifferencer->Compare(got, want)) { + return true; + } + (*result_listener) << "got: " << DescribeMessage(got); + (*result_listener) << "\n"; + (*result_listener) << "wanted: " << DescribeMessage(want); + return false; +} + +} // namespace cel_conformance + +#endif // THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ diff --git a/env/BUILD b/env/BUILD new file mode 100644 index 000000000..0c17d6305 --- /dev/null +++ b/env/BUILD @@ -0,0 +1,320 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "config", + srcs = [ + "config.cc", + "type_info.cc", + ], + hdrs = [ + "config.h", + "type_info.h", + ], + deps = [ + "//common:ast", + "//common:constant", + "//common:type", + "//common:type_kind", + "//internal:status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env", + srcs = ["env.cc"], + hdrs = ["env.h"], + deps = [ + ":config", + "//checker:type_checker_builder", + "//common:constant", + "//common:container", + "//common:decl", + "//common:signature", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//env/internal:ext_registry", + "//internal:status_macros", + "//parser:macro", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_runtime", + srcs = ["env_runtime.cc"], + hdrs = ["env_runtime.h"], + deps = [ + ":config", + "//env/internal:runtime_ext_registry", + "//internal:status_macros", + "//runtime", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "//runtime:standard_functions", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_std_extensions", + srcs = ["env_std_extensions.cc"], + hdrs = ["env_std_extensions.h"], + deps = [ + ":env", + "//checker:optional", + "//compiler:optional", + "//extensions:bindings_ext", + "//extensions:comprehensions_v2", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:proto_ext", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + ], +) + +cc_library( + name = "env_yaml", + srcs = ["env_yaml.cc"], + hdrs = ["env_yaml.h"], + copts = [ + "-fexceptions", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + "//common:ast", + "//common:constant", + "//common:signature", + "//internal:status_macros", + "//internal:strings", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@yaml-cpp", + ], +) + +cc_library( + name = "runtime_std_extensions", + srcs = ["runtime_std_extensions.cc"], + hdrs = ["runtime_std_extensions.h"], + deps = [ + ":env_runtime", + "//checker:optional", + "//env/internal:runtime_ext_registry", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext", + "//extensions:math_ext_decls", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "config_test", + srcs = ["config_test.cc"], + deps = [ + ":config", + "//common:constant", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "type_info_test", + srcs = ["type_info_test.cc"], + deps = [ + ":config", + "//common:type", + "//common:type_proto", + "//common/ast:metadata", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_test", + srcs = ["env_test.cc"], + deps = [ + ":config", + ":env", + "//checker:type_check_issue", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:constant", + "//common:decl", + "//common:expr", + "//common:type", + "//common:value", + "//compiler", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_runtime_test", + srcs = ["env_runtime_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":env_yaml", + ":runtime_std_extensions", + "//checker:validation_result", + "//common:ast", + "//common:source", + "//common:value", + "//compiler", + "//extensions:math_ext", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_std_extensions_test", + srcs = ["env_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_std_extensions", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "env_yaml_test", + srcs = ["env_yaml_test.cc"], + deps = [ + ":config", + ":env_yaml", + "//common:constant", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "runtime_std_extensions_test", + srcs = ["runtime_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":runtime_std_extensions", + "//checker:optional", + "//checker:validation_result", + "//common:ast", + "//common:value", + "//compiler", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:strings", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/config.cc b/env/config.cc new file mode 100644 index 000000000..202a607bf --- /dev/null +++ b/env/config.cc @@ -0,0 +1,196 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "internal/status_macros.h" + +namespace cel { + +namespace { + +const char* ConstantKindToTypeName(const ConstantKind& kind) { + return std::visit(absl::Overload{ + [](const std::monostate& arg) { return "dyn"; }, + [](const std::nullptr_t& arg) { return "null"; }, + [](bool arg) { return "bool"; }, + [](int64_t arg) { return "int"; }, + [](uint64_t arg) { return "uint"; }, + [](double arg) { return "double"; }, + [](const BytesConstant& arg) { return "bytes"; }, + [](const StringConstant& arg) { return "string"; }, + [](absl::Duration arg) { return "duration"; }, + [](absl::Time arg) { return "timestamp"; }, + }, + kind); +} +} // namespace + +absl::Status Config::AddExtensionConfig(std::string name, int version) { + for (const ExtensionConfig& extension_config : extension_configs_) { + if (extension_config.name == name) { + if (extension_config.version == version) { + return absl::OkStatus(); + } + std::string version_str; + if (version == ExtensionConfig::kLatest) { + version_str = "'latest'"; + } else { + version_str = absl::StrCat(version); + } + return absl::AlreadyExistsError(absl::StrCat( + "Extension '", name, "' version ", extension_config.version, + " is already included. Cannot also include version ", version_str)); + } + } + extension_configs_.push_back( + ExtensionConfig{.name = std::move(name), .version = version}); + return absl::OkStatus(); +} + +absl::Status Config::SetStandardLibraryConfig( + const Config::StandardLibraryConfig& standard_library_config) { + if (!standard_library_config.included_macros.empty() && + !standard_library_config.excluded_macros.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded macros."); + } + + if (!standard_library_config.included_functions.empty() && + !standard_library_config.excluded_functions.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded functions."); + } + + absl::flat_hash_set included_function_names; + for (const auto& function : standard_library_config.included_functions) { + if (function.second.empty()) { + included_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.included_functions) { + if (included_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot include function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + absl::flat_hash_set excluded_function_names; + for (const auto& function : standard_library_config.excluded_functions) { + if (function.second.empty()) { + excluded_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.excluded_functions) { + if (excluded_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot exclude function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + standard_library_config_ = standard_library_config; + return absl::OkStatus(); +} + +absl::Status Config::AddVariableConfig(const VariableConfig& variable_config) { + for (const VariableConfig& existing_variable_config : variable_configs_) { + if (existing_variable_config.name == variable_config.name) { + return absl::AlreadyExistsError(absl::StrCat( + "Variable '", variable_config.name, "' is already included.")); + } + } + if (variable_config.value.has_value()) { + absl::string_view constant_type_name = + ConstantKindToTypeName(variable_config.value.kind()); + if (constant_type_name != variable_config.type_info.name) { + return absl::InvalidArgumentError( + absl::StrCat("Variable '", variable_config.name, "' has type ", + variable_config.type_info.name, + " but is assigned a constant value of type ", + constant_type_name, ".")); + } + } + variable_configs_.push_back(variable_config); + return absl::OkStatus(); +} + +absl::Status Config::ValidateFunctionConfig( + const FunctionConfig& function_config) { + for (const auto& overload : function_config.overload_configs) { + if (overload.is_member_function && overload.parameters.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Function '", function_config.name, "' overload '", + overload.overload_id, + "' is marked as a member function but has no parameters. Member " + "functions must have at least one parameter (target).")); + } + } + return absl::OkStatus(); +} + +absl::Status Config::AddFunctionConfig(const FunctionConfig& function_config) { + CEL_RETURN_IF_ERROR(ValidateFunctionConfig(function_config)); + function_configs_.push_back(function_config); + return absl::OkStatus(); +} + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config) { + os << "StandardLibraryConfig("; + if (!config.included_macros.empty()) { + os << "\n included_macros=" << absl::StrJoin(config.included_macros, ", "); + } + if (!config.excluded_macros.empty()) { + os << "\n excluded_macros=" << absl::StrJoin(config.excluded_macros, ", "); + } + if (!config.included_functions.empty()) { + os << "\n included_functions=" + << absl::StrJoin(config.included_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + if (!config.excluded_functions.empty()) { + os << "\n excluded_functions=" + << absl::StrJoin(config.excluded_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + os << "\n)"; + return os; +} + +} // namespace cel diff --git a/env/config.h b/env/config.h new file mode 100644 index 000000000..68e4a1dd9 --- /dev/null +++ b/env/config.h @@ -0,0 +1,173 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ +#define THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "common/constant.h" + +namespace cel { + +class Config { + public: + void SetName(std::string name) { name_ = std::move(name); } + std::string GetName() const { return name_; } + + void SetContextType(std::string context_type) { + context_type_ = std::move(context_type); + } + std::string GetContextType() const { return context_type_; } + + struct ContainerConfig { + std::string name; + std::vector abbreviations; + struct Alias { + std::string alias; + std::string qualified_name; + }; + std::vector aliases; + + bool IsEmpty() const { + return name.empty() && abbreviations.empty() && aliases.empty(); + } + }; + + void SetContainerConfig(ContainerConfig container_config) { + container_config_ = std::move(container_config); + } + + const ContainerConfig& GetContainerConfig() const { + return container_config_; + } + + struct ExtensionConfig { + static constexpr int kLatest = std::numeric_limits::max(); + + std::string name; + int version = kLatest; + }; + + absl::Status AddExtensionConfig(std::string name, + int version = ExtensionConfig::kLatest); + + const std::vector& GetExtensionConfigs() const { + return extension_configs_; + } + + struct StandardLibraryConfig { + // Exclude the entire standard library. + bool disable = false; + + // Exclude all standard library macros. + bool disable_macros = false; + + // Either included or excluded macros can be set, not both. If neither are + // set, all standard library macros are included. + absl::flat_hash_set included_macros; + absl::flat_hash_set excluded_macros; + + // Sets of pairs of function name and overload id to include or exclude. + // Either included or excluded functions can be set, not both. If neither + // are set, all standard library functions are included. + // If an overload is specified, only that overload is included or excluded. + // If no overload is specified (empty second element of pair), all overloads + // are included or excluded. + absl::flat_hash_set> included_functions; + absl::flat_hash_set> excluded_functions; + + bool IsEmpty() const { + return !disable && !disable_macros && included_macros.empty() && + excluded_macros.empty() && included_functions.empty() && + excluded_functions.empty(); + } + }; + + absl::Status SetStandardLibraryConfig( + const StandardLibraryConfig& standard_library_config); + + const StandardLibraryConfig& GetStandardLibraryConfig() const { + return standard_library_config_; + } + + struct TypeInfo { + std::string name; + std::vector params; + bool is_type_param = false; + }; + + struct VariableConfig { + std::string name; + std::string description; + TypeInfo type_info; + Constant value; + }; + + // Adds a variable config to the environment. The variable name and type + // are used by the CEL type checker to validate expressions. The variable + // value is used as an input value at runtime. + // + // Returns an error if a variable with the same name already exists, or if the + // type of the constant value does not match the specified type. + absl::Status AddVariableConfig(const VariableConfig& variable_config); + + const std::vector& GetVariableConfigs() const { + return variable_configs_; + } + + struct FunctionOverloadConfig { + std::string overload_id; + std::vector examples; + bool is_member_function = false; + std::vector parameters; + TypeInfo return_type; + }; + + struct FunctionConfig { + std::string name; + std::string description; + std::vector overload_configs; + }; + + absl::Status AddFunctionConfig(const FunctionConfig& function_config); + + const std::vector& GetFunctionConfigs() const { + return function_configs_; + } + + private: + std::string name_; + std::string context_type_; + ContainerConfig container_config_; + std::vector extension_configs_; + StandardLibraryConfig standard_library_config_; + std::vector variable_configs_; + std::vector function_configs_; + + absl::Status ValidateFunctionConfig(const FunctionConfig& function_config); +}; + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ diff --git a/env/config_test.cc b/env/config_test.cc new file mode 100644 index 000000000..8cfc3cf7f --- /dev/null +++ b/env/config_test.cc @@ -0,0 +1,277 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/constant.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::UnorderedElementsAre; + +TEST(EnvConfigTest, ExtensionConfigs) { + Config config; + ASSERT_THAT( + config.AddExtensionConfig("math", Config::ExtensionConfig::kLatest), + IsOk()); + ASSERT_THAT(config.AddExtensionConfig("optional", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("strings"), IsOk()); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvConfigTest, ExtensionConfigConflict) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 3), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::string expected_error; // Empty if no error is expected. +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + + Config config; + absl::Status status = + config.SetStandardLibraryConfig(param.standard_library_config); + if (param.expected_error.empty()) { + EXPECT_THAT(status, IsOk()); + } else { + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + ::testing::Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_macros = {"all", "exists"}, + .excluded_macros = {"map", "filter"}, + }, + .expected_error = "Cannot set both included and excluded macros.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add(int,int)'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add(int,int)'", + })); + +TEST(VariableConfigTest, VariableConfig) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = + { + .name = "mytype", + .params = {{.name = "int"}, {.name = "A", .is_type_param = true}}, + }, + }; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + + ASSERT_EQ(config.GetVariableConfigs().size(), 1); + const auto& added_config = config.GetVariableConfigs()[0]; + EXPECT_EQ(added_config.type_info.name, "mytype"); + ASSERT_THAT(added_config.type_info.params.size(), 2); + EXPECT_EQ(added_config.type_info.params[0].name, "int"); + EXPECT_FALSE(added_config.type_info.params[0].is_type_param); + EXPECT_EQ(added_config.type_info.params[1].name, "A"); + EXPECT_TRUE(added_config.type_info.params[1].is_type_param); +} + +TEST(VariableConfigTest, VariableConfigConflict) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), IsOk()); + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(VariableConfigTest, VariableConfigValueTypeMismatch) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + .value = Constant(StringConstant("hello")), + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Variable 'test' has type int but is assigned " + "a constant value of type string."))); +} + +TEST(FunctionConfigTest, FunctionConfig) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.description = "Ultimate test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_with_pill", + .examples = {"oracle.isTheOne('Neo', RED)"}, + .is_member_function = true, + .parameters = {{.name = "string"}, {.name = "Choice"}}, + .return_type = {.name = "bool"}, + }); + ASSERT_THAT(config.AddFunctionConfig(function_config), IsOk()); + ASSERT_EQ(config.GetFunctionConfigs().size(), 1); + const auto& added_config = config.GetFunctionConfigs()[0]; + EXPECT_EQ(added_config.name, "test"); + EXPECT_EQ(added_config.description, "Ultimate test"); + EXPECT_EQ(added_config.overload_configs.size(), 1); + + const auto& overload_config = added_config.overload_configs[0]; + EXPECT_EQ(overload_config.overload_id, "test_with_pill"); + EXPECT_THAT(overload_config.examples, + ElementsAre("oracle.isTheOne('Neo', RED)")); + EXPECT_TRUE(overload_config.is_member_function); + EXPECT_THAT( + overload_config.parameters, + ElementsAre(AllOf(Field(&Config::TypeInfo::name, "string"), + Field(&Config::TypeInfo::is_type_param, false)), + AllOf(Field(&Config::TypeInfo::name, "Choice"), + Field(&Config::TypeInfo::is_type_param, false)))); + EXPECT_THAT(overload_config.return_type, + Field(&Config::TypeInfo::name, "bool")); +} + +TEST(FunctionConfigTest, FunctionConfigInvalidMember) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_member_no_params", + .is_member_function = true, + .parameters = {}, + }); + EXPECT_THAT(config.AddFunctionConfig(function_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("is marked as a member function but has no " + "parameters"))); +} + +} // namespace +} // namespace cel diff --git a/env/env.cc b/env/env.cc new file mode 100644 index 000000000..85c5139da --- /dev/null +++ b/env/env.cc @@ -0,0 +1,222 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/signature.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "env/config.h" +#include "env/type_info.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, + absl::string_view macro) { + if (config.disable_macros) { + return false; + } + if (config.excluded_macros.contains(macro)) { + return false; + } + if (!config.included_macros.empty() && + !config.included_macros.contains(macro)) { + return false; + } + return true; +} + +bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, + absl::string_view function, + const OverloadDecl& overload) { + if (config.excluded_functions.empty() && config.included_functions.empty()) { + return true; + } + + if (!config.excluded_functions.empty()) { + if (config.excluded_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } + absl::StatusOr signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + if (signature.ok() && config.excluded_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return false; + } + } + + if (!config.included_functions.empty()) { + if (config.included_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.included_functions.contains( + std::make_pair(std::string(function), ""))) { + return true; + } + // Ok to call MakeOverloadSignature() again, because in practice either + // included or excluded functions may be specified, but not both. + absl::StatusOr signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + if (signature.ok() && config.included_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return true; + } + return false; + } + + return true; // Never reached +} + +absl::StatusOr MakeStdlibSubset( + const Config::StandardLibraryConfig& standard_library_config) { + CompilerLibrarySubset subset; + subset.library_id = "stdlib"; + // Capturing by reference is safe. The returned CompilerLibrarySubset's + // callbacks are only used during CompilerBuilder::Build() to configure + // contributed functions and macros. They are not retained by the constructed + // Compiler instance. The referenced config outlives the Build() call. + subset.should_include_macro = [&standard_library_config](const Macro& macro) { + return ShouldIncludeMacro(standard_library_config, macro.function()); + }; + subset.should_include_overload = [&standard_library_config]( + absl::string_view function, + const OverloadDecl& overload) { + return ShouldIncludeFunction(standard_library_config, function, overload); + }; + return subset; +} + +absl::StatusOr FunctionConfigToFunctionDecl( + const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + FunctionDecl function_decl; + function_decl.set_name(function_config.name); + for (const Config::FunctionOverloadConfig& overload_config : + function_config.overload_configs) { + OverloadDecl overload_decl; + overload_decl.set_id(overload_config.overload_id); + overload_decl.set_member(overload_config.is_member_function); + for (const Config::TypeInfo& parameter : overload_config.parameters) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(parameter, descriptor_pool, arena)); + overload_decl.mutable_args().push_back(parameter_type); + } + CEL_ASSIGN_OR_RETURN( + Type return_type, + TypeInfoToType(overload_config.return_type, descriptor_pool, arena)); + overload_decl.set_result(return_type); + CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); + } + return function_decl; +} + +} // namespace + +Env::Env() { + compiler_options_.parser_options.enable_quoted_identifiers = true; + compiler_options_.adapt_parser_errors = true; +} + +absl::StatusOr> Env::NewCompilerBuilder() { + CEL_ASSIGN_OR_RETURN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); + cel::TypeCheckerBuilder& checker_builder = + compiler_builder->GetCheckerBuilder(); + + ExpressionContainer container; + CEL_RETURN_IF_ERROR( + container.SetContainer(config_.GetContainerConfig().name)); + for (const auto& abbr : config_.GetContainerConfig().abbreviations) { + CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); + } + + if (!config_.GetContextType().empty()) { + CEL_RETURN_IF_ERROR( + checker_builder.AddContextDeclaration(config_.GetContextType())); + } + for (const auto& alias : config_.GetContainerConfig().aliases) { + CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); + } + checker_builder.SetExpressionContainer(std::move(container)); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrary(StandardCompilerLibrary())); + CEL_ASSIGN_OR_RETURN(CompilerLibrarySubset standard_library_subset, + MakeStdlibSubset(config_.GetStandardLibraryConfig())); + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrarySubset(std::move(standard_library_subset))); + } + for (const Config::ExtensionConfig& extension_config : + config_.GetExtensionConfigs()) { + CEL_ASSIGN_OR_RETURN(CompilerLibrary library, + extension_registry_.GetCompilerLibrary( + extension_config.name, extension_config.version)); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(std::move(library))); + } + + google::protobuf::Arena* arena = checker_builder.arena(); + for (const Config::VariableConfig& variable_config : + config_.GetVariableConfigs()) { + VariableDecl variable_decl; + variable_decl.set_name(variable_config.name); + CEL_ASSIGN_OR_RETURN(Type type, + TypeInfoToType(variable_config.type_info, + descriptor_pool_.get(), arena)); + variable_decl.set_type(type); + if (variable_config.value.has_value()) { + variable_decl.set_value(variable_config.value); + } + CEL_RETURN_IF_ERROR(checker_builder.AddVariable(variable_decl)); + } + + for (const Config::FunctionConfig& function_config : + config_.GetFunctionConfigs()) { + CEL_ASSIGN_OR_RETURN(FunctionDecl function_decl, + FunctionConfigToFunctionDecl(function_config, arena, + descriptor_pool_.get())); + CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); + } + + return compiler_builder; +} + +absl::StatusOr> Env::NewCompiler() { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler_builder, + NewCompilerBuilder()); + return compiler_builder->Build(); +} +} // namespace cel diff --git a/env/env.h b/env/env.h new file mode 100644 index 000000000..9830b67d7 --- /dev/null +++ b/env/env.h @@ -0,0 +1,76 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_H_ + +#include + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/internal/ext_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Env class establishes the environment for compiling CEL expressions. +// +// It is used to configure compiler options, extension functions, and other +// customizable CEL features. +class Env { + public: + Env(); + + // Registers a `CompilerLibrary` with the environment. Note that the library + // does not automatically get added to a `Compiler`. `NewCompiler` relies + // on `Config` to determine which libraries to load. + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + extension_registry_.RegisterCompilerLibrary(name, alias, version, + std::move(library_factory)); + } + + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + const google::protobuf::DescriptorPool* GetDescriptorPool() const { + return descriptor_pool_.get(); + } + + void SetConfig(const Config& config) { config_ = config; } + + absl::StatusOr> NewCompilerBuilder(); + + // Shortcut for NewCompilerBuilder() followed by Build(). + absl::StatusOr> NewCompiler(); + + private: + cel::env_internal::ExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + CompilerOptions compiler_options_; + Config config_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_H_ diff --git a/env/env_runtime.cc b/env/env_runtime.cc new file mode 100644 index 000000000..33e0747cc --- /dev/null +++ b/env/env_runtime.cc @@ -0,0 +1,89 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" + +namespace cel { + +void EnvRuntime::RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback) { + extension_registry_.AddFunctionRegistration( + name, alias, version, std::move(function_registration_callback)); +} + +absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { + const std::vector& extension_configs = + config_.GetExtensionConfigs(); + const Config::ExtensionConfig* optional_extension_config = nullptr; + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (extension_config.name == "optional") { + optional_extension_config = &extension_config; + runtime_options_.enable_qualified_type_identifiers = true; + break; + } + } + + CEL_ASSIGN_OR_RETURN( + RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool_, runtime_options_)); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR(RegisterStandardFunctions( + runtime_builder.function_registry(), runtime_options_)); + } + + // Register optional extension functions first, because other extensions + // depend on it (e.g. regex). + if (optional_extension_config != nullptr) { + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, optional_extension_config->name, + optional_extension_config->version)); + } + + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (&extension_config == optional_extension_config) { + continue; + } + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, extension_config.name, + extension_config.version)); + } + return runtime_builder; +} + +absl::StatusOr> EnvRuntime::NewRuntime() { + CEL_ASSIGN_OR_RETURN(RuntimeBuilder runtime_builder, CreateRuntimeBuilder()); + return std::move(runtime_builder).Build(); +} + +} // namespace cel diff --git a/env/env_runtime.h b/env/env_runtime.h new file mode 100644 index 000000000..63473c295 --- /dev/null +++ b/env/env_runtime.h @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "env/config.h" +#include "env/internal/runtime_ext_registry.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// EnvRuntime class establishes the environment for creating CEL runtimes. +// +// It is used to configure runtime options, extension functions, and other +// customizable CEL runtime features. +// +// EnvRuntime is separate from Env to avoid a dependency on the compiler for +// binaries that only use the runtime. +// +// Even though EnvRuntime is separate from Env, the Config and DescriptorPool +// passed to EnvRuntime are expected to be the same as those passed to Env for +// compilation. This ensures consistency between compilation and runtime. +class EnvRuntime { + public: + // Registers a function registration callback for an extension. The callback + // is invoked when a runtime is created, if the corresponding functions are + // enabled in the runtime config. + void RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback); + + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + void SetConfig(const Config& config) { config_ = config; } + + RuntimeOptions& mutable_runtime_options() { return runtime_options_; } + + absl::StatusOr CreateRuntimeBuilder(); + + // Shortcut for CreateRuntimeBuilder() followed by Build(). + absl::StatusOr> NewRuntime(); + + private: + cel::env_internal::RuntimeExtensionRegistry& GetRuntimeExtensionRegistry() { + return extension_registry_; + } + + friend void RegisterStandardExtensions(EnvRuntime& env_runtime); + + cel::env_internal::RuntimeExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + Config config_; + RuntimeOptions runtime_options_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc new file mode 100644 index 000000000..47892772c --- /dev/null +++ b/env/env_runtime_test.cc @@ -0,0 +1,199 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "extensions/math_ext.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string config_yaml; + std::string expr; + bool expected_to_fail = false; +}; + +class EnvRuntimeTest : public testing::TestWithParam {}; + +TEST_P(EnvRuntimeTest, EndToEnd) { + const TestCase& param = GetParam(); + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.config_yaml)); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + std::unique_ptr ast; + if (!param.expected_to_fail) { + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(ast, result.ReleaseAst()); + } else { + // Bypass type checking to allow compilation to succeed since we expect the + // runtime to fail. + ASSERT_OK_AND_ASSIGN(std::unique_ptr source, + NewSource(param.expr, "")); + ASSERT_OK_AND_ASSIGN(ast, compiler->GetParser().Parse(*source)); + } + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + + absl::StatusOr> program_or = + runtime->CreateProgram(std::move(ast)); + if (param.expected_to_fail) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr; + return; + } + + ASSERT_THAT(program_or, IsOk()) << " expr: " << param.expr; + + std::unique_ptr program = *std::move(program_or); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) << " expr: " << param.expr; +} + +std::vector GetEnvRuntimeTestCases() { + return { + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + - name: "optional" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + )yaml", + .expr = "1 + 2 == 3", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "1 + 2 == 3", + .expected_to_fail = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, + ValuesIn(GetEnvRuntimeTestCases())); + +TEST(EnvRuntimeTest, RegisterExtensionFunctions) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("math.sqrt(4) == 2.0")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + env_runtime.RegisterExtensionFunctions( + "cel.lib.math", "math", 2, + [](cel::RuntimeBuilder& runtime_builder, + const cel::RuntimeOptions& opts) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), opts, 2); + }); + env_runtime.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()); +} +} // namespace +} // namespace cel diff --git a/env/env_std_extensions.cc b/env/env_std_extensions.cc new file mode 100644 index 000000000..f2041b979 --- /dev/null +++ b/env/env_std_extensions.cc @@ -0,0 +1,76 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include "checker/optional.h" +#include "compiler/optional.h" +#include "env/env.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/proto_ext.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" + +namespace cel { + +void RegisterStandardExtensions(Env& env) { + env.RegisterCompilerLibrary("cel.lib.ext.bindings", "bindings", 0, []() { + return extensions::BindingsCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.encoders", "encoders", 0, []() { + return extensions::EncodersCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.lists", "lists", version, + [version]() { return extensions::ListsCompilerLibrary(version); }); + } + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.math", "math", version, + [version]() { return extensions::MathCompilerLibrary(version); }); + } + for (int version = 0; version <= kOptionalExtensionLatestVersion; ++version) { + env.RegisterCompilerLibrary("optional", "", version, [version]() { + return OptionalCompilerLibrary(version); + }); + } + env.RegisterCompilerLibrary("cel.lib.ext.protos", "protos", 0, []() { + return extensions::ProtoExtCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.sets", "sets", 0, []() { + return extensions::SetsCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.strings", "strings", version, + [version]() { return extensions::StringsCompilerLibrary(version); }); + } + env.RegisterCompilerLibrary( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + []() { return extensions::ComprehensionsV2CompilerLibrary(); }); + env.RegisterCompilerLibrary("cel.lib.ext.regex", "regex", 0, []() { + return extensions::RegexExtCompilerLibrary(); + }); +} + +} // namespace cel diff --git a/env/env_std_extensions.h b/env/env_std_extensions.h new file mode 100644 index 000000000..79cf37dbf --- /dev/null +++ b/env/env_std_extensions.h @@ -0,0 +1,42 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ + +#include "env/env.h" + +namespace cel { + +// Registers the standard CEL extensions with the given environment. This makes +// them available, but does not enable them. See Env::Config for how to enable +// extensions. +// +// Extensions are registered under the following names: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// - cel.lib.ext.regex (alias: "regex") +void RegisterStandardExtensions(Env& env); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ diff --git a/env/env_std_extensions_test.cc b/env/env_std_extensions_test.cc new file mode 100644 index 000000000..7d9572cc0 --- /dev/null +++ b/env/env_std_extensions_test.cc @@ -0,0 +1,116 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::TestWithParam; + +struct TestCase { + std::string extension; + std::string expr; +}; + +class EnvStdExtensions : public testing::TestWithParam {}; + +TEST_P(EnvStdExtensions, RegistrationTest) { + const TestCase& param = GetParam(); + + Env env; + RegisterStandardExtensions(env); + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.AddExtensionConfig(param.extension), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(param.expr)); + ASSERT_TRUE(result.IsValid()) << "Expected no issues for expr: " << param.expr + << " but got: " << result.FormatError(); +} + +INSTANTIATE_TEST_SUITE_P( + RegistrationTest, EnvStdExtensions, + ::testing::Values( + TestCase{ + .extension = "cel.lib.ext.bindings", // official name + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "bindings", // alias + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "encoders", + .expr = "base64.encode(b'hello')", + }, + TestCase{ + .extension = "lists", + .expr = "[1, 2, 3].sort()", + }, + TestCase{ + .extension = "lists", + .expr = "['a'].sortBy(e, e)", + }, + TestCase{ + .extension = "math", + .expr = "math.sqrt(-1)", + }, + TestCase{ + .extension = "optional", + .expr = "[1, 2].first()", + }, + TestCase{ + .extension = "optional", + .expr = "[0][?1]", // optional syntax auto-enabled + }, + TestCase{ + .extension = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension = "strings", + .expr = "'foo'.reverse()", + }, + TestCase{ + .extension = "two-var-comprehensions", + .expr = "[1, 2, 3, 4].all(i, v, i < v)", + }, + TestCase{ + .extension = "regex", + .expr = "regex.replace('abc', '$', '_end')", + })); + +} // namespace +} // namespace cel diff --git a/env/env_test.cc b/env/env_test.cc new file mode 100644 index 000000000..00143a857 --- /dev/null +++ b/env/env_test.cc @@ -0,0 +1,666 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Property; +using ::testing::UnorderedElementsAre; +using ::testing::Values; +using ::testing::ValuesIn; + +Expr TestMacroExpander(MacroExprFactory& factory, absl::Span args) { + return factory.NewStringConst("Hello"); +} + +class TestLibrary : public CompilerLibrary { + public: + explicit TestLibrary(int version) + : CompilerLibrary( + "testlib", + [version](ParserBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto macro1, + cel::Macro::Global("testMacro1", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto macro2, + cel::Macro::Global("testMacro2", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro2)); + } + return status; + }, + [version](TypeCheckerBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto func1, cel::MakeFunctionDecl( + "testFunc1", MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto func2, + cel::MakeFunctionDecl("testFunc2", + MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func2)); + } + return status; + }) {}; +}; + +absl::StatusOr CompileAndEvalExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, env.NewCompiler()); + if (compiler == nullptr) { + return absl::InternalError("Failed to create compiler"); + } + CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expr)); + if (!result.GetIssues().empty()) { + return absl::InvalidArgumentError(result.FormatError()); + } + + cel::RuntimeOptions opts; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder(env.GetDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( + rt_builder, cel::ReferenceResolverEnabled::kAlways)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + if (runtime == nullptr) { + return absl::InternalError("Failed to create runtime"); + } + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, result.ReleaseAst()); + if (ast == nullptr) { + return absl::InternalError("Failed to create AST"); + } + google::protobuf::Arena arena; + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + if (program == nullptr) { + return absl::InternalError("Failed to create program"); + } + CEL_ASSIGN_OR_RETURN(Value value, program->Evaluate(&arena, activation)); + return value; +} + +absl::StatusOr CompileAndEvalBooleanExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(auto value, CompileAndEvalExpr(env, expr, activation)); + return value.GetBool(); +} + +class LibraryConfigTest : public testing::Test { + protected: + void SetUp() override { + env_.RegisterCompilerLibrary("testlib", "ml", 1, + []() { return TestLibrary(1); }); + env_.RegisterCompilerLibrary("testlib", "ml", 2, + []() { return TestLibrary(2); }); + env_.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + } + + Env env_; +}; + +TEST_F(LibraryConfigTest, DefaultVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib"), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), IsEmpty()); + EXPECT_THAT(result4.GetIssues(), IsEmpty()); +} + +TEST_F(LibraryConfigTest, SpecificVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib", 1), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testMacro2'")))); + EXPECT_THAT(result4.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testFunc2'")))); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::vector expected_valid_expressions; + std::vector expected_invalid_expressions; +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.SetStandardLibraryConfig(param.standard_library_config), + IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + for (const std::string& expr : param.expected_valid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), IsEmpty()) + << "With config: " << param.standard_library_config + << ", expected no issues for expr: " << expr + << " but got: " << result1.FormatError(); + } + for (const std::string& expr : param.expected_invalid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), Not(IsEmpty())) + << "With config: " << param.standard_library_config + << ", expected compilation error for expr: " << expr << " but got: \'" + << result1.FormatError() << "\'"; + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + .expected_valid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable = true}, + .expected_invalid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable_macros = true}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + .expected_invalid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_functions = {{"_+_", ""}}}, + .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "_+_(bytes,bytes)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}, + {"_+_", "_+_(string,string)"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "add_bytes"}, + {"_+_", "add_list"}, + {"_+_", "add_string"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_functions = {{"_+_", ""}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "_+_(int,int)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + })); + +TEST(ContainerConfigTest, ContainerConfig) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig({.name = "cel.expr.conformance.proto2"}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContainerConfigTest, ContainerConfigWithAbbreviations) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .abbreviations = {"cel.expr.conformance.proto2.TestAllTypes"}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContainerConfigTest, ContainerConfigWithAliases) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .aliases = { + {.alias = "MyTestType", + .qualified_name = "cel.expr.conformance.proto2.TestAllTypes"}}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("MyTestType{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContextVariableConfigTest, Basic) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContextType("cel.expr.conformance.proto3.TestAllTypes"); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + // Top-level fields of TestAllTypes like "single_int32" should resolve + // successfully. + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("single_int32 > 10")); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto result_invalid, + compiler->Compile("non_existent_field > 10")); + EXPECT_THAT(result_invalid.GetIssues(), Not(IsEmpty())); +} + +struct VariableConfigWithValueTestCase { + Config::VariableConfig variable_config; + std::string validate_type_expr; + std::string validate_value_expr; +}; + +class VariableConfigWithValueTest + : public testing::TestWithParam {}; + +TEST_P(VariableConfigWithValueTest, VariableConfigWithValue) { + const VariableConfigWithValueTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + ASSERT_THAT(config.AddVariableConfig(param.variable_config), IsOk()); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN( + bool type_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_type_expr)); + ASSERT_TRUE(type_as_expected) << " expr: " << param.validate_type_expr; + if (!param.validate_value_expr.empty()) { + ASSERT_OK_AND_ASSIGN( + bool value_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_value_expr)); + ASSERT_TRUE(value_as_expected) << " expr: " << param.validate_value_expr; + } +} + +Config::VariableConfig MakeConstant( + absl::string_view variable_name, absl::string_view type_name, + absl::AnyInvocable setter) { + Config::VariableConfig variable_config; + variable_config.name = variable_name; + Constant c; + setter(c); + variable_config.type_info.name = type_name; + variable_config.value = c; + return variable_config; +} + +std::vector +GetVariableConfigWithValueTestCases() { + return { + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "null", [](auto& c) { c.set_null_value(nullptr); }), + .validate_type_expr = "type(x) == type(null)", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "bool", [](auto& c) { c.set_bool_value(true); }), + .validate_type_expr = "type(x) == bool", + .validate_value_expr = "x == true", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "int", [](Constant& c) { c.set_int_value(42); }), + .validate_type_expr = "type(x) == int", + .validate_value_expr = "x == 42", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "uint", [](Constant& c) { c.set_uint_value(777); }), + .validate_type_expr = "type(x) == uint", + .validate_value_expr = "x == 777u", + }, + VariableConfigWithValueTestCase{ + .variable_config = + MakeConstant("x", "double", + [](Constant& c) { c.set_double_value(1.0 / 3.0); }), + .validate_type_expr = "type(x) == double", + .validate_value_expr = "x > 0.333 && x < 0.334", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant("x", "bytes", + [](Constant& c) { + c.set_bytes_value(absl::string_view( + "\xff\x00\x01", 3)); + }), + .validate_type_expr = "type(x) == bytes", + .validate_value_expr = "x == b'\\xff\\x00\\x01'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "string", [](Constant& c) { c.set_string_value("hello"); }), + .validate_type_expr = "type(x) == string", + .validate_value_expr = "x == 'hello'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "timestamp", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_timestamp_value(absl::FromUnixSeconds(1767323045)); + }), + .validate_type_expr = + "type(x) == type(timestamp('2026-01-02T03:04:05Z'))", + .validate_value_expr = "x == timestamp('2026-01-02T03:04:05Z')", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "duration", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_duration_value(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3)); + }), + .validate_type_expr = "type(x) == type(duration('1h2m3s'))", + .validate_value_expr = "x == duration('1h2m3s')", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, VariableConfigWithValueTest, + ValuesIn(GetVariableConfigWithValueTestCases())); + +struct FunctionConfigTestCase { + Config::FunctionConfig function_config; + std::vector variable_configs; + std::string expr; + std::string expected_error; +}; + +class FunctionConfigTest + : public testing::TestWithParam {}; + +TEST_P(FunctionConfigTest, FunctionConfig) { + const FunctionConfigTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + for (const Config::VariableConfig& variable_config : param.variable_configs) { + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + } + ASSERT_THAT(config.AddFunctionConfig(param.function_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + if (param.expected_error.empty()) { + EXPECT_TRUE(result.GetIssues().empty()) + << " expr: " << param.expr << " error: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + UnorderedElementsAre(Property(&TypeCheckIssue::message, + HasSubstr(param.expected_error)))) + << " expr: " << param.expr << " error: " << result.FormatError(); + } +} + +std::vector GetFunctionConfigTestCases() { + return {{ + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(int,int)", + .examples = {"add(1, 2) -> 3"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "add(1, 2)", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 3"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "1.add(2) == 3", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(string,string)", + .examples = + {"add('hello', 'world') -> 'hello world'"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "add('hello', 'world')", + .expected_error = "found no matching overload for 'add' applied to " + "'(string, string)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 'three'"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "1.add(2) == 3", + .expected_error = "found no matching overload for '_==_' applied to " + "'(string, int)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "double"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Matching opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "int"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Mismatched opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + .expected_error = "found no matching overload for 'sum' applied to " + "'(collection(double))'", + }, + }}; +} + +INSTANTIATE_TEST_SUITE_P(FunctionConfigTest, FunctionConfigTest, + ::testing::ValuesIn(GetFunctionConfigTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_yaml.cc b/env/env_yaml.cc new file mode 100644 index 000000000..281cf3ff1 --- /dev/null +++ b/env/env_yaml.cc @@ -0,0 +1,1322 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/ast.h" +#include "common/constant.h" +#include "common/signature.h" +#include "env/config.h" +#include "env/type_info.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "yaml-cpp/emitter.h" +#include "yaml-cpp/emittermanip.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/mark.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +namespace { + +std::string FormatYamlErrorMessage(absl::string_view yaml, + absl::string_view error, + const YAML::Mark& mark) { + if (mark.is_null()) { + return std::string(error); + } + std::string message; + absl::StrAppend(&message, mark.line + 1, ":", mark.column + 1, ": ", error, + "\n|"); + size_t start = mark.pos - mark.column; + size_t end = yaml.find('\n', mark.pos); + if (end == std::string::npos) { + end = yaml.size(); + } + + absl::StrAppend(&message, yaml.substr(start, end - start), "\n|", + std::string(mark.column, ' '), "^"); + + return message; +} + +absl::StatusOr LoadYaml(const std::string& yaml) { + try { + return YAML::Load(yaml); + } catch (YAML::ParserException& e) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, e.msg, e.mark)); + } +} + +absl::Status YamlError(absl::string_view yaml, const YAML::Node& node, + absl::string_view error) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, error, node.Mark())); +} + +std::string GetString(absl::string_view yaml, const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return ""; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return ""; + } +} + +bool IsBinary(const YAML::Node& node) { + return node.Tag() == "!!binary" || node.Tag() == "tag:yaml.org,2002:binary"; +} + +absl::StatusOr GetBinary(absl::string_view yaml, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar() || !IsBinary(node)) { + return ""; + } + std::string binary; + // Instead of using the YAML::Binary type, we use absl::Base64Unescape + // because YAML::Binary is lenient to Base64 decoding errors. + if (absl::Base64Unescape(GetString(yaml, node), &binary)) { + return binary; + } else { + return YamlError(yaml, node, + absl::StrCat("Node '", GetString(yaml, node), + "' is not a valid Base64 encoded binary")); + } +} + +absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return false; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + return YamlError(yaml, node, + absl::StrCat("Node '", key, "' is not a boolean")); + } +} + +// Returns the key in the map `node` that has the given `value_node` as its +// value. If no such key exists, returns `value_node` itself. +YAML::Node GetContextNodeForKeyValue(const YAML::Node& node, + const YAML::Node& value_node) { + for (const auto& kv : node) { + if (kv.second.IsDefined() && kv.second.is(value_node)) { + return kv.first; + } + } + return value_node; +} + +absl::Status ParseName(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node name = root["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' is not a string"); + } + config.SetName(GetString(yaml, name)); + } + return absl::OkStatus(); +} + +absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node container = root["container"]; + if (!container.IsDefined()) { + return absl::OkStatus(); + } + + if (container.IsScalar()) { + config.SetContainerConfig({.name = GetString(yaml, container)}); + return absl::OkStatus(); + } + + if (!container.IsMap()) { + return YamlError(yaml, container, + "Node 'container' is neither a string nor a map"); + } + + Config::ContainerConfig container_config; + + const YAML::Node name = container["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' in container is not a string"); + } + container_config.name = GetString(yaml, name); + } + + const YAML::Node abbreviations = container["abbreviations"]; + if (abbreviations.IsDefined()) { + if (!abbreviations.IsSequence()) { + return YamlError(yaml, abbreviations, + "Node 'abbreviations' is not a sequence"); + } + for (const YAML::Node& abbr : abbreviations) { + if (!abbr.IsScalar()) { + return YamlError(yaml, abbr, "Abbreviation is not a string"); + } + container_config.abbreviations.push_back(GetString(yaml, abbr)); + } + } + + const YAML::Node aliases = container["aliases"]; + if (aliases.IsDefined()) { + if (!aliases.IsSequence()) { + return YamlError(yaml, aliases, "Node 'aliases' is not a sequence"); + } + for (const YAML::Node& alias_node : aliases) { + if (!alias_node.IsMap()) { + return YamlError(yaml, alias_node, "Alias entry is not a map"); + } + const YAML::Node alias_key = alias_node["alias"]; + const YAML::Node qualified_name_key = alias_node["qualified_name"]; + + if (!alias_key.IsDefined() || !alias_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'alias' string"); + } + if (!qualified_name_key.IsDefined() || !qualified_name_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'qualified_name' string"); + } + + container_config.aliases.push_back( + {.alias = GetString(yaml, alias_key), + .qualified_name = GetString(yaml, qualified_name_key)}); + } + } + + config.SetContainerConfig(std::move(container_config)); + return absl::OkStatus(); +} + +absl::Status ParseExtensionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node extensions = root["extensions"]; + if (!extensions.IsDefined()) { + return absl::OkStatus(); + } + if (!extensions.IsSequence()) { + return YamlError(yaml, extensions, "Node 'extensions' is not a sequence"); + } + + for (const YAML::Node& extension : extensions) { + if (!extension || !extension.IsMap()) { + return YamlError(yaml, extension, "Extension is not a map"); + } + const YAML::Node name = extension["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Extension name is not a string"); + } + std::string name_str = GetString(yaml, name); + + const YAML::Node version = extension["version"]; + std::string version_str = GetString(yaml, version); + int extension_version; + if (version.IsDefined()) { + bool is_valid_version = false; + if (version.IsScalar()) { + if (version_str == "latest") { + extension_version = Config::ExtensionConfig::kLatest; + is_valid_version = true; + } else { + if (absl::SimpleAtoi(version_str, &extension_version) && + extension_version >= 0) { + is_valid_version = true; + } + } + } + if (!is_valid_version) { + return YamlError( + yaml, version, + absl::StrCat("Extension '", name_str, + "' version is not a valid number or 'latest'")); + } + } else { + extension_version = Config::ExtensionConfig::kLatest; + } + absl::Status add_status = + config.AddExtensionConfig(name_str, extension_version); + if (!add_status.ok()) { + return YamlError(yaml, extension, add_status.message()); + } + } + return absl::OkStatus(); +} + +absl::StatusOr> ParseMacroList( + absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set macro_set; + const YAML::Node macros = standard_library[std::string(key)]; + if (!macros.IsDefined()) { + return macro_set; + } + if (!macros.IsSequence()) { + return YamlError(yaml, macros, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& macro : macros) { + if (!macro.IsScalar()) { + return YamlError(yaml, macro, + absl::StrCat("Entry in '", key, "' is not a string")); + } + macro_set.insert(GetString(yaml, macro)); + } + return macro_set; +} + +absl::StatusOr>> +ParseFunctionList(absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set> function_set; + const YAML::Node functions = standard_library[std::string(key)]; + if (!functions.IsDefined()) { + return function_set; + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& function : functions) { + if (!function.IsMap()) { + return YamlError(yaml, function, + absl::StrCat("Entry in '", key, "' is not a map")); + } + const YAML::Node name = function["name"]; + if (!name.IsDefined()) { + return YamlError( + yaml, function, + absl::StrCat("Function name in not specified in '", key, "'")); + } + if (!name.IsScalar()) { + return YamlError( + yaml, name, + absl::StrCat("Function name in '", key, "' entry is not a string")); + } + std::string name_str = GetString(yaml, name); + const YAML::Node overloads = function["overloads"]; + if (!overloads.IsDefined()) { + function_set.insert(std::make_pair(name_str, "")); + } else { + if (!overloads.IsSequence()) { + return YamlError( + yaml, overloads, + absl::StrCat("Overloads in '", key, "' entry is not a sequence")); + } + for (const YAML::Node& overload : overloads) { + if (!overload.IsMap()) { + return YamlError( + yaml, overload, + absl::StrCat("Overload in '", key, "' entry is not a map")); + } + const YAML::Node id = overload["id"]; + if (!id || !id.IsScalar()) { + return YamlError( + yaml, id, + absl::StrCat("Overload id in '", key, "' entry is not a string")); + } + function_set.insert(std::make_pair(name_str, GetString(yaml, id))); + } + } + } + return function_set; +} + +absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node standard_library = root["stdlib"]; + if (!standard_library.IsDefined()) { + return absl::OkStatus(); + } + + if (!standard_library.IsMap()) { + return YamlError(yaml, standard_library, + "Standard library config ('stdlib') is not a map"); + } + + Config::StandardLibraryConfig standard_library_config; + + const YAML::Node disable = standard_library["disable"]; + if (disable.IsDefined()) { + if (!disable.IsScalar()) { + return YamlError(yaml, disable, "Node 'disable' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable, + GetBool(yaml, "disable", disable)); + } + + const YAML::Node disable_macros = standard_library["disable_macros"]; + if (disable_macros.IsDefined()) { + if (!disable_macros.IsScalar()) { + return YamlError(yaml, disable_macros, + "Node 'disable_macros' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable_macros, + GetBool(yaml, "disable_macros", disable_macros)); + } + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_macros, + ParseMacroList(yaml, standard_library, "include_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_macros, + ParseMacroList(yaml, standard_library, "exclude_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_functions, + ParseFunctionList(yaml, standard_library, "include_functions")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_functions, + ParseFunctionList(yaml, standard_library, "exclude_functions")); + + return config.SetStandardLibraryConfig(standard_library_config); +} + +absl::StatusOr ParseTypeInfo(const YAML::Node& node, + absl::string_view yaml) { + Config::TypeInfo type_config; + const YAML::Node type = node["type"]; + const YAML::Node type_name = node["type_name"]; + if (type.IsDefined() && type_name.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(node, type_name), + "Node 'type' and 'type_name' are mutually exclusive"); + } + + if (type.IsDefined()) { + if (!type.IsScalar()) { + return YamlError(yaml, type, "Node 'type' is not a string"); + } + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(GetString(yaml, type))); + CEL_ASSIGN_OR_RETURN(auto type_config, TypeSpecToTypeInfo(type_spec)); + return type_config; + } + + if (!type_name.IsDefined()) { + return type_config; + } + if (!type_name || !type_name.IsScalar()) { + return YamlError(yaml, type_name, "Node 'type_name' is not a string"); + } + type_config.name = GetString(yaml, type_name); + + const YAML::Node is_type_param = node["is_type_param"]; + if (is_type_param.IsDefined()) { + if (!is_type_param.IsScalar()) { + return YamlError(yaml, is_type_param, + "Node 'is_type_param' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(type_config.is_type_param, + GetBool(yaml, "is_type_param", is_type_param)); + } + + const YAML::Node params = node["params"]; + if (!params.IsDefined()) { + return type_config; + } + if (!params.IsSequence()) { + return YamlError(yaml, params, "Node 'params' is not a sequence"); + } + for (const YAML::Node& param : params) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_config, + ParseTypeInfo(param, yaml)); + type_config.params.push_back(param_config); + } + + return type_config; +} + +bool CompareTypeInfo(const Config::TypeInfo& a, const Config::TypeInfo& b) { + if (a.name != b.name) { + return a.name < b.name; + } + if (a.params.size() != b.params.size()) { + return a.params.size() < b.params.size(); + } + for (size_t i = 0; i < a.params.size(); ++i) { + if (CompareTypeInfo(a.params[i], b.params[i])) { + return true; + } + if (CompareTypeInfo(b.params[i], a.params[i])) { + return false; + } + } + return false; // They are equal +} + +ConstantKindCase GetConstantKindCase(absl::string_view type_name) { + static const auto kTypeNameToConstantKindCase = + absl::NoDestructor>({ + {"null", ConstantKindCase::kNull}, + {"bool", ConstantKindCase::kBool}, + {"int", ConstantKindCase::kInt}, + {"uint", ConstantKindCase::kUint}, + {"double", ConstantKindCase::kDouble}, + {"string", ConstantKindCase::kString}, + {"bytes", ConstantKindCase::kBytes}, + {"duration", ConstantKindCase::kDuration}, + {"timestamp", ConstantKindCase::kTimestamp}, + }); + if (auto it = kTypeNameToConstantKindCase->find(type_name); + it != kTypeNameToConstantKindCase->end()) { + return it->second; + } + return ConstantKindCase::kUnspecified; +} + +absl::StatusOr ParseConstantValue(absl::string_view yaml, + const YAML::Node& node, + ConstantKindCase constant_kind_case, + absl::string_view value) { + switch (constant_kind_case) { + case ConstantKindCase::kNull: + if (!value.empty()) { + return YamlError(yaml, node, "Failed to parse null constant"); + } + return Constant(nullptr); + case ConstantKindCase::kBool: + if (absl::EqualsIgnoreCase(value, "true")) { + return Constant(true); + } else if (absl::EqualsIgnoreCase(value, "false")) { + return Constant(false); + } else { + return YamlError(yaml, node, "Failed to parse bool constant"); + } + case ConstantKindCase::kInt: + int64_t int_value; + if (!absl::SimpleAtoi(value, &int_value)) { + return YamlError(yaml, node, "Failed to parse int constant"); + } + return Constant(int_value); + case ConstantKindCase::kUint: + uint64_t uint_value; + if (absl::EndsWith(value, "u")) { + value = value.substr(0, value.size() - 1); + } + if (!absl::SimpleAtoi(value, &uint_value)) { + return YamlError(yaml, node, "Failed to parse uint constant"); + } + return Constant(uint_value); + case ConstantKindCase::kDouble: + double double_value; + if (!absl::SimpleAtod(value, &double_value)) { + return YamlError(yaml, node, "Failed to parse double constant"); + } + return Constant(double_value); + case ConstantKindCase::kBytes: { + if (!IsBinary(node)) { + absl::StatusOr bytes_literal = + internal::ParseBytesLiteral(value); + if (bytes_literal.ok()) { + return Constant(BytesConstant(*bytes_literal)); + } + } + return Constant(BytesConstant(value)); + } + case ConstantKindCase::kString: + return Constant(StringConstant(value)); + case ConstantKindCase::kDuration: { + // Duration is deprecated as a builtin type, but still supported for + // compatibility. + absl::Duration duration_value; + if (!absl::ParseDuration(value, &duration_value)) { + return YamlError(yaml, node, "Failed to parse duration constant"); + } + return Constant(duration_value); + } + case ConstantKindCase::kTimestamp: { + // Timestamp is deprecated as a builtin type, but still supported for + // compatibility. + absl::Time timestamp_value; + std::string error; + // Format: YYYY-MM-DDThh:mm:ssZ + if (!absl::ParseTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, ×tamp_value, + &error)) { + return YamlError( + yaml, node, + absl::StrCat("Failed to parse timestamp constant: ", error, + " supported format: YYYY-MM-DDThh:mm:ssZ")); + } + return Constant(timestamp_value); + } + default: + // This should never happen. + return YamlError(yaml, node, "Constant type is not supported"); + } +} + +absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node variables = root["variables"]; + if (!variables.IsDefined()) { + return absl::OkStatus(); + } + if (!variables.IsSequence()) { + return YamlError(yaml, variables, "Node 'variables' is not a sequence"); + } + + for (const YAML::Node& variable : variables) { + Config::VariableConfig variable_config; + if (!variable || !variable.IsMap()) { + return YamlError(yaml, variable, "Variable is not a map"); + } + const YAML::Node name = variable["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Variable name is not a string"); + } + variable_config.name = GetString(yaml, name); + const YAML::Node description = variable["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Variable description is not a string"); + } + variable_config.description = GetString(yaml, description); + } + const YAML::Node type = variable["type"]; + Config::TypeInfo type_info; + if (type.IsDefined() && !type.IsScalar()) { + // Old format, type spec is in 'type' instead of directly embedded. + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable["type"], yaml)); + } else { + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable, yaml)); + } + ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); + std::string value_str; + YAML::Node value = variable["value"]; + if (value.IsDefined()) { + if (constant_kind_case == ConstantKindCase::kUnspecified) { + return YamlError(yaml, value, + absl::StrCat("Constant type '", type_info.name, + "' is not supported")); + } + if (!value.IsScalar()) { + return YamlError(yaml, value, "Variable value is not a scalar"); + } + if (IsBinary(value)) { + CEL_ASSIGN_OR_RETURN(value_str, GetBinary(yaml, value)); + } else { + value_str = GetString(yaml, value); + } + } + + variable_config.type_info = type_info; + + if (constant_kind_case != ConstantKindCase::kUnspecified && + !value_str.empty()) { + CEL_ASSIGN_OR_RETURN( + variable_config.value, + ParseConstantValue(yaml, value, constant_kind_case, value_str)); + } else if (constant_kind_case == ConstantKindCase::kNull) { + variable_config.value = Constant(nullptr); + } + + CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); + } + return absl::OkStatus(); +} + +absl::StatusOr ParseFunctionOverloadConfig( + absl::string_view yaml, const YAML::Node& overload, + absl::string_view function_name) { + Config::FunctionOverloadConfig overload_config; + if (!overload || !overload.IsMap()) { + return YamlError(yaml, overload, "Function overload is not a map"); + } + const YAML::Node id = overload["id"]; + if (id.IsDefined()) { + if (!id.IsScalar()) { + return YamlError(yaml, id, "Function overload id is not a string"); + } + overload_config.overload_id = GetString(yaml, id); + } + const YAML::Node examples = overload["examples"]; + if (examples.IsDefined()) { + if (!examples.IsSequence()) { + return YamlError(yaml, examples, + "Function overload examples is not a sequence"); + } + for (const YAML::Node& example : examples) { + if (!example.IsScalar()) { + return YamlError(yaml, example, + "Function overload example is not a string"); + } + overload_config.examples.push_back(GetString(yaml, example)); + } + } + + const YAML::Node signature_node = overload["signature"]; + const YAML::Node target = overload["target"]; + const YAML::Node args = overload["args"]; + if (signature_node.IsDefined()) { + if (!signature_node.IsScalar()) { + return YamlError(yaml, signature_node, + "Function overload signature is not a string"); + } + + if (target.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, target), + "Function overload signature and target are mutually " + "exclusive"); + } + if (args.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, args), + "Function overload signature and args are mutually " + "exclusive"); + } + + std::string signature = GetString(yaml, signature_node); + CEL_ASSIGN_OR_RETURN(ParsedFunctionOverload parsed_signature, + ParseFunctionSignature(signature)); + if (parsed_signature.function_name != function_name) { + return YamlError(yaml, signature_node, + absl::StrCat("Function overload name \"", + parsed_signature.function_name, + "\" does not match function name \"", + function_name, "\"")); + } + overload_config.is_member_function = parsed_signature.is_member; + if (overload_config.overload_id.empty()) { + overload_config.overload_id = signature; + } + if (!parsed_signature.signature_type.has_function()) { + return absl::InternalError(absl::StrCat( + "Function overload signature has no function type: ", signature)); + } + const FunctionTypeSpec& function_type_spec = + parsed_signature.signature_type.function(); + for (const auto& arg : function_type_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto type_info, TypeSpecToTypeInfo(arg)); + overload_config.parameters.push_back(std::move(type_info)); + } + } else { + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; + overload_config.parameters.push_back(type_info); + } + + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, + "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + } + const YAML::Node return_type = overload["return"]; + if (return_type.IsDefined()) { + if (return_type.IsScalar()) { + CEL_ASSIGN_OR_RETURN(auto type_spec, + ParseTypeSpec(GetString(yaml, return_type))); + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + TypeSpecToTypeInfo(type_spec)); + } else if (return_type.IsMap()) { + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } else { + return YamlError( + yaml, return_type, + "Function overload return type is neither a string nor a map"); + } + } + return overload_config; +} + +absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node functions = root["functions"]; + if (!functions.IsDefined()) { + return absl::OkStatus(); + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, "Node 'functions' is not a sequence"); + } + + for (const YAML::Node& function : functions) { + Config::FunctionConfig function_config; + if (!function || !function.IsMap()) { + return YamlError(yaml, function, "Function is not a map"); + } + const YAML::Node name = function["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Function name is not a string"); + } + function_config.name = GetString(yaml, name); + const YAML::Node description = function["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Function description is not a string"); + } + function_config.description = GetString(yaml, description); + } + const YAML::Node overloads = function["overloads"]; + if (overloads.IsDefined()) { + if (!overloads.IsSequence()) { + return YamlError(yaml, overloads, + "Function 'overloads' item is not a sequence"); + } + + for (const YAML::Node& overload : overloads) { + CEL_ASSIGN_OR_RETURN( + Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload, function_config.name)); + function_config.overload_configs.push_back(std::move(overload_config)); + } + } + + CEL_RETURN_IF_ERROR(config.AddFunctionConfig(function_config)); + } + return absl::OkStatus(); +} + +void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { + const auto& container_config = env_config.GetContainerConfig(); + if (container_config.IsEmpty()) { + return; + } + + out << YAML::Key << "container"; + if (container_config.abbreviations.empty() && + container_config.aliases.empty()) { + out << YAML::Value << YAML::DoubleQuoted << container_config.name; + } else { + out << YAML::Value << YAML::BeginMap; + if (!container_config.name.empty()) { + out << YAML::Key << "name" << YAML::Value << YAML::DoubleQuoted + << container_config.name; + } + if (!container_config.abbreviations.empty()) { + std::vector sorted_abbrs = container_config.abbreviations; + absl::c_sort(sorted_abbrs); + out << YAML::Key << "abbreviations" << YAML::Value << YAML::BeginSeq; + for (const auto& abbr : sorted_abbrs) { + out << YAML::Value << YAML::DoubleQuoted << abbr; + } + out << YAML::EndSeq; + } + if (!container_config.aliases.empty()) { + std::vector sorted_aliases = + container_config.aliases; + absl::c_sort(sorted_aliases, [](const Config::ContainerConfig::Alias& a, + const Config::ContainerConfig::Alias& b) { + return a.alias < b.alias; + }); + out << YAML::Key << "aliases" << YAML::Value << YAML::BeginSeq; + for (const auto& alias : sorted_aliases) { + out << YAML::BeginMap; + out << YAML::Key << "alias" << YAML::Value << YAML::DoubleQuoted + << alias.alias; + out << YAML::Key << "qualified_name" << YAML::Value + << YAML::DoubleQuoted << alias.qualified_name; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } +} + +void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { + if (env_config.GetExtensionConfigs().empty()) { + return; + } + + // Sort the extensions to make the output deterministic. + std::vector sorted_extensions = + env_config.GetExtensionConfigs(); + absl::c_sort(sorted_extensions, [](const Config::ExtensionConfig& a, + const Config::ExtensionConfig& b) { + return a.name < b.name; + }); + out << YAML::Key << "extensions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::ExtensionConfig& extension_config : sorted_extensions) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << extension_config.name; + if (extension_config.version != Config::ExtensionConfig::kLatest) { + out << YAML::Key << "version"; + out << YAML::Value << extension_config.version; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitMacroList(YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set& macros) { + if (macros.empty()) { + return; + } + out << YAML::Key << std::string(key); + out << YAML::Value << YAML::BeginSeq; + std::vector sorted_macros(macros.begin(), macros.end()); + absl::c_sort(sorted_macros); + for (const std::string& macro : sorted_macros) { + out << YAML::Value << YAML::DoubleQuoted << macro; + } + out << YAML::EndSeq; +} + +void EmitFunctionList( + YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set>& functions) { + if (functions.empty()) { + return; + } + + // Build a map from function name to a vector of overload ids. + // Using std::map ensures function names are sorted. + std::map> function_overloads; + for (const auto& pair : functions) { + function_overloads[pair.first].push_back(pair.second); + } + + out << YAML::Key << std::string(key) << YAML::Value << YAML::BeginSeq; + for (auto const& [name, overloads] : function_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << name; + + // If the only overload is the empty string, it signifies that all overloads + // of the function are included/excluded. In this case, we don't emit the + // "overloads" key. Otherwise, emit the specific overloads. + if (!(overloads.size() == 1 && overloads[0].empty())) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = overloads; + absl::c_sort(sorted_overloads); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const std::string& overload : sorted_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { + const Config::StandardLibraryConfig& standard_library_config = + env_config.GetStandardLibraryConfig(); + if (standard_library_config.IsEmpty()) { + return; + } + + out << YAML::Key << "stdlib" << YAML::Value << YAML::BeginMap; + if (standard_library_config.disable) { + out << YAML::Key << "disable" << YAML::Value << true; + } + if (standard_library_config.disable_macros) { + out << YAML::Key << "disable_macros" << YAML::Value << true; + } + EmitMacroList(out, "include_macros", standard_library_config.included_macros); + EmitMacroList(out, "exclude_macros", standard_library_config.excluded_macros); + EmitFunctionList(out, "include_functions", + standard_library_config.included_functions); + EmitFunctionList(out, "exclude_functions", + standard_library_config.excluded_functions); + out << YAML::EndMap; +} + +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + // Note: the map is already started when this is called, so we don't emit + // BeginMap here or EndMap at the end. + bool signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(type_info); + if (type_spec.ok()) { + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "type"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; + } + } + } + if (!signature_generated) { + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } +} + +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + const auto& variable_configs = env_config.GetVariableConfigs(); + if (variable_configs.empty()) { + return; + } + + // Sort variable_configs by name to ensure deterministic output. + std::vector sorted_variable_configs = + variable_configs; + absl::c_sort(sorted_variable_configs, + [](const Config::VariableConfig& a, + const Config::VariableConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "variables"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::VariableConfig& variable_config : + sorted_variable_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.name; + if (!variable_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.description; + } + EmitTypeInfo(variable_config.type_info, out, options); + if (variable_config.value.has_value()) { + const Constant& constant = variable_config.value; + switch (constant.kind_case()) { + case ConstantKindCase::kUnspecified: + case ConstantKindCase::kNull: + break; + case ConstantKindCase::kBool: + out << YAML::Key << "value" << YAML::Value << constant.bool_value(); + break; + case ConstantKindCase::kInt: + out << YAML::Key << "value" << YAML::Value << constant.int_value(); + break; + case ConstantKindCase::kUint: + out << YAML::Key << "value" << YAML::Value << constant.uint_value(); + break; + case ConstantKindCase::kDouble: + out << YAML::Key << "value" << YAML::Value << constant.double_value(); + break; + case ConstantKindCase::kBytes: { + out << YAML::Key << "value"; + const std::string& bytes_value = constant.bytes_value(); + std::string hex_escaped = "b\""; + for (unsigned char byte : bytes_value) { + absl::StrAppend(&hex_escaped, "\\x"); + absl::StrAppendFormat(&hex_escaped, "%02x", byte); + } + absl::StrAppend(&hex_escaped, "\""); + out << YAML::Value << hex_escaped; + break; + } + case ConstantKindCase::kString: + out << YAML::Key << "value"; + out << YAML::Value << YAML::DoubleQuoted << constant.string_value(); + break; + case ConstantKindCase::kDuration: + out << YAML::Key << "value" << YAML::Value; + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + out << absl::FormatDuration(constant.duration_value()); + break; + case ConstantKindCase::kTimestamp: + out << YAML::Key << "value" << YAML::Value; + out << absl::FormatTime( + "%Y-%m-%d%ET%H:%M:%E*SZ", + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + constant.timestamp_value(), absl::UTCTimeZone()); + break; + } + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitFunctionOverloadConfig( + absl::string_view function_name, + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + out << YAML::BeginMap; + bool signature_generated = false; + std::string signature_str; + if (options.use_type_signatures) { + bool param_type_spec_generated = true; + std::vector params; + params.reserve(overload_config.parameters.size()); + for (const auto& parameter : overload_config.parameters) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(parameter); + if (!type_spec.ok()) { + param_type_spec_generated = false; + break; + } + params.push_back(std::move(*type_spec)); + } + if (param_type_spec_generated) { + absl::StatusOr signature = MakeOverloadSignature( + function_name, params, overload_config.is_member_function); + if (signature.ok()) { + signature_str = std::move(*signature); + signature_generated = true; + } + } + } + if (!overload_config.overload_id.empty()) { + if (!signature_generated || overload_config.overload_id != signature_str) { + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + } + } + if (signature_generated) { + out << YAML::Key << "signature"; + out << YAML::Value << YAML::DoubleQuoted << signature_str; + } + if (!signature_generated) { + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out, options); + } else { + EmitTypeInfo(overload_config.parameters[0], out, options); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } + } + bool return_type_signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = + TypeInfoToTypeSpec(overload_config.return_type); + if (type_spec.ok()) { + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + return_type_signature_generated = true; + } + } + } + if (!return_type_signature_generated) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out, options); + out << YAML::EndMap; + } + out << YAML::EndMap; +} + +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + const std::vector& function_configs = + env_config.GetFunctionConfigs(); + if (function_configs.empty()) { + return; + } + + // Sort function_configs by name to ensure deterministic output. + std::vector sorted_function_configs = + function_configs; + absl::c_sort(sorted_function_configs, + [](const Config::FunctionConfig& a, + const Config::FunctionConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "functions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionConfig& function_config : + sorted_function_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << function_config.name; + if (!function_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << function_config.description; + } + if (!function_config.overload_configs.empty()) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = + function_config.overload_configs; + absl::c_sort(sorted_overloads, + [](const Config::FunctionOverloadConfig& a, + const Config::FunctionOverloadConfig& b) { + for (size_t i = 0; i < a.parameters.size(); ++i) { + // Order like this: foo(a), foo(a, b) + if (i >= b.parameters.size()) { + return false; + } + if (CompareTypeInfo(a.parameters[i], b.parameters[i])) { + return true; + } + if (CompareTypeInfo(b.parameters[i], a.parameters[i])) { + return false; + } + } + return false; + }); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionOverloadConfig& overload_config : + sorted_overloads) { + EmitFunctionOverloadConfig(function_config.name, overload_config, out, + options); + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +absl::Status ParseContextVariableConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node context_variable = root["context_variable"]; + if (!context_variable.IsDefined()) { + return absl::OkStatus(); + } + if (!context_variable.IsMap()) { + return YamlError(yaml, context_variable, + "Node 'context_variable' is not a map"); + } + + const YAML::Node type_name = context_variable["type_name"]; + const YAML::Node type = context_variable["type"]; + const YAML::Node* type_node = nullptr; + if (type.IsDefined() && type.IsScalar()) { + type_node = &type; + } else if (type_name.IsDefined() && type_name.IsScalar()) { + type_node = &type_name; + } else { + return YamlError(yaml, context_variable, + "Node 'context_variable' does not have a valid type"); + } + ABSL_DCHECK(type_node != nullptr); + config.SetContextType(GetString(yaml, *type_node)); + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { + Config config; + CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); + if (!root.IsDefined() || root.IsNull()) { + return config; + } + + if (!root.IsMap()) { + return absl::InvalidArgumentError(FormatYamlErrorMessage( + yaml, "Invalid CEL environment config YAML", root.Mark())); + } + + CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContextVariableConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); + return config; +} + +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options) { + YAML::Emitter out(os); + out.SetIndent(2); + out << YAML::BeginMap; + if (!env_config.GetName().empty()) { + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << env_config.GetName(); + } + EmitContainerConfig(env_config, out); + EmitExtensionConfigs(env_config, out); + EmitStandardLibraryConfig(env_config, out); + EmitVariableConfigs(env_config, out, options); + EmitFunctionConfigs(env_config, out, options); + out << YAML::EndMap; +} + +} // namespace cel diff --git a/env/env_yaml.h b/env/env_yaml.h new file mode 100644 index 000000000..7bf7bf6b4 --- /dev/null +++ b/env/env_yaml.h @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" + +namespace cel { + +// EnvConfigFromYaml creates an environment configuration from a YAML string. +// +// To ensure safety, only pass trusted YAML input. yaml-cpp has some fuzz +// coverage, but its security model is unclear. Additionally, callers should be +// aware that improper CEL configuration can lead to unsafe or unpredictably +// expensive expressions. +absl::StatusOr EnvConfigFromYaml(const std::string& yaml); + +struct EnvConfigToYamlOptions { + // Whether to use type and overload signatures instead of arg/return types in + // the output YAML. + // Example of type signature: "map>" vs + // type_name: "map" + // params: + // - type_name: "int" + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // + // Example of overload signature config: + // name: "foo" + // overloads: + // - signature: "timestamp.foo(A<~B>)" + // return: "int" + // vs + // name: "foo" + // overloads: + // - id: "foo_id" + // target: + // type_name: "timestamp" + // args: + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // return: + // type_name: "int" + // TODO(uncreated-issue/91): default to true after all dependencies are updated + bool use_type_signatures = false; +}; + +// EnvConfigToYaml serializes an environment configuration as a YAML string. +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc new file mode 100644 index 000000000..c5bd1b787 --- /dev/null +++ b/env/env_yaml_test.cc @@ -0,0 +1,1949 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAreArray; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(EnvYamlTest, ParseContainerConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: "test.container" + )yaml")); + + EXPECT_THAT(config.GetContainerConfig(), + Field(&Config::ContainerConfig::name, "test.container")); +} + +TEST(EnvYamlTest, ParseContainerConfig_AlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: + name: test.container + abbreviations: + - abbr1.Abbr1 + - abbr2.Abbr2 + aliases: + - alias: alias1 + qualified_name: qual.name1 + - alias: alias2 + qualified_name: qual.name2 + )yaml")); + + const auto& container_config = config.GetContainerConfig(); + EXPECT_EQ(container_config.name, "test.container"); + EXPECT_THAT(container_config.abbreviations, + UnorderedElementsAre("abbr1.Abbr1", "abbr2.Abbr2")); + ASSERT_THAT(container_config.aliases, SizeIs(2)); + EXPECT_EQ(container_config.aliases[0].alias, "alias1"); + EXPECT_EQ(container_config.aliases[0].qualified_name, "qual.name1"); + EXPECT_EQ(container_config.aliases[1].alias, "alias2"); + EXPECT_EQ(container_config.aliases[1].qualified_name, "qual.name2"); +} + +TEST(EnvYamlTest, ParseExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + extensions: + - name: "math" + version: latest + - name: "optional" + version: 2 + - name: "strings" + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvYamlTest, DefaultExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), IsEmpty()); +} + +TEST(EnvYamlTest, ParseStdlibConfig_ExclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + disable: true + disable_macros: true + exclude_macros: + - map + - filter + exclude_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_TRUE(stdlib_config.disable); + EXPECT_TRUE(stdlib_config.disable_macros); + EXPECT_THAT(stdlib_config.excluded_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT(stdlib_config.included_macros, IsEmpty()); + EXPECT_THAT( + stdlib_config.excluded_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + include_macros: + - map + - filter + include_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: "_+_(list<~A>,list<~A>)" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_THAT(stdlib_config.included_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT( + stdlib_config.included_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseVariableConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "msg" + type_name: "google.expr.proto3.test.TestAllTypes" + description: >- + msg represents all possible type permutation which + CEL understands from a proto perspective + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "msg"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "google.expr.proto3.test.TestAllTypes"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); + EXPECT_EQ(variable_config.description, + "msg represents all possible type permutation which CEL " + "understands from a proto perspective"); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type: "map" + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +TEST(EnvYamlTest, ParseContextVariableConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type_name: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableConfigAlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable) { + EXPECT_THAT(EnvConfigFromYaml(R"yaml( + context_variable: 123 + + )yaml"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' is not a map"))); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable2) { + EXPECT_THAT( + EnvConfigFromYaml(R"yaml( + context_variable: + type: + foo: bar + )yaml"), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' does not have a valid type"))); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type_name: "map" + params: + - type_name: "string" + - type_name: "A" + is_type_param: true + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +TEST(EnvYamlTest, ParseVariableConfigWithNestedRuleOldFormat) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "x" + type: + type_name: "int" + )yaml")); + + ASSERT_THAT(config.GetVariableConfigs(), SizeIs(1)); + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "x"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "int"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); +} + +struct ParseConstantTestCase { + std::string type; + std::string value; + std::string expected_error; // Empty if no error. + Constant expected_constant; +}; + +class EnvYamlParseConstantTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { + const ParseConstantTestCase& param = GetParam(); + const std::string yaml = absl::StrFormat( + R"yaml( + variables: + - name: "const" + type: "%s" + value: %s + )yaml", + param.type, param.value); + absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); + if (!param.expected_error.empty()) { + EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + return; + } + ASSERT_OK_AND_ASSIGN(Config config, status_or_config); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "const"); + EXPECT_EQ(variable_config.type_info.name, param.type) << " yaml: " << yaml; + EXPECT_EQ(variable_config.value, param.expected_constant) + << " yaml: " << yaml; +} + +std::vector GetParseConstantTestCases() { + return { + ParseConstantTestCase{ + .type = "null", + .value = "\"\"", + .expected_constant = Constant(nullptr), + }, + ParseConstantTestCase{ + .type = "null", + .value = "anything", + .expected_error = "Failed to parse null constant", + }, + ParseConstantTestCase{ + .type = "bool", + .value = "TRUE", + .expected_constant = Constant(true), + }, + ParseConstantTestCase{ + .type = "bool", + .value = "false", + .expected_constant = Constant(false), + }, + ParseConstantTestCase{ + .type = "bool", + .value = "yes", + .expected_error = "Failed to parse bool constant", + }, + ParseConstantTestCase{ + .type = "int", + .value = "42", + .expected_constant = Constant(int64_t{42}), + }, + ParseConstantTestCase{ + .type = "int", + .value = "41.999", + .expected_error = "Failed to parse int constant", + }, + ParseConstantTestCase{ + .type = "uint", + .value = "42", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type = "uint", + .value = "42u", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type = "uint", + .value = "-1", + .expected_error = "Failed to parse uint constant", + }, + ParseConstantTestCase{ + .type = "double", + .value = "42.42", + .expected_constant = Constant(42.42), + }, + ParseConstantTestCase{ + .type = "double", + .value = "abc", + .expected_error = "Failed to parse double constant", + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "b\"\\xFF\\x00\\x01\"", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "!!binary /wAB", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "!!binary YWJj=", + .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type = "string", + .value = "abc", + .expected_constant = Constant(StringConstant("abc")), + }, + ParseConstantTestCase{ + .type = "string", + .value = "\"\\\"abc\\\"\"", + .expected_constant = Constant(StringConstant("\"abc\"")), + }, + ParseConstantTestCase{ + .type = "duration", + .value = "1s", + .expected_constant = Constant(absl::Seconds(1)), + }, + ParseConstantTestCase{ + .type = "duration", + .value = "abc", + .expected_error = "Failed to parse duration constant", + }, + ParseConstantTestCase{ + .type = "timestamp", + .value = "2023-01-01T00:00:00Z", + .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), + }, + ParseConstantTestCase{ + .type = "timestamp", + .value = "abc", + .expected_error = "Failed to parse timestamp constant", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseConstantTest, EnvYamlParseConstantTest, + ::testing::ValuesIn(GetParseConstantTestCases())); + +struct ParseFunctionTestCase { + std::string yaml; + Config::FunctionConfig expected_function_config; +}; + +class EnvYamlParseFunctionTest + : public testing::TestWithParam {}; + +void ExpectTypeInfoEqual(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + EXPECT_EQ(actual.name, expected.name); + EXPECT_EQ(actual.is_type_param, expected.is_type_param); + ASSERT_THAT(actual.params, SizeIs(expected.params.size())); + for (size_t i = 0; i < expected.params.size(); ++i) { + ExpectTypeInfoEqual(actual.params[i], expected.params[i]); + } +} + +TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { + const ParseFunctionTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.yaml)); + + ASSERT_THAT(config.GetFunctionConfigs(), SizeIs(1)); + const Config::FunctionConfig& function_config = + config.GetFunctionConfigs()[0]; + const Config::FunctionConfig& expected = param.expected_function_config; + + EXPECT_EQ(function_config.name, expected.name); + EXPECT_EQ(function_config.description, expected.description); + + ASSERT_THAT(function_config.overload_configs, + SizeIs(expected.overload_configs.size())); + + for (size_t i = 0; i < expected.overload_configs.size(); ++i) { + const auto& actual_overload = function_config.overload_configs[i]; + const auto& expected_overload = expected.overload_configs[i]; + + EXPECT_EQ(actual_overload.overload_id, expected_overload.overload_id); + EXPECT_THAT(actual_overload.examples, + ElementsAreArray(expected_overload.examples)); + EXPECT_EQ(actual_overload.is_member_function, + expected_overload.is_member_function); + + ASSERT_THAT(actual_overload.parameters, + SizeIs(expected_overload.parameters.size())); + for (size_t j = 0; j < expected_overload.parameters.size(); ++j) { + ExpectTypeInfoEqual(actual_overload.parameters[j], + expected_overload.parameters[j]); + } + + ExpectTypeInfoEqual(actual_overload.return_type, + expected_overload.return_type); + } +} + +std::vector GetParseFunctionTestCases() { + return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - signature: "google.protobuf.StringValue.isEmpty()" + examples: + - "''.isEmpty() // true" + return: "bool" + - signature: "list<~T>.isEmpty()" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = + "google.protobuf.StringValue.isEmpty()", + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = {{.name = "string_wrapper"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .overload_id = "list<~T>.isEmpty()", + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - id: "wrapper_string_isEmpty" + examples: + - "''.isEmpty() // true" + target: + type_name: "google.protobuf.StringValue" + return: + type_name: "bool" + - id: "list_isEmpty" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + target: + type_name: "list" + params: + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "wrapper_string_isEmpty", + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = + {{.name = "google.protobuf.StringValue"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .overload_id = "list_isEmpty", + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - signature: "contains(list<~T>, ~T)" + examples: + - "contains([1, 2, 3], 2) // true" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "contains(list<~T>, ~T)", + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - id: "global_contains" + examples: + - "contains([1, 2, 3], 2) // true" + args: + - type_name: "list" + params: + - type_name: "T" + is_type_param: true + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "global_contains", + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseFunctionTest, EnvYamlParseFunctionTest, + ::testing::ValuesIn(GetParseFunctionTestCases())); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +class EnvYamlParseTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { + const ParseTestCase& param = GetParam(); + absl::StatusOr config = EnvConfigFromYaml(param.yaml); + EXPECT_THAT(config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlParseTest, EnvYamlParseTest, + ::testing::Values( + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Invalid CEL environment config YAML\n" + "| invalid yaml \n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + name: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'name' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + - error: "error" + )yaml", + .expected_error = + "3:19: Node 'container' is neither a string nor a map\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + name: [] + )yaml", + .expected_error = "3:25: Node 'name' in container is not a string\n" + "| name: []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: "abbr" + )yaml", + .expected_error = "3:34: Node 'abbreviations' is not a sequence\n" + "| abbreviations: \"abbr\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: + - [] + )yaml", + .expected_error = "4:21: Abbreviation is not a string\n" + "| - []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: "not a sequence" + )yaml", + .expected_error = "3:28: Node 'aliases' is not a sequence\n" + "| aliases: \"not a sequence\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - "not a map" + )yaml", + .expected_error = "4:21: Alias entry is not a map\n" + "| - \"not a map\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - qualified_name: "qual" + )yaml", + .expected_error = "4:21: Alias entry missing 'alias' string\n" + "| - qualified_name: \"qual\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - alias: "my_alias" + )yaml", + .expected_error = "4:21: Alias entry missing" + " 'qualified_name' string\n" + "| - alias: \"my_alias\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + -name: "optional" + - name: "other" + )yaml", + .expected_error = "5:21: end of map not found\n" + "| - name: \"other\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: "bar" + )yaml", + .expected_error = "2:27: Node 'extensions' is not a sequence\n" + "| extensions: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: + - something: "bar" + )yaml", + .expected_error = "4:19: Extension name is not a string\n" + "| - something: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: last + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: last\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: -15 + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: -15\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: 1 + - name: "math" + version: 2 + )yaml", + .expected_error = "5:19: Extension 'math' version 1 is already " + "included. Cannot also include version 2\n" + "| - name: \"math\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: "error" + )yaml", + .expected_error = "2:23: Standard library config ('stdlib') " + "is not a map\n" + "| stdlib: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable: "error" + )yaml", + .expected_error = "3:26: Node 'disable' is not a boolean\n" + "| disable: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable_macros: "error" + )yaml", + .expected_error = "3:33: Node 'disable_macros' is not a boolean\n" + "| disable_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: "error" + )yaml", + .expected_error = "3:33: Node 'exclude_macros' is not a sequence\n" + "| exclude_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: + - foo: "error" + )yaml", + .expected_error = "4:19: Entry in 'exclude_macros' " + "is not a string\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: "error" + )yaml", + .expected_error = "3:36: Node 'include_functions' " + "is not a sequence\n" + "| include_functions: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - "error" + )yaml", + .expected_error = "4:19: Entry in 'include_functions' " + "is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - foo: "error" + )yaml", + .expected_error = "4:19: Function name in not specified in " + "'include_functions'\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "5:30: Overloads in 'include_functions' entry " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - foo_string + )yaml", + .expected_error = "6:21: Overload in 'include_functions' entry " + "is not a map\n" + "| - foo_string\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - id: + - foo_int64 + )yaml", + .expected_error = "7:21: Overload id in 'include_functions' entry " + "is not a string\n" + "| - foo_int64\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: + - type_name: "opaque" + )yaml", + .expected_error = "4:19: Variable name is not a string\n" + "| - type_name: \"opaque\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: + - params: + )yaml", + .expected_error = "5:21: Node 'type_name' is not a string\n" + "| - params:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + params: + - type_name: "int" + - type_name: "A" + is_type_param: maybe + )yaml", + .expected_error = "8:38: Node 'is_type_param' is not a boolean\n" + "| is_type_param: maybe\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + type: "opaque" + )yaml", + .expected_error = "4:19: Node 'type' and 'type_name'" + " are mutually exclusive\n" + "| type_name: \"opaque\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: -1 + )yaml", + .expected_error = "5:26: Failed to parse uint constant\n" + "| value: -1\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: many + )yaml", + .expected_error = "2:26: Node 'functions' is not a sequence\n" + "| functions: many\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: + - overloads: + )yaml", + .expected_error = "4:19: Function name is not a string\n" + "| - overloads:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "4:30: Function 'overloads' item " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: + - "error" + )yaml", + .expected_error = "6:25: Function overload id is not a string\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + - "error" + )yaml", + .expected_error = "7:25: Function overload target is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + type_name: "Foo" + params: + - type_name: + - is_type_param: true + )yaml", + .expected_error = "10:31: Node 'type_name' is not a string\n" + "| " + "- is_type_param: true\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + args: "a bunch" + )yaml", + .expected_error = "6:29: Function overload args is not a sequence\n" + "| args: \"a bunch\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + return: [1] + )yaml", + .expected_error = "6:31: Function overload return type" + " is neither a string nor a map\n" + "| return: [1]\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + signature: "bar()" + )yaml", + .expected_error = "6:34: Function overload name \"bar\" " + "does not match function name \"foo\"\n" + "| signature: \"bar()\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: [ "foo()" ] + )yaml", + .expected_error = + "5:34: Function overload signature is not a string\n" + "| - signature: [ \"foo()\" ]\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + target: + type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and target " + "are mutually exclusive\n" + "| target:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + args: + - type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and args are " + "mutually exclusive\n" + "| args:\n" + "| ^", + })); + +std::string Unindent(std::string_view yaml) { + absl::string_view yaml_view = yaml; + std::vector lines = absl::StrSplit(yaml_view, '\n'); + int indent = -1; + std::vector unindented_lines; + for (auto& line : lines) { + std::size_t pos = line.find_first_not_of(" \t"); + if (pos == std::string::npos) { + // Skip blank lines. + continue; + } + if (indent == -1) { + indent = pos; + } + if (pos >= indent) { + unindented_lines.push_back(line.substr(indent)); + } else { + unindented_lines.push_back(line); + } + } + return absl::StrJoin(unindented_lines, "\n"); +} + +struct ExportTestCase { + absl::StatusOr config; + std::string expected_yaml; + std::string expected_alt_yaml; +}; + +class EnvYamlExportTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlExportTest, EnvYamlExport) { + const ExportTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, param.config); + std::stringstream ss; + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); + std::string yaml_output = Unindent(ss.str()); + std::string expected_yaml = Unindent(param.expected_yaml); + EXPECT_EQ(yaml_output, expected_yaml); + + if (!param.expected_alt_yaml.empty()) { + std::stringstream alt_ss; + EnvConfigToYaml(config, alt_ss, {.use_type_signatures = false}); + std::string alt_yaml_output = Unindent(alt_ss.str()); + std::string expected_alt_yaml = Unindent(param.expected_alt_yaml); + EXPECT_EQ(alt_yaml_output, expected_alt_yaml); + } +} + +std::vector GetExportTestCases() { + return { + ExportTestCase{ + .config = + []() { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig( + {.name = "test.container", + .abbreviations = {"foo", "bar"}, + .aliases = { + {.alias = "foo", .qualified_name = "test.foo"}, + {.alias = "bar", .qualified_name = "test.bar"}, + }}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: + name: "test.container" + abbreviations: + - "bar" + - "foo" + aliases: + - alias: "bar" + qualified_name: "test.bar" + - alias: "foo" + qualified_name: "test.foo" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("math")); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("optional", 2)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("bindings")); + return config; + }(), + .expected_yaml = R"yaml( + extensions: + - name: "bindings" + - name: "math" + - name: "optional" + version: 2 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable_macros = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable_macros: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), + std::make_pair("matches", ""), + std::make_pair("_+_", "_+_(bytes,bytes)"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_functions: + - name: "_+_" + overloads: + - id: "_+_(bytes,bytes)" + - id: "_+_(list<~A>,list<~A>)" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "null"}, + .value = Constant(nullptr)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "null" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "bool"}, + .value = Constant(true)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "bool" + value: true + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bool" + value: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "int"}, + .value = Constant(int64_t{42})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "int" + value: 42 + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "int" + value: 42 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "uint"}, + .value = Constant(uint64_t{777})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "uint" + value: 777 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "double"}, + .value = Constant(0.75)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "double" + value: 0.75 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "bytes"}, + .value = Constant( + BytesConstant(absl::string_view("\xff\x00\x01", 3)))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "bytes" + value: b"\xff\x00\x01" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + Constant c; + c.set_string_value("'single' \"double\""); + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "string"}, + .value = Constant(StringConstant("'single' \"double\""))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "string" + value: "'single' \"double\"" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "duration"}, + .value = Constant(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "duration" + value: 1h2m3s + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "duration" + value: 1h2m3s + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "timestamp"}, + .value = Constant(absl::FromUnixSeconds(1767323045))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = + "google.expr.proto3.test.TestAllTypes"}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "google.expr.proto3.test.TestAllTypes" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = { + .name = "A", + .params = {{.name = "int"}, + {.name = "B", .is_type_param = true}}}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "A" + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "A" + params: + - type_name: "int" + - type_name: "B" + is_type_param: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig({.name = "foo"})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_id", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + is_type_param: true + return: + type_name: "int" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .description = "my desc", + .overload_configs = { + {.overload_id = "foo_overload_a", + .parameters = {{.name = "timestamp"}}, + .return_type = {.name = "list", + .params = {{.name = "int"}}}}, + {.overload_id = "foo_overload_b", + .parameters = {{.name = "double"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "string"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + description: "my desc" + overloads: + - id: "foo_overload_b" + signature: "foo(double,A)" + return: "string" + - id: "foo_overload_a" + signature: "foo(timestamp)" + return: "list" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + description: "my desc" + overloads: + - id: "foo_overload_b" + args: + - type_name: "double" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "string" + - id: "foo_overload_a" + args: + - type_name: "timestamp" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "timestamp.foo(A<~B>)", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "timestamp.foo(A<~B>)" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + is_type_param: true + return: + type_name: "int" + )yaml", + }, + }; +}; + +INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, + ::testing::ValuesIn(GetExportTestCases())); + +class EnvYamlStructuredRoundTripTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlStructuredRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetStructuredRoundTripTestCases() { + return { + R"yaml( + stdlib: + disable: true + disable_macros: true + )yaml", + R"yaml( + name: "test.env" + container: "common.proto.prefix" + extensions: + - name: "math" + version: 0 + - name: "optional" + version: 2 + stdlib: + include_macros: + - "filter" + - "map" + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + container: + name: "test.container" + abbreviations: + - "abbr1.Abbr1" + - "abbr2.Abbr2" + aliases: + - alias: "alias1" + qualified_name: "qual.name1" + - alias: "alias2" + qualified_name: "qual.name2" + )yaml", + R"yaml( + extensions: + - name: "bindings" + - name: "math" + stdlib: + exclude_macros: + - "filter" + - "map" + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + functions: + - name: "bar" + - name: "foo" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlStructuredRoundTripTest, EnvYamlStructuredRoundTripTest, + ::testing::ValuesIn(GetStructuredRoundTripTestCases())); + +class EnvYamlSignatureRoundTripTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlSignatureRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetSignatureRoundTripTestCases() { + return { + R"yaml( + variables: + - name: "a" + type: "null" + - name: "b" + type: "bool" + value: true + - name: "c" + type: "int" + value: 42 + - name: "d" + type: "uint" + value: 777 + - name: "e" + type: "double" + value: 0.75 + - name: "f" + type: "bytes" + value: b"\xff\x00\x01" + - name: "g" + type: "string" + value: "plain 'single' \"double\"" + - name: "h" + type: "duration" + value: 1h2m3s + - name: "i" + type: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "foo(timestamp,A<~B>)" + return: "list" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlSignatureRoundTripTest, + EnvYamlSignatureRoundTripTest, + ::testing::ValuesIn(GetSignatureRoundTripTestCases())); + +} // namespace +} // namespace cel diff --git a/env/internal/BUILD b/env/internal/BUILD new file mode 100644 index 000000000..ec4a0b15c --- /dev/null +++ b/env/internal/BUILD @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "ext_registry", + srcs = ["ext_registry.cc"], + hdrs = ["ext_registry.h"], + deps = [ + "//compiler", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "runtime_ext_registry", + srcs = ["runtime_ext_registry.cc"], + hdrs = ["runtime_ext_registry.h"], + deps = [ + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ext_registry_test", + srcs = ["ext_registry_test.cc"], + deps = [ + ":ext_registry", + "//checker:type_checker_builder", + "//compiler", + "//internal:testing", + "//parser:parser_interface", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "runtime_ext_registry_test", + srcs = ["runtime_ext_registry_test.cc"], + deps = [ + ":runtime_ext_registry", + "//common:ast", + "//common:source", + "//common:value", + "//common:value_testing", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:options", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/internal/ext_registry.cc b/env/internal/ext_registry.cc new file mode 100644 index 000000000..b32239ac3 --- /dev/null +++ b/env/internal/ext_registry.cc @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +void ExtensionRegistry::RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + library_registry_.push_back( + LibraryRegistration(name, alias, version, std::move(library_factory))); +} + +absl::StatusOr ExtensionRegistry::GetCompilerLibrary( + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name)); + } + version = max_version; + } + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.GetLibrary(); + } + } + + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name, "#", version)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/ext_registry.h b/env/internal/ext_registry.h new file mode 100644 index 000000000..ab5b67a24 --- /dev/null +++ b/env/internal/ext_registry.h @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +// A registry for CEL compiler extension libraries. +// +// Used to register and retrieve CompilerLibraries by name (or alias) and +// version. +class ExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory); + + absl::StatusOr GetCompilerLibrary(absl::string_view name, + int version) const; + + private: + class LibraryRegistration final { + public: + LibraryRegistration( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + factory_(std::move(library_factory)) {} + + CompilerLibrary GetLibrary() const { return factory_(); } + + private: + std::string name_; + std::string alias_; + int version_; + absl::AnyInvocable factory_; + + friend class ExtensionRegistry; + }; + + std::vector library_registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ diff --git a/env/internal/ext_registry_test.cc b/env/internal/ext_registry_test.cc new file mode 100644 index 000000000..9e345c781 --- /dev/null +++ b/env/internal/ext_registry_test.cc @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "internal/testing.h" +#include "parser/parser_interface.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::Field; +using ::testing::HasSubstr; + +TEST(ExtensionRegistryTest, GetCompilerLibrary) { + ExtensionRegistry registry; + registry.RegisterCompilerLibrary("foo1", "f", 1, []() { + return CompilerLibrary("foo1_1", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo1", "f", 2, []() { + return CompilerLibrary("foo1_2", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo2", "", 1, []() { + return CompilerLibrary("foo2_1", nullptr, nullptr); + }); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 2), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 3), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo1#3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", 1), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", ExtensionRegistry::kLatest), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/internal/runtime_ext_registry.cc b/env/internal/runtime_ext_registry.cc new file mode 100644 index 000000000..dc78a38e3 --- /dev/null +++ b/env/internal/runtime_ext_registry.cc @@ -0,0 +1,64 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +void RuntimeExtensionRegistry::AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) { + registry_.push_back(Registration(name, alias, version, + std::move(function_registration_callback))); +} + +absl::Status RuntimeExtensionRegistry::RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options, + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); + } + version = max_version; + } + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.RegisterExtensionFunctions(runtime_builder, + runtime_options); + } + } + + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/runtime_ext_registry.h b/env/internal/runtime_ext_registry.h new file mode 100644 index 000000000..67838519f --- /dev/null +++ b/env/internal/runtime_ext_registry.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +using FunctionRegistrationCallback = absl::AnyInvocable; + +// A registry for CEL runtime extension functions. +// +// Used to register runtime functions for extensions by name (or alias) and +// version. +class RuntimeExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback); + + absl::Status RegisterExtensionFunctions(RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options, + absl::string_view name, + int version) const; + + private: + class Registration final { + public: + Registration(absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + function_registration_callback_( + std::move(function_registration_callback)) {} + + absl::Status RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) const { + return function_registration_callback_(runtime_builder, runtime_options); + } + + private: + std::string name_; + std::string alias_; + int version_; + FunctionRegistrationCallback function_registration_callback_; + + friend class RuntimeExtensionRegistry; + }; + + std::vector registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ diff --git a/env/internal/runtime_ext_registry_test.cc b/env/internal/runtime_ext_registry_test.cc new file mode 100644 index 000000000..c6125d20f --- /dev/null +++ b/env/internal/runtime_ext_registry_test.cc @@ -0,0 +1,126 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::cel::test::StringValueIs; + +Value Hello1(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, old " + input.ToString() + "!", + context.arena()); +} + +Value Hello2(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, new " + input.ToString() + "!", + context.arena()); +} + +RuntimeExtensionRegistry GetRuntimeExtensionRegistry() { + RuntimeExtensionRegistry registry; + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 1, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("hello", &Hello1, + runtime_builder.function_registry()); + }); + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 2, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterMemberOverload("hello", &Hello2, + runtime_builder.function_registry()); + }); + return registry; +} + +class RuntimeExtensionRegistryTest : public testing::Test { + protected: + absl::StatusOr Run(std::string_view extension_name, int version, + std::string_view expr) { + const RuntimeExtensionRegistry registry = GetRuntimeExtensionRegistry(); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr parser, + NewParserBuilder(ParserOptions())->Build()); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr source, NewSource(expr, "")); + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, parser->Parse(*source)); + + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + cel::RuntimeOptions runtime_options; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool, runtime_options)); + + CEL_RETURN_IF_ERROR(registry.RegisterExtensionFunctions( + runtime_builder, runtime_options, extension_name, version)); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + Activation activation; + return program->Evaluate(&arena_, activation); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(RuntimeExtensionRegistryTest, SpecificExtensionVersion) { + EXPECT_THAT(Run("hello_extension", 1, "hello('world')"), + IsOkAndHolds(StringValueIs("Hello, old world!"))); +} + +TEST_F(RuntimeExtensionRegistryTest, LatestExtensionVersion) { + EXPECT_THAT(Run("hello_extension_alias", RuntimeExtensionRegistry::kLatest, + "'world'.hello()"), + IsOkAndHolds(StringValueIs("Hello, new world!"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc new file mode 100644 index 000000000..b866a5965 --- /dev/null +++ b/env/runtime_std_extensions.cc @@ -0,0 +1,133 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "checker/optional.h" +#include "env/env_runtime.h" +#include "env/internal/runtime_ext_registry.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext.h" +#include "extensions/math_ext_decls.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" +#include "runtime/optional_types.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { + +void RegisterStandardExtensions(EnvRuntime& env_runtime) { + env_internal::RuntimeExtensionRegistry& registry = + env_runtime.GetRuntimeExtensionRegistry(); + registry.AddFunctionRegistration( + "cel.lib.ext.bindings", "bindings", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.encoders", "encoders", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterEncodersFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.lists", "lists", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterListsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.math", "math", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.optional", "optional", version, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::EnableOptionalTypes(runtime_builder); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.protos", "protos", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.sets", "sets", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterSetsFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.strings", "strings", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + cel::extensions::StringsExtensionOptions strings_options; + strings_options.version = version; + return cel::extensions::RegisterStringsFunctions( + runtime_builder.function_registry(), runtime_options, + strings_options); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.regex", "regex", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterRegexExtensionFunctions( + runtime_builder); + }); +} + +} // namespace cel diff --git a/env/runtime_std_extensions.h b/env/runtime_std_extensions.h new file mode 100644 index 000000000..d7f714226 --- /dev/null +++ b/env/runtime_std_extensions.h @@ -0,0 +1,46 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ + +#include "env/env_runtime.h" + +namespace cel { + +// Registers the standard CEL extension functions with the given environment +// runtime. This makes them available, but does not enable them. See Env::Config +// for how to enable extensions. +// +// Included in the standard runtime environment: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// +// NOTE: Not included in the standard runtime environment yet - include manually +// if needed: +// - cel.lib.ext.regex (alias: "regex") +// +void RegisterStandardExtensions(EnvRuntime& env_runtime); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ diff --git a/env/runtime_std_extensions_test.cc b/env/runtime_std_extensions_test.cc new file mode 100644 index 000000000..4c7cb9829 --- /dev/null +++ b/env/runtime_std_extensions_test.cc @@ -0,0 +1,229 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/optional.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/strings.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string extension_name; + std::vector extension_versions = {0}; + int latest_extension_version = 0; + std::string expr; + bool requires_optional_extension = false; +}; + +using RuntimeStdExtensionTest = testing::TestWithParam; + +TEST_P(RuntimeStdExtensionTest, RegisterStandardExtensions) { + const TestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env); + + Config compiler_config; + // For the compilation step, assume latest version of the extension to ensure + // a successful compilation. Later, we will test the runtime with different + // extension versions. + ASSERT_THAT(compiler_config.AddExtensionConfig( + param.extension_name, Config::ExtensionConfig::kLatest), + IsOk()); + env.SetConfig(compiler_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + for (int version = 0; version <= param.latest_extension_version; ++version) { + Config runtime_config; + // Request a specific version of the extension to be configured in the + // runtime. + ASSERT_THAT( + runtime_config.AddExtensionConfig(param.extension_name, version), + IsOk()); + if (param.requires_optional_extension) { + ASSERT_THAT(runtime_config.AddExtensionConfig("optional"), IsOk()); + } + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool( + cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(runtime_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + absl::StatusOr> program_or = + runtime->CreateProgram(std::make_unique(*ast)); + + // If the function is not supported in this extension version, check that + // the program creation returned an error. + if (!absl::c_contains(param.extension_versions, version)) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr << " version: " << version; + continue; + } + + ASSERT_THAT(program_or, IsOk()) + << " expr: " << param.expr << " version: " << version; + std::unique_ptr program = *std::move(program_or); + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) + << " expr: " << param.expr << " version: " << version; + } +} + +std::vector GetRuntimeStdExtensionTestCases() { + return { + TestCase{ + // The "bindings" extension does not register any runtime functions - + // only macros. + .extension_name = "bindings", + .expr = "cel.bind(t, 42, t + 1) == 43", + }, + TestCase{ + .extension_name = "encoders", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].slice(0, 1) == [3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[[1, 2], 3].flatten() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].sort() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.least([1, -2, 3]) == -2", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.floor(42.9) == 42.0", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.sqrt(4) == 2.0", + }, + TestCase{ + .extension_name = "optional", + .extension_versions = {0, 1, 2}, + .latest_extension_version = kOptionalExtensionLatestVersion, + .expr = "optional.of(1).hasValue()", + }, + TestCase{ + // No runtime functions. + .extension_name = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension_name = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {0, 1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'Hello, who!'.replace('who', 'World') == 'Hello, World!'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "strings.quote('hello') == '\"hello\"'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "['hello', 'world'].join(', ') == 'hello, world'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'stressed'.reverse() == 'desserts'", + }, + TestCase{ + // No runtime functions. + .extension_name = "cel.lib.ext.comprev2", + .expr = "[1, 2, 3].map(i, i * 2) == [2, 4, 6]", + }, + TestCase{ + .extension_name = "cel.lib.ext.regex", + .expr = "regex.replace('abc', '$', '_end') == 'abc_end'", + .requires_optional_extension = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(RuntimeStdExtensionTest, RuntimeStdExtensionTest, + ValuesIn(GetRuntimeStdExtensionTestCases())); + +} // namespace +} // namespace cel diff --git a/env/type_info.cc b/env/type_info.cc new file mode 100644 index 000000000..f49fab9f4 --- /dev/null +++ b/env/type_info.cc @@ -0,0 +1,410 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +std::optional TypeNameToTypeKind(absl::string_view type_name) { + // Excluded types: + // kUnknown + // kError + // kTypeParam + // kFunction + // kEnum + + static const absl::NoDestructor< + absl::flat_hash_map> + kTypeNameToTypeKind({ + {"null", TypeKind::kNull}, + {"bool", TypeKind::kBool}, + {"int", TypeKind::kInt}, + {"uint", TypeKind::kUint}, + {"double", TypeKind::kDouble}, + {"string", TypeKind::kString}, + {"bytes", TypeKind::kBytes}, + {"timestamp", TypeKind::kTimestamp}, + {TimestampType::kName, TypeKind::kTimestamp}, + {"duration", TypeKind::kDuration}, + {DurationType::kName, TypeKind::kDuration}, + {"list", TypeKind::kList}, + {"map", TypeKind::kMap}, + {"", TypeKind::kDyn}, + {"any", TypeKind::kAny}, + {"dyn", TypeKind::kDyn}, + {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {"bool_wrapper", TypeKind::kBoolWrapper}, + {IntWrapperType::kName, TypeKind::kIntWrapper}, + {"int_wrapper", TypeKind::kIntWrapper}, + {UintWrapperType::kName, TypeKind::kUintWrapper}, + {"uint_wrapper", TypeKind::kUintWrapper}, + {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {"double_wrapper", TypeKind::kDoubleWrapper}, + {StringWrapperType::kName, TypeKind::kStringWrapper}, + {"string_wrapper", TypeKind::kStringWrapper}, + {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"bytes_wrapper", TypeKind::kBytesWrapper}, + {"type", TypeKind::kType}, + }); + if (auto it = kTypeNameToTypeKind->find(type_name); + it != kTypeNameToTypeKind->end()) { + return it->second; + } + + return std::nullopt; +} +} // namespace + +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena) { + if (type_info.is_type_param) { + return TypeParamType(type_info.name); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty() && descriptor_pool != nullptr) { + const google::protobuf::Descriptor* type = + descriptor_pool->FindMessageTypeByName(type_info.name); + if (type != nullptr) { + return Type::Message(type); + } + } + // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types + std::vector parameter_types; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(param, descriptor_pool, arena)); + parameter_types.push_back(parameter_type); + } + + return OpaqueType(arena, type_info.name, parameter_types); + } + + switch (*type_kind) { + case TypeKind::kNull: + return NullType(); + case TypeKind::kBool: + return BoolType(); + case TypeKind::kInt: + return IntType(); + case TypeKind::kUint: + return UintType(); + case TypeKind::kDouble: + return DoubleType(); + case TypeKind::kString: + return StringType(); + case TypeKind::kBytes: + return BytesType(); + case TypeKind::kDuration: + return DurationType(); + case TypeKind::kTimestamp: + return TimestampType(); + case TypeKind::kList: { + Type element_type; + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN( + element_type, + TypeInfoToType(type_info.params[0], descriptor_pool, arena)); + } else { + element_type = DynType(); + } + return ListType(arena, element_type); + } + case TypeKind::kMap: { + Type key_type = DynType(); + Type value_type = DynType(); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN( + value_type, + TypeInfoToType(type_info.params[1], descriptor_pool, arena)); + } + return MapType(arena, key_type, value_type); + } + case TypeKind::kDyn: + return DynType(); + case TypeKind::kAny: + return AnyType(); + case TypeKind::kBoolWrapper: + return BoolWrapperType(); + case TypeKind::kIntWrapper: + return IntWrapperType(); + case TypeKind::kUintWrapper: + return UintWrapperType(); + case TypeKind::kDoubleWrapper: + return DoubleWrapperType(); + case TypeKind::kStringWrapper: + return StringWrapperType(); + case TypeKind::kBytesWrapper: + return BytesWrapperType(); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeType(arena, DynType()); + } + CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + return TypeType(arena, type); + } + default: + return DynType(); + } +} +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + return TypeSpec(ParamTypeSpec(type_info.name)); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty()) { + return TypeSpec(MessageTypeSpec(type_info.name)); + } else { + std::vector param_specs; + param_specs.reserve(type_info.params.size()); + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(TypeSpec param_spec, TypeInfoToTypeSpec(param)); + param_specs.push_back(std::move(param_spec)); + } + return TypeSpec(AbstractType(type_info.name, std::move(param_specs))); + } + } + + switch (*type_kind) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec()); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kList: { + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(TypeSpec elem_type, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } else { + return TypeSpec(ListTypeSpec()); + } + } + case TypeKind::kMap: { + if (type_info.params.empty()) { + return TypeSpec(MapTypeSpec()); + } + CEL_ASSIGN_OR_RETURN(TypeSpec key_type, + TypeInfoToTypeSpec(type_info.params[0])); + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN(TypeSpec value_type, + TypeInfoToTypeSpec(type_info.params[1])); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + return TypeSpec(MapTypeSpec( + std::make_unique(std::move(key_type)), nullptr)); + } + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec()); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeSpec(std::make_unique(DynTypeSpec())); + } + CEL_ASSIGN_OR_RETURN(TypeSpec type_param, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec(std::make_unique(std::move(type_param))); + } + default: + return TypeSpec(DynTypeSpec()); + } +} + +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec) { + Config::TypeInfo type_info; + + if (type_spec.has_dyn()) { + type_info.name = "dyn"; + } else if (type_spec.has_null()) { + type_info.name = "null"; + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + type_info.name = "bool"; + break; + case PrimitiveType::kInt64: + type_info.name = "int"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint"; + break; + case PrimitiveType::kDouble: + type_info.name = "double"; + break; + case PrimitiveType::kString: + type_info.name = "string"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes"; + break; + default: + return absl::InvalidArgumentError("Unspecified primitive type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + type_info.name = "bool_wrapper"; + break; + case PrimitiveType::kInt64: + type_info.name = "int_wrapper"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint_wrapper"; + break; + case PrimitiveType::kDouble: + type_info.name = "double_wrapper"; + break; + case PrimitiveType::kString: + type_info.name = "string_wrapper"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes_wrapper"; + break; + default: + return absl::InvalidArgumentError("Unspecified wrapper type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + type_info.name = "any"; + break; + case WellKnownTypeSpec::kTimestamp: + type_info.name = "timestamp"; + break; + case WellKnownTypeSpec::kDuration: + type_info.name = "duration"; + break; + default: + return absl::InvalidArgumentError("Unspecified well known type"); + } + } else if (type_spec.has_list_type()) { + type_info.name = "list"; + const ListTypeSpec& list_type = type_spec.list_type(); + if (list_type.has_elem_type() && list_type.elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(list_type.elem_type())); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_map_type()) { + type_info.name = "map"; + const MapTypeSpec& map_type = type_spec.map_type(); + bool has_key = + map_type.has_key_type() && map_type.key_type().is_specified(); + bool has_value = + map_type.has_value_type() && map_type.value_type().is_specified(); + if (has_key || has_value) { + if (has_key) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(map_type.key_type())); + type_info.params.push_back(std::move(param)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + if (has_value) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_value, + TypeSpecToTypeInfo(map_type.value_type())); + type_info.params.push_back(std::move(param_value)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + } + } else if (type_spec.has_message_type()) { + type_info.name = type_spec.message_type().type(); + } else if (type_spec.has_type_param()) { + type_info.name = type_spec.type_param().type(); + type_info.is_type_param = true; + } else if (type_spec.has_type()) { + type_info.name = "type"; + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(type_spec.type())); + type_info.params.push_back(std::move(param)); + } else if (type_spec.has_abstract_type()) { + type_info.name = type_spec.abstract_type().name(); + for (const TypeSpec& param_spec : + type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(param_spec)); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_error()) { + return absl::InvalidArgumentError( + "ErrorType cannot be converted to TypeInfo"); + } else if (type_spec.has_function()) { + return absl::InvalidArgumentError( + "FunctionType cannot be converted to TypeInfo"); + } else { + return absl::InvalidArgumentError("Unknown TypeSpec kind"); + } + + return type_info; +} + +} // namespace cel diff --git a/env/type_info.h b/env/type_info.h new file mode 100644 index 000000000..3f802ce1a --- /dev/null +++ b/env/type_info.h @@ -0,0 +1,42 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ +#define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "env/config.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Converts a Config::TypeInfo to a cel::Type. Returns an error if the type_info +// cannot be converted to a known cel::Type, a list configured with more than +// one parameter. +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); + +// Converts a Config::TypeInfo to a cel::TypeSpec. +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info); + +// Converts a cel::TypeSpec to a Config::TypeInfo. +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ diff --git a/env/type_info_test.cc b/env/type_info_test.cc new file mode 100644 index 000000000..f9d46f9a9 --- /dev/null +++ b/env/type_info_test.cc @@ -0,0 +1,300 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "common/ast/metadata.h" +#include "common/type.h" +#include "common/type_proto.h" +#include "env/config.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/text_format.h" + +namespace cel { + +std::ostream& operator<<(std::ostream& os, const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + os << "?"; + } + os << type_info.name; + if (!type_info.params.empty()) { + os << "<"; + for (size_t i = 0; i < type_info.params.size(); ++i) { + if (i > 0) os << ", "; + os << type_info.params[i]; + } + os << ">"; + } + return os; +} + +namespace { + +using absl_testing::IsOk; +using absl_testing::StatusIs; +using testing::ValuesIn; + +struct TestCase { + Config::TypeInfo type_info; + std::string expected_type_pb; +}; + +using TypeInfoTest = testing::TestWithParam; + +TEST_P(TypeInfoTest, TypeInfo) { + const TestCase& param = GetParam(); + cel::expr::Type expected_type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, + &expected_type_pb)); + + google::protobuf::Arena arena; + const google::protobuf::DescriptorPool* descriptor_pool = + cel::internal::GetTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN( + cel::Type actual_type, + cel::TypeInfoToType(param.type_info, descriptor_pool, &arena)); + + cel::expr::Type actual_type_pb; + ASSERT_THAT(cel::TypeToProto(actual_type, &actual_type_pb), IsOk()); + EXPECT_THAT(actual_type_pb, + cel::internal::test::EqualsProto(expected_type_pb)); +} + +std::vector GetTestCases() { + return { + TestCase{ + .type_info = {.name = "int"}, + .expected_type_pb = "primitive: INT64", + }, + TestCase{ + .type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", + }, + TestCase{ + .type_info = {.name = "list"}, + .expected_type_pb = "list_type { elem_type { dyn {} }}", + }, + TestCase{ + .type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "map_type { key_type { primitive: STRING } " + "value_type { primitive: INT64 }}", + }, + TestCase{ + .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + .expected_type_pb = + "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", + }, + TestCase{ + .type_info = {.name = "A", + .params = {Config::TypeInfo{.name = "B", + .is_type_param = true}}}, + .expected_type_pb = + "abstract_type { name: 'A' parameter_types { type_param: 'B' } }", + }, + TestCase{ + .type_info = {.name = "any"}, + .expected_type_pb = "well_known: ANY", + }, + TestCase{ + .type_info = {.name = "timestamp"}, + .expected_type_pb = "well_known: TIMESTAMP", + }, + TestCase{ + .type_info = {.name = "google.protobuf.DoubleValue"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TestCase{ + .type_info = {.name = "double_wrapper"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TestCase{ + .type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "duration"}}}, + .expected_type_pb = "type: { well_known: DURATION }", + }, + TestCase{ + .type_info = {.name = "parameterized", + .params = {{.name = "A", .is_type_param = true}, + {.name = "double"}}}, + .expected_type_pb = "abstract_type { name: 'parameterized' " + "parameter_types { type_param: 'A' } " + "parameter_types { primitive: DOUBLE } }", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); + +bool TypeInfoEqImpl(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + if (actual.name != expected.name) return false; + if (actual.is_type_param != expected.is_type_param) return false; + if (actual.params.size() != expected.params.size()) return false; + for (size_t i = 0; i < actual.params.size(); ++i) { + if (!TypeInfoEqImpl(actual.params[i], expected.params[i])) return false; + } + return true; +} + +MATCHER_P(TypeInfoEq, expected, "") { return TypeInfoEqImpl(arg, expected); } + +struct TypeSpecTestCase { + TypeSpec type_spec; + Config::TypeInfo expected_type_info; +}; + +using TypeSpecToTypeInfoTest = testing::TestWithParam; + +TEST_P(TypeSpecToTypeInfoTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config::TypeInfo actual_type_info, + TypeSpecToTypeInfo(param.type_spec)); + EXPECT_THAT(actual_type_info, TypeInfoEq(param.expected_type_info)); +} + +std::vector GetTypeSpecTestCases() { + return { + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveType::kInt64), + .expected_type_info = {.name = "int"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(ListTypeSpec()), + .expected_type_info = {.name = "list"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(MapTypeSpec()), + .expected_type_info = {.name = "map"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto2.TestAllTypes")), + .expected_type_info = + {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + }, + TypeSpecTestCase{ + .type_spec = + TypeSpec(AbstractType("A", {TypeSpec(ParamTypeSpec("B"))})), + .expected_type_info = {.name = "A", + .params = {Config::TypeInfo{ + .name = "B", .is_type_param = true}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kAny), + .expected_type_info = {.name = "any"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_type_info = {.name = "timestamp"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_type_info = {.name = "double_wrapper"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + std::make_unique(WellKnownTypeSpec::kDuration)), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = + "duration"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(std::make_unique(DynTypeSpec())), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(DynTypeSpec{}), + .expected_type_info = {.name = "dyn"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(NullTypeSpec{}), + .expected_type_info = {.name = "null"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(DynTypeSpec()), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "dyn"}, + Config::TypeInfo{.name = "int"}}}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeSpecToTypeInfoTest, TypeSpecToTypeInfoTest, + ValuesIn(GetTypeSpecTestCases())); + +using TypeInfoToTypeSpecTest = testing::TestWithParam; + +TEST_P(TypeInfoToTypeSpecTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(TypeSpec actual_type_spec, + TypeInfoToTypeSpec(param.expected_type_info)); + EXPECT_EQ(actual_type_spec, param.type_spec); +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoToTypeSpecTest, TypeInfoToTypeSpecTest, + ValuesIn(GetTypeSpecTestCases())); + +TEST(TypeSpecToTypeInfoTest, ErrorConversions) { + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(ErrorTypeSpec::kValue)), + StatusIs(absl::StatusCode::kInvalidArgument, + "ErrorType cannot be converted to TypeInfo")); + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(FunctionTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, + "FunctionType cannot be converted to TypeInfo")); + EXPECT_THAT( + TypeSpecToTypeInfo(TypeSpec(UnsetTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, "Unknown TypeSpec kind")); +} + +} // namespace +} // namespace cel diff --git a/eval/README.md b/eval/README.md index ee6fd0798..32fa4bda4 100644 --- a/eval/README.md +++ b/eval/README.md @@ -3,4 +3,4 @@ A C++ implementation of a [Common Expression Language][1] evaluator. -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index ce0a4847b..f7300cb58 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -1,11 +1,91 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +DEFAULT_VISIBILITY = [ + "//eval:__subpackages__", + "//runtime:__subpackages__", + "//extensions:__subpackages__", + "//testing:__subpackages__", +] + # This package contains code # that compiles Expr object into evaluatable CelExpression package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) exports_files(["LICENSE"]) +package_group( + name = "coverage_visibility", + packages = [ + "//tools/...", + ], +) + +cc_library( + name = "flat_expr_builder_extensions", + srcs = ["flat_expr_builder_extensions.cc"], + hdrs = ["flat_expr_builder_extensions.h"], + deps = [ + ":resolver", + "//base:ast", + "//base:data", + "//common:expr", + "//common:native_type", + "//common:value", + "//eval/eval:direct_expression_step", + "//eval/eval:evaluator_core", + "//eval/eval:trace_step", + "//internal:casts", + "//runtime:runtime_options", + "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "flat_expr_builder_extensions_test", + srcs = ["flat_expr_builder_extensions_test.cc"], + deps = [ + ":flat_expr_builder_extensions", + ":resolver", + "//common:expr", + "//common:native_type", + "//common:value", + "//eval/eval:const_value_step", + "//eval/eval:direct_expression_step", + "//eval/eval:evaluator_core", + "//eval/eval:function_step", + "//internal:status_macros", + "//internal:testing", + "//runtime:function_registry", + "//runtime:runtime_issue", + "//runtime:runtime_options", + "//runtime:type_registry", + "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "flat_expr_builder", srcs = [ @@ -15,30 +95,64 @@ cc_library( "flat_expr_builder.h", ], deps = [ - ":constant_folding", - "//base:status_macros", + ":check_ast_extensions", + ":flat_expr_builder_extensions", + ":resolver", + "//base:ast", + "//base:builtins", + "//base:data", + "//common:allocator", + "//common:ast", + "//common:ast_traverse", + "//common:ast_visitor", + "//common:constant", + "//common:expr", + "//common:kind", + "//common:type", + "//common:value", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", "//eval/eval:container_access_step", "//eval/eval:create_list_step", + "//eval/eval:create_map_step", "//eval/eval:create_struct_step", + "//eval/eval:direct_expression_step", + "//eval/eval:equality_steps", "//eval/eval:evaluator_core", - "//eval/eval:expression_build_warning", "//eval/eval:function_step", "//eval/eval:ident_step", "//eval/eval:jump_step", + "//eval/eval:lazy_init_step", "//eval/eval:logic_step", + "//eval/eval:optional_or_step", "//eval/eval:select_step", + "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", - "//eval/public:ast_traverse", - "//eval/public:ast_visitor", - "//eval/public:cel_builtins", - "//eval/public:cel_expression", - "//eval/public:cel_function_registry", + "//eval/eval:trace_step", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_issue", + "//runtime:runtime_options", + "//runtime:type_registry", + "//runtime/internal:convert_constant", + "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -48,22 +162,56 @@ cc_test( "flat_expr_builder_test.cc", ], deps = [ + ":cel_expression_builder_flat_impl", + ":constant_folding", ":flat_expr_builder", - "//base:status_macros", + ":qualified_reference_resolver", + "//base:builtins", + "//common:function_descriptor", + "//common:kind", + "//common:value", + "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", "//eval/public:cel_builtins", + "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", + "//eval/public:cel_function", + "//eval/public:cel_function_adapter", + "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_value", + "//eval/public:portable_cel_function_adapter", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/structs:cel_proto_descriptor_pool_builder", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", + "//internal:proto_matchers", + "//internal:status_macros", + "//internal:testing", + "//parser", + "//parser:options", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:runtime_options", + "//runtime:standard_functions", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:field_mask_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -74,21 +222,99 @@ cc_test( "flat_expr_builder_comprehensions_test.cc", ], deps = [ + ":cel_expression_builder_flat_impl", + ":comprehension_vulnerability_check", ":flat_expr_builder", - "//base:status_macros", + "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", - "//eval/public:cel_builtins", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", - "//eval/public:unknown_set", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "//parser", + "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:field_mask_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "cel_expression_builder_flat_impl", + srcs = [ + "cel_expression_builder_flat_impl.cc", + ], + hdrs = [ + "cel_expression_builder_flat_impl.h", + ], + deps = [ + ":flat_expr_builder", + "//base:ast", + "//common:native_type", + "//eval/eval:cel_expression_flat_impl", + "//eval/eval:direct_expression_step", + "//eval/eval:evaluator_core", + "//eval/public:cel_expression", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//extensions/protobuf:ast_converters", + "//internal:status_macros", + "//runtime:runtime_issue", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_test( + name = "cel_expression_builder_flat_impl_test", + srcs = [ + "cel_expression_builder_flat_impl_test.cc", + ], + deps = [ + ":cel_expression_builder_flat_impl", + ":constant_folding", + ":regex_precompilation_optimization", + "//eval/eval:cel_expression_flat_impl", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expression", + "//eval/public:cel_function", + "//eval/public:cel_value", + "//eval/public:portable_cel_function_adapter", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:protobuf_descriptor_type_provider", + "//eval/public/testing:matchers", + "//extensions:bindings_ext", + "//internal:status_macros", + "//internal:testing", + "//parser", + "//parser:macro", + "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -102,15 +328,24 @@ cc_library( "constant_folding.h", ], deps = [ + ":flat_expr_builder_extensions", + ":resolver", + "//base:builtins", + "//base:data", + "//common:ast", + "//common:constant", + "//common:expr", + "//common:value", "//eval/eval:const_value_step", - "//eval/public:cel_builtins", - "//eval/public:cel_function", - "//eval/public:cel_function_registry", - "//eval/public:cel_value", - "//eval/public/containers:container_backed_list_impl", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "//eval/eval:evaluator_core", + "//internal:status_macros", + "//runtime:activation", + "//runtime/internal:convert_constant", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) @@ -121,13 +356,32 @@ cc_test( ], deps = [ ":constant_folding", - "//base:status_macros", - "//eval/public:builtin_func_registrar", - "//eval/public:cel_function_registry", - "//eval/testutil:test_message_cc_proto", + ":flat_expr_builder_extensions", + ":resolver", + "//base:ast", + "//common:expr", + "//common:value", + "//eval/eval:const_value_step", + "//eval/eval:create_list_step", + "//eval/eval:create_map_step", + "//eval/eval:evaluator_core", + "//extensions/protobuf:ast_converters", + "//internal:status_macros", + "//internal:testing", + "//parser", + "//runtime:function_registry", + "//runtime:runtime_issue", + "//runtime:runtime_options", + "//runtime:type_registry", + "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -141,18 +395,70 @@ cc_library( "qualified_reference_resolver.h", ], deps = [ - "//base:status_macros", - "//eval/eval:const_value_step", - "//eval/eval:expression_build_warning", - "//eval/public:cel_builtins", - "//eval/public:cel_function_registry", + ":flat_expr_builder_extensions", + ":resolver", + "//base:ast", + "//base:builtins", + "//common:ast", + "//common:ast_rewrite", + "//common:expr", + "//common:kind", + "//runtime:runtime_issue", + "//runtime/internal:issue_collector", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "check_ast_extensions", + srcs = ["check_ast_extensions.cc"], + hdrs = ["check_ast_extensions.h"], + deps = [ + "//common:ast", + "//common/ast:metadata", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "check_ast_extensions_test", + srcs = ["check_ast_extensions_test.cc"], + deps = [ + ":check_ast_extensions", + "//common:ast", + "//common:expr", + "//common/ast:metadata", + "//internal:testing", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "resolver", + srcs = ["resolver.cc"], + hdrs = ["resolver.h"], + deps = [ + "//common:kind", + "//common:type", + "//common:value", + "//internal:status_macros", + "//runtime:function_overload_reference", + "//runtime:function_registry", + "//runtime:type_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -163,11 +469,27 @@ cc_test( ], deps = [ ":qualified_reference_resolver", - "//base:status_macros", - "//testutil:util", + ":resolver", + "//base:ast", + "//base:builtins", + "//common:ast", + "//common:expr", + "//common/ast:expr_proto", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_function", + "//eval/public:cel_function_registry", + "//eval/public:cel_value", + "//extensions/protobuf:ast_converters", + "//internal:proto_matchers", + "//internal:testing", + "//runtime:runtime_issue", + "//runtime:type_registry", + "//runtime/internal:issue_collector", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", - "@com_google_absl//absl/types:optional", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -178,18 +500,163 @@ cc_test( "flat_expr_builder_short_circuiting_conformance_test.cc", ], deps = [ - ":flat_expr_builder", - "//base:status_macros", + ":cel_expression_builder_flat_impl", + "//base:builtins", "//eval/public:activation", "//eval/public:cel_attribute", - "//eval/public:cel_builtins", "//eval/public:cel_expression", - "//eval/public:cel_options", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", + "//internal:testing", + "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "resolver_test", + size = "small", + srcs = ["resolver_test.cc"], + deps = [ + ":resolver", + "//common:value", + "//eval/public:cel_function", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//eval/public:cel_value", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "regex_precompilation_optimization", + srcs = ["regex_precompilation_optimization.cc"], + hdrs = ["regex_precompilation_optimization.h"], + deps = [ + ":flat_expr_builder_extensions", + "//base:builtins", + "//common:ast", + "//common:casting", + "//common:expr", + "//common:native_type", + "//common:value", + "//eval/eval:compiler_constant_step", + "//eval/eval:direct_expression_step", + "//eval/eval:evaluator_core", + "//eval/eval:regex_match_step", + "//internal:casts", + "//internal:re2_options", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/types:optional", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "regex_precompilation_optimization_test", + srcs = ["regex_precompilation_optimization_test.cc"], + deps = [ + ":cel_expression_builder_flat_impl", + ":constant_folding", + ":flat_expr_builder", + ":flat_expr_builder_extensions", + ":regex_precompilation_optimization", + ":resolver", + "//common:ast", + "//eval/eval:evaluator_core", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expression", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//eval/public:cel_type_registry", + "//eval/public:cel_value", + "//internal:testing", + "//parser", + "//runtime:runtime_issue", + "//runtime:runtime_options", + "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "comprehension_vulnerability_check", + srcs = ["comprehension_vulnerability_check.cc"], + hdrs = ["comprehension_vulnerability_check.h"], + deps = [ + ":flat_expr_builder_extensions", + "//base:builtins", + "//common:ast", + "//common:constant", + "//common:expr", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "instrumentation", + srcs = ["instrumentation.cc"], + hdrs = ["instrumentation.h"], + deps = [ + ":flat_expr_builder_extensions", + "//common:ast", + "//common:expr", + "//common:value", + "//eval/eval:evaluator_core", + "//eval/eval:expression_step_base", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "instrumentation_test", + srcs = ["instrumentation_test.cc"], + deps = [ + ":constant_folding", + ":flat_expr_builder", + ":instrumentation", + ":regex_precompilation_optimization", + "//common:ast", + "//common:value", + "//eval/eval:evaluator_core", + "//extensions/protobuf:ast_converters", + "//internal:testing", + "//parser", + "//runtime:activation", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime:standard_functions", + "//runtime:type_registry", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/compiler/cel_expression_builder_flat_impl.cc b/eval/compiler/cel_expression_builder_flat_impl.cc new file mode 100644 index 000000000..98ecc6aae --- /dev/null +++ b/eval/compiler/cel_expression_builder_flat_impl.cc @@ -0,0 +1,111 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "eval/compiler/cel_expression_builder_flat_impl.h" + +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/base/macros.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/ast.h" +#include "common/native_type.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/cel_expression.h" +#include "extensions/protobuf/ast_converters.h" +#include "internal/status_macros.h" +#include "runtime/runtime_issue.h" + +namespace google::api::expr::runtime { + +using ::cel::Ast; +using ::cel::RuntimeIssue; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; // NOLINT: adjusted in OSS +using ::cel::expr::SourceInfo; + +absl::StatusOr> +CelExpressionBuilderFlatImpl::CreateExpression( + const Expr* expr, const SourceInfo* source_info, + std::vector* warnings) const { + ABSL_ASSERT(expr != nullptr); + CEL_ASSIGN_OR_RETURN( + std::unique_ptr converted_ast, + cel::extensions::CreateAstFromParsedExpr(*expr, source_info)); + return CreateExpressionImpl(std::move(converted_ast), warnings); +} + +absl::StatusOr> +CelExpressionBuilderFlatImpl::CreateExpression( + const Expr* expr, const SourceInfo* source_info) const { + return CreateExpression(expr, source_info, + /*warnings=*/nullptr); +} + +absl::StatusOr> +CelExpressionBuilderFlatImpl::CreateExpression( + const CheckedExpr* checked_expr, + std::vector* warnings) const { + ABSL_ASSERT(checked_expr != nullptr); + CEL_ASSIGN_OR_RETURN( + std::unique_ptr converted_ast, + cel::extensions::CreateAstFromCheckedExpr(*checked_expr)); + + return CreateExpressionImpl(std::move(converted_ast), warnings); +} + +absl::StatusOr> +CelExpressionBuilderFlatImpl::CreateExpression( + const CheckedExpr* checked_expr) const { + return CreateExpression(checked_expr, /*warnings=*/nullptr); +} + +absl::StatusOr> +CelExpressionBuilderFlatImpl::CreateExpressionImpl( + std::unique_ptr converted_ast, + std::vector* warnings) const { + std::vector issues; + auto* issues_ptr = (warnings != nullptr) ? &issues : nullptr; + + CEL_ASSIGN_OR_RETURN(FlatExpression impl, + flat_expr_builder_.CreateExpressionImpl( + std::move(converted_ast), issues_ptr)); + + if (issues_ptr != nullptr) { + for (const auto& issue : issues) { + warnings->push_back(issue.ToStatus()); + } + } + if (flat_expr_builder_.options().max_recursion_depth != 0 && + !impl.subexpressions().empty() && + // mainline expression is exactly one recursive step. + impl.subexpressions().front().size() == 1 && + impl.subexpressions().front().front()->GetNativeTypeId() == + cel::NativeTypeId::For()) { + return CelExpressionRecursiveImpl::Create(env_, std::move(impl)); + } + + return std::make_unique(env_, std::move(impl)); +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/cel_expression_builder_flat_impl.h b/eval/compiler/cel_expression_builder_flat_impl.h new file mode 100644 index 000000000..6f47f4ec3 --- /dev/null +++ b/eval/compiler/cel_expression_builder_flat_impl.h @@ -0,0 +1,108 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CEL_EXPRESSION_BUILDER_FLAT_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CEL_EXPRESSION_BUILDER_FLAT_IMPL_H_ + +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/ast.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/runtime_options.h" + +namespace google::api::expr::runtime { + +// CelExpressionBuilder implementation. +// Builds instances of CelExpressionFlatImpl. +class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { + public: + CelExpressionBuilderFlatImpl( + absl_nonnull std::shared_ptr env, + const cel::RuntimeOptions& options) + : env_(std::move(env)), + flat_expr_builder_(env_, options, /*use_legacy_type_provider=*/true) { + ABSL_DCHECK(env_->IsInitialized()); + } + + explicit CelExpressionBuilderFlatImpl( + absl_nonnull std::shared_ptr env) + : CelExpressionBuilderFlatImpl(std::move(env), cel::RuntimeOptions()) {} + + absl::StatusOr> CreateExpression( + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info) const override; + + absl::StatusOr> CreateExpression( + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, + std::vector* warnings) const override; + + absl::StatusOr> CreateExpression( + const cel::expr::CheckedExpr* checked_expr) const override; + + absl::StatusOr> CreateExpression( + const cel::expr::CheckedExpr* checked_expr, + std::vector* warnings) const override; + + FlatExprBuilder& flat_expr_builder() { return flat_expr_builder_; } + + void set_container(std::string container) override { + flat_expr_builder_.set_container(std::move(container)); + } + + // CelFunction registry. Extension function should be registered with it + // prior to expression creation. + CelFunctionRegistry* GetRegistry() const override { + return &env_->legacy_function_registry; + } + + // CEL Type registry. Provides a means to resolve the CEL built-in types to + // CelValue instances, and to extend the set of types and enums known to + // expressions by registering them ahead of time. + CelTypeRegistry* GetTypeRegistry() const override { + return &env_->legacy_type_registry; + } + + absl::string_view container() const override { + return flat_expr_builder_.container(); + } + + private: + absl::StatusOr> CreateExpressionImpl( + std::unique_ptr converted_ast, + std::vector* warnings) const; + + absl_nonnull std::shared_ptr env_; + FlatExprBuilder flat_expr_builder_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CEL_EXPRESSION_BUILDER_FLAT_IMPL_H_ diff --git a/eval/compiler/cel_expression_builder_flat_impl_test.cc b/eval/compiler/cel_expression_builder_flat_impl_test.cc new file mode 100644 index 000000000..9802d2a05 --- /dev/null +++ b/eval/compiler/cel_expression_builder_flat_impl_test.cc @@ -0,0 +1,657 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Smoke tests for CelExpressionBuilderFlatImpl. This class is a thin wrapper +// over FlatExprBuilder, so most of the tests are just covering the conversion +// code from the legacy APIs to the implementation. See +// flat_expr_builder_test.cc for additional tests. +#include "eval/compiler/cel_expression_builder_flat_impl.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "eval/compiler/constant_folding.h" +#include "eval/compiler/regex_precompilation_optimization.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/portable_cel_function_adapter.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" +#include "extensions/bindings_ext.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/macro.h" +#include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto3::NestedTestAllTypes; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::google::api::expr::parser::Macro; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::parser::ParseWithMacros; +using ::testing::_; +using ::testing::Contains; +using ::testing::HasSubstr; +using ::testing::IsNull; +using ::testing::NotNull; + +TEST(CelExpressionBuilderFlatImplTest, Error) { + Expr expr; + SourceInfo source_info; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid empty expression"))); +} + +TEST(CelExpressionBuilderFlatImplTest, ParsedExpr) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test::IsCelInt64(3)); +} + +struct RecursiveTestCase { + std::string test_name; + std::string expr; + test::CelValueMatcher matcher; + std::string pb_expr; +}; + +class RecursivePlanTest : public ::testing::TestWithParam { + protected: + absl::Status SetupBuilder(CelExpressionBuilderFlatImpl& builder) { + builder.GetTypeRegistry()->RegisterEnum("TestEnum", + {{"FOO", 1}, {"BAR", 2}}); + + CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder.GetRegistry())); + return builder.GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( + "LazilyBoundMult", false, + {CelValue::Type::kInt64, CelValue::Type::kInt64})); + } + + absl::Status SetupActivation(Activation& activation, google::protobuf::Arena* arena) { + activation.InsertValue("int_1", CelValue::CreateInt64(1)); + activation.InsertValue("string_abc", CelValue::CreateStringView("abc")); + activation.InsertValue("string_def", CelValue::CreateStringView("def")); + auto* map = google::protobuf::Arena::Create(arena); + CEL_RETURN_IF_ERROR( + map->Add(CelValue::CreateStringView("a"), CelValue::CreateInt64(1))); + CEL_RETURN_IF_ERROR( + map->Add(CelValue::CreateStringView("b"), CelValue::CreateInt64(2))); + activation.InsertValue("map_var", CelValue::CreateMap(map)); + auto* msg = google::protobuf::Arena::Create(arena); + msg->mutable_child()->mutable_payload()->set_single_int64(42); + activation.InsertValue("struct_var", + CelProtoWrapper::CreateMessage(msg, arena)); + activation.InsertValue("TestEnum.BAR", CelValue::CreateInt64(-1)); + + CEL_RETURN_IF_ERROR(activation.InsertFunction( + PortableBinaryFunctionAdapter::Create( + "LazilyBoundMult", false, + [](google::protobuf::Arena*, int64_t lhs, int64_t rhs) -> int64_t { + return lhs * rhs; + }))); + + return absl::OkStatus(); + } +}; + +absl::StatusOr ParseTestCase(const RecursiveTestCase& test_case) { + static const std::vector* kMacros = []() { + auto* result = new std::vector(Macro::AllMacros()); + absl::c_copy(cel::extensions::bindings_macros(), + std::back_inserter(*result)); + return result; + }(); + + if (!test_case.expr.empty()) { + return ParseWithMacros(test_case.expr, *kMacros, ""); + } else if (!test_case.pb_expr.empty()) { + ParsedExpr result; + if (!google::protobuf::TextFormat::ParseFromString(test_case.pb_expr, &result)) { + return absl::InvalidArgumentError("Failed to parse proto"); + } + return result; + } + return absl::InvalidArgumentError("No expression provided"); +} + +TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) { + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); + cel::RuntimeOptions options; + options.container = "cel.expr.conformance.proto3"; + google::protobuf::Arena arena; + // Unbounded. + options.max_recursion_depth = -1; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK(SetupBuilder(builder)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + NotNull()); + + Activation activation; + + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test_case.matcher); +} + +TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) { + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); + cel::RuntimeOptions options; + options.container = "cel.expr.conformance.proto3"; + google::protobuf::Arena arena; + // Unbounded. + options.max_recursion_depth = -1; + options.enable_comprehension_list_append = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK(SetupBuilder(builder)); + + builder.flat_expr_builder().AddProgramOptimizer( + cel::runtime_internal::CreateConstantFoldingOptimizer()); + builder.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + NotNull()); + + Activation activation; + + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test_case.matcher); +} + +TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) { + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); + cel::RuntimeOptions options; + options.container = "cel.expr.conformance.proto3"; + google::protobuf::Arena arena; + auto cb = [](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { + return absl::OkStatus(); + }; + // Unbounded. + options.max_recursion_depth = -1; + options.enable_recursive_tracing = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK(SetupBuilder(builder)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + NotNull()); + + Activation activation; + + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Trace(activation, &arena, cb)); + EXPECT_THAT(result, test_case.matcher); +} + +TEST_P(RecursivePlanTest, Disabled) { + google::protobuf::LinkMessageReflection(); + + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); + cel::RuntimeOptions options; + options.container = "cel.expr.conformance.proto3"; + google::protobuf::Arena arena; + // disabled. + options.max_recursion_depth = 0; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK(SetupBuilder(builder)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + IsNull()); + + Activation activation; + + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test_case.matcher); +} + +INSTANTIATE_TEST_SUITE_P( + RecursivePlanTest, RecursivePlanTest, + testing::ValuesIn(std::vector{ + {"constant", "'abc'", test::IsCelString("abc")}, + {"call", "1 + 2", test::IsCelInt64(3)}, + {"nested_call", "1 + 1 + 1 + 1", test::IsCelInt64(4)}, + {"and", "true && false", test::IsCelBool(false)}, + {"or", "true || false", test::IsCelBool(true)}, + {"ternary", "(true || false) ? 2 + 2 : 3 + 3", test::IsCelInt64(4)}, + {"create_list", "3 in [1, 2, 3]", test::IsCelBool(true)}, + {"create_list_complex", "3 in [2 / 2, 4 / 2, 6 / 2]", + test::IsCelBool(true)}, + {"ident", "int_1 == 1", test::IsCelBool(true)}, + {"ident_complex", "int_1 + 2 > 4 ? string_abc : string_def", + test::IsCelString("def")}, + {"select", "struct_var.child.payload.single_int64", + test::IsCelInt64(42)}, + {"nested_select", "[map_var.a, map_var.b].size() == 2", + test::IsCelBool(true)}, + {"map_index", "map_var['b']", test::IsCelInt64(2)}, + {"list_index", "[1, 2, 3][1]", test::IsCelInt64(2)}, + {"compre_exists", "[1, 2, 3, 4].exists(x, x == 3)", + test::IsCelBool(true)}, + {"compre_map", "8 in [1, 2, 3, 4].map(x, x * 2)", + test::IsCelBool(true)}, + {"map_var_compre_exists", "map_var.exists(key, key == 'b')", + test::IsCelBool(true)}, + {"map_compre_exists", "{'a': 1, 'b': 2}.exists(k, k == 'b')", + test::IsCelBool(true)}, + {"create_map", "{'a': 42, 'b': 0, 'c': 0}.size()", test::IsCelInt64(3)}, + {"create_struct", + "NestedTestAllTypes{payload: TestAllTypes{single_int64: " + "-42}}.payload.single_int64", + test::IsCelInt64(-42)}, + {"bind", R"(cel.bind(x, "1", x + x + x + x))", + test::IsCelString("1111")}, + {"nested_bind", R"(cel.bind(x, 20, cel.bind(y, 30, x + y)))", + test::IsCelInt64(50)}, + {"bind_with_comprehensions", + R"(cel.bind(x, [1, 2], cel.bind(y, x.map(z, z * 2), y.exists(z, z == 4))))", + test::IsCelBool(true)}, + {"shadowable_value_default", R"(TestEnum.FOO == 1)", + test::IsCelBool(true)}, + {"shadowable_value_shadowed", R"(TestEnum.BAR == -1)", + test::IsCelBool(true)}, + {"lazily_resolved_function", "LazilyBoundMult(123, 2) == 246", + test::IsCelBool(true)}, + {"re_matches", "matches(string_abc, '[ad][be][cf]')", + test::IsCelBool(true)}, + {"re_matches_receiver", + "(string_abc + string_def).matches(r'(123)?' + r'abc' + r'def')", + test::IsCelBool(true)}, + {"block", "", test::IsCelBool(true), + R"pb( + expr { + id: 1 + call_expr { + function: "cel.@block" + args { + id: 2 + list_expr { + elements { const_expr { int64_value: 8 } } + elements { const_expr { int64_value: 10 } } + } + } + args { + id: 3 + call_expr { + function: "_<_" + args { ident_expr { name: "@index0" } } + args { ident_expr { name: "@index1" } } + } + } + } + })pb"}, + {"block_with_comprehensions", "", test::IsCelBool(true), + // Something like: + // variables: + // - users: {'bob': ['bar'], 'alice': ['foo', 'bar']} + // - somone_has_bar: users.exists(u, 'bar' in users[u]) + // policy: + // - someone_has_bar && !users.exists(u, u == 'eve')) + // + R"pb( + expr { + call_expr { + function: "cel.@block" + args { + list_expr { + elements { + struct_expr: { + entries: { + map_key: { const_expr: { string_value: "bob" } } + value: { + list_expr: { + elements: { const_expr: { string_value: "bar" } } + } + } + } + entries: { + map_key: { const_expr: { string_value: "alice" } } + value: { + list_expr: { + elements: { const_expr: { string_value: "bar" } } + elements: { const_expr: { string_value: "foo" } } + } + } + } + } + } + elements { + id: 16 + comprehension_expr: { + iter_var: "u" + iter_range: { + id: 1 + ident_expr: { name: "@index0" } + } + accu_var: "__result__" + accu_init: { + id: 9 + const_expr: { bool_value: false } + } + loop_condition: { + id: 12 + call_expr: { + function: "@not_strictly_false" + args: { + id: 11 + call_expr: { + function: "!_" + args: { + id: 10 + ident_expr: { name: "__result__" } + } + } + } + } + } + loop_step: { + id: 14 + call_expr: { + function: "_||_" + args: { + id: 13 + ident_expr: { name: "__result__" } + } + args: { + id: 5 + call_expr: { + function: "@in" + args: { + id: 4 + const_expr: { string_value: "bar" } + } + args: { + id: 7 + call_expr: { + function: "_[_]" + args: { + id: 6 + ident_expr: { name: "@index0" } + } + args: { + id: 8 + ident_expr: { name: "u" } + } + } + } + } + } + } + } + result: { + id: 15 + ident_expr: { name: "__result__" } + } + } + } + } + } + args { + id: 17 + call_expr: { + function: "_&&_" + args: { + id: 1 + ident_expr: { name: "@index1" } + } + args: { + id: 2 + call_expr: { + function: "!_" + args: { + id: 16 + comprehension_expr: { + iter_var: "u" + iter_range: { + id: 3 + ident_expr: { name: "@index0" } + } + accu_var: "__result__" + accu_init: { + id: 9 + const_expr: { bool_value: false } + } + loop_condition: { + id: 12 + call_expr: { + function: "@not_strictly_false" + args: { + id: 11 + call_expr: { + function: "!_" + args: { + id: 10 + ident_expr: { name: "__result__" } + } + } + } + } + } + loop_step: { + id: 14 + call_expr: { + function: "_||_" + args: { + id: 13 + ident_expr: { name: "__result__" } + } + args: { + id: 7 + call_expr: { + function: "_==_" + args: { + id: 6 + ident_expr: { name: "u" } + } + args: { + id: 8 + const_expr: { string_value: "eve" } + } + } + } + } + } + result: { + id: 15 + ident_expr: { name: "__result__" } + } + } + } + } + } + } + } + } + })pb"}}), + + [](const testing::TestParamInfo& info) -> std::string { + return info.param.test_name; + }); + +TEST(CelExpressionBuilderFlatImplTest, ParsedExprWithWarnings) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); + cel::RuntimeOptions options; + options.fail_on_warnings = false; + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + std::vector warnings; + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info(), + &warnings)); + + EXPECT_THAT(warnings, Contains(StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("No overloads")))); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test::IsCelError( + StatusIs(_, HasSubstr("No matching overloads")))); +} + +TEST(CelExpressionBuilderFlatImplTest, EmptyLegacyTypeViewUnsupported) { + // Creating type values directly (instead of using the builtin functions and + // identifiers from the type registry) is not recommended for CEL users. The + // name is expected to be non-empty. + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("x")); + cel::RuntimeOptions options; + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + activation.InsertValue("x", CelValue::CreateCelTypeView("")); + google::protobuf::Arena arena; + ASSERT_THAT(plan->Evaluate(activation, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CelExpressionBuilderFlatImplTest, LegacyTypeViewSupported) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("x")); + cel::RuntimeOptions options; + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + activation.InsertValue("x", CelValue::CreateCelTypeView("MyType")); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsCelType()); + EXPECT_EQ(result.CelTypeOrDie().value(), "MyType"); +} + +TEST(CelExpressionBuilderFlatImplTest, CheckedExpr) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); + CheckedExpr checked_expr; + checked_expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + checked_expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&checked_expr)); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test::IsCelInt64(3)); +} + +TEST(CelExpressionBuilderFlatImplTest, CheckedExprWithWarnings) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); + CheckedExpr checked_expr; + checked_expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + checked_expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + cel::RuntimeOptions options; + options.fail_on_warnings = false; + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + std::vector warnings; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&checked_expr, &warnings)); + + EXPECT_THAT(warnings, Contains(StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("No overloads")))); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test::IsCelError( + StatusIs(_, HasSubstr("No matching overloads")))); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/check_ast_extensions.cc b/eval/compiler/check_ast_extensions.cc new file mode 100644 index 000000000..37181b535 --- /dev/null +++ b/eval/compiler/check_ast_extensions.cc @@ -0,0 +1,58 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast) { + std::vector runtime_extensions; + absl::flat_hash_set seen_extension_ids; + + for (const cel::ExtensionSpec& extension : ast.source_info().extensions()) { + bool is_runtime = false; + for (const cel::ExtensionSpec::Component& component : + extension.affected_components()) { + if (component == cel::ExtensionSpec::Component::kRuntime) { + is_runtime = true; + break; + } + } + + if (!is_runtime) { + continue; + } + + if (!seen_extension_ids.insert(extension.id()).second) { + return absl::InvalidArgumentError( + absl::StrCat("duplicate extension ID: ", extension.id())); + } + runtime_extensions.push_back(extension); + } + + return runtime_extensions; +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/check_ast_extensions.h b/eval/compiler/check_ast_extensions.h new file mode 100644 index 000000000..443c6ac09 --- /dev/null +++ b/eval/compiler/check_ast_extensions.h @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ + +#include + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +// Extracts and validates extension tags from the AST `ast` that affect the +// runtime component. Returns the validated list of runtime extensions, or an +// error if there are multiple runtime extensions with the same ID. +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ diff --git a/eval/compiler/check_ast_extensions_test.cc b/eval/compiler/check_ast_extensions_test.cc new file mode 100644 index 000000000..9e5838905 --- /dev/null +++ b/eval/compiler/check_ast_extensions_test.cc @@ -0,0 +1,110 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "common/ast.h" +#include "common/ast/metadata.h" +#include "common/expr.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::Ast; +using ::cel::Expr; +using ::cel::ExtensionSpec; +using ::cel::SourceInfo; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; +using ::testing::SizeIs; + +TEST(ExtractAndValidateRuntimeExtensionsTest, EmptyExtensions) { + Ast ast(Expr{}, SourceInfo{}); + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FiltersNonRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext2", nullptr, {ExtensionSpec::Component::kTypeChecker})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, ExtractsRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext2", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext3", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")), + Property(&ExtensionSpec::id, Eq("ext2")))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FailsOnDuplicateRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext1", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + StatusIs(absl::StatusCode::kInvalidArgument, + "duplicate extension ID: ext1")); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, IgnoresDuplicateNonRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/comprehension_vulnerability_check.cc b/eval/compiler/comprehension_vulnerability_check.cc new file mode 100644 index 000000000..ca3905024 --- /dev/null +++ b/eval/compiler/comprehension_vulnerability_check.cc @@ -0,0 +1,275 @@ +// +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/compiler/comprehension_vulnerability_check.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "base/builtins.h" +#include "common/ast.h" +#include "common/constant.h" +#include "common/expr.h" +#include "eval/compiler/flat_expr_builder_extensions.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::CallExpr; +using ::cel::ComprehensionExpr; +using ::cel::Constant; +using ::cel::Expr; +using ::cel::IdentExpr; +using ::cel::ListExpr; +using ::cel::MapExpr; +using ::cel::SelectExpr; +using ::cel::StructExpr; +using ::cel::UnspecifiedExpr; + +// ComprehensionAccumulationReferences recursively walks an expression to count +// the locations where the given accumulation var_name is referenced. +// +// The purpose of this function is to detect cases where the accumulation +// variable might be used in hand-rolled ASTs that cause exponential memory +// consumption. The var_name is generally not accessible by CEL expression +// writers, only by macro authors. However, a hand-rolled AST makes it possible +// to misuse the accumulation variable. +// +// Limitations: +// - This check only covers standard operators and functions. +// Extension functions may cause the same issue if they allocate an amount of +// memory that is dependent on the size of the inputs. +// +// - This check is not exhaustive. There may be ways to construct an AST to +// trigger exponential memory growth not captured by this check. +// +// The algorithm for reference counting is as follows: +// +// * Calls - If the call is a concatenation operator, sum the number of places +// where the variable appears within the call, as this could result +// in memory explosion if the accumulation variable type is a list +// or string. Otherwise, return 0. +// +// accu: ["hello"] +// expr: accu + accu // memory grows exponentionally +// +// * CreateList - If the accumulation var_name appears within multiple elements +// of a CreateList call, this means that the accumulation is +// generating an ever-expanding tree of values that will likely +// exhaust memory. +// +// accu: ["hello"] +// expr: [accu, accu] // memory grows exponentially +// +// * CreateStruct - If the accumulation var_name as an entry within the +// creation of a map or message value, then it's possible that the +// comprehension is accumulating an ever-expanding tree of values. +// +// accu: {"key": "val"} +// expr: {1: accu, 2: accu} +// +// * Comprehension - If the accumulation var_name is not shadowed by a nested +// iter_var or accu_var, then it may be accmulating memory within a +// nested context. The accumulation may occur on either the +// comprehension loop_step or result step. +// +// Since this behavior generally only occurs within hand-rolled ASTs, it is +// very reasonable to opt-in to this check only when using human authored ASTs. +int ComprehensionAccumulationReferences(const cel::Expr& expr, + absl::string_view var_name) { + struct Handler { + const Expr& expr; + absl::string_view var_name; + + int operator()(const CallExpr& call) { + int references = 0; + absl::string_view function = call.function(); + // Return the maximum reference count of each side of the ternary branch. + if (function == cel::builtin::kTernary && call.args().size() == 3) { + return std::max( + ComprehensionAccumulationReferences(call.args()[1], var_name), + ComprehensionAccumulationReferences(call.args()[2], var_name)); + } + // Return the number of times the accumulator var_name appears in the add + // expression. There's no arg size check on the add as it may become a + // variadic add at a future date. + if (function == cel::builtin::kAdd) { + for (int i = 0; i < call.args().size(); i++) { + references += + ComprehensionAccumulationReferences(call.args()[i], var_name); + } + + return references; + } + // Return whether the accumulator var_name is used as the operand in an + // index expression or in the identity `dyn` function. + if ((function == cel::builtin::kIndex && call.args().size() == 2) || + (function == cel::builtin::kDyn && call.args().size() == 1)) { + return ComprehensionAccumulationReferences(call.args()[0], var_name); + } + return 0; + } + int operator()(const ComprehensionExpr& comprehension) { + absl::string_view accu_var = comprehension.accu_var(); + absl::string_view iter_var = comprehension.iter_var(); + + int result_references = 0; + int loop_step_references = 0; + int sum_of_accumulator_references = 0; + + // The accumulation or iteration variable shadows the var_name and so will + // not manipulate the target var_name in a nested comprehension scope. + if (accu_var != var_name && iter_var != var_name) { + loop_step_references = ComprehensionAccumulationReferences( + comprehension.loop_step(), var_name); + } + + // Accumulator variable (but not necessarily iter var) can shadow an + // outer accumulator variable in the result sub-expression. + if (accu_var != var_name) { + result_references = ComprehensionAccumulationReferences( + comprehension.result(), var_name); + } + + // Count the raw number of times the accumulator variable was referenced. + // This is to account for cases where the outer accumulator is shadowed by + // the inner accumulator, while the inner accumulator is being used as the + // iterable range. + // + // An equivalent expression to this problem: + // + // outer_accu := outer_accu + // for y in outer_accu: + // outer_accu += input + // return outer_accu + + // If this is overly restrictive (Ex: when generalized reducers is + // implemented), we may need to revisit this solution + + sum_of_accumulator_references = ComprehensionAccumulationReferences( + comprehension.accu_init(), var_name); + + sum_of_accumulator_references += ComprehensionAccumulationReferences( + comprehension.iter_range(), var_name); + + // Count the number of times the accumulator var_name within the loop_step + // or the nested comprehension result. + // + // This doesn't cover cases where the inner accumulator accumulates the + // outer accumulator then is returned in the inner comprehension result. + return std::max({loop_step_references, result_references, + sum_of_accumulator_references}); + } + + int operator()(const ListExpr& list) { + // Count the number of times the accumulator var_name appears within a + // create list expression's elements. + int references = 0; + for (int i = 0; i < list.elements().size(); i++) { + references += ComprehensionAccumulationReferences( + list.elements()[i].expr(), var_name); + } + return references; + } + + int operator()(const StructExpr& map) { + // Count the number of times the accumulation variable occurs within + // entry values. + int references = 0; + for (int i = 0; i < map.fields().size(); i++) { + const auto& entry = map.fields()[i]; + if (entry.has_value()) { + references += + ComprehensionAccumulationReferences(entry.value(), var_name); + } + } + return references; + } + + int operator()(const MapExpr& map) { + // Count the number of times the accumulation variable occurs within + // entry values. + int references = 0; + for (int i = 0; i < map.entries().size(); i++) { + const auto& entry = map.entries()[i]; + if (entry.has_value()) { + references += + ComprehensionAccumulationReferences(entry.value(), var_name); + } + } + return references; + } + + int operator()(const SelectExpr& select) { + // Test only expressions have a boolean return and thus cannot easily + // allocate large amounts of memory. + if (select.test_only()) { + return 0; + } + // Return whether the accumulator var_name appears within a non-test + // select operand. + return ComprehensionAccumulationReferences(select.operand(), var_name); + } + + int operator()(const IdentExpr& ident) { + // Return whether the identifier name equals the accumulator var_name. + return ident.name() == var_name ? 1 : 0; + } + + int operator()(const Constant& constant) { return 0; } + + int operator()(const UnspecifiedExpr&) { return 0; } + } handler{expr, var_name}; + return absl::visit(handler, expr.kind()); +} + +bool ComprehensionHasMemoryExhaustionVulnerability( + const ComprehensionExpr& comprehension) { + absl::string_view accu_var = comprehension.accu_var(); + const auto& loop_step = comprehension.loop_step(); + return ComprehensionAccumulationReferences(loop_step, accu_var) >= 2; +} + +class ComprehensionVulnerabilityCheck : public ProgramOptimizer { + public: + absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { + if (node.has_comprehension_expr() && + ComprehensionHasMemoryExhaustionVulnerability( + node.comprehension_expr())) { + return absl::InvalidArgumentError( + "Comprehension contains memory exhaustion vulnerability"); + } + return absl::OkStatus(); + } + + absl::Status OnPostVisit(PlannerContext& context, + const cel::Expr& node) override { + return absl::OkStatus(); + } +}; + +} // namespace + +ProgramOptimizerFactory CreateComprehensionVulnerabilityCheck() { + return [](PlannerContext&, const cel::Ast& ast) { + return std::make_unique(); + }; +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/comprehension_vulnerability_check.h b/eval/compiler/comprehension_vulnerability_check.h new file mode 100644 index 000000000..5dd6615ac --- /dev/null +++ b/eval/compiler/comprehension_vulnerability_check.h @@ -0,0 +1,51 @@ +// +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_COMPREHENSION_VULNERABILITY_CHECK_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_COMPREHENSION_VULNERABILITY_CHECK_H_ + +#include "eval/compiler/flat_expr_builder_extensions.h" + +namespace google::api::expr::runtime { + +// Create a program optimizer that checks for memory consumption vulnerability +// in comprehensions. +// +// Hand-rolled ASTs or custom Macro implementations can reference the implicit +// accumulator variable in comprehensions to generate objects exponential in the +// size of the inputs. Type checked expressions using the built-in macros and +// functions are not susceptible to this. +// +// This check is not exhaustive, but will catch most accidental triggers of +// this behavior in the standard env. It does not consider custom extension +// functions. +// +// This implementation recursively traverses the AST, so it is not safe for +// deeply nested ASTs or in environments with smaller stack limits. +// +// conceptual example with a generalized reducer macro: +// [1, 2, 3, 4] +// .reduce( +// /*iter_var=*/ unused, +// /*accu_var=*/ accu, +// /*accu_init=*/ [1], +// /*loop_step=*/ accu + accu, +// /*result=*/ accu) +// resulting list sizes per iteration: 2, 4, 8, 16. +ProgramOptimizerFactory CreateComprehensionVulnerabilityCheck(); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_COMPREHENSION_VULNERABILITY_CHECK_H_ diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index c93c8a750..118fc94c5 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -1,230 +1,280 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/compiler/constant_folding.h" -#include "absl/strings/str_cat.h" +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/builtins.h" +#include "base/type_provider.h" +#include "common/ast.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" #include "eval/eval/const_value_step.h" -#include "eval/public/cel_builtins.h" -#include "eval/public/cel_function_registry.h" -#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/eval/evaluator_core.h" +#include "internal/status_macros.h" +#include "runtime/activation.h" +#include "runtime/internal/convert_constant.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace cel::runtime_internal { namespace { -using google::api::expr::v1alpha1::Expr; +using ::cel::Expr; +using ::cel::builtin::kAnd; +using ::cel::builtin::kOr; +using ::cel::builtin::kTernary; +using ::cel::runtime_internal::ConvertConstant; +using ::google::api::expr::runtime::CreateConstValueDirectStep; +using ::google::api::expr::runtime::CreateConstValueStep; +using ::google::api::expr::runtime::EvaluationListener; +using ::google::api::expr::runtime::ExecutionFrame; +using ::google::api::expr::runtime::ExecutionPath; +using ::google::api::expr::runtime::ExecutionPathView; +using ::google::api::expr::runtime::FlatExpressionEvaluatorState; +using ::google::api::expr::runtime::PlannerContext; +using ::google::api::expr::runtime::ProgramOptimizer; +using ::google::api::expr::runtime::ProgramOptimizerFactory; +using ::google::api::expr::runtime::Resolver; + +enum class IsConst { + kConditional, + kNonConst, +}; -class ConstantFoldingTransform { +class ConstantFoldingExtension : public ProgramOptimizer { public: - ConstantFoldingTransform( - const CelFunctionRegistry& registry, google::protobuf::Arena* arena, - absl::flat_hash_map& constant_idents) - : registry_(registry), - arena_(arena), - constant_idents_(constant_idents), - counter_(0) {} - - // Copies the expression by pulling out constant sub-expressions into - // CelValue idents. Returns true if the expression is a constant. - bool Transform(const Expr& expr, Expr* out) { - out->set_id(expr.id()); - switch (expr.expr_kind_case()) { - case Expr::kConstExpr: { - // create a constant that references the input expression data - // since the output expression is temporary - auto value = ConvertConstant(&expr.const_expr()); - if (value.has_value()) { - makeConstant(value.value(), out); - return true; - } else { - out->mutable_const_expr()->MergeFrom(expr.const_expr()); - return false; - } - } - case Expr::kIdentExpr: - out->mutable_ident_expr()->set_name(expr.ident_expr().name()); - return false; - case Expr::kSelectExpr: { - auto select_expr = out->mutable_select_expr(); - Transform(expr.select_expr().operand(), select_expr->mutable_operand()); - select_expr->set_field(expr.select_expr().field()); - select_expr->set_test_only(expr.select_expr().test_only()); - return false; - } - case Expr::kCallExpr: { - auto call_expr = out->mutable_call_expr(); - const bool receiver_style = expr.call_expr().has_target(); - const int arg_num = expr.call_expr().args_size(); - bool all_constant = true; - if (receiver_style) { - all_constant = Transform(expr.call_expr().target(), - call_expr->mutable_target()) && - all_constant; - } - call_expr->set_function(expr.call_expr().function()); - for (int i = 0; i < arg_num; i++) { - all_constant = - Transform(expr.call_expr().args(i), call_expr->add_args()) && - all_constant; - } - // short-circuiting affects evaluation of logic combinators, so we do - // not fold them here - if (!all_constant || call_expr->function() == builtin::kAnd || - call_expr->function() == builtin::kOr || - call_expr->function() == builtin::kTernary) { - return false; - } + ConstantFoldingExtension( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl_nullable std::shared_ptr shared_arena, + google::protobuf::Arena* absl_nonnull arena, + absl_nullable std::shared_ptr + shared_message_factory, + google::protobuf::MessageFactory* absl_nonnull message_factory, + const TypeProvider& type_provider) + : shared_arena_(std::move(shared_arena)), + shared_message_factory_(std::move(shared_message_factory)), + state_(kDefaultStackLimit, kComprehensionSlotCount, type_provider, + descriptor_pool, message_factory, arena) {} - // compute argument list - const int arg_size = arg_num + (receiver_style ? 1 : 0); - std::vector arg_types(arg_size, CelValue::Type::kAny); - auto overloads = registry_.FindOverloads(call_expr->function(), - receiver_style, arg_types); + absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, + const Expr& node) override; + absl::Status OnPostVisit(google::api::expr::runtime::PlannerContext& context, + const Expr& node) override; - // do not proceed if there are no overloads registered - if (overloads.empty()) { - return false; - } + private: + // Most constant folding evaluations are simple + // binary operators. + static constexpr size_t kDefaultStackLimit = 4; - std::vector arg_values; - arg_values.reserve(arg_size); - if (receiver_style) { - arg_values.push_back(removeConstant(call_expr->target())); - } - for (int i = 0; i < arg_num; i++) { - arg_values.push_back(removeConstant(call_expr->args(i))); - } + // Comprehensions are not evaluated -- the current implementation can't detect + // if the comprehension variables are only used in a const way. + static constexpr size_t kComprehensionSlotCount = 0; - // compute function overload - // consider consolidating the logic with FunctionStep - const CelFunction* matched_function = nullptr; - for (auto overload : overloads) { - if (overload->MatchArguments(arg_values)) { - matched_function = overload; - } - } - if (matched_function == nullptr) { - // propagate argument errors up the expression - for (const CelValue& arg : arg_values) { - if (arg.IsError()) { - makeConstant(arg, out); - return true; - } - } - makeConstant( - CreateNoMatchingOverloadError(arena_, call_expr->function()), - out); - return true; - } - CelValue result; - auto status = matched_function->Evaluate(arg_values, &result, arena_); - if (status.ok()) { - makeConstant(result, out); - } else { - makeConstant( - CreateErrorValue(arena_, status.message(), status.code()), out); - } - return true; + absl_nullable std::shared_ptr shared_arena_; + ABSL_ATTRIBUTE_UNUSED + absl_nullable std::shared_ptr shared_message_factory_; + Activation empty_; + FlatExpressionEvaluatorState state_; + + std::vector is_const_; +}; + +IsConst IsConstExpr(const Expr& expr, const Resolver& resolver) { + switch (expr.kind_case()) { + case ExprKindCase::kConstant: + return IsConst::kConditional; + case ExprKindCase::kIdentExpr: + return IsConst::kNonConst; + case ExprKindCase::kComprehensionExpr: + // Not yet supported, need to identify whether range and + // iter vars are compatible with const folding. + return IsConst::kNonConst; + case ExprKindCase::kStructExpr: + return IsConst::kNonConst; + case ExprKindCase::kMapExpr: + // Empty maps are rare and not currently supported as they may eventually + // have similar issues to empty list when used within comprehensions or + // macros. + if (expr.map_expr().entries().empty()) { + return IsConst::kNonConst; + } + return IsConst::kConditional; + case ExprKindCase::kListExpr: + if (expr.list_expr().elements().empty()) { + // Don't fold for empty list to allow comprehension + // list append optimization. + return IsConst::kNonConst; + } + return IsConst::kConditional; + case ExprKindCase::kSelectExpr: + return IsConst::kConditional; + case ExprKindCase::kCallExpr: { + const auto& call = expr.call_expr(); + // Short Circuiting operators not yet supported. + if (call.function() == kAnd || call.function() == kOr || + call.function() == kTernary) { + return IsConst::kNonConst; + } + // For now we skip constant folding for cel.@block. We do not yet setup + // slots. When we enable constant folding for comprehensions (like + // cel.bind), we can address cel.@block. + if (call.function() == "cel.@block") { + return IsConst::kNonConst; } - case Expr::kListExpr: { - auto list_expr = out->mutable_list_expr(); - int list_size = expr.list_expr().elements_size(); - bool all_constant = true; - for (int i = 0; i < list_size; i++) { - auto elt = list_expr->add_elements(); - all_constant = - Transform(expr.list_expr().elements(i), elt) && all_constant; - } - if (!all_constant) { - return false; - } - // create a constant list value - std::vector values(list_size); - for (int i = 0; i < list_size; i++) { - values[i] = removeConstant(list_expr->elements(i)); - } - CelList* cel_list = google::protobuf::Arena::Create( - arena_, std::move(values)); - makeConstant(CelValue::CreateList(cel_list), out); - return true; + int arg_len = call.args().size() + (call.has_target() ? 1 : 0); + // Check for any lazy overloads (activation dependant) + if (!resolver + .FindLazyOverloads(call.function(), call.has_target(), arg_len) + .empty()) { + return IsConst::kNonConst; } - case Expr::kStructExpr: { - auto struct_expr = out->mutable_struct_expr(); - struct_expr->set_message_name(expr.struct_expr().message_name()); - int entries_size = expr.struct_expr().entries_size(); - for (int i = 0; i < entries_size; i++) { - auto& entry = expr.struct_expr().entries(i); - auto new_entry = struct_expr->add_entries(); - new_entry->set_id(entry.id()); - switch (entry.key_kind_case()) { - case Expr::CreateStruct::Entry::kFieldKey: - new_entry->set_field_key(entry.field_key()); - break; - case Expr::CreateStruct::Entry::kMapKey: - Transform(entry.map_key(), new_entry->mutable_map_key()); - break; - default: - GOOGLE_LOG(ERROR) << "Unsupported Entry kind: " << entry.key_kind_case(); - break; - } - Transform(entry.value(), new_entry->mutable_value()); + + auto overloads = + resolver.FindOverloads(call.function(), call.has_target(), arg_len); + // Check for any contextual overloads. If there are any, we cowardly + // avoid constant folding instead of trying to check if one of the + // overloads would be safe to use. + for (const auto& overload : overloads) { + if (overload.descriptor.is_contextual()) { + return IsConst::kNonConst; } - return false; - } - case Expr::kComprehensionExpr: { - // do not fold comprehensions for now: would require significal - // factoring out of comprehension semantics from the evaluator - auto& input_expr = expr.comprehension_expr(); - auto out_expr = out->mutable_comprehension_expr(); - out_expr->set_iter_var(input_expr.iter_var()); - Transform(input_expr.accu_init(), out_expr->mutable_accu_init()); - Transform(input_expr.iter_range(), out_expr->mutable_iter_range()); - out_expr->set_accu_var(input_expr.accu_var()); - Transform(input_expr.loop_condition(), - out_expr->mutable_loop_condition()); - Transform(input_expr.loop_step(), out_expr->mutable_loop_step()); - Transform(input_expr.result(), out_expr->mutable_result()); - return false; } - default: - GOOGLE_LOG(ERROR) << "Unsupported Expr kind: " << expr.expr_kind_case(); - return false; + + return IsConst::kConditional; } + case ExprKindCase::kUnspecifiedExpr: + default: + return IsConst::kNonConst; } +} - private: - void makeConstant(CelValue value, Expr* out) { - auto ident = absl::StrCat("$v", counter_++); - constant_idents_.emplace(ident, value); - out->mutable_ident_expr()->set_name(ident); +absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context, + const Expr& node) { + IsConst is_const = IsConstExpr(node, context.resolver()); + is_const_.push_back(is_const); + + return absl::OkStatus(); +} + +absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, + const Expr& node) { + if (is_const_.empty()) { + return absl::InternalError("ConstantFoldingExtension called out of order."); } - CelValue removeConstant(const Expr& ident) { - return constant_idents_.extract(ident.ident_expr().name()).mapped(); + IsConst is_const = is_const_.back(); + is_const_.pop_back(); + + if (is_const == IsConst::kNonConst) { + // update parent + if (!is_const_.empty()) { + is_const_.back() = IsConst::kNonConst; + } + return absl::OkStatus(); } + ExecutionPathView subplan = context.GetSubplan(node); + if (subplan.empty()) { + // This subexpression is already optimized out or suppressed. + return absl::OkStatus(); + } + // copy string to managed handle if backed by the original program. + Value value; + if (node.has_const_expr()) { + CEL_ASSIGN_OR_RETURN(value, + ConvertConstant(node.const_expr(), state_.arena())); + } else { + ExecutionFrame frame(subplan, empty_, context.options(), state_); + state_.Reset(); + // Update stack size to accommodate sub expression. + // This only results in a vector resize if the new maxsize is greater than + // the current capacity. + state_.value_stack().SetMaxSize(subplan.size()); - const CelFunctionRegistry& registry_; + auto result = frame.Evaluate(); + // If this would be a runtime error, then don't adjust the program plan, but + // rather allow the error to occur at runtime to preserve the evaluation + // contract with non-constant folding use cases. + if (!result.ok()) { + return absl::OkStatus(); + } + value = *result; + if (value->Is()) { + return absl::OkStatus(); + } + } - // Owns constant values created during folding - google::protobuf::Arena* arena_; - absl::flat_hash_map& constant_idents_; + // If recursive planning enabled (recursion limit unbounded or at least 1), + // use a recursive (direct) step for the folded constant. + // + // Constant folding is applied leaf to root based on the program plan so far, + // so the planner will have an opportunity to validate that the recursion + // limit is being followed when visiting parent nodes in the AST. + if (context.options().max_recursion_depth != 0) { + return context.ReplaceSubplan( + node, CreateConstValueDirectStep(std::move(value), node.id()), 1); + } - int counter_; -}; + // Otherwise make a stack machine plan. + ExecutionPath new_plan; + CEL_ASSIGN_OR_RETURN( + new_plan.emplace_back(), + CreateConstValueStep(std::move(value), node.id(), false)); + + return context.ReplaceSubplan(node, std::move(new_plan)); +} } // namespace -void FoldConstants(const Expr& expr, const CelFunctionRegistry& registry, - google::protobuf::Arena* arena, - absl::flat_hash_map& constant_idents, - Expr* out) { - ConstantFoldingTransform constant_folder(registry, arena, constant_idents); - constant_folder.Transform(expr, out); +ProgramOptimizerFactory CreateConstantFoldingOptimizer( + absl_nullable std::shared_ptr arena, + absl_nullable std::shared_ptr message_factory) { + return + [shared_arena = std::move(arena), + shared_message_factory = std::move(message_factory)]( + PlannerContext& context, + const Ast&) -> absl::StatusOr> { + // If one was explicitly provided during planning or none was explicitly + // provided during configuration, request one from the planning context. + // Otherwise use the one provided during configuration. + google::protobuf::Arena* absl_nonnull arena = + context.HasExplicitArena() || shared_arena == nullptr + ? context.MutableArena() + : shared_arena.get(); + google::protobuf::MessageFactory* absl_nonnull message_factory = + context.HasExplicitMessageFactory() || + shared_message_factory == nullptr + ? context.MutableMessageFactory() + : shared_message_factory.get(); + return std::make_unique( + context.descriptor_pool(), shared_arena, arena, + shared_message_factory, message_factory, context.type_reflector()); + }; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel::runtime_internal diff --git a/eval/compiler/constant_folding.h b/eval/compiler/constant_folding.h index 20a1627bb..c871cd2c9 100644 --- a/eval/compiler/constant_folding.h +++ b/eval/compiler/constant_folding.h @@ -1,28 +1,42 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/container/flat_hash_map.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_function_registry.h" -#include "eval/public/cel_value.h" +#include + +#include "absl/base/nullability.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace cel::runtime_internal { -// A transformation over input expression that produces a new expression with -// constant sub-expressions replaced by generated idents in the constant_idents -// map. This transformation preserves the IDs of the input sub-expressions. -void FoldConstants(const google::api::expr::v1alpha1::Expr& expr, - const CelFunctionRegistry& registry, google::protobuf::Arena* arena, - absl::flat_hash_map& constant_idents, - google::api::expr::v1alpha1::Expr* out); +// Create a new constant folding extension. +// Eagerly evaluates sub expressions with all constant inputs, and replaces said +// sub expression with the result. +// +// Note: the precomputed values may be allocated using the provided +// MemoryManager so it must outlive any programs created with this +// extension. +google::api::expr::runtime::ProgramOptimizerFactory +CreateConstantFoldingOptimizer( + absl_nullable std::shared_ptr arena = nullptr, + absl_nullable std::shared_ptr message_factory = + nullptr); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index f7e24e7e4..d1c0c31e0 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -1,452 +1,573 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/compiler/constant_folding.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/text_format.h" -#include "google/protobuf/util/message_differencer.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "eval/public/builtin_func_registrar.h" -#include "eval/public/cel_function_registry.h" -#include "eval/testutil/test_message.pb.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/ast.h" +#include "common/expr.h" +#include "common/value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/create_list_step.h" +#include "eval/eval/create_map_step.h" +#include "eval/eval/evaluator_core.h" +#include "extensions/protobuf/ast_converters.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/function_registry.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" + +namespace cel::runtime_internal { namespace { -using google::api::expr::v1alpha1::Expr; - -// Validate select is preserved as-is -TEST(ConstantFoldingTest, Select) { - Expr expr; - // has(x.y) - google::protobuf::TextFormat::ParseFromString(R"( - id: 1 - select_expr { - operand { - id: 2 - ident_expr { name: "x" } - } - field: "y" - test_only: true - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - google::protobuf::util::MessageDifferencer md; - EXPECT_TRUE(md.Compare(out, expr)) << out.DebugString(); - EXPECT_TRUE(idents.empty()); +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::Expr; +using ::cel::RuntimeIssue; +using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::CreateConstValueStep; +using ::google::api::expr::runtime::CreateCreateListStep; +using ::google::api::expr::runtime::CreateCreateStructStepForMap; +using ::google::api::expr::runtime::ExecutionPath; +using ::google::api::expr::runtime::PlannerContext; +using ::google::api::expr::runtime::ProgramBuilder; +using ::google::api::expr::runtime::ProgramOptimizer; +using ::google::api::expr::runtime::ProgramOptimizerFactory; +using ::google::api::expr::runtime::Resolver; +using ::testing::SizeIs; + +class UpdatedConstantFoldingTest : public testing::Test { + public: + UpdatedConstantFoldingTest() + : env_(NewTestingRuntimeEnv()), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry), + issue_collector_(RuntimeIssue::Severity::kError), + resolver_("", function_registry_, type_registry_, + type_registry_.GetComposedTypeProvider()) {} + + protected: + absl_nonnull std::shared_ptr env_; + google::protobuf::Arena arena_; + cel::FunctionRegistry& function_registry_; + cel::TypeRegistry& type_registry_; + cel::RuntimeOptions options_; + IssueCollector issue_collector_; + Resolver resolver_; +}; + +absl::StatusOr> ParseFromCel( + absl::string_view expression) { + CEL_ASSIGN_OR_RETURN(ParsedExpr expr, Parse(expression)); + return cel::extensions::CreateAstFromParsedExpr(expr); } -// Validate struct message creation -TEST(ConstantFoldingTest, StructMessage) { - Expr expr; - // {"field1": "y", "field2": "t"} - google::protobuf::TextFormat::ParseFromString( - R"pb( - id: 5 - struct_expr { - entries { - id: 11 - field_key: "field1" - value { const_expr { string_value: "value1" } } - } - entries { - id: 7 - field_key: "field2" - value { const_expr { int64_value: 12 } } - } - message_name: "MyProto" - })pb", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - Expr expected; - google::protobuf::TextFormat::ParseFromString(R"( - id: 5 - struct_expr { - entries { - id: 11 - field_key: "field1" - value { ident_expr { name: "$v0" } } - } - entries { - id: 7 - field_key: "field2" - value { ident_expr { name: "$v1" } } - } - message_name: "MyProto" - })", - &expected); - google::protobuf::util::MessageDifferencer md; - EXPECT_TRUE(md.Compare(out, expected)) << out.DebugString(); - - EXPECT_EQ(idents.size(), 2); - EXPECT_TRUE(idents["$v0"].IsString()); - EXPECT_EQ(idents["$v0"].StringOrDie().value(), "value1"); - EXPECT_TRUE(idents["$v1"].IsInt64()); - EXPECT_EQ(idents["$v1"].Int64OrDie(), 12); +// While CEL doesn't provide execution order guarantees per se, short circuiting +// operators are treated specially to evaluate to user expectations. +// +// These behaviors aren't easily observable since the flat expression doesn't +// expose any details about the program after building, so a lot of setup is +// needed to simulate what the expression builder does. +TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("true ? true : false")); + + const Expr& call = ast->root_expr(); + const Expr& condition = call.call_expr().args()[0]; + const Expr& true_branch = call.call_expr().args()[1]; + const Expr& false_branch = call.call_expr().args()[2]; + + ProgramBuilder program_builder; + program_builder.EnterSubexpression(&call); + // condition + program_builder.EnterSubexpression(&condition); + ASSERT_OK_AND_ASSIGN(auto step, + CreateConstValueStep(cel::BoolValue(true), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&condition); + + // true + program_builder.EnterSubexpression(&true_branch); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(true), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&true_branch); + + // false + program_builder.EnterSubexpression(&false_branch); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(true), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&false_branch); + + // ternary. + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&call); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act + // Issue the visitation calls. + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + ASSERT_OK(constant_folder->OnPreVisit(context, call)); + ASSERT_OK(constant_folder->OnPreVisit(context, condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, condition)); + ASSERT_OK(constant_folder->OnPreVisit(context, true_branch)); + ASSERT_OK(constant_folder->OnPostVisit(context, true_branch)); + ASSERT_OK(constant_folder->OnPreVisit(context, false_branch)); + ASSERT_OK(constant_folder->OnPostVisit(context, false_branch)); + ASSERT_OK(constant_folder->OnPostVisit(context, call)); + + // Assert + // No changes attempted. + auto path = std::move(program_builder).FlattenMain(); + EXPECT_THAT(path, SizeIs(4)); } -// Validate struct creation is not folded but recursed into -TEST(ConstantFoldingTest, StructComprehension) { - Expr expr; - // {"x": "y", "z": "t"} - google::protobuf::TextFormat::ParseFromString(R"( - id: 5 - struct_expr { - entries { - id: 11 - field_key: "x" - value { const_expr { string_value: "y" } } - } - entries { - id: 7 - map_key { const_expr { string_value: "z" } } - value { const_expr { string_value: "t" } } - } - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - Expr expected; - google::protobuf::TextFormat::ParseFromString(R"( - id: 5 - struct_expr { - entries { - id: 11 - field_key: "x" - value { ident_expr { name: "$v0" } } - } - entries { - id: 7 - map_key { ident_expr { name: "$v1" } } - value { ident_expr { name: "$v2" } } - } - })", - &expected); - google::protobuf::util::MessageDifferencer md; - EXPECT_TRUE(md.Compare(out, expected)) << out.DebugString(); - - EXPECT_EQ(idents.size(), 3); - EXPECT_TRUE(idents["$v0"].IsString()); - EXPECT_EQ(idents["$v0"].StringOrDie().value(), "y"); - EXPECT_TRUE(idents["$v1"].IsString()); - EXPECT_TRUE(idents["$v2"].IsString()); +TEST_F(UpdatedConstantFoldingTest, SkipsOr) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("false || true")); + + const Expr& call = ast->root_expr(); + const Expr& left_condition = call.call_expr().args()[0]; + const Expr& right_condition = call.call_expr().args()[1]; + + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&call); + + // left + program_builder.EnterSubexpression(&left_condition); + ASSERT_OK_AND_ASSIGN(auto step, + CreateConstValueStep(cel::BoolValue(false), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&left_condition); + + // right + program_builder.EnterSubexpression(&right_condition); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(true), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&right_condition); + + // op + // Just a placeholder. + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&call); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act + // Issue the visitation calls. + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + ASSERT_OK(constant_folder->OnPreVisit(context, call)); + ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPreVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, call)); + + // Assert + // No changes attempted. + auto path = std::move(program_builder).FlattenMain(); + EXPECT_THAT(path, SizeIs(3)); } -TEST(ConstantFoldingTest, ListComprehension) { - Expr expr; - // [1, [2, 3]] - google::protobuf::TextFormat::ParseFromString(R"( - id: 45 - list_expr { - elements { const_expr { int64_value: 1 } } - elements { - list_expr { - elements { const_expr { int64_value: 2 } } - elements { const_expr { int64_value: 3 } } - } - } - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - ASSERT_EQ(out.id(), 45); - ASSERT_TRUE(out.has_ident_expr()) << out.DebugString(); - ASSERT_EQ(idents.size(), 1); - auto value = idents[out.ident_expr().name()]; - ASSERT_TRUE(value.IsList()); - const auto& list = *value.ListOrDie(); - ASSERT_EQ(list.size(), 2); - ASSERT_TRUE(list[0].IsInt64()); - ASSERT_EQ(list[0].Int64OrDie(), 1); - ASSERT_TRUE(list[1].IsList()); - ASSERT_EQ(list[1].ListOrDie()->size(), 2); +TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("true && false")); + + const Expr& call = ast->root_expr(); + const Expr& left_condition = call.call_expr().args()[0]; + const Expr& right_condition = call.call_expr().args()[1]; + + ProgramBuilder program_builder; + program_builder.EnterSubexpression(&call); + + // left + program_builder.EnterSubexpression(&left_condition); + ASSERT_OK_AND_ASSIGN(auto step, + CreateConstValueStep(cel::BoolValue(true), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&left_condition); + + // right + program_builder.EnterSubexpression(&right_condition); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(false), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&right_condition); + + // op + // Just a placeholder. + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&call); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act + // Issue the visitation calls. + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + ASSERT_OK(constant_folder->OnPreVisit(context, call)); + ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPreVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, call)); + + // Assert + // No changes attempted. + ExecutionPath path = std::move(program_builder).FlattenMain(); + EXPECT_THAT(path, SizeIs(3)); } -// Validate that logic function application are not folded -TEST(ConstantFoldingTest, LogicApplication) { - Expr expr; - // true && false - google::protobuf::TextFormat::ParseFromString(R"( - id: 105 - call_expr { - function: "_&&_" - args { - const_expr { bool_value: true } - } - args { - const_expr { bool_value: false } - } - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - ASSERT_EQ(out.id(), 105); - ASSERT_TRUE(out.has_call_expr()) << out.DebugString(); - ASSERT_EQ(idents.size(), 2); +TEST_F(UpdatedConstantFoldingTest, CreatesList) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("[1, 2]")); + + const Expr& create_list = ast->root_expr(); + const Expr& elem_one = create_list.list_expr().elements()[0].expr(); + const Expr& elem_two = create_list.list_expr().elements()[1].expr(); + + ProgramBuilder program_builder; + // Simulate the visitor order. + program_builder.EnterSubexpression(&create_list); + + // elem one + program_builder.EnterSubexpression(&elem_one); + ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::IntValue(1L), 1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&elem_one); + + // elem two + program_builder.EnterSubexpression(&elem_two); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&elem_two); + + // createlist + ASSERT_OK_AND_ASSIGN(step, CreateCreateListStep(create_list.list_expr(), 3)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&create_list); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act + // Issue the visitation calls. + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + ASSERT_OK(constant_folder->OnPreVisit(context, create_list)); + ASSERT_OK(constant_folder->OnPreVisit(context, elem_one)); + ASSERT_OK(constant_folder->OnPostVisit(context, elem_one)); + ASSERT_OK(constant_folder->OnPreVisit(context, elem_two)); + ASSERT_OK(constant_folder->OnPostVisit(context, elem_two)); + ASSERT_OK(constant_folder->OnPostVisit(context, create_list)); + + // Assert + // Single constant value for the two element list. + ExecutionPath path = std::move(program_builder).FlattenMain(); + EXPECT_THAT(path, SizeIs(1)); } -TEST(ConstantFoldingTest, FunctionApplication) { - Expr expr; - // [1] + [2] - google::protobuf::TextFormat::ParseFromString(R"( - id: 15 - call_expr { - function: "_+_" - args { - list_expr { - elements { const_expr { int64_value: 1 } } - } - } - args { - list_expr { - elements { const_expr { int64_value: 2 } } - } - } - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - ASSERT_EQ(out.id(), 15); - ASSERT_TRUE(out.has_ident_expr()) << out.DebugString(); - ASSERT_EQ(idents.size(), 1); - ASSERT_TRUE(idents[out.ident_expr().name()].IsList()); - - const auto& list = *idents[out.ident_expr().name()].ListOrDie(); - ASSERT_EQ(list.size(), 2); - ASSERT_EQ(list[0].Int64OrDie(), 1); - ASSERT_EQ(list[1].Int64OrDie(), 2); +TEST_F(UpdatedConstantFoldingTest, CreatesLargeList) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("[1, 2, 3, 4, 5]")); + + const Expr& create_list = ast->root_expr(); + const Expr& elem0 = create_list.list_expr().elements()[0].expr(); + const Expr& elem1 = create_list.list_expr().elements()[1].expr(); + const Expr& elem2 = create_list.list_expr().elements()[2].expr(); + const Expr& elem3 = create_list.list_expr().elements()[3].expr(); + const Expr& elem4 = create_list.list_expr().elements()[4].expr(); + + ProgramBuilder program_builder; + // Simulate the visitor order. + ASSERT_TRUE(program_builder.EnterSubexpression(&create_list) != nullptr); + + // 0 + ASSERT_TRUE(program_builder.EnterSubexpression(&elem0) != nullptr); + ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::IntValue(1L), 1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&elem0); + + // 1 + ASSERT_TRUE(program_builder.EnterSubexpression(&elem1)); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&elem1); + + // 2 + ASSERT_TRUE(program_builder.EnterSubexpression(&elem2) != nullptr); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(3L), 3)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&elem2); + + // 3 + ASSERT_TRUE(program_builder.EnterSubexpression(&elem3) != nullptr); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(4L), 4)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&elem3); + + // 4 + ASSERT_TRUE(program_builder.EnterSubexpression(&elem4) != nullptr); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(5L), 5)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&elem4); + + // createlist + ASSERT_OK_AND_ASSIGN(step, CreateCreateListStep(create_list.list_expr(), 6)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&create_list); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act + // Issue the visitation calls. + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + ASSERT_THAT(constant_folder->OnPreVisit(context, create_list), IsOk()); + ASSERT_THAT(constant_folder->OnPreVisit(context, elem0), IsOk()); + ASSERT_THAT(constant_folder->OnPostVisit(context, elem0), IsOk()); + ASSERT_THAT(constant_folder->OnPreVisit(context, elem1), IsOk()); + ASSERT_THAT(constant_folder->OnPostVisit(context, elem1), IsOk()); + ASSERT_THAT(constant_folder->OnPreVisit(context, elem2), IsOk()); + ASSERT_THAT(constant_folder->OnPostVisit(context, elem2), IsOk()); + ASSERT_THAT(constant_folder->OnPreVisit(context, elem3), IsOk()); + ASSERT_THAT(constant_folder->OnPostVisit(context, elem3), IsOk()); + ASSERT_THAT(constant_folder->OnPreVisit(context, elem4), IsOk()); + ASSERT_THAT(constant_folder->OnPostVisit(context, elem4), IsOk()); + ASSERT_THAT(constant_folder->OnPostVisit(context, create_list), IsOk()); + + // Assert + // Single constant value for the two element list. + ExecutionPath path = std::move(program_builder).FlattenMain(); + EXPECT_THAT(path, SizeIs(1)); } -TEST(ConstantFoldingTest, FunctionApplicationWithReceiver) { - Expr expr; - // [1, 1].size() - google::protobuf::TextFormat::ParseFromString(R"( - id: 10 - call_expr { - function: "size" - target { - list_expr { - elements { const_expr { int64_value: 1 } } - elements { const_expr { int64_value: 1 } } - } - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - ASSERT_EQ(out.id(), 10); - ASSERT_TRUE(out.has_ident_expr()) << out.DebugString(); - ASSERT_EQ(idents.size(), 1); - ASSERT_TRUE(idents[out.ident_expr().name()].IsInt64()); - ASSERT_EQ(idents[out.ident_expr().name()].Int64OrDie(), 2); +TEST_F(UpdatedConstantFoldingTest, CreatesMap) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("{1: 2}")); + + const Expr& create_map = ast->root_expr(); + const Expr& key = create_map.map_expr().entries()[0].key(); + const Expr& value = create_map.map_expr().entries()[0].value(); + + ProgramBuilder program_builder; + program_builder.EnterSubexpression(&create_map); + + // key + program_builder.EnterSubexpression(&key); + ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::IntValue(1L), 1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&key); + + // value + program_builder.EnterSubexpression(&value); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&value); + + // create map + ASSERT_OK_AND_ASSIGN( + step, CreateCreateStructStepForMap(create_map.map_expr().entries().size(), + {}, 3)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&create_map); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act + // Issue the visitation calls. + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + ASSERT_OK(constant_folder->OnPreVisit(context, create_map)); + ASSERT_OK(constant_folder->OnPreVisit(context, key)); + ASSERT_OK(constant_folder->OnPostVisit(context, key)); + ASSERT_OK(constant_folder->OnPreVisit(context, value)); + ASSERT_OK(constant_folder->OnPostVisit(context, value)); + ASSERT_OK(constant_folder->OnPostVisit(context, create_map)); + + // Assert + // Single constant value for the map. + ExecutionPath path = std::move(program_builder).FlattenMain(); + EXPECT_THAT(path, SizeIs(1)); } -TEST(ConstantFoldingTest, FunctionApplicationNoOverload) { - Expr expr; - // 1 + [2] - google::protobuf::TextFormat::ParseFromString(R"( - id: 16 - call_expr { - function: "_+_" - args { - const_expr { int64_value: 1 } - } - args { - list_expr { - elements { const_expr { int64_value: 2 } } - } - } - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - ASSERT_EQ(out.id(), 16); - ASSERT_TRUE(out.has_ident_expr()) << out.DebugString(); - ASSERT_EQ(idents.size(), 1); - ASSERT_TRUE(CheckNoMatchingOverloadError(idents[out.ident_expr().name()])); +TEST_F(UpdatedConstantFoldingTest, CreatesInvalidMap) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("{1.0: 2}")); + + const Expr& create_map = ast->root_expr(); + const Expr& key = create_map.map_expr().entries()[0].key(); + const Expr& value = create_map.map_expr().entries()[0].value(); + + ProgramBuilder program_builder; + program_builder.EnterSubexpression(&create_map); + + // key + program_builder.EnterSubexpression(&key); + ASSERT_OK_AND_ASSIGN(auto step, + CreateConstValueStep(cel::DoubleValue(1.0), 1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&key); + + // value + program_builder.EnterSubexpression(&value); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&value); + + // create map + ASSERT_OK_AND_ASSIGN( + step, CreateCreateStructStepForMap(create_map.map_expr().entries().size(), + {}, 3)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&create_map); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act + // Issue the visitation calls. + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + ASSERT_OK(constant_folder->OnPreVisit(context, create_map)); + ASSERT_OK(constant_folder->OnPreVisit(context, key)); + ASSERT_OK(constant_folder->OnPostVisit(context, key)); + ASSERT_OK(constant_folder->OnPreVisit(context, value)); + ASSERT_OK(constant_folder->OnPostVisit(context, value)); + ASSERT_OK(constant_folder->OnPostVisit(context, create_map)); + + ExecutionPath path = std::move(program_builder).FlattenMain(); + EXPECT_THAT(path, SizeIs(1)); } -// Validate that comprehension is recursed into -TEST(ConstantFoldingTest, MapComprehension) { - Expr expr; - // {1: "", 2: ""}.all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( - id: 1 - comprehension_expr { - iter_var: "k" - accu_var: "accu" - accu_init { - id: 2 - const_expr { bool_value: true } - } - loop_condition { - id: 3 - ident_expr { name: "accu" } - } - result { - id: 4 - ident_expr { name: "accu" } - } - loop_step { - id: 5 - call_expr { - function: "_&&_" - args { - ident_expr { name: "accu" } - } - args { - call_expr { - function: "_>_" - args { ident_expr { name: "k" } } - args { const_expr { int64_value: 0 } } - } - } - } - } - iter_range { - id: 6 - struct_expr { - entries { - map_key { const_expr { int64_value: 1 } } - value { const_expr { string_value: "" } } - } - entries { - id: 7 - map_key { const_expr { int64_value: 2 } } - value { const_expr { string_value: "" } } - } - } - } - })", - &expr); - - google::protobuf::Arena arena; - CelFunctionRegistry registry; - - absl::flat_hash_map idents; - Expr out; - FoldConstants(expr, registry, &arena, idents, &out); - - Expr expected; - google::protobuf::TextFormat::ParseFromString(R"( - id: 1 - comprehension_expr { - iter_var: "k" - accu_var: "accu" - accu_init { - id: 2 - ident_expr { name: "$v0" } - } - loop_condition { - id: 3 - ident_expr { name: "accu" } - } - result { - id: 4 - ident_expr { name: "accu" } - } - loop_step { - id: 5 - call_expr { - function: "_&&_" - args { - ident_expr { name: "accu" } - } - args { - call_expr { - function: "_>_" - args { ident_expr { name: "k" } } - args { ident_expr { name: "$v5" } } - } - } - } - } - iter_range { - id: 6 - struct_expr { - entries { - map_key { ident_expr { name: "$v1" } } - value { ident_expr { name: "$v2" } } - } - entries { - id: 7 - map_key { ident_expr { name: "$v3" } } - value { ident_expr { name: "$v4" } } - } - } - } - })", - &expected); - google::protobuf::util::MessageDifferencer md; - EXPECT_TRUE(md.Compare(out, expected)) << out.DebugString(); - - EXPECT_EQ(idents.size(), 6); - EXPECT_TRUE(idents["$v0"].IsBool()); - EXPECT_TRUE(idents["$v1"].IsInt64()); - EXPECT_TRUE(idents["$v2"].IsString()); - EXPECT_TRUE(idents["$v3"].IsInt64()); - EXPECT_TRUE(idents["$v4"].IsString()); - EXPECT_TRUE(idents["$v5"].IsInt64()); +TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("true && false")); + + const Expr& call = ast->root_expr(); + const Expr& left_condition = call.call_expr().args()[0]; + const Expr& right_condition = call.call_expr().args()[1]; + + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&call); + // left + program_builder.EnterSubexpression(&left_condition); + ASSERT_OK_AND_ASSIGN(auto step, + CreateConstValueStep(cel::BoolValue(true), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&left_condition); + + // right + program_builder.EnterSubexpression(&right_condition); + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(false), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&right_condition); + + // op + // Just a placeholder. + ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); + program_builder.AddStep(std::move(step)); + program_builder.ExitSubexpression(&call); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ProgramOptimizerFactory constant_folder_factory = + CreateConstantFoldingOptimizer(); + + // Act / Assert + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, *ast)); + EXPECT_THAT(constant_folder->OnPostVisit(context, left_condition), + StatusIs(absl::StatusCode::kInternal)); } } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel::runtime_internal diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 61df4fb01..aa9a8858c 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1,92 +1,238 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include "eval/compiler/flat_expr_builder.h" -#include "stack" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/absl_check.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" -#include "absl/strings/str_split.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "eval/compiler/constant_folding.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "base/ast.h" +#include "base/builtins.h" +#include "base/type_provider.h" +#include "common/allocator.h" +#include "common/ast.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/kind.h" +#include "common/type.h" +#include "common/value.h" +#include "eval/compiler/check_ast_extensions.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" #include "eval/eval/const_value_step.h" #include "eval/eval/container_access_step.h" #include "eval/eval/create_list_step.h" +#include "eval/eval/create_map_step.h" #include "eval/eval/create_struct_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/equality_steps.h" #include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_build_warning.h" #include "eval/eval/function_step.h" #include "eval/eval/ident_step.h" #include "eval/eval/jump_step.h" +#include "eval/eval/lazy_init_step.h" #include "eval/eval/logic_step.h" +#include "eval/eval/optional_or_step.h" #include "eval/eval/select_step.h" +#include "eval/eval/shadowable_value_step.h" #include "eval/eval/ternary_step.h" -#include "eval/public/ast_traverse.h" -#include "eval/public/ast_visitor.h" -#include "eval/public/cel_builtins.h" -#include "eval/public/cel_function_registry.h" +#include "eval/eval/trace_step.h" +#include "internal/status_macros.h" +#include "runtime/internal/convert_constant.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; -using Ident = google::api::expr::v1alpha1::Expr::Ident; -using Select = google::api::expr::v1alpha1::Expr::Select; -using Call = google::api::expr::v1alpha1::Expr::Call; -using CreateList = google::api::expr::v1alpha1::Expr::CreateList; -using CreateStruct = google::api::expr::v1alpha1::Expr::CreateStruct; -using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; +using ::cel::Ast; +using ::cel::AstTraverse; +using ::cel::RuntimeIssue; +using ::cel::StringValue; +using ::cel::Value; +using ::cel::runtime_internal::ConvertConstant; +using ::cel::runtime_internal::GetLegacyRuntimeTypeProvider; +using ::cel::runtime_internal::GetRuntimeTypeProvider; +using ::cel::runtime_internal::IssueCollector; + +constexpr absl::string_view kOptionalOrFn = "or"; +constexpr absl::string_view kOptionalOrValueFn = "orValue"; +constexpr absl::string_view kBlock = "cel.@block"; // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; +// Error code for failed recursive program building. Generally indicates an +// optimization doesn't support recursive programs. +absl::Status FailedRecursivePlanning() { + return absl::InternalError( + "failed to build recursive program. check for unsupported optimizations"); +} + +// Helper for bookkeeping variables mapped to indexes. +class IndexManager { + public: + IndexManager() : next_free_slot_(0), max_slot_count_(0) {} + + size_t ReserveSlots(size_t n) { + size_t result = next_free_slot_; + next_free_slot_ += n; + if (next_free_slot_ > max_slot_count_) { + max_slot_count_ = next_free_slot_; + } + return result; + } + + size_t ReleaseSlots(size_t n) { + next_free_slot_ -= n; + return next_free_slot_; + } + + size_t max_slot_count() const { return max_slot_count_; } + + private: + size_t next_free_slot_; + size_t max_slot_count_; +}; + +// Helper for computing jump offsets. +// +// Jumps should be self-contained to a single expression node -- jumping +// outside that range is a bug. +struct ProgramStepIndex { + int index; + ProgramBuilder::Subexpression* subexpression; +}; + // A convenience wrapper for offset-calculating logic. class Jump { public: - explicit Jump() : self_index_(-1), jump_step_(nullptr) {} - explicit Jump(int self_index, JumpStepBase* jump_step) + // Default constructor for empty jump. + // + // Users must check that jump is non-empty before calling member functions. + explicit Jump() : self_index_{-1, nullptr}, jump_step_(nullptr) {} + Jump(ProgramStepIndex self_index, JumpStepBase* jump_step) : self_index_(self_index), jump_step_(jump_step) {} - void set_target(int index) { - // 0 offset means no-op. - jump_step_->set_jump_offset(index - self_index_ - 1); + + static absl::StatusOr CalculateOffset(ProgramStepIndex base, + ProgramStepIndex target) { + if (target.subexpression != base.subexpression) { + return absl::InternalError( + "Jump target must be contained in the parent" + "subexpression"); + } + + int offset = base.subexpression->CalculateOffset(base.index, target.index); + return offset; } + + absl::Status set_target(ProgramStepIndex target) { + CEL_ASSIGN_OR_RETURN(int offset, CalculateOffset(self_index_, target)); + + jump_step_->set_jump_offset(offset); + return absl::OkStatus(); + } + bool exists() { return jump_step_ != nullptr; } private: - int self_index_; + ProgramStepIndex self_index_; JumpStepBase* jump_step_; }; class CondVisitor { public: - virtual ~CondVisitor() {} - virtual void PreVisit(const Expr* expr) = 0; - virtual void PostVisitArg(int arg_num, const Expr* expr) = 0; - virtual void PostVisit(const Expr* expr) = 0; + virtual ~CondVisitor() = default; + virtual void PreVisit(const cel::Expr* expr) = 0; + virtual void PostVisitArg(int arg_num, const cel::Expr* expr) = 0; + virtual void PostVisit(const cel::Expr* expr) = 0; + virtual void PostVisitTarget(const cel::Expr* expr) {} +}; + +enum class BinaryCond { + kAnd = 0, + kOr, + kOptionalOr, + kOptionalOrValue, }; // Visitor managing the "&&" and "||" operatiions. +// Implements short-circuiting if enabled. +// +// With short-circuiting enabled, generates a program like: +// +-------------+------------------------+-----------------------+ +// | PC | Step | Stack | +// +-------------+------------------------+-----------------------+ +// | i + 0 | | arg1 | +// | i + 1 | ConditionalJump i + 4 | arg1 | +// | i + 2 | | arg1, arg2 | +// | i + 3 | BooleanOperator | Op(arg1, arg2) | +// | i + 4 | | arg1 | Op(arg1, arg2) | +// +-------------+------------------------+------------------------+ class BinaryCondVisitor : public CondVisitor { public: - explicit BinaryCondVisitor(FlatExprVisitor* visitor, bool cond_value, + explicit BinaryCondVisitor(FlatExprVisitor* visitor, BinaryCond cond, bool short_circuiting) - : visitor_(visitor), - cond_value_(cond_value), - short_circuiting_(short_circuiting) {} + : visitor_(visitor), cond_(cond), short_circuiting_(short_circuiting) {} - void PreVisit(const Expr* expr) override; - void PostVisitArg(int arg_num, const Expr* expr) override; - void PostVisit(const Expr* expr) override; + void PreVisit(const cel::Expr* expr) override; + void PostVisitArg(int arg_num, const cel::Expr* expr) override; + void PostVisit(const cel::Expr* expr) override; + void PostVisitTarget(const cel::Expr* expr) override; private: FlatExprVisitor* visitor_; - const bool cond_value_; - Jump jump_step_; + const BinaryCond cond_; + std::vector jump_steps_; bool short_circuiting_; }; @@ -94,9 +240,9 @@ class TernaryCondVisitor : public CondVisitor { public: explicit TernaryCondVisitor(FlatExprVisitor* visitor) : visitor_(visitor) {} - void PreVisit(const Expr* expr) override; - void PostVisitArg(int arg_num, const Expr* expr) override; - void PostVisit(const Expr* expr) override; + void PreVisit(const cel::Expr* expr) override; + void PostVisitArg(int arg_num, const cel::Expr* expr) override; + void PostVisit(const cel::Expr* expr) override; private: FlatExprVisitor* visitor_; @@ -110,170 +256,688 @@ class ExhaustiveTernaryCondVisitor : public CondVisitor { explicit ExhaustiveTernaryCondVisitor(FlatExprVisitor* visitor) : visitor_(visitor) {} - void PreVisit(const Expr* expr) override {} - void PostVisitArg(int arg_num, const Expr* expr) override {} - void PostVisit(const Expr* expr) override; + void PreVisit(const cel::Expr* expr) override; + void PostVisitArg(int arg_num, const cel::Expr* expr) override {} + void PostVisit(const cel::Expr* expr) override; private: FlatExprVisitor* visitor_; }; -// Visitor Comprehension expression. -class ComprehensionVisitor : public CondVisitor { +// Returns a hint for the number of program nodes (steps or subexpressions) that +// will be created for this expr. +size_t SizeHint(const cel::Expr& expr) { + switch (expr.kind_case()) { + case cel::ExprKindCase::kConstant: + return 1; + case cel::ExprKindCase::kIdentExpr: + return 1; + case cel::ExprKindCase::kSelectExpr: + return 2; + case cel::ExprKindCase::kCallExpr: + return expr.call_expr().args().size() + + (expr.call_expr().has_target() ? 2 : 1); + case cel::ExprKindCase::kListExpr: + return expr.list_expr().elements().size() + 1; + case cel::ExprKindCase::kStructExpr: + return expr.struct_expr().fields().size() + 1; + case cel::ExprKindCase::kMapExpr: + return 2 * expr.struct_expr().fields().size() + 1; + default: + return 1; + } + return 0; +} + +// Returns whether this comprehension appears to be a standard map/filter +// macro implementation. It is not exhaustive, so it is unsafe to use with +// custom comprehensions outside of the standard macros or hand crafted ASTs. +bool IsOptimizableListAppend(const cel::ComprehensionExpr* comprehension, + bool enable_comprehension_list_append) { + if (!enable_comprehension_list_append) { + return false; + } + absl::string_view accu_var = comprehension->accu_var(); + if (accu_var.empty() || + comprehension->result().ident_expr().name() != accu_var) { + return false; + } + if (!comprehension->accu_init().has_list_expr() || + !comprehension->accu_init().list_expr().elements().empty()) { + return false; + } + + if (!comprehension->loop_step().has_call_expr()) { + return false; + } + + // Macro loop_step for a filter() will contain a ternary: + // filter ? accu_var + [elem] : accu_var + // Macro loop_step for a map() will contain a list concat operation: + // accu_var + [elem] + const auto* call_expr = &comprehension->loop_step().call_expr(); + + if (call_expr->function() == cel::builtin::kTernary && + call_expr->args().size() == 3) { + if (!call_expr->args()[1].has_call_expr()) { + return false; + } + call_expr = &(call_expr->args()[1].call_expr()); + } + + return call_expr->function() == cel::builtin::kAdd && + call_expr->args().size() == 2 && + call_expr->args()[0].has_ident_expr() && + call_expr->args()[0].ident_expr().name() == accu_var && + call_expr->args()[1].has_list_expr() && + call_expr->args()[1].list_expr().elements().size() == 1; +} + +// Assuming `IsOptimizableListAppend()` return true, return a pointer to the +// call `accu_var + [elem]`. +const cel::CallExpr* GetOptimizableListAppendCall( + const cel::ComprehensionExpr* comprehension) { + ABSL_DCHECK(IsOptimizableListAppend( + comprehension, /*enable_comprehension_list_append=*/true)); + + // Macro loop_step for a filter() will contain a ternary: + // filter ? accu_var + [elem] : accu_var + // Macro loop_step for a map() will contain a list concat operation: + // accu_var + [elem] + const auto* call_expr = &comprehension->loop_step().call_expr(); + + if (call_expr->function() == cel::builtin::kTernary && + call_expr->args().size() == 3) { + call_expr = &(call_expr->args()[1].call_expr()); + } + return call_expr; +} + +// Assuming `IsOptimizableListAppend()` return true, return a pointer to the +// node `[elem]`. +const cel::Expr* GetOptimizableListAppendOperand( + const cel::ComprehensionExpr* comprehension) { + return &GetOptimizableListAppendCall(comprehension)->args()[1]; +} + +// Returns whether this comprehension appears to be a macro implementation for +// map transformations. It is not exhaustive, so it is unsafe to use with custom +// comprehensions outside of the standard macros or hand crafted ASTs. +bool IsOptimizableMapInsert(const cel::ComprehensionExpr* comprehension, + bool enable_comprehension_mutable_map) { + if (!enable_comprehension_mutable_map) { + return false; + } + if (comprehension->iter_var().empty() || comprehension->iter_var2().empty()) { + return false; + } + absl::string_view accu_var = comprehension->accu_var(); + if (accu_var.empty() || !comprehension->has_result() || + !comprehension->result().has_ident_expr() || + comprehension->result().ident_expr().name() != accu_var) { + return false; + } + if (!comprehension->accu_init().has_map_expr()) { + return false; + } + if (!comprehension->loop_step().has_call_expr()) { + return false; + } + const auto* call_expr = &comprehension->loop_step().call_expr(); + + if (call_expr->function() == cel::builtin::kTernary && + call_expr->args().size() == 3) { + if (!call_expr->args()[1].has_call_expr()) { + return false; + } + call_expr = &(call_expr->args()[1].call_expr()); + } + return call_expr->function() == "cel.@mapInsert" && + (call_expr->args().size() == 2 || call_expr->args().size() == 3) && + call_expr->args()[0].has_ident_expr() && + call_expr->args()[0].ident_expr().name() == accu_var; +} + +bool IsBind(const cel::ComprehensionExpr* comprehension) { + static constexpr absl::string_view kUnusedIterVar = "#unused"; + + return comprehension->loop_condition().const_expr().has_bool_value() && + comprehension->loop_condition().const_expr().bool_value() == false && + comprehension->iter_var() == kUnusedIterVar && + comprehension->iter_var2().empty() && + comprehension->iter_range().has_list_expr() && + comprehension->iter_range().list_expr().elements().empty(); +} + +bool IsBlock(const cel::CallExpr* call) { return call->function() == kBlock; } + +// Visitor for Comprehension expressions. +class ComprehensionVisitor { public: - explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting) + explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting, + bool is_trivial, size_t iter_slot, + size_t iter2_slot, size_t accu_slot) : visitor_(visitor), next_step_(nullptr), cond_step_(nullptr), - short_circuiting_(short_circuiting) {} + short_circuiting_(short_circuiting), + is_trivial_(is_trivial), + accu_init_extracted_(false), + iter_slot_(iter_slot), + iter2_slot_(iter2_slot), + accu_slot_(accu_slot) {} + + void PreVisit(const cel::Expr* expr); + absl::Status PostVisitArg(cel::ComprehensionArg arg_num, + const cel::Expr* comprehension_expr) { + if (is_trivial_) { + PostVisitArgTrivial(arg_num, comprehension_expr); + return absl::OkStatus(); + } else { + return PostVisitArgDefault(arg_num, comprehension_expr); + } + } + void PostVisit(const cel::Expr* expr); - void PreVisit(const Expr* expr) override; - void PostVisitArg(int arg_num, const Expr* expr) override; - void PostVisit(const Expr* expr) override; + void MarkAccuInitExtracted() { accu_init_extracted_ = true; } private: + void PostVisitArgTrivial(cel::ComprehensionArg arg_num, + const cel::Expr* comprehension_expr); + + absl::Status PostVisitArgDefault(cel::ComprehensionArg arg_num, + const cel::Expr* comprehension_expr); + FlatExprVisitor* visitor_; + ComprehensionInitStep* init_step_; ComprehensionNextStep* next_step_; ComprehensionCondStep* cond_step_; - int next_step_pos_; - int cond_step_pos_; + ProgramStepIndex init_step_pos_; + ProgramStepIndex next_step_pos_; + ProgramStepIndex cond_step_pos_; bool short_circuiting_; + bool is_trivial_; + bool accu_init_extracted_; + size_t iter_slot_; + size_t iter2_slot_; + size_t accu_slot_; }; -class FlatExprVisitor : public AstVisitor { +absl::flat_hash_set MakeOptionalIndicesSet( + const cel::ListExpr& create_list_expr) { + absl::flat_hash_set optional_indices; + for (size_t i = 0; i < create_list_expr.elements().size(); ++i) { + if (create_list_expr.elements()[i].optional()) { + optional_indices.insert(static_cast(i)); + } + } + return optional_indices; +} + +absl::flat_hash_set MakeOptionalIndicesSet( + const cel::StructExpr& create_struct_expr) { + absl::flat_hash_set optional_indices; + for (size_t i = 0; i < create_struct_expr.fields().size(); ++i) { + if (create_struct_expr.fields()[i].optional()) { + optional_indices.insert(static_cast(i)); + } + } + return optional_indices; +} + +absl::flat_hash_set MakeOptionalIndicesSet( + const cel::MapExpr& map_expr) { + absl::flat_hash_set optional_indices; + for (size_t i = 0; i < map_expr.entries().size(); ++i) { + if (map_expr.entries()[i].optional()) { + optional_indices.insert(static_cast(i)); + } + } + return optional_indices; +} + +class FlatExprVisitor : public cel::AstVisitor { public: + enum class CallHandlerResult { + // The call was intercepted, no additional processing is needed. + kIntercepted, + // The call was not intercepted, continue with the default processing. + kNotIntercepted, + }; + + // Handler for functions with builtin implementations. + // This is used to replace the usual dispatcher step that applies + // the arguments to a candidate function from the function registry. + using CallHandler = absl::AnyInvocable; + FlatExprVisitor( - const CelFunctionRegistry* function_registry, ExecutionPath* path, - bool short_circuiting, - const std::set& enums, - absl::string_view container, - const absl::flat_hash_map& constant_idents, - bool enable_comprehension, BuilderWarnings* warnings, - std::set* iter_variable_names) - : flattened_path_(path), + const Resolver& resolver, const cel::RuntimeOptions& options, + std::vector> program_optimizers, + const absl::flat_hash_map& reference_map, + const cel::TypeProvider& type_provider, IssueCollector& issue_collector, + ProgramBuilder& program_builder, PlannerContext& extension_context, + bool enable_optional_types) + : resolver_(resolver), + type_provider_(type_provider), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), - function_registry_(function_registry), - short_circuiting_(short_circuiting), - constant_idents_(constant_idents), - enable_comprehension_(enable_comprehension), - builder_warnings_(warnings), - iter_variable_names_(iter_variable_names) { - GOOGLE_CHECK(iter_variable_names_); - - auto container_elements = absl::StrSplit(container, '.'); - - // Build list of prefixes from container. Non-empty prefixes must end with - // ".", otherwise prefix "abc.xy" will match "abc.xyz.EnumName". - std::string prefix = ""; - std::vector prefixes; - prefixes.push_back(prefix); - for (const auto& elem : container_elements) { - absl::StrAppend(&prefix, elem, "."); - prefixes.push_back(prefix); - } - - for (const auto& prefix : prefixes) { - for (auto enum_desc : enums) { - absl::string_view enum_name = enum_desc->full_name(); - if (!absl::StartsWith(enum_name, prefix)) { - continue; + options_(options), + program_optimizers_(std::move(program_optimizers)), + issue_collector_(issue_collector), + program_builder_(program_builder), + extension_context_(extension_context), + enable_optional_types_(enable_optional_types) { + constexpr size_t kCallHandlerSizeHint = 11; + call_handlers_.reserve(kCallHandlerSizeHint); + call_handlers_[cel::builtin::kIndex] = [this](const cel::Expr& expr, + const cel::CallExpr& call) { + return HandleIndex(expr, call); + }; + call_handlers_[kBlock] = [this](const cel::Expr& expr, + const cel::CallExpr& call) { + return HandleBlock(expr, call); + }; + call_handlers_[cel::builtin::kAdd] = [this](const cel::Expr& expr, + const cel::CallExpr& call) { + return HandleListAppend(expr, call); + }; + if (options_.enable_fast_builtins) { + call_handlers_[cel::builtin::kNotStrictlyFalse] = + [this](const cel::Expr& expr, const cel::CallExpr& call) { + return HandleNotStrictlyFalse(expr, call); + }; + call_handlers_[cel::builtin::kNotStrictlyFalseDeprecated] = + [this](const cel::Expr& expr, const cel::CallExpr& call) { + return HandleNotStrictlyFalse(expr, call); + }; + call_handlers_[cel::builtin::kNot] = [this](const cel::Expr& expr, + const cel::CallExpr& call) { + return HandleNot(expr, call); + }; + if (options_.enable_heterogeneous_equality) { + for (const auto& in_op : + {cel::builtin::kIn, cel::builtin::kInDeprecated, + cel::builtin::kInFunction}) { + call_handlers_[in_op] = [this](const cel::Expr& expr, + const cel::CallExpr& call) { + return HandleHeterogeneousEqualityIn(expr, call); + }; + } + // Try to detect if the environment is setup with a custom equality + // implementation. + if (resolver_ + .FindOverloads(cel::builtin::kEqual, + /*receiver_style=*/false, + {cel::Kind::kAny, cel::Kind::kAny}) + .empty()) { + call_handlers_[cel::builtin::kEqual] = + [this](const cel::Expr& expr, const cel::CallExpr& call) { + return HandleHeterogeneousEquality(expr, call, + /*inequality=*/false); + }; + call_handlers_[cel::builtin::kInequal] = + [this](const cel::Expr& expr, const cel::CallExpr& call) { + return HandleHeterogeneousEquality(expr, call, + /*inequality=*/true); + }; } + } + } + } - auto remainder = absl::StripPrefix(enum_name, prefix); - for (int i = 0; i < enum_desc->value_count(); i++) { - auto value_desc = enum_desc->value(i); - if (value_desc) { - // "prefixes" container is ascending-ordered. As such, we will be - // assigning enum reference to the deepest available. - // E.g. if both a.b.c.Name and a.b.Name are available, and - // we try to reference "Name" with the scope of "a.b.c", - // it will be resolved to "a.b.c.Name". - auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", - value_desc->name()); - enum_map_[key] = value_desc; - } + void SetMaxRecursionDepth(int max_recursion_depth) { + max_recursion_depth_ = max_recursion_depth; + } + + bool PlanRecursiveProgram() const { return max_recursion_depth_ > 0; } + + void PreVisitExpr(const cel::Expr& expr) override { + ValidateOrError(!absl::holds_alternative(expr.kind()), + "Invalid empty expression"); + if (!progress_status_.ok()) { + return; + } + if (resume_from_suppressed_branch_ == nullptr && + suppressed_branches_.find(&expr) != suppressed_branches_.end()) { + resume_from_suppressed_branch_ = &expr; + } + + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.in && block.bindings_set.contains(&expr)) { + block.current_binding = &expr; + } + } + + auto* subexpression = + program_builder_.EnterSubexpression(&expr, SizeHint(expr)); + if (subexpression == nullptr) { + progress_status_.Update( + absl::InternalError("same CEL expr visited twice")); + return; + } + + for (const std::unique_ptr& optimizer : + program_optimizers_) { + absl::Status status = optimizer->OnPreVisit(extension_context_, expr); + if (!status.ok()) { + SetProgressStatusIfError(status); + } + } + } + + void PostVisitExpr(const cel::Expr& expr) override { + if (!progress_status_.ok()) { + return; + } + if (&expr == resume_from_suppressed_branch_) { + resume_from_suppressed_branch_ = nullptr; + } + + for (const std::unique_ptr& optimizer : + program_optimizers_) { + absl::Status status = optimizer->OnPostVisit(extension_context_, expr); + if (!status.ok()) { + SetProgressStatusIfError(status); + return; + } + } + + auto* subexpression = program_builder_.current(); + if (subexpression != nullptr && options_.enable_recursive_tracing && + subexpression->IsRecursive()) { + auto program = subexpression->ExtractRecursiveProgram(); + subexpression->set_recursive_program( + std::make_unique(std::move(program.step)), program.depth); + } + + program_builder_.ExitSubexpression(&expr); + + if (!comprehension_stack_.empty() && + comprehension_stack_.back().is_optimizable_bind && + (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { + SetProgressStatusIfError( + MaybeExtractSubexpression(&expr, comprehension_stack_.back())); + } + + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.current_binding == &expr) { + int index = program_builder_.ExtractSubexpression(&expr); + if (index == -1) { + SetProgressStatusIfError( + absl::InvalidArgumentError("failed to extract subexpression")); + return; } + block.subexpressions[block.current_index++] = index; + block.current_binding = nullptr; } } } - void PostVisitConst(const Constant* const_expr, const Expr* expr, - const SourcePosition*) override { + void PostVisitConst(const cel::Expr& expr, + const cel::Constant& const_expr) override { if (!progress_status_.ok()) { return; } - auto value = ConvertConstant(const_expr); - if (value.has_value()) { - AddStep(CreateConstValueStep(value.value(), expr->id())); - } else { - SetProgressStatusError(absl::Status(absl::StatusCode::kInvalidArgument, - "Unsupported constant type")); + absl::StatusOr converted_value = + ConvertConstant(const_expr, cel::NewDeleteAllocator()); + + if (!converted_value.ok()) { + SetProgressStatusIfError(converted_value.status()); + return; + } + + if (options_.max_recursion_depth > 0 || options_.max_recursion_depth < 0) { + SetRecursiveStep(CreateConstValueDirectStep( + std::move(converted_value).value(), expr.id()), + 1); + return; + } + + AddStep( + CreateConstValueStep(std::move(converted_value).value(), expr.id())); + } + + struct SlotLookupResult { + int slot; + int subexpression; + }; + + // Helper to lookup a variable mapped to a slot. + // + // If lazy evaluation enabled and ided as a lazy expression, + // subexpression and slot will be set. + SlotLookupResult LookupSlot(absl::string_view path) { + // If there's a leading dot, it cannot resolve to a local variable. + if (absl::StartsWith(path, ".")) { + return {-1, -1}; + } + if (block_.has_value()) { + const BlockInfo& block = *block_; + if (block.in) { + absl::string_view index_suffix = path; + if (absl::ConsumePrefix(&index_suffix, "@index")) { + size_t index; + if (!absl::SimpleAtoi(index_suffix, &index)) { + SetProgressStatusIfError( + issue_collector_.AddIssue(RuntimeIssue::CreateError( + absl::InvalidArgumentError("bad @index")))); + return {-1, -1}; + } + if (index >= block.size) { + SetProgressStatusIfError( + issue_collector_.AddIssue(RuntimeIssue::CreateError( + absl::InvalidArgumentError(absl::StrCat( + "invalid @index greater than number of bindings: ", + index, " >= ", block.size))))); + return {-1, -1}; + } + if (index >= block.current_index) { + SetProgressStatusIfError( + issue_collector_.AddIssue(RuntimeIssue::CreateError( + absl::InvalidArgumentError(absl::StrCat( + "@index references current or future binding: ", index, + " >= ", block.current_index))))); + return {-1, -1}; + } + return {static_cast(block.index + index), + block.subexpressions[index]}; + } + } + } + if (!comprehension_stack_.empty()) { + for (int i = comprehension_stack_.size() - 1; i >= 0; i--) { + const ComprehensionStackRecord& record = comprehension_stack_[i]; + if (record.iter_var_in_scope && + record.comprehension->iter_var() == path) { + if (record.is_optimizable_bind) { + SetProgressStatusIfError(issue_collector_.AddIssue( + RuntimeIssue::CreateWarning(absl::InvalidArgumentError( + "Unexpected iter_var access in trivial comprehension")))); + return {-1, -1}; + } + return {static_cast(record.iter_slot), -1}; + } + if (record.iter_var2_in_scope && + record.comprehension->iter_var2() == path) { + return {static_cast(record.iter2_slot), -1}; + } + if (record.accu_var_in_scope && + record.comprehension->accu_var() == path) { + int slot = record.accu_slot; + int subexpression = -1; + if (record.is_optimizable_bind) { + subexpression = record.subexpression; + } + return {slot, subexpression}; + } + } } + if (absl::StartsWith(path, "@it:") || absl::StartsWith(path, "@it2:") || + absl::StartsWith(path, "@ac:")) { + // If we see a CSE generated comprehension variable that was not + // resolvable through the normal comprehension scope resolution, reject it + // now rather than surfacing errors at activation time. + SetProgressStatusIfError( + issue_collector_.AddIssue(RuntimeIssue::CreateError( + absl::InvalidArgumentError("out of scope reference to CSE " + "generated comprehension variable")))); + } + return {-1, -1}; } // Ident node handler. // Invoked after child nodes are processed. - void PostVisitIdent(const Ident* ident_expr, const Expr* expr, - const SourcePosition*) override { + void PostVisitIdent(const cel::Expr& expr, + const cel::IdentExpr& ident_expr) override { if (!progress_status_.ok()) { return; } + absl::string_view path = ident_expr.name(); + if (!ValidateOrError( + !path.empty(), + "Invalid expression: identifier 'name' must not be empty")) { + return; + } - std::string path(ident_expr->name()); + // Check if this is a local variable first (since it should shadow most + // other interpretations). + SlotLookupResult slot = LookupSlot(path); - // Automatically replace constant idents with the backing CEL values. - auto constant = constant_idents_.find(path); - if (constant != constant_idents_.end()) { - AddStep(CreateConstValueStep(constant->second, expr->id(), false)); + if (slot.subexpression >= 0) { + auto* subexpression = + program_builder_.GetExtractedSubexpression(slot.subexpression); + if (subexpression == nullptr) { + SetProgressStatusIfError( + absl::InternalError("bad subexpression reference")); + return; + } + if (subexpression->IsRecursive()) { + const auto& program = subexpression->recursive_program(); + SetRecursiveStep( + CreateDirectLazyInitStep(slot.slot, program.step.get(), expr.id()), + program.depth + 1); + } else { + // Off by one since mainline expression will be index 0. + AddStep( + CreateLazyInitStep(slot.slot, slot.subexpression + 1, expr.id())); + } + return; + } else if (slot.slot >= 0) { + if (options_.max_recursion_depth != 0) { + SetRecursiveStep( + CreateDirectSlotIdentStep(ident_expr.name(), slot.slot, expr.id()), + 1); + } else { + AddStep( + CreateIdentStepForSlot(ident_expr.name(), slot.slot, expr.id())); + } return; } - // Generate namespace map + // Attempt to resolve a select expression as a namespaced identifier for an + // enum or type constant value. + std::optional const_value; + int64_t select_root_id = -1; + std::string path_candidate; - const google::protobuf::EnumValueDescriptor* value_desc = nullptr; - - // Fill out namespace map for wrapping Select's while (!namespace_stack_.empty()) { - const auto& select_node = namespace_stack_.back(); + const auto& select_node = namespace_stack_.front(); // Generate path in format ".....". - absl::StrAppend(&path, ".", select_node.second); - namespace_map_[select_node.first] = path; - - // Attempt to match namespace - auto it = enum_map_.find(path); - if (it != enum_map_.end()) { - resolved_select_expr_ = select_node.first; - value_desc = it->second; + const cel::Expr* select_expr = select_node.first; + path_candidate = absl::StrCat(path, ".", select_node.second); + + // Attempt to find a constant enum or type value which matches the + // qualified path present in the expression. Whether the identifier + // can be resolved to a type instance depends on whether the option to + // 'enable_qualified_type_identifiers' is set to true. + const_value = resolver_.FindConstant(path_candidate, select_expr->id()); + if (const_value) { + resolved_select_expr_ = select_expr; + select_root_id = select_expr->id(); + path = path_candidate; + namespace_stack_.clear(); + break; } + namespace_stack_.pop_front(); + } - namespace_stack_.pop_back(); + if (!const_value) { + // Attempt to resolve a simple identifier as an enum or type constant + // value. + const_value = resolver_.FindConstant(path, expr.id()); + select_root_id = expr.id(); } - if (resolved_select_expr_) { - if (!resolved_select_expr_->has_select_expr()) { - progress_status_ = absl::InternalError("Unexpected Expr type"); + // TODO(issues/97): Need to add support for resolving packaged names at + // runtime if Parse-only. For checked, checker should have reported the + // expected interpretation. + if (const_value) { + // If the path starts with a dot, strip it. + absl::string_view name = absl::StripPrefix(path, "."); + if (options_.max_recursion_depth != 0) { + SetRecursiveStep( + CreateDirectShadowableValueStep( + name, std::move(const_value).value(), select_root_id), + 1); return; } - AddStep(CreateConstValueStep(value_desc, resolved_select_expr_->id())); + AddStep(CreateShadowableValueStep(name, std::move(const_value).value(), + select_root_id)); return; } - AddStep(CreateIdentStep(ident_expr, expr->id())); + + absl::string_view ident_name = absl::StripPrefix(ident_expr.name(), "."); + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectIdentStep(ident_name, expr.id()), 1); + } else { + AddStep(CreateIdentStep(ident_name, expr.id())); + } } - void PreVisitSelect(const Select* select_expr, const Expr* expr, - const SourcePosition*) override { + void PreVisitSelect(const cel::Expr& expr, + const cel::SelectExpr& select_expr) override { if (!progress_status_.ok()) { return; } + if (!ValidateOrError( + !select_expr.field().empty(), + "invalid expression: select 'field' must not be empty")) { + return; + } + if (!ValidateOrError( + select_expr.has_operand() && + select_expr.operand().kind_case() != + cel::ExprKindCase::kUnspecifiedExpr, + "invalid expression: select must specify an operand")) { + return; + } + // Not exactly the cleanest solution - we peek into child of // select_expr. // Chain of multiple SELECT ending with IDENT can represent namespaced // entity. - if (select_expr->operand().has_ident_expr() || - select_expr->operand().has_select_expr()) { - namespace_stack_.push_back({expr, select_expr->field()}); + if (!select_expr.test_only() && (select_expr.operand().has_ident_expr() || + select_expr.operand().has_select_expr())) { + // select expressions are pushed in reverse order: + // google.type.Expr is pushed as: + // - field: 'Expr' + // - field: 'type' + // - id: 'google' + // + // The search order though is as follows: + // - id: 'google.type.Expr' + // - id: 'google.type', field: 'Expr' + // - id: 'google', field: 'type', field: 'Expr' + for (size_t i = 0; i < namespace_stack_.size(); i++) { + auto ns = namespace_stack_[i]; + namespace_stack_[i] = { + ns.first, absl::StrCat(select_expr.field(), ".", ns.second)}; + } + namespace_stack_.push_back({&expr, select_expr.field()}); } else { namespace_stack_.clear(); } @@ -281,8 +945,8 @@ class FlatExprVisitor : public AstVisitor { // Select node handler. // Invoked after child nodes are processed. - void PostVisitSelect(const Select* select_expr, const Expr* expr, - const SourcePosition*) override { + void PostVisitSelect(const cel::Expr& expr, + const cel::SelectExpr& select_expr) override { if (!progress_status_.ok()) { return; } @@ -292,174 +956,796 @@ class FlatExprVisitor : public AstVisitor { // to resolved enum value has been already created, thus preceding chain // of selects is no longer relevant. if (resolved_select_expr_) { - if (expr == resolved_select_expr_) { + if (&expr == resolved_select_expr_) { resolved_select_expr_ = nullptr; } return; } - std::string select_path = ""; - - auto it = namespace_map_.find(expr); - if (it != namespace_map_.end()) { - select_path = it->second; + if (auto depth = RecursionEligible(); depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != 1) { + SetProgressStatusIfError(absl::InternalError( + "unexpected number of dependencies for select operation.")); + return; + } + StringValue field = cel::StringValue(select_expr.field()); + + SetRecursiveStep( + CreateDirectSelectStep(std::move(deps[0]), std::move(field), + select_expr.test_only(), expr.id(), + options_.enable_empty_wrapper_null_unboxing, + enable_optional_types_), + *depth + 1); + return; } - AddStep(CreateSelectStep(select_expr, expr->id(), select_path)); + AddStep(CreateSelectStep(select_expr, expr.id(), + options_.enable_empty_wrapper_null_unboxing, + enable_optional_types_)); } // Call node handler group. // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. - void PreVisitCall(const Call* call_expr, const Expr* expr, - const SourcePosition*) override { + void PreVisitCall(const cel::Expr& expr, + const cel::CallExpr& call_expr) override { if (!progress_status_.ok()) { return; } std::unique_ptr cond_visitor; - if (call_expr->function() == builtin::kAnd) { - cond_visitor = absl::make_unique( - this, /* cond_value= */ false, short_circuiting_); - } else if (call_expr->function() == builtin::kOr) { - cond_visitor = absl::make_unique( - this, /* cond_value= */ true, short_circuiting_); - } else if (call_expr->function() == builtin::kTernary) { - if (short_circuiting_) { - cond_visitor = absl::make_unique(this); + if (call_expr.function() == cel::builtin::kAnd) { + cond_visitor = std::make_unique( + this, BinaryCond::kAnd, options_.short_circuiting); + } else if (call_expr.function() == cel::builtin::kOr) { + cond_visitor = std::make_unique( + this, BinaryCond::kOr, options_.short_circuiting); + } else if (call_expr.function() == cel::builtin::kTernary) { + if (options_.short_circuiting) { + cond_visitor = std::make_unique(this); } else { - cond_visitor = absl::make_unique(this); + cond_visitor = std::make_unique(this); + } + } else if (enable_optional_types_ && + call_expr.function() == kOptionalOrFn && + call_expr.has_target() && call_expr.args().size() == 1) { + cond_visitor = std::make_unique( + this, BinaryCond::kOptionalOr, options_.short_circuiting); + } else if (enable_optional_types_ && + call_expr.function() == kOptionalOrValueFn && + call_expr.has_target() && call_expr.args().size() == 1) { + cond_visitor = std::make_unique( + this, BinaryCond::kOptionalOrValue, options_.short_circuiting); + } else if (IsBlock(&call_expr)) { + // cel.@block + if (block_.has_value()) { + // There can only be one for now. + SetProgressStatusIfError( + absl::InvalidArgumentError("multiple cel.@block are not allowed")); + return; + } + block_ = BlockInfo(); + BlockInfo& block = *block_; + block.in = true; + if (call_expr.args().empty()) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "malformed cel.@block: missing list of bound expressions")); + return; + } + if (call_expr.args().size() != 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "malformed cel.@block: missing bound expression")); + return; + } + if (!call_expr.args()[0].has_list_expr()) { + SetProgressStatusIfError( + absl::InvalidArgumentError("malformed cel.@block: first argument " + "is not a list of bound expressions")); + return; + } + const auto& list_expr = call_expr.args().front().list_expr(); + block.size = list_expr.elements().size(); + + block.bindings_set.reserve(block.size); + for (const auto& list_expr_element : list_expr.elements()) { + if (list_expr_element.optional()) { + SetProgressStatusIfError( + absl::InvalidArgumentError("malformed cel.@block: list of bound " + "expressions contains an optional")); + return; + } + block.bindings_set.insert(&list_expr_element.expr()); } + block.index = index_manager().ReserveSlots(block.size); + block.slot_count = block.size; + block.expr = &expr; + block.bindings = &call_expr.args()[0]; + block.bound = &call_expr.args()[1]; + block.subexpressions.resize(block.size, -1); } else { return; } if (cond_visitor) { - cond_visitor->PreVisit(expr); - cond_visitor_stack_.emplace(expr, std::move(cond_visitor)); + cond_visitor->PreVisit(&expr); + cond_visitor_stack_.push({&expr, std::move(cond_visitor)}); + } + } + + // Returns the maximum recursion depth of the current program if it is + // eligible for recursion, or nullopt if it is not. + std::optional RecursionEligible() { + if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { + return std::nullopt; } + return program_builder_.current()->RecursiveDependencyDepth(); + } + + std::vector> + ExtractRecursiveDependencies() { + // Must check recursion eligibility before calling. + ABSL_DCHECK(program_builder_.current() != nullptr); + + return program_builder_.current()->ExtractRecursiveDependencies(); + } + + void MakeTernaryRecursive(const cel::Expr* expr) { + if (expr->call_expr().args().size() != 3) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "unexpected number of args for builtin ternary")); + return; + } + + const cel::Expr* condition_expr = &expr->call_expr().args()[0]; + const cel::Expr* left_expr = &expr->call_expr().args()[1]; + const cel::Expr* right_expr = &expr->call_expr().args()[2]; + + auto* condition_plan = program_builder_.GetSubexpression(condition_expr); + auto* left_plan = program_builder_.GetSubexpression(left_expr); + auto* right_plan = program_builder_.GetSubexpression(right_expr); + + if (condition_plan == nullptr || !condition_plan->IsRecursive() || + left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + + int max_depth = std::max({0, condition_plan->recursive_program().depth, + left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); + + SetRecursiveStep( + CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, + left_plan->ExtractRecursiveProgram().step, + right_plan->ExtractRecursiveProgram().step, + expr->id(), options_.short_circuiting), + max_depth + 1); + } + + void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { + int args_size = expr->call_expr().args().size(); + if (args_size < 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "unexpected number of args for builtin boolean operator &&/||")); + return; + } + + auto* current_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[0]); + if (current_plan == nullptr || !current_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + int current_depth = current_plan->recursive_program().depth; + std::unique_ptr current_step = + current_plan->ExtractRecursiveProgram().step; + + for (int i = 1; i < args_size; ++i) { + auto* next_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[i]); + if (next_plan == nullptr || !next_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + current_depth = + std::max(current_depth, next_plan->recursive_program().depth); + std::unique_ptr next_step = + next_plan->ExtractRecursiveProgram().step; + if (is_or) { + current_step = + CreateDirectOrStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } else { + current_step = + CreateDirectAndStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } + current_depth++; + } + SetRecursiveStep(std::move(current_step), current_depth); + } + + void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { + if (!expr->call_expr().has_target() || + expr->call_expr().args().size() != 1) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "unexpected number of args for optional.or{Value}")); + return; + } + const cel::Expr* left_expr = &expr->call_expr().target(); + const cel::Expr* right_expr = &expr->call_expr().args()[0]; + + auto* left_plan = program_builder_.GetSubexpression(left_expr); + auto* right_plan = program_builder_.GetSubexpression(right_expr); + + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); + + SetRecursiveStep(CreateDirectOptionalOrStep( + expr->id(), left_plan->ExtractRecursiveProgram().step, + right_plan->ExtractRecursiveProgram().step, + is_or_value, options_.short_circuiting), + max_depth + 1); + } + + void MaybeMakeBindRecursive(const cel::Expr* expr, + const cel::ComprehensionExpr* comprehension, + size_t accu_slot) { + if (!PlanRecursiveProgram()) { + return; + } + + auto* result_plan = + program_builder_.GetSubexpression(&comprehension->result()); + + if (result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + + int result_depth = result_plan->recursive_program().depth; + + auto program = result_plan->ExtractRecursiveProgram(); + SetRecursiveStep( + CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), + result_depth + 1); + } + + void MaybeMakeComprehensionRecursive( + const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, + size_t iter_slot, size_t iter2_slot, size_t accu_slot) { + if (!PlanRecursiveProgram()) { + return; + } + + auto* accu_plan = + program_builder_.GetSubexpression(&comprehension->accu_init()); + auto* range_plan = + program_builder_.GetSubexpression(&comprehension->iter_range()); + auto* loop_plan = + program_builder_.GetSubexpression(&comprehension->loop_step()); + auto* condition_plan = + program_builder_.GetSubexpression(&comprehension->loop_condition()); + auto* result_plan = + program_builder_.GetSubexpression(&comprehension->result()); + if (accu_plan == nullptr || !accu_plan->IsRecursive() || + range_plan == nullptr || !range_plan->IsRecursive() || + loop_plan == nullptr || !loop_plan->IsRecursive() || + condition_plan == nullptr || !condition_plan->IsRecursive() || + result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + + int max_depth = 0; + max_depth = std::max(max_depth, accu_plan->recursive_program().depth); + max_depth = std::max(max_depth, range_plan->recursive_program().depth); + max_depth = std::max(max_depth, loop_plan->recursive_program().depth); + max_depth = std::max(max_depth, condition_plan->recursive_program().depth); + max_depth = std::max(max_depth, result_plan->recursive_program().depth); + + auto step = CreateDirectComprehensionStep( + iter_slot, iter2_slot, accu_slot, + range_plan->ExtractRecursiveProgram().step, + accu_plan->ExtractRecursiveProgram().step, + loop_plan->ExtractRecursiveProgram().step, + condition_plan->ExtractRecursiveProgram().step, + result_plan->ExtractRecursiveProgram().step, options_.short_circuiting, + expr->id()); + + SetRecursiveStep(std::move(step), max_depth + 1); } // Invoked after all child nodes are processed. - void PostVisitCall(const Call* call_expr, const Expr* expr, - const SourcePosition*) override { + void PostVisitCall(const cel::Expr& expr, + const cel::CallExpr& call_expr) override { if (!progress_status_.ok()) { return; } - auto cond_visitor = FindCondVisitor(expr); + auto cond_visitor = FindCondVisitor(&expr); if (cond_visitor) { - cond_visitor->PostVisit(expr); + cond_visitor->PostVisit(&expr); cond_visitor_stack_.pop(); - } else { - // Special case for "_[_]". - if (call_expr->function() == builtin::kIndex) { - AddStep(CreateContainerAccessStep(call_expr, expr->id())); + return; + } + + // Check if the call is intercepted by a custom handler. + if (auto handler = call_handlers_.find(call_expr.function()); + handler != call_handlers_.end()) { + CallHandlerResult result = handler->second(expr, call_expr); + if (result == CallHandlerResult::kIntercepted) { return; - } - // For regular functions, just create one based on registry. - AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_, - builder_warnings_)); + } // otherwise, apply default function handling. } + + AddResolvedFunctionStep(&call_expr, &expr, call_expr.function()); } - void PreVisitComprehension(const Comprehension*, const Expr* expr, - const SourcePosition*) override { + void PreVisitComprehension( + const cel::Expr& expr, + const cel::ComprehensionExpr& comprehension) override { if (!progress_status_.ok()) { return; } - if (!enable_comprehension_) { - SetProgressStatusError(absl::Status(absl::StatusCode::kInvalidArgument, - "Comprehension support is disabled")); + if (!ValidateOrError(options_.enable_comprehension, + "Comprehension support is disabled")) { + return; + } + const auto& accu_var = comprehension.accu_var(); + const auto& iter_var = comprehension.iter_var(); + const auto& iter_var2 = comprehension.iter_var2(); + ValidateOrError(!accu_var.empty(), + "Invalid comprehension: 'accu_var' must not be empty"); + ValidateOrError(!iter_var.empty(), + "Invalid comprehension: 'iter_var' must not be empty"); + ValidateOrError( + accu_var != iter_var, + "Invalid comprehension: 'accu_var' must not be the same as 'iter_var'"); + ValidateOrError(accu_var != iter_var2, + "Invalid comprehension: 'accu_var' must not be the same as " + "'iter_var2'"); + ValidateOrError(iter_var2 != iter_var, + "Invalid comprehension: 'iter_var2' must not be the same " + "as 'iter_var'"); + ValidateOrError(comprehension.has_accu_init(), + "Invalid comprehension: 'accu_init' must be set"); + ValidateOrError(comprehension.has_loop_condition(), + "Invalid comprehension: 'loop_condition' must be set"); + ValidateOrError(comprehension.has_loop_step(), + "Invalid comprehension: 'loop_step' must be set"); + ValidateOrError(comprehension.has_result(), + "Invalid comprehension: 'result' must be set"); + + size_t iter_slot, iter2_slot, accu_slot, slot_count; + bool is_bind = IsBind(&comprehension); + + if (is_bind) { + accu_slot = iter_slot = iter2_slot = index_manager_.ReserveSlots(1); + slot_count = 1; + } else if (comprehension.iter_var2().empty()) { + iter_slot = iter2_slot = index_manager_.ReserveSlots(2); + accu_slot = iter_slot + 1; + slot_count = 2; + } else { + iter_slot = index_manager_.ReserveSlots(3); + iter2_slot = iter_slot + 1; + accu_slot = iter2_slot + 1; + slot_count = 3; + } + + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.in) { + block.slot_count += slot_count; + slot_count = 0; + } + } + // If this is in the scope of an optimized bind accu-init, account the slots + // to the outermost bind-init scope. + // + // The init expression is effectively inlined at the first usage in the + // critical path (which is unknown at plan time), so the used slots need to + // be dedicated for the entire scope of that bind. + for (ComprehensionStackRecord& record : comprehension_stack_) { + if (record.in_accu_init && record.is_optimizable_bind) { + record.slot_count += slot_count; + slot_count = 0; + break; + } + // If no bind init subexpression, account normally. } - cond_visitor_stack_.emplace( - expr, absl::make_unique(this, short_circuiting_)); - auto cond_visitor = FindCondVisitor(expr); - cond_visitor->PreVisit(expr); + + comprehension_stack_.push_back( + {&expr, &comprehension, iter_slot, iter2_slot, accu_slot, slot_count, + /*subexpression=*/-1, + /*.is_optimizable_list_append=*/ + IsOptimizableListAppend(&comprehension, + options_.enable_comprehension_list_append), + /*.is_optimizable_map_insert=*/ + IsOptimizableMapInsert(&comprehension, + options_.enable_comprehension_mutable_map), + /*.is_optimizable_bind=*/is_bind, + /*.iter_var_in_scope=*/false, + /*.iter_var2_in_scope=*/false, + /*.accu_var_in_scope=*/false, + /*.in_accu_init=*/false, + std::make_unique(this, options_.short_circuiting, + is_bind, iter_slot, iter2_slot, + accu_slot)}); + comprehension_stack_.back().visitor->PreVisit(&expr); } // Invoked after all child nodes are processed. - void PostVisitComprehension(const Comprehension* comprehension_expr, - const Expr* expr, - const SourcePosition*) override { + void PostVisitComprehension( + const cel::Expr& expr, + const cel::ComprehensionExpr& comprehension_expr) override { if (!progress_status_.ok()) { return; } - auto cond_visitor = FindCondVisitor(expr); - cond_visitor->PostVisit(expr); - cond_visitor_stack_.pop(); - // Save off the names of the variables we're using, such that we have a - // full set of the names from the entire evaluation tree at the end. - if (!comprehension_expr->accu_var().empty()) { - iter_variable_names_->insert(comprehension_expr->accu_var()); + ComprehensionStackRecord& record = comprehension_stack_.back(); + if (comprehension_stack_.empty() || + record.comprehension != &comprehension_expr) { + return; } - if (!comprehension_expr->iter_var().empty()) { - iter_variable_names_->insert(comprehension_expr->iter_var()); + + record.visitor->PostVisit(&expr); + + index_manager_.ReleaseSlots(record.slot_count); + comprehension_stack_.pop_back(); + } + + void PreVisitComprehensionSubexpression( + const cel::Expr& expr, const cel::ComprehensionExpr& compr, + cel::ComprehensionArg comprehension_arg) override { + if (!progress_status_.ok()) { + return; + } + + if (comprehension_stack_.empty() || + comprehension_stack_.back().comprehension != &compr) { + return; + } + + ComprehensionStackRecord& record = comprehension_stack_.back(); + + switch (comprehension_arg) { + case cel::ITER_RANGE: { + record.in_accu_init = false; + record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; + record.accu_var_in_scope = false; + break; + } + case cel::ACCU_INIT: { + record.in_accu_init = true; + record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; + record.accu_var_in_scope = false; + break; + } + case cel::LOOP_CONDITION: { + record.in_accu_init = false; + record.iter_var_in_scope = true; + record.iter_var2_in_scope = true; + record.accu_var_in_scope = true; + break; + } + case cel::LOOP_STEP: { + record.in_accu_init = false; + record.iter_var_in_scope = true; + record.iter_var2_in_scope = true; + record.accu_var_in_scope = true; + break; + } + case cel::RESULT: { + record.in_accu_init = false; + record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; + record.accu_var_in_scope = true; + break; + } + } + } + + void PostVisitComprehensionSubexpression( + const cel::Expr& expr, const cel::ComprehensionExpr& compr, + cel::ComprehensionArg comprehension_arg) override { + if (!progress_status_.ok()) { + return; + } + + if (comprehension_stack_.empty() || + comprehension_stack_.back().comprehension != &compr) { + return; } + + SetProgressStatusIfError(comprehension_stack_.back().visitor->PostVisitArg( + comprehension_arg, comprehension_stack_.back().expr)); } // Invoked after each argument node processed. - void PostVisitArg(int arg_num, const Expr* expr, - const SourcePosition*) override { + void PostVisitArg(const cel::Expr& expr, int arg_num) override { + if (!progress_status_.ok()) { + return; + } + auto cond_visitor = FindCondVisitor(&expr); + if (cond_visitor) { + cond_visitor->PostVisitArg(arg_num, &expr); + } + } + + void PostVisitTarget(const cel::Expr& expr) override { if (!progress_status_.ok()) { return; } - auto cond_visitor = FindCondVisitor(expr); + auto cond_visitor = FindCondVisitor(&expr); if (cond_visitor) { - cond_visitor->PostVisitArg(arg_num, expr); + cond_visitor->PostVisitTarget(&expr); } } // CreateList node handler. // Invoked after child nodes are processed. - void PostVisitCreateList(const CreateList* list_expr, const Expr* expr, - const SourcePosition*) override { + void PostVisitList(const cel::Expr& expr, + const cel::ListExpr& list_expr) override { if (!progress_status_.ok()) { return; } - AddStep(CreateCreateListStep(list_expr, expr->id())); + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.bindings == &expr) { + // Do nothing, this is the cel.@block bindings list. + return; + } + } + + if (!comprehension_stack_.empty()) { + const ComprehensionStackRecord& comprehension = + comprehension_stack_.back(); + if (comprehension.is_optimizable_list_append) { + if (&(comprehension.comprehension->accu_init()) == &expr) { + if (PlanRecursiveProgram()) { + SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); + return; + } + AddStep(CreateMutableListStep(expr.id())); + return; + } + if (GetOptimizableListAppendOperand(comprehension.comprehension) == + &expr) { + return; + } + } + } + if (std::optional depth = RecursionEligible(); depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != list_expr.elements().size()) { + SetProgressStatusIfError(absl::InternalError( + "Unexpected number of plan elements for CreateList expr")); + return; + } + auto step = CreateDirectListStep( + std::move(deps), MakeOptionalIndicesSet(list_expr), expr.id()); + SetRecursiveStep(std::move(step), *depth + 1); + return; + } + AddStep(CreateCreateListStep(list_expr, expr.id())); } // CreateStruct node handler. // Invoked after child nodes are processed. - void PostVisitCreateStruct(const CreateStruct* struct_expr, const Expr* expr, - const SourcePosition*) override { + void PostVisitStruct(const cel::Expr& expr, + const cel::StructExpr& struct_expr) override { if (!progress_status_.ok()) { return; } - AddStep(CreateCreateStructStep(struct_expr, expr->id())); + auto status_or_resolved_fields = + ResolveCreateStructFields(struct_expr, expr.id()); + if (!status_or_resolved_fields.ok()) { + SetProgressStatusIfError(status_or_resolved_fields.status()); + return; + } + std::string resolved_name = + std::move(status_or_resolved_fields.value().first); + std::vector fields = + std::move(status_or_resolved_fields.value().second); + + if (auto depth = RecursionEligible(); depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != struct_expr.fields().size()) { + SetProgressStatusIfError(absl::InternalError( + "Unexpected number of plan elements for CreateStruct expr")); + return; + } + auto step = CreateDirectCreateStructStep( + std::move(resolved_name), std::move(fields), std::move(deps), + MakeOptionalIndicesSet(struct_expr), expr.id()); + SetRecursiveStep(std::move(step), *depth + 1); + return; + } + + AddStep(CreateCreateStructStep(std::move(resolved_name), std::move(fields), + MakeOptionalIndicesSet(struct_expr), + expr.id())); + } + + void PostVisitMap(const cel::Expr& expr, + const cel::MapExpr& map_expr) override { + for (const auto& entry : map_expr.entries()) { + ValidateOrError(entry.has_key(), "Map entry missing key"); + ValidateOrError(entry.has_value(), "Map entry missing value"); + } + + if (!comprehension_stack_.empty()) { + const ComprehensionStackRecord& comprehension = + comprehension_stack_.back(); + if (comprehension.is_optimizable_map_insert) { + if (&(comprehension.comprehension->accu_init()) == &expr) { + if (PlanRecursiveProgram()) { + SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); + return; + } + AddStep(CreateMutableMapStep(expr.id())); + return; + } + } + } + + if (auto depth = RecursionEligible(); depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != 2 * map_expr.entries().size()) { + SetProgressStatusIfError(absl::InternalError( + "Unexpected number of plan elements for CreateStruct expr")); + return; + } + auto step = CreateDirectCreateMapStep( + std::move(deps), MakeOptionalIndicesSet(map_expr), expr.id()); + SetRecursiveStep(std::move(step), *depth + 1); + return; + } + AddStep(CreateCreateStructStepForMap(map_expr.entries().size(), + MakeOptionalIndicesSet(map_expr), + expr.id())); } absl::Status progress_status() const { return progress_status_; } - void AddStep(absl::StatusOr> step_status) { - if (step_status.ok() && progress_status_.ok()) { - flattened_path_->push_back(std::move(step_status.value())); + // Mark a branch as suppressed. The visitor will continue as normal, but + // any emitted program steps are ignored. + // + // Only applies to branches that have not yet been visited (pre-order). + void SuppressBranch(const cel::Expr* expr) { + suppressed_branches_.insert(expr); + } + + void AddResolvedFunctionStep(const cel::CallExpr* call_expr, + const cel::Expr* expr, + absl::string_view function) { + // Establish the search criteria for a given function. + bool receiver_style = call_expr->has_target(); + size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0); + + // First, search for lazily defined function overloads. + // Lazy functions shadow eager functions with the same signature. + auto lazy_overloads = resolver_.FindLazyOverloads( + function, call_expr->has_target(), num_args, expr->id()); + if (!lazy_overloads.empty()) { + if (auto depth = RecursionEligible(); depth.has_value()) { + auto args = program_builder_.current()->ExtractRecursiveDependencies(); + SetRecursiveStep(CreateDirectLazyFunctionStep( + expr->id(), *call_expr, std::move(args), + std::move(lazy_overloads)), + *depth + 1); + return; + } + AddStep(CreateFunctionStep(*call_expr, expr->id(), + std::move(lazy_overloads))); + return; + } + + // Second, search for eagerly defined function overloads. + auto overloads = + resolver_.FindOverloads(function, receiver_style, num_args, expr->id()); + if (overloads.empty()) { + // Create a warning that the overload could not be found. Depending on the + // builder_warnings configuration, this could result in termination of the + // CelExpression creation or an inspectable warning for use within runtime + // logging. + auto status = issue_collector_.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError( + "No overloads provided for FunctionStep creation"), + RuntimeIssue::ErrorCode::kNoMatchingOverload)); + if (!status.ok()) { + SetProgressStatusIfError(status); + return; + } + } + + if (auto recursion_depth = RecursionEligible(); + recursion_depth.has_value()) { + // Nonnull while active -- nullptr indicates logic error elsewhere in the + // builder. + ABSL_DCHECK(program_builder_.current() != nullptr); + auto args = program_builder_.current()->ExtractRecursiveDependencies(); + SetRecursiveStep( + CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args), + std::move(overloads)), + *recursion_depth + 1); + return; + } + AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads))); + } + + // Add a step to the program, taking ownership. If successful, returns the + // pointer to the step. Otherwise, returns nullptr. + // + // Note: the pointer is only guaranteed to stay valid until the parent + // subexpression is finalized. Optimizers may modify the program plan which + // may free the step at that point. + ExpressionStep* AddStep( + absl::StatusOr> step) { + if (step.ok()) { + return AddStep(*std::move(step)); } else { - SetProgressStatusError(step_status.status()); + SetProgressStatusIfError(step.status()); } + return nullptr; } - void AddStep(std::unique_ptr step) { - if (progress_status_.ok()) { - flattened_path_->push_back(std::move(step)); + template + std::enable_if_t, T*> AddStep( + std::unique_ptr step) { + if (progress_status_.ok() && !PlanningSuppressed()) { + return static_cast(program_builder_.AddStep(std::move(step))); + } + return nullptr; + } + + void SetRecursiveStep(std::unique_ptr step, int depth) { + if (!progress_status_.ok() || PlanningSuppressed()) { + return; + } + if (program_builder_.current() == nullptr) { + SetProgressStatusIfError(absl::InternalError( + "CEL AST traversal out of order in flat_expr_builder.")); + return; + } + program_builder_.current()->set_recursive_program(std::move(step), depth); + if (depth > max_recursion_depth_) { + SetProgressStatusIfError(absl::InvalidArgumentError( + absl::StrCat("Maximum recursion depth of ", + options_.max_recursion_depth, " exceeded"))); } } - void SetProgressStatusError(const absl::Status& status) { + void SetProgressStatusIfError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } } - // Index of the next step to be inserted. - int GetCurrentIndex() const { return flattened_path_->size(); } + // Index of the next step to be inserted, in terms of the current + // subexpression + ProgramStepIndex GetCurrentIndex() const { + // Nonnull while active -- nullptr indicates logic error in the builder. + ABSL_DCHECK(program_builder_.current() != nullptr); + return {static_cast(program_builder_.current()->elements().size()), + program_builder_.current()}; + } - CondVisitor* FindCondVisitor(const Expr* expr) const { + CondVisitor* FindCondVisitor(const cel::Expr* expr) const { if (cond_visitor_stack_.empty()) { return nullptr; } @@ -469,75 +1755,547 @@ class FlatExprVisitor : public AstVisitor { return (latest.first == expr) ? latest.second.get() : nullptr; } + IndexManager& index_manager() { return index_manager_; } + + size_t slot_count() const { return index_manager_.max_slot_count(); } + + void AddOptimizer(std::unique_ptr optimizer) { + program_optimizers_.push_back(std::move(optimizer)); + } + + // Tests the boolean predicate, and if false produces an InvalidArgumentError + // which concatenates the error_message and any optional message_parts as the + // error status message. + template + bool ValidateOrError(bool valid_expression, absl::string_view error_message, + MP... message_parts) { + if (valid_expression) { + return true; + } + SetProgressStatusIfError(absl::InvalidArgumentError( + absl::StrCat(error_message, message_parts...))); + return false; + } + private: - ExecutionPath* flattened_path_; + struct ComprehensionStackRecord { + const cel::Expr* expr; + const cel::ComprehensionExpr* comprehension; + size_t iter_slot; + size_t iter2_slot; + size_t accu_slot; + size_t slot_count; + // -1 indicates this shouldn't be used. + int subexpression; + bool is_optimizable_list_append; + bool is_optimizable_map_insert; + bool is_optimizable_bind; + bool iter_var_in_scope; + bool iter_var2_in_scope; + bool accu_var_in_scope; + bool in_accu_init; + std::unique_ptr visitor; + }; + + struct BlockInfo { + // True if we are currently visiting the `cel.@block` node or any of its + // children. + bool in = false; + // Pointer to the `cel.@block` node. + const cel::Expr* expr = nullptr; + // Pointer to the `cel.@block` bindings, that is the first argument to the + // function. + const cel::Expr* bindings = nullptr; + // Set of pointers to the elements of `bindings` above. + absl::flat_hash_set bindings_set; + // Pointer to the `cel.@block` bound expression, that is the second argument + // to the function. + const cel::Expr* bound = nullptr; + // The number of entries in the `cel.@block`. + size_t size = 0; + // Starting slot index for `cel.@block`. We occupy he slot indices `index` + // through `index + size + (var_size * 2)`. + size_t index = 0; + // The total number of slots needed for evaluating the bound expressions. + size_t slot_count = 0; + // The current slot index we are processing, any index references must be + // less than this to be valid. + size_t current_index = 0; + // Pointer to the current `cel.@block` being processed, that is one of the + // elements within the first argument. + const cel::Expr* current_binding = nullptr; + // Mapping between block indices and their subexpressions, fixed size with + // exactly `size` elements. Unprocessed indices are set to `-1`. + std::vector subexpressions; + }; + + bool PlanningSuppressed() const { + return resume_from_suppressed_branch_ != nullptr; + } + + absl::Status MaybeExtractSubexpression(const cel::Expr* expr, + ComprehensionStackRecord& record) { + if (!record.is_optimizable_bind) { + return absl::OkStatus(); + } + + int index = program_builder_.ExtractSubexpression(expr); + if (index == -1) { + return absl::InternalError("Failed to extract subexpression"); + } + + record.subexpression = index; + + record.visitor->MarkAccuInitExtracted(); + + return absl::OkStatus(); + } + + // Resolve the name of the message type being created and the names of set + // fields. + absl::StatusOr>> + ResolveCreateStructFields(const cel::StructExpr& create_struct_expr, + int64_t expr_id) { + absl::string_view ast_name = create_struct_expr.name(); + + std::optional> type; + CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); + + if (!type.has_value()) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid struct creation: missing type info for '", ast_name, "'")); + } + + std::string resolved_name = std::move(type).value().first; + + std::vector fields; + fields.reserve(create_struct_expr.fields().size()); + for (const auto& entry : create_struct_expr.fields()) { + if (entry.name().empty()) { + return absl::InvalidArgumentError("Struct field missing name"); + } + if (!entry.has_value()) { + return absl::InvalidArgumentError("Struct field missing value"); + } + CEL_ASSIGN_OR_RETURN(auto field, type_provider_.FindStructTypeFieldByName( + resolved_name, entry.name())); + if (!field.has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid message creation: field '", entry.name(), + "' not found in '", resolved_name, "'")); + } + fields.push_back(entry.name()); + } + + return std::make_pair(std::move(resolved_name), std::move(fields)); + } + + CallHandlerResult HandleIndex(const cel::Expr& expr, + const cel::CallExpr& call); + CallHandlerResult HandleBlock(const cel::Expr& expr, + const cel::CallExpr& call); + CallHandlerResult HandleListAppend(const cel::Expr& expr, + const cel::CallExpr& call); + CallHandlerResult HandleNot(const cel::Expr& expr, const cel::CallExpr& call); + CallHandlerResult HandleNotStrictlyFalse(const cel::Expr& expr, + const cel::CallExpr& call); + + CallHandlerResult HandleHeterogeneousEquality(const cel::Expr& expr, + const cel::CallExpr& call, + bool inequality); + + CallHandlerResult HandleHeterogeneousEqualityIn(const cel::Expr& expr, + const cel::CallExpr& call); + + const Resolver& resolver_; + const cel::TypeProvider& type_provider_; absl::Status progress_status_; + absl::flat_hash_map call_handlers_; - std::stack>> + std::stack>> cond_visitor_stack_; - // Maps effective namespace names to Expr objects (IDENTs/SELECTs) that - // define scopes for those namespaces. - std::unordered_map namespace_map_; // Tracks SELECT-...SELECT-IDENT chains. - std::deque> namespace_stack_; + std::deque> namespace_stack_; // When multiple SELECT-...SELECT-IDENT chain is resolved as namespace, this // field is used as marker suppressing CelExpression creation for SELECTs. - const Expr* resolved_select_expr_; + const cel::Expr* resolved_select_expr_; - // Fully resolved enum value names. - absl::node_hash_map - enum_map_; + const cel::RuntimeOptions& options_; - const CelFunctionRegistry* function_registry_; + std::vector comprehension_stack_; + absl::flat_hash_set suppressed_branches_; + const cel::Expr* resume_from_suppressed_branch_ = nullptr; + std::vector> program_optimizers_; + IssueCollector& issue_collector_; - bool short_circuiting_; + ProgramBuilder& program_builder_; + PlannerContext& extension_context_; + IndexManager index_manager_; - const absl::flat_hash_map& constant_idents_; + bool enable_optional_types_; + std::optional block_; + int max_recursion_depth_ = 0; +}; - bool enable_comprehension_; +FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( + const cel::Expr& expr, const cel::CallExpr& call_expr) { + ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); + if (!ValidateOrError( + (call_expr.args().size() == 2 && !call_expr.has_target()) || + // TODO(uncreated-issue/79): A few clients use the index operator with a + // target in custom ASTs. + (call_expr.args().size() == 1 && call_expr.has_target()), + "unexpected number of args for builtin index operator")) { + return CallHandlerResult::kIntercepted; + } - BuilderWarnings* builder_warnings_; + if (auto depth = RecursionEligible(); depth.has_value()) { + auto args = ExtractRecursiveDependencies(); + if (args.size() != 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "unexpected number of args for builtin index operator")); + return CallHandlerResult::kIntercepted; + } + SetRecursiveStep( + CreateDirectContainerAccessStep(std::move(args[0]), std::move(args[1]), + enable_optional_types_, expr.id()), + *depth + 1); + return CallHandlerResult::kIntercepted; + } + AddStep( + CreateContainerAccessStep(call_expr, expr.id(), enable_optional_types_)); + return CallHandlerResult::kIntercepted; +} - std::set* iter_variable_names_; -}; +FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( + const cel::Expr& expr, const cel::CallExpr& call_expr) { + ABSL_DCHECK(call_expr.function() == cel::builtin::kNot); -void BinaryCondVisitor::PreVisit(const Expr* expr) { - if (expr->call_expr().args().size() != 2) { - visitor_->SetProgressStatusError(absl::InvalidArgumentError( - "Unexpected number of arguments in a binary function call.")); + if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), + "unexpected number of args for builtin not operator")) { + return CallHandlerResult::kIntercepted; } + + if (auto depth = RecursionEligible(); depth.has_value()) { + auto args = ExtractRecursiveDependencies(); + if (args.size() != 1) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "unexpected number of args for builtin not operator")); + return CallHandlerResult::kIntercepted; + } + SetRecursiveStep(CreateDirectNotStep(std::move(args[0]), expr.id()), + *depth + 1); + return CallHandlerResult::kIntercepted; + } + AddStep(CreateNotStep(expr.id())); + return CallHandlerResult::kIntercepted; } -void BinaryCondVisitor::PostVisitArg(int arg_num, const Expr* expr) { - if (!short_circuiting_) { - // nothing to do. +FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( + const cel::Expr& expr, const cel::CallExpr& call_expr) { + if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), + "unexpected number of args for builtin " + "not_strictly_false operator")) { + return CallHandlerResult::kIntercepted; + } + + if (auto depth = RecursionEligible(); depth.has_value()) { + auto args = ExtractRecursiveDependencies(); + if (args.size() != 1) { + SetProgressStatusIfError( + absl::InvalidArgumentError("unexpected number of args for builtin " + "@not_strictly_false operator")); + return CallHandlerResult::kIntercepted; + } + SetRecursiveStep( + CreateDirectNotStrictlyFalseStep(std::move(args[0]), expr.id()), + *depth + 1); + return CallHandlerResult::kIntercepted; + } + AddStep(CreateNotStrictlyFalseStep(expr.id())); + return CallHandlerResult::kIntercepted; +} + +FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( + const cel::Expr& expr, const cel::CallExpr& call_expr) { + ABSL_DCHECK(call_expr.function() == kBlock); + if (!block_.has_value() || block_->expr != &expr || + call_expr.args().size() != 2 || call_expr.has_target()) { + SetProgressStatusIfError( + absl::InvalidArgumentError("unexpected call to internal cel.@block")); + return CallHandlerResult::kIntercepted; + } + + BlockInfo& block = *block_; + block.in = false; + index_manager().ReleaseSlots(block.slot_count); + + // Check if eligible for recursion and update the plan if so. + // + // The first argument to @block is the list of initializers. These don't + // generate a plan in the main program (they are tracked separately to support + // lazy evaluation) so we only need to extract the second argument -- the body + // of the block that uses the initializers. + ProgramBuilder::Subexpression* body_subexpression = + program_builder_.GetSubexpression(&call_expr.args()[1]); + + if (options_.max_recursion_depth != 0 && body_subexpression != nullptr && + body_subexpression->IsRecursive() && + (options_.max_recursion_depth < 0 || + body_subexpression->recursive_program().depth < + options_.max_recursion_depth)) { + auto recursive_program = body_subexpression->ExtractRecursiveProgram(); + SetRecursiveStep( + CreateDirectBlockStep(block.index, block.slot_count, + std::move(recursive_program.step), expr.id()), + recursive_program.depth + 1); + return CallHandlerResult::kIntercepted; + } + + // Otherwise, iterative plan. + if (block.slot_count > 0) { + AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); + } + + return CallHandlerResult::kIntercepted; +} + +FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend( + const cel::Expr& expr, const cel::CallExpr& call_expr) { + ABSL_DCHECK(call_expr.function() == cel::builtin::kAdd); + + // Check to see if this is a special case of add that should really be + // treated as a list append + if (!comprehension_stack_.empty() && + comprehension_stack_.back().is_optimizable_list_append) { + // Already checked that this is an optimizeable comprehension, + // check that this is the correct list append node. + const cel::ComprehensionExpr* comprehension = + comprehension_stack_.back().comprehension; + const cel::Expr& loop_step = comprehension->loop_step(); + // Macro loop_step for a map() will contain a list concat operation: + // accu_var + [elem] + if (&loop_step == &expr) { + AddResolvedFunctionStep(&call_expr, &expr, + cel::builtin::kRuntimeListAppend); + return CallHandlerResult::kIntercepted; + } + // Macro loop_step for a filter() will contain a ternary: + // filter ? accu_var + [elem] : accu_var + if (loop_step.has_call_expr() && + loop_step.call_expr().function() == cel::builtin::kTernary && + loop_step.call_expr().args().size() == 3 && + &(loop_step.call_expr().args()[1]) == &expr) { + AddResolvedFunctionStep(&call_expr, &expr, + cel::builtin::kRuntimeListAppend); + return CallHandlerResult::kIntercepted; + } + } + + return CallHandlerResult::kNotIntercepted; +} + +FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( + const cel::Expr& expr, const cel::CallExpr& call, bool inequality) { + if (!ValidateOrError( + call.args().size() == 2 && !call.has_target(), + "unexpected number of args for builtin equality operator")) { + return CallHandlerResult::kIntercepted; + } + + if (auto depth = RecursionEligible(); depth.has_value()) { + auto args = ExtractRecursiveDependencies(); + if (args.size() != 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "unexpected number of args for builtin equality operator")); + return CallHandlerResult::kIntercepted; + } + SetRecursiveStep( + CreateDirectEqualityStep(std::move(args[0]), std::move(args[1]), + inequality, expr.id()), + *depth + 1); + return CallHandlerResult::kIntercepted; + } + AddStep(CreateEqualityStep(inequality, expr.id())); + return CallHandlerResult::kIntercepted; +} + +FlatExprVisitor::CallHandlerResult +FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, + const cel::CallExpr& call) { + if (!ValidateOrError(call.args().size() == 2 && !call.has_target(), + "unexpected number of args for builtin 'in' operator")) { + return CallHandlerResult::kIntercepted; + } + + if (auto depth = RecursionEligible(); depth.has_value()) { + auto args = ExtractRecursiveDependencies(); + if (args.size() != 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( + "unexpected number of args for builtin 'in' operator")); + return CallHandlerResult::kIntercepted; + } + SetRecursiveStep( + CreateDirectInStep(std::move(args[0]), std::move(args[1]), expr.id()), + *depth + 1); + return CallHandlerResult::kIntercepted; + } + + AddStep(CreateInStep(expr.id())); + return CallHandlerResult::kIntercepted; +} + +void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { + switch (cond_) { + case BinaryCond::kAnd: + ABSL_FALLTHROUGH_INTENDED; + case BinaryCond::kOr: + visitor_->ValidateOrError( + !expr->call_expr().has_target() && + expr->call_expr().args().size() >= 2, + "Invalid argument count for a binary function call."); + break; + case BinaryCond::kOptionalOr: + ABSL_FALLTHROUGH_INTENDED; + case BinaryCond::kOptionalOrValue: + visitor_->ValidateOrError(expr->call_expr().has_target() && + expr->call_expr().args().size() == 1, + "Invalid argument count for or/orValue call."); + break; + } +} + +void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { return; } - if (arg_num == 0) { + const int last_arg_index = expr->call_expr().args().size() - 1; + if (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) { + if (arg_num > 0) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + default: + break; + } + if (short_circuiting_ && !jump_steps_.empty()) { + visitor_->SetProgressStatusIfError( + jump_steps_.back().set_target(visitor_->GetCurrentIndex())); + } + } + if (short_circuiting_ && arg_num < last_arg_index) { + std::unique_ptr jump_step; + switch (cond_) { + case BinaryCond::kAnd: + jump_step = CreateCondJumpStep(false, true, {}, expr->id()); + break; + case BinaryCond::kOr: + jump_step = CreateCondJumpStep(true, true, {}, expr->id()); + break; + default: + ABSL_UNREACHABLE(); + } + ProgramStepIndex index = visitor_->GetCurrentIndex(); + if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); + jump_step_ptr) { + jump_steps_.push_back(Jump(index, jump_step_ptr)); + } + } + } +} + +void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } + if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || + cond_ == BinaryCond::kOptionalOrValue)) { // If first branch evaluation result is enough to determine output, - // jump over the second branch and provide result as final output. - auto jump_step_status = - CreateCondJumpStep(cond_value_, true, {}, expr->id()); - if (jump_step_status.ok()) { - jump_step_ = - Jump(visitor_->GetCurrentIndex(), jump_step_status.value().get()); + // jump over the second branch and provide result of the first argument as + // final output. + // Retain a pointer to the jump step so we can update the target after + // planning the second argument. + std::unique_ptr jump_step; + switch (cond_) { + case BinaryCond::kOptionalOr: + jump_step = CreateOptionalHasValueJumpStep(false, expr->id()); + break; + case BinaryCond::kOptionalOrValue: + jump_step = CreateOptionalHasValueJumpStep(true, expr->id()); + break; + default: + ABSL_UNREACHABLE(); + } + ProgramStepIndex index = visitor_->GetCurrentIndex(); + if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); + jump_step_ptr) { + jump_steps_.push_back(Jump(index, jump_step_ptr)); } - visitor_->AddStep(std::move(jump_step_status)); } } -void BinaryCondVisitor::PostVisit(const Expr* expr) { - visitor_->AddStep((cond_value_) ? CreateOrStep(expr->id()) - : CreateAndStep(expr->id())); - if (short_circuiting_) { - jump_step_.set_target(visitor_->GetCurrentIndex()); + +void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); + break; + case BinaryCond::kOr: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); + break; + case BinaryCond::kOptionalOr: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/false); + break; + case BinaryCond::kOptionalOrValue: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/true); + break; + default: + ABSL_UNREACHABLE(); + } + return; + } + + if (cond_ == BinaryCond::kOptionalOr || + cond_ == BinaryCond::kOptionalOrValue) { + switch (cond_) { + case BinaryCond::kOptionalOr: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); + break; + case BinaryCond::kOptionalOrValue: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); + break; + default: + ABSL_UNREACHABLE(); + } + if (short_circuiting_) { + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } + } } } -void TernaryCondVisitor::PreVisit(const Expr*) {} +void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { + visitor_->ValidateOrError( + !expr->call_expr().has_target() && expr->call_expr().args().size() == 3, + "Invalid argument count for a ternary function call."); +} -void TernaryCondVisitor::PostVisitArg(int arg_num, const Expr* expr) { +void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } // Ternary operator "_?_:_" requires a special handing. // In contrary to regular function call, its execution affects the control // flow of the overall CEL expression. @@ -552,37 +2310,37 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const Expr* expr) { // condition argument for ternary operator if (arg_num == 0) { // Jump in case of error or non-bool - auto error_jump_status = CreateBoolCheckJumpStep({}, expr->id()); - if (error_jump_status.ok()) { - error_jump_ = - Jump(visitor_->GetCurrentIndex(), error_jump_status.value().get()); + ProgramStepIndex error_jump_pos = visitor_->GetCurrentIndex(); + auto* error_jump = + visitor_->AddStep(CreateBoolCheckJumpStep({}, expr->id())); + if (error_jump) { + error_jump_ = Jump(error_jump_pos, error_jump); } - visitor_->AddStep(std::move(error_jump_status)); // Jump to the second branch of execution // Value is to be removed from the stack. - auto jump_to_second_status = - CreateCondJumpStep(false, false, {}, expr->id()); - if (jump_to_second_status.ok()) { - jump_to_second_ = Jump(visitor_->GetCurrentIndex(), - jump_to_second_status.value().get()); + ProgramStepIndex cond_jump_pos = visitor_->GetCurrentIndex(); + auto* jump_to_second = + visitor_->AddStep(CreateCondJumpStep(false, false, {}, expr->id())); + if (jump_to_second) { + jump_to_second_ = + Jump(cond_jump_pos, static_cast(jump_to_second)); } - visitor_->AddStep(std::move(jump_to_second_status)); } else if (arg_num == 1) { // Jump after the first and over the second branch of execution. // Value is to be removed from the stack. - auto jump_after_first_status = CreateJumpStep({}, expr->id()); - if (jump_after_first_status.ok()) { - jump_after_first_ = Jump(visitor_->GetCurrentIndex(), - jump_after_first_status.value().get()); + ProgramStepIndex jump_pos = visitor_->GetCurrentIndex(); + auto* jump_after_first = visitor_->AddStep(CreateJumpStep({}, expr->id())); + if (!jump_after_first) { + return; } - visitor_->AddStep(std::move(jump_after_first_status)); + jump_after_first_ = Jump(jump_pos, jump_after_first); - if (jump_to_second_.exists()) { - jump_to_second_.set_target(visitor_->GetCurrentIndex()); - } else { - visitor_->SetProgressStatusError(absl::InvalidArgumentError( - "Error configuring ternary operator: jump_to_second_ is null")); + if (visitor_->ValidateOrError( + jump_to_second_.exists(), + "Error configuring ternary operator: jump_to_second_ is null")) { + visitor_->SetProgressStatusIfError( + jump_to_second_.set_target(visitor_->GetCurrentIndex())); } } // Code executed after traversing the final branch of execution @@ -590,167 +2348,293 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const Expr* expr) { // clattered. } -void TernaryCondVisitor::PostVisit(const Expr*) { - // Determine and set jump offset in jump instruction. - if (error_jump_.exists()) { - error_jump_.set_target(visitor_->GetCurrentIndex()); - } else { - visitor_->SetProgressStatusError(absl::InvalidArgumentError( - "Error configuring ternary operator: error_jump_ is null")); +void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); return; } - if (jump_after_first_.exists()) { - jump_after_first_.set_target(visitor_->GetCurrentIndex()); - } else { - visitor_->SetProgressStatusError(absl::InvalidArgumentError( - "Error configuring ternary operator: jump_after_first_ is null")); - return; + // Determine and set jump offset in jump instruction. + if (visitor_->ValidateOrError( + error_jump_.exists(), + "Error configuring ternary operator: error_jump_ is null")) { + visitor_->SetProgressStatusIfError( + error_jump_.set_target(visitor_->GetCurrentIndex())); + } + if (visitor_->ValidateOrError( + jump_after_first_.exists(), + "Error configuring ternary operator: jump_after_first_ is null")) { + visitor_->SetProgressStatusIfError( + jump_after_first_.set_target(visitor_->GetCurrentIndex())); } } -void ExhaustiveTernaryCondVisitor::PostVisit(const Expr* expr) { - visitor_->AddStep(CreateTernaryStep(expr->id())); +void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { + visitor_->ValidateOrError( + !expr->call_expr().has_target() && expr->call_expr().args().size() == 3, + "Invalid argument count for a ternary function call."); } -const Expr* Int64ConstImpl(int64_t value) { - Constant* constant = new Constant; - constant->set_int64_value(value); - Expr* expr = new Expr; - expr->set_allocated_const_expr(constant); - return expr; +void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } + visitor_->AddStep(CreateTernaryStep(expr->id())); } -const Expr* MinusOne() { - static const Expr* expr = Int64ConstImpl(-1); - return expr; +void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { + if (is_trivial_) { + visitor_->SuppressBranch(&expr->comprehension_expr().iter_range()); + visitor_->SuppressBranch(&expr->comprehension_expr().loop_condition()); + visitor_->SuppressBranch(&expr->comprehension_expr().loop_step()); + } } -const Expr* LoopStepDummy() { - static const Expr* expr = Int64ConstImpl(-10); - return expr; -} +absl::Status ComprehensionVisitor::PostVisitArgDefault( + cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return absl::OkStatus(); + } + switch (arg_num) { + case cel::ITER_RANGE: { + init_step_pos_ = visitor_->GetCurrentIndex(); + init_step_ = visitor_->AddStep( + std::make_unique(expr->id())); + break; + } + case cel::ACCU_INIT: { + next_step_pos_ = visitor_->GetCurrentIndex(); + next_step_ = visitor_->AddStep(std::make_unique( + iter_slot_, iter2_slot_, accu_slot_, expr->id())); + break; + } + case cel::LOOP_CONDITION: { + cond_step_pos_ = visitor_->GetCurrentIndex(); + cond_step_ = visitor_->AddStep(std::make_unique( + iter_slot_, iter2_slot_, accu_slot_, short_circuiting_, expr->id())); + break; + } + case cel::LOOP_STEP: { + ProgramStepIndex index = visitor_->GetCurrentIndex(); + auto* jump_to_next = visitor_->AddStep(CreateJumpStep({}, expr->id())); + if (!jump_to_next) { + break; + } + Jump jump_helper(index, jump_to_next); + visitor_->SetProgressStatusIfError( + jump_helper.set_target(next_step_pos_)); + + // Set offsets jumping to the result step. + if (cond_step_) { + CEL_ASSIGN_OR_RETURN( + int jump_from_cond, + Jump::CalculateOffset(cond_step_pos_, visitor_->GetCurrentIndex())); + cond_step_->set_jump_offset(jump_from_cond); + } -const Expr* CurrentValueDummy() { - static const Expr* expr = Int64ConstImpl(-20); - return expr; -} + if (next_step_) { + CEL_ASSIGN_OR_RETURN( + int jump_from_next, + Jump::CalculateOffset(next_step_pos_, visitor_->GetCurrentIndex())); -void ComprehensionVisitor::PreVisit(const Expr*) { - const Expr* dummy = LoopStepDummy(); - visitor_->AddStep(CreateConstValueStep( - ConvertConstant(&dummy->const_expr()).value(), dummy->id(), false)); + next_step_->set_jump_offset(jump_from_next); + } + break; + } + case cel::RESULT: { + if (!init_step_ || !next_step_ || !cond_step_) { + // Encountered an error earlier. Can't determine where to jump. + break; + } + visitor_->AddStep(CreateComprehensionFinishStep(accu_slot_, expr->id())); + // Set offsets jumping past the result step in case of errors. + CEL_ASSIGN_OR_RETURN( + int jump_from_init, + Jump::CalculateOffset(init_step_pos_, visitor_->GetCurrentIndex())); + init_step_->set_error_jump_offset(jump_from_init); + + CEL_ASSIGN_OR_RETURN( + int jump_from_next, + Jump::CalculateOffset(next_step_pos_, visitor_->GetCurrentIndex())); + next_step_->set_error_jump_offset(jump_from_next); + + CEL_ASSIGN_OR_RETURN( + int jump_from_cond, + Jump::CalculateOffset(cond_step_pos_, visitor_->GetCurrentIndex())); + cond_step_->set_error_jump_offset(jump_from_cond); + break; + } + } + return absl::OkStatus(); } -void ComprehensionVisitor::PostVisitArg(int arg_num, const Expr* expr) { - const Comprehension* comprehension = &expr->comprehension_expr(); - const auto accu_var = comprehension->accu_var(); - const auto iter_var = comprehension->iter_var(); - // TODO(issues/20): Consider refactoring the comprehension prologue step. +void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, + const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } switch (arg_num) { - case ITER_RANGE: { - // Post-process iter_range to list its keys if it's a map. - visitor_->AddStep(CreateListKeysStep(expr->id())); - const Expr* minus1 = MinusOne(); - visitor_->AddStep(CreateConstValueStep( - ConvertConstant(&minus1->const_expr()).value(), minus1->id(), false)); - const Expr* dummy = CurrentValueDummy(); - visitor_->AddStep(CreateConstValueStep( - ConvertConstant(&dummy->const_expr()).value(), dummy->id(), false)); + case cel::ITER_RANGE: { break; } - case ACCU_INIT: { - next_step_pos_ = visitor_->GetCurrentIndex(); - next_step_ = new ComprehensionNextStep(accu_var, iter_var, expr->id()); - visitor_->AddStep(std::unique_ptr(next_step_)); + case cel::ACCU_INIT: { + if (!accu_init_extracted_) { + visitor_->AddStep(CreateAssignSlotAndPopStep(accu_slot_)); + } break; } - case LOOP_CONDITION: { - cond_step_pos_ = visitor_->GetCurrentIndex(); - cond_step_ = new ComprehensionCondStep(accu_var, iter_var, - short_circuiting_, expr->id()); - visitor_->AddStep(std::unique_ptr(cond_step_)); + case cel::LOOP_CONDITION: { break; } - case LOOP_STEP: { - auto jump_to_next = CreateJumpStep( - next_step_pos_ - visitor_->GetCurrentIndex() - 1, expr->id()); - if (jump_to_next.ok()) { - visitor_->AddStep(std::move(jump_to_next)); - } - // Set offsets. - cond_step_->set_jump_offset(visitor_->GetCurrentIndex() - cond_step_pos_ - - 1); - next_step_->set_jump_offset(visitor_->GetCurrentIndex() - next_step_pos_ - - 1); + case cel::LOOP_STEP: { break; } - case RESULT: { - visitor_->AddStep(std::unique_ptr( - new ComprehensionFinish(accu_var, iter_var, expr->id()))); - next_step_->set_error_jump_offset(visitor_->GetCurrentIndex() - - next_step_pos_ - 1); - cond_step_->set_error_jump_offset(visitor_->GetCurrentIndex() - - cond_step_pos_ - 1); + case cel::RESULT: { + visitor_->AddStep(CreateClearSlotStep(accu_slot_, expr->id())); break; } } } -void ComprehensionVisitor::PostVisit(const Expr*) {} +void ComprehensionVisitor::PostVisit(const cel::Expr* expr) { + if (is_trivial_) { + visitor_->MaybeMakeBindRecursive(expr, &expr->comprehension_expr(), + accu_slot_); + return; + } + visitor_->MaybeMakeComprehensionRecursive( + expr, &expr->comprehension_expr(), iter_slot_, iter2_slot_, accu_slot_); +} -} // namespace +// Flattens the expression table into the end of the mainline expression vector +// and returns an index to the individual sub expressions. +std::vector FlattenExpressionTable( + ProgramBuilder& program_builder, ExecutionPath& main) { + std::vector> ranges; + main = program_builder.FlattenMain(); + ranges.push_back(std::make_pair(0, main.size())); + + std::vector subexpressions = + program_builder.FlattenSubexpressions(); + for (auto& subexpression : subexpressions) { + ranges.push_back(std::make_pair(main.size(), subexpression.size())); + absl::c_move(subexpression, std::back_inserter(main)); + } -absl::StatusOr> -FlatExprBuilder::CreateExpression(const Expr* expr, - const SourceInfo* source_info, - std::vector* warnings) const { - ExecutionPath execution_path; - BuilderWarnings warnings_builder(fail_on_warnings_); + std::vector subexpression_indexes; + subexpression_indexes.reserve(ranges.size()); + for (const auto& range : ranges) { + subexpression_indexes.push_back( + absl::MakeSpan(main).subspan(range.first, range.second)); + } + return subexpression_indexes; +} - if (absl::StartsWith(container(), ".") || absl::EndsWith(container(), ".")) { +absl::Status CheckAstExtensions( + const std::vector& extensions) { + for (const cel::ExtensionSpec& extension : extensions) { + if (extension.id() == "cel_block" && extension.version().major() == 1) { + // cel_block v1 is always supported. + continue; + } + + // TODO(uncreated-issue/89): Add support for json field names. + return absl::InvalidArgumentError(absl::StrCat( + "unsupported CEL extension: ", extension.id(), "@", + extension.version().major(), ".", extension.version().minor())); + } + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr FlatExprBuilder::CreateExpressionImpl( + std::unique_ptr ast, std::vector* issues) const { + if (absl::StartsWith(container_, ".") || absl::EndsWith(container_, ".")) { return absl::InvalidArgumentError( - absl::StrCat("Invalid expression container:", container())); + absl::StrCat("Invalid expression container: '", container_, "'")); } - absl::flat_hash_map idents; + RuntimeIssue::Severity max_severity = options_.fail_on_warnings + ? RuntimeIssue::Severity::kWarning + : RuntimeIssue::Severity::kError; + IssueCollector issue_collector(max_severity); + + absl::StatusOr> runtime_extensions = + ExtractAndValidateRuntimeExtensions(*ast); - // transformed expression preserving expression IDs - Expr out; - if (constant_folding_) { - FoldConstants(*expr, *this->GetRegistry(), constant_arena_, idents, &out); + if (!runtime_extensions.ok()) { + CEL_RETURN_IF_ERROR(issue_collector.AddIssue( + RuntimeIssue::CreateError(runtime_extensions.status()))); } - std::set iter_variable_names; - FlatExprVisitor visitor(this->GetRegistry(), &execution_path, - shortcircuiting_, resolvable_enums(), container(), - idents, enable_comprehension_, &warnings_builder, - &iter_variable_names); + auto status = CheckAstExtensions(*runtime_extensions); + if (!status.ok()) { + CEL_RETURN_IF_ERROR( + issue_collector.AddIssue(RuntimeIssue::CreateError(status))); + } + + Resolver resolver(container_, function_registry_, type_registry_, + GetTypeProvider(), + options_.enable_qualified_type_identifiers); + + std::shared_ptr arena; + ProgramBuilder program_builder; + PlannerContext extension_context(env_, resolver, options_, GetTypeProvider(), + issue_collector, program_builder, arena); + + for (const std::unique_ptr& transform : ast_transforms_) { + CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, *ast)); + } - AstTraverse(constant_folding_ ? &out : expr, source_info, &visitor); + std::vector> optimizers; + for (const ProgramOptimizerFactory& optimizer_factory : program_optimizers_) { + CEL_ASSIGN_OR_RETURN(auto optimizer, + optimizer_factory(extension_context, *ast)); + if (optimizer != nullptr) { + optimizers.push_back(std::move(optimizer)); + } + } + + // These objects are expected to remain scoped to one build call -- references + // to them shouldn't be persisted in any part of the result expression. + FlatExprVisitor visitor(resolver, options_, std::move(optimizers), + ast->reference_map(), GetTypeProvider(), + issue_collector, program_builder, extension_context, + enable_optional_types_); + + if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { + int depth_limit = options_.max_recursion_depth == -1 + ? std::numeric_limits::max() + : options_.max_recursion_depth; + visitor.SetMaxRecursionDepth(depth_limit); + } + + cel::TraversalOptions opts; + opts.use_comprehension_callbacks = true; + AstTraverse(ast->root_expr(), visitor, opts); if (!visitor.progress_status().ok()) { return visitor.progress_status(); } - std::unique_ptr expression_impl = - absl::make_unique( - expr, std::move(execution_path), comprehension_max_iterations_, - std::move(iter_variable_names), enable_unknowns_, - enable_unknown_function_results_, enable_missing_attribute_errors_); - - if (warnings != nullptr) { - *warnings = std::move(warnings_builder).warnings(); + if (issues != nullptr) { + (*issues) = issue_collector.ExtractIssues(); } - return std::move(expression_impl); -} -absl::StatusOr> -FlatExprBuilder::CreateExpression(const Expr* expr, - const SourceInfo* source_info) const { - return CreateExpression(expr, source_info, nullptr); + ExecutionPath execution_path; + std::vector subexpressions = + FlattenExpressionTable(program_builder, execution_path); + + return FlatExpression(std::move(execution_path), std::move(subexpressions), + visitor.slot_count(), GetTypeProvider(), options_, + std::move(arena)); +} +const cel::TypeProvider& FlatExprBuilder::GetTypeProvider() const { + return use_legacy_type_provider_ + ? static_cast( + *GetLegacyRuntimeTypeProvider(type_registry_)) + : GetRuntimeTypeProvider(type_registry_); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index a9591b857..aa4d0b4e5 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -1,93 +1,104 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "eval/public/cel_expression.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/ast.h" +#include "base/type_provider.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/eval/evaluator_core.h" +#include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" + +namespace google::api::expr::runtime { // CelExpressionBuilder implementation. // Builds instances of CelExpressionFlatImpl. -class FlatExprBuilder : public CelExpressionBuilder { +class FlatExprBuilder { public: - FlatExprBuilder() - : enable_unknowns_(false), - enable_unknown_function_results_(false), - enable_missing_attribute_errors_(false), - shortcircuiting_(true), - constant_folding_(false), - constant_arena_(nullptr), - enable_comprehension_(true), - comprehension_max_iterations_(0), - fail_on_warnings_(true) {} - - // set_enable_unknowns controls support for unknowns in expressions created. - void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } - - // set_enable_missing_attribute_errors support for error injection in - // expressions created. - void set_enable_missing_attribute_errors(bool enabled) { - enable_missing_attribute_errors_ = enabled; + FlatExprBuilder( + absl_nonnull std::shared_ptr env, + const cel::RuntimeOptions& options, bool use_legacy_type_provider = false) + : env_(std::move(env)), + options_(options), + container_(options.container), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry), + use_legacy_type_provider_(use_legacy_type_provider) {} + + void AddAstTransform(std::unique_ptr transform) { + ast_transforms_.push_back(std::move(transform)); } - // set_enable_unknown_function_results controls support for unknown function - // results. - void set_enable_unknown_function_results(bool enabled) { - enable_unknown_function_results_ = enabled; + void AddProgramOptimizer(ProgramOptimizerFactory optimizer) { + program_optimizers_.push_back(std::move(optimizer)); } - // set_shortcircuiting regulates shortcircuiting of some expressions. - // Be default shortcircuiting is enabled. - void set_shortcircuiting(bool enabled) { shortcircuiting_ = enabled; } - - // Toggle constant folding optimization. By default it is not enabled. - // The provided arena is used to hold the generated constants. - void set_constant_folding(bool enabled, google::protobuf::Arena* arena) { - constant_folding_ = enabled; - constant_arena_ = arena; + void set_container(std::string container) { + container_ = std::move(container); } - void set_enable_comprehension(bool enabled) { - enable_comprehension_ = enabled; - } + absl::string_view container() const { return container_; } - void set_comprehension_max_iterations(int max_iterations) { - comprehension_max_iterations_ = max_iterations; - } + // TODO(uncreated-issue/45): Add overload for cref AST. At the moment, all the users + // can pass ownership of a freshly converted AST. + absl::StatusOr CreateExpressionImpl( + std::unique_ptr ast, + std::vector* issues) const; - // Warnings (e.g. no function bound) fail immediately. - void set_fail_on_warnings(bool should_fail) { - fail_on_warnings_ = should_fail; - } + const cel::runtime_internal::RuntimeEnv& env() const { return *env_; } + + const cel::RuntimeOptions& options() const { return options_; } - absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info) const override; + // Called by `cel::extensions::EnableOptionalTypes` to indicate that special + // `optional_type` handling is needed. + void enable_optional_types() { enable_optional_types_ = true; } - absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, - std::vector* warnings) const override; + bool optional_types_enabled() const { return enable_optional_types_; } private: - bool enable_unknowns_; - bool enable_unknown_function_results_; - bool enable_missing_attribute_errors_; - bool shortcircuiting_; - - bool constant_folding_; - google::protobuf::Arena* constant_arena_; - bool enable_comprehension_; - int comprehension_max_iterations_; - bool fail_on_warnings_; + const cel::TypeProvider& GetTypeProvider() const; + + const absl_nonnull std::shared_ptr + env_; + + cel::RuntimeOptions options_; + std::string container_; + bool enable_optional_types_ = false; + // TODO(uncreated-issue/45): evaluate whether we should use a shared_ptr here to + // allow built expressions to keep the registries alive. + const cel::FunctionRegistry& function_registry_; + const cel::TypeRegistry& type_registry_; + bool use_legacy_type_provider_; + std::vector> ast_transforms_; + std::vector program_optimizers_; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index 5aae97f5c..9d46d8dd8 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -1,180 +1,638 @@ -#include "google/api/expr/v1alpha1/syntax.pb.h" +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "cel/expr/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/status/status.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" +#include "eval/compiler/comprehension_vulnerability_check.h" #include "eval/compiler/flat_expr_builder.h" +#include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" -#include "eval/public/cel_builtins.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" -#include "eval/public/unknown_set.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" -#include "base/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; - -// [1, 2].filter(x, [3, 4].all(y, x < y)) -const char kNestedComprehension[] = R"pb( - id: 27 - comprehension_expr { - iter_var: "x" - iter_range { - id: 1 - list_expr { - elements { - id: 2 - const_expr { int64_value: 1 } - } - elements { - id: 3 - const_expr { int64_value: 2 } - } - } - } - accu_var: "__result__" - accu_init { - id: 22 - list_expr {} - } - loop_condition { - id: 23 - const_expr { bool_value: true } +using ::absl_testing::StatusIs; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::expr::CheckedExpr; +using ::cel::expr::ParsedExpr; +using ::testing::HasSubstr; + +class CelExpressionBuilderFlatImplComprehensionsTest + : public testing::TestWithParam { + public: + CelExpressionBuilderFlatImplComprehensionsTest() = default; + + bool enable_recursive_planning() { return GetParam(); } + + cel::RuntimeOptions GetRuntimeOptions() { + cel::RuntimeOptions options; + if (enable_recursive_planning()) { + options.max_recursion_depth = -1; } - loop_step { - id: 26 - call_expr { - function: "_?_:_" - args { - id: 20 + options.enable_comprehension_list_append = true; + return options; + } +}; + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + parser::Parse("[1, 2].filter(x, [3, 4].all(y, x < y))")); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].map(x, x * 2)")); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); + EXPECT_THAT((*result.ListOrDie())[0], + test::EqualsCelValue(CelValue::CreateInt64(2))); + EXPECT_THAT((*result.ListOrDie())[1], + test::EqualsCelValue(CelValue::CreateInt64(4))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + parser::Parse("[7].exists_one(a, a == 7)")); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + parser::Parse("[7, 7].exists_one(a, a == 7)")); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ListCompWithUnknowns) { + cel::RuntimeOptions options = GetRuntimeOptions(); + options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + parser::Parse("items.exists(i, i < 0)")); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + activation.set_unknown_attribute_patterns({CelAttributePattern{ + "items", + {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1))}}}); + ContainerBackedListImpl list_impl = ContainerBackedListImpl({ + CelValue::CreateInt64(1), + // element items[1] is marked unknown, so the computation should produce + // and unknown set. + CelValue::CreateInt64(-1), + CelValue::CreateInt64(2), + }); + activation.InsertValue("items", CelValue::CreateList(&list_impl)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsUnknownSet()) << result.DebugString(); + + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); + EXPECT_THAT(attrs, testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("items")); + EXPECT_THAT(attrs.begin()->qualifier_path(), testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->qualifier_path().at(0).GetInt64Key().value(), + testing::Eq(1)); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + InvalidComprehensionWithRewrite) { + CheckedExpr expr; + // The rewrite step which occurs when an identifier gets a more qualified name + // from the reference map has the potential to make invalid comprehensions + // appear valid, by populating missing fields with default values. + // var.(x, ) + google::protobuf::TextFormat::ParseFromString( + R"pb( + reference_map { + key: 1 + value { name: "qualified.var" } + } + expr { comprehension_expr { - iter_var: "y" + iter_var: "x" iter_range { - id: 6 - list_expr { - elements { - id: 7 - const_expr { int64_value: 3 } - } - elements { - id: 8 - const_expr { int64_value: 4 } - } - } + id: 1 + ident_expr { name: "var" } } - accu_var: "__result__" + accu_var: "y" accu_init { - id: 14 + id: 1 const_expr { bool_value: true } } - loop_condition { - id: 16 + } + })pb", + &expr); + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::AnyOf(HasSubstr("Invalid comprehension"), + HasSubstr("Invalid empty expression")))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithConcatVulernability) { + CheckedExpr expr; + // The comprehension loop step performs an unsafe concatenation of the + // accumulation variable with itself or one of its children. + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "var" } } + accu_var: "y" + accu_init { list_expr {} } + result { ident_expr { name: "y" } } + loop_condition { const_expr { bool_value: true } } + loop_step { call_expr { - function: "@not_strictly_false" + function: "_?_:_" + args { const_expr { bool_value: true } } + args { ident_expr { name: "y" } } args { - id: 15 - ident_expr { name: "__result__" } + call_expr { + function: "_+_" + args { + call_expr { + function: "dyn" + args { ident_expr { name: "y" } } + } + } + args { + call_expr { + function: "_[_]" + args { ident_expr { name: "y" } } + args { const_expr { int64_value: 0 } } + } + } + } } } } + } + })pb", + &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithListVulernability) { + CheckedExpr expr; + // The comprehension + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "var" } } + accu_var: "y" + accu_init { list_expr {} } + result { ident_expr { name: "y" } } + loop_condition { const_expr { bool_value: true } } loop_step { - id: 18 - call_expr { - function: "_&&_" - args { - id: 17 - ident_expr { name: "__result__" } + list_expr { + elements { ident_expr { name: "y" } } + elements { + list_expr { + elements { + select_expr { + operand { ident_expr { name: "y" } } + field: "z" + } + } + } } - args { - id: 12 - call_expr { - function: "_<_" - args { - id: 11 - ident_expr { name: "x" } + } + } + } + } + )pb", + &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithStructVulernability) { + CheckedExpr expr; + // The comprehension loop step builds a deeply nested struct which expands + // exponentially. + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "var" } } + accu_var: "y" + accu_init { list_expr {} } + result { ident_expr { name: "y" } } + loop_condition { const_expr { bool_value: true } } + loop_step { + struct_expr { + entries { + map_key { const_expr { string_value: "key" } } + value { ident_expr { name: "y" } } + } + entries { + map_key { const_expr { string_value: "present" } } + value { + select_expr { + test_only: true + operand { ident_expr { name: "y" } } + field: "z" } - args { - id: 13 - ident_expr { name: "y" } + } + } + entries { + map_key { const_expr { string_value: "key_subset" } } + value { + select_expr { + operand { ident_expr { name: "y" } } + field: "z" } } } } } - result { - id: 19 - ident_expr { name: "__result__" } + } + } + )pb", + &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionResultVulernability) { + CheckedExpr expr; + // The nested comprehension performs an unsafe concatenation on the parent + // accumulator variable within its 'result' expression. + // + // The inner-most comprehension shadows its parent, but still refers to its + // oldest ancestor. It, however, does not do anything unsafe. + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "var" } } + accu_var: "y" + accu_init { list_expr {} } + result { ident_expr { name: "y" } } + loop_condition { const_expr { bool_value: true } } + loop_step { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "y" } } + accu_var: "z" + accu_init { list_expr {} } + result { + call_expr { + function: "_+_" + args { ident_expr { name: "y" } } + args { ident_expr { name: "y" } } + } + } + loop_condition { const_expr { bool_value: true } } + loop_step { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "y" } } + accu_var: "z" + accu_init { list_expr {} } + result { + call_expr { + function: "dyn" + args { ident_expr { name: "y" } } + } + } + loop_condition { const_expr { bool_value: true } } + loop_step { + call_expr { + function: "dyn" + args { ident_expr { name: "y" } } + } + } + } + } } } } - args { - id: 25 - call_expr { - function: "_+_" - args { - id: 21 - ident_expr { name: "__result__" } + )pb", + &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionLoopStepVulernability) { + CheckedExpr expr; + // The nested comprehension performs an unsafe concatenation on the parent + // accumulator variable within its 'loop_step'. + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "var" } } + accu_var: "y" + accu_init { list_expr {} } + result { ident_expr { name: "y" } } + loop_condition { const_expr { bool_value: true } } + loop_step { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "y" } } + accu_var: "z" + accu_init { list_expr {} } + result { ident_expr { name: "z" } } + loop_condition { const_expr { bool_value: true } } + loop_step { + call_expr { + function: "_+_" + args { ident_expr { name: "y" } } + args { ident_expr { name: "y" } } + } + } + } } - args { - id: 24 - list_expr { - elements { - id: 5 - ident_expr { name: "x" } + } + } + )pb", + &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionLoopStepVulernabilityResult) { + CheckedExpr expr; + // The nested comprehension performs an unsafe concatenation on the parent + // accumulator. + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "outer_iter" + iter_range { ident_expr { name: "input_list" } } + accu_var: "outer_accu" + accu_init { ident_expr { name: "input_list" } } + loop_condition { + id: 3 + const_expr { bool_value: true } + } + loop_step { + comprehension_expr { + # the iter_var shadows the outer accumulator on the loop step + # but not the result step. + iter_var: "outer_accu" + iter_range { list_expr {} } + accu_var: "inner_accu" + accu_init { list_expr {} } + loop_condition { const_expr { bool_value: true } } + loop_step { list_expr {} } + result { + call_expr { + function: "_+_" + args { ident_expr { name: "outer_accu" } } + args { ident_expr { name: "outer_accu" } } + } } } } + result { list_expr {} } } } - args { - id: 21 - ident_expr { name: "__result__" } + )pb", + &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionLoopStepIterRangeVulnerability) { + CheckedExpr expr; + // The nested comprehension unsafely modifies the parent accumulator + // (outer_accu) being used as a iterable range + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "x" + iter_range { ident_expr { name: "input_list" } } + accu_var: "outer_accu" + accu_init { ident_expr { name: "input_list" } } + loop_condition { const_expr { bool_value: true } } + loop_step { + comprehension_expr { + iter_var: "y" + iter_range { ident_expr { name: "outer_accu" } } + accu_var: "inner_accu" + accu_init { ident_expr { name: "outer_accu" } } + loop_condition { const_expr { bool_value: true } } + loop_step { + call_expr { + function: "_+_" + args { ident_expr { name: "inner_accu" } } + args { const_expr { string_value: "12345" } } + } + } + result { ident_expr { name: "inner_accu" } } + } + } + result { ident_expr { name: "outer_accu" } } + } } - } - } - result { - id: 21 - ident_expr { name: "__result__" } - } - })pb"; + )pb", + &expr); -TEST(FlatExprBuilderComprehensionsTest, NestedComp) { - FlatExprBuilder builder; - Expr expr; - ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedComprehension, &expr)); + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - SourceInfo source_info; - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + InvalidBindComprehension) { + ParsedExpr expr; + // Trivial comprehensions (such as cel.bind), are optimized by skipping the + // planning for the loop step, however the planner will still warn if the + // loop step references the unused var. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "#unused" + iter_range { + id: 1 + list_expr {} + } + accu_var: "bind_var" + accu_init { + id: 1 + const_expr { bool_value: true } + } + loop_step { + call_expr { + function: "_&&_" + args { ident_expr { name: "#unused" } } + args { ident_expr { name: "bind_var" } } + } + } + loop_condition { const_expr { bool_value: false } } + result { ident_expr { name: "bind_var" } } + } + })pb", + &expr)); - auto cel_expr = std::move(build_status.value()); + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - Activation activation; - google::protobuf::Arena arena; - auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(result_or); - CelValue result = result_or.value(); - ASSERT_TRUE(result.IsList()); - EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); + EXPECT_THAT( + builder.CreateExpression(&(expr.expr()), nullptr).status(), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Unexpected iter_var access in trivial comprehension"))); } +INSTANTIATE_TEST_SUITE_P(TestSuite, + CelExpressionBuilderFlatImplComprehensionsTest, + testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "recursive" : "default"; + }); + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc new file mode 100644 index 000000000..ee106ff4a --- /dev/null +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -0,0 +1,474 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/compiler/flat_expr_builder_extensions.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/expr.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +namespace { + +using Subexpression = google::api::expr::runtime::ProgramBuilder::Subexpression; + +// Remap a recursive program to its parent if the parent is a transparent +// wrapper. +void MaybeReassignChildRecursiveProgram(Subexpression* parent) { + if (parent->IsFlattened() || parent->IsRecursive()) { + return; + } + if (parent->elements().size() != 1) { + return; + } + auto* child_alternative = + absl::get_if(&parent->elements()[0]); + if (child_alternative == nullptr) { + return; + } + + auto& child_subexpression = *child_alternative; + if (!child_subexpression->IsRecursive()) { + return; + } + + auto child_program = child_subexpression->ExtractRecursiveProgram(); + parent->set_recursive_program(std::move(child_program.step), + child_program.depth); +} + +} // namespace + +Subexpression::Subexpression(const cel::Expr* self, ProgramBuilder* owner) + : self_(self), parent_(nullptr), owner_(owner) {} + +size_t Subexpression::ComputeSize() const { + if (IsFlattened()) { + return flattened_elements().size(); + } else if (IsRecursive()) { + return 1; + } + std::vector to_expand{this}; + size_t size = 0; + while (!to_expand.empty()) { + const auto* expr = to_expand.back(); + to_expand.pop_back(); + if (expr->IsFlattened()) { + size += expr->flattened_elements().size(); + continue; + } else if (expr->IsRecursive()) { + size += 1; + continue; + } + for (const auto& elem : expr->elements()) { + if (auto* child = absl::get_if(&elem); child != nullptr) { + to_expand.push_back(*child); + } else { + size += 1; + } + } + } + return size; +} + +std::optional Subexpression::RecursiveDependencyDepth() const { + auto* tree = absl::get_if(&program_); + int depth = 0; + if (tree == nullptr) { + return std::nullopt; + } + for (const auto& element : *tree) { + auto* subexpression = absl::get_if(&element); + if (subexpression == nullptr) { + return std::nullopt; + } + if (!(*subexpression)->IsRecursive()) { + return std::nullopt; + } + depth = std::max(depth, (*subexpression)->recursive_program().depth); + } + return depth; +} + +std::vector> +Subexpression::ExtractRecursiveDependencies() const { + auto* tree = absl::get_if(&program_); + std::vector> dependencies; + if (tree == nullptr) { + return {}; + } + for (const auto& element : *tree) { + auto* subexpression = absl::get_if(&element); + if (subexpression == nullptr) { + return {}; + } + if (!(*subexpression)->IsRecursive()) { + return {}; + } + dependencies.push_back((*subexpression)->ExtractRecursiveProgram().step); + } + return dependencies; +} + +Subexpression* absl_nullable Subexpression::ExtractChild(Subexpression* child) { + ABSL_DCHECK(child != nullptr); + if (IsFlattened()) { + return nullptr; + } + for (auto iter = elements().begin(); iter != elements().end(); ++iter) { + Subexpression::Element& element = *iter; + if (!absl::holds_alternative(element)) { + continue; + } + Subexpression* candidate = absl::get(element); + if (candidate != child) { + continue; + } + elements().erase(iter); + return candidate; + } + return nullptr; +} + +// Compute the offset for moving the pc from after the base step to before the +// target step. +int Subexpression::CalculateOffset(int base, int target) const { + ABSL_DCHECK(!IsFlattened()); + ABSL_DCHECK(!IsRecursive()); + + int sign = 1; + int start = base + 1; + int end = target; + + if (end <= start) { + // When target is before base we have to consider the size of the base step + // and target (offset is from after base to before target). + start = target; + end = base + 1; + sign = -1; + } + + ABSL_DCHECK_GE(start, 0); + ABSL_DCHECK_GE(end, 0); + ABSL_DCHECK_LE(start, elements().size()); + ABSL_DCHECK_LE(end, elements().size()); + + int sum = 0; + for (int i = start; i < end; ++i) { + const auto& element = elements()[i]; + if (auto* subexpr = absl::get_if(&element); + subexpr != nullptr) { + sum += (*subexpr)->ComputeSize(); + } else { + // Individual step or wrapped recursive program. + sum += 1; + } + } + + return sign * sum; +} + +void Subexpression::Flatten() { + struct Record { + Subexpression* subexpr; + size_t offset; + }; + + if (IsFlattened()) { + return; + } + + std::vector> flat; + + std::vector flatten_stack; + + flatten_stack.push_back({this, 0}); + while (!flatten_stack.empty()) { + Record top = flatten_stack.back(); + flatten_stack.pop_back(); + size_t offset = top.offset; + auto* subexpr = top.subexpr; + if (subexpr->IsFlattened()) { + auto& elements = subexpr->flattened_elements(); + absl::c_move(elements, std::back_inserter(flat)); + elements.clear(); + continue; + } else if (subexpr->IsRecursive()) { + flat.push_back(std::make_unique( + std::move(subexpr->ExtractRecursiveProgram().step), + subexpr->self_->id())); + continue; + } + auto& elements = subexpr->elements(); + size_t size = elements.size(); + size_t i = offset; + for (; i < size; ++i) { + auto& element = elements[i]; + if (auto* child = absl::get_if(&element); + child != nullptr) { + // push resume then child so child elements are processed first. + flatten_stack.push_back({subexpr, i + 1}); + flatten_stack.push_back({*child, 0}); + break; + } else if (auto* step = + absl::get_if>(&element); + step != nullptr) { + flat.push_back(std::move(*step)); + } else { + ABSL_UNREACHABLE(); + } + } + if (i == size) { + elements.clear(); + } + } + program_ = std::move(flat); +} + +Subexpression::RecursiveProgram Subexpression::ExtractRecursiveProgram() { + ABSL_DCHECK(IsRecursive()); + auto result = std::move(absl::get(program_)); + program_.emplace>(); + return result; +} + +bool Subexpression::ExtractTo( + std::vector>& out) { + if (!IsFlattened()) { + return false; + } + + out.reserve(out.size() + flattened_elements().size()); + absl::c_move(flattened_elements(), std::back_inserter(out)); + program_.emplace>(); + + return true; +} + +std::vector> +ProgramBuilder::FlattenSubexpression(Subexpression* expr) { + std::vector> out; + + if (!expr) { + return out; + } + + expr->Flatten(); + expr->ExtractTo(out); + return out; +} + +ProgramBuilder::ProgramBuilder() + : root_(nullptr), current_(nullptr), subprogram_map_() {} + +ExecutionPath ProgramBuilder::FlattenMain() { + auto out = FlattenSubexpression(root_); + root_ = nullptr; + return out; +} + +std::vector ProgramBuilder::FlattenSubexpressions() { + std::vector out; + out.reserve(extracted_subexpressions_.size()); + for (auto& subexpression : extracted_subexpressions_) { + out.push_back(FlattenSubexpression(subexpression)); + } + extracted_subexpressions_.clear(); + return out; +} + +Subexpression* absl_nullable ProgramBuilder::EnterSubexpression( + const cel::Expr* expr, size_t size_hint) { + Subexpression* subexpr = MakeSubexpression(expr); + if (subexpr == nullptr) { + return subexpr; + } + + subexpr->elements().reserve(size_hint); + if (current_ == nullptr) { + root_ = subexpr; + current_ = subexpr; + return subexpr; + } + + current_->AddSubexpression(subexpr); + subexpr->parent_ = current_->self_; + current_ = subexpr; + return subexpr; +} + +Subexpression* absl_nullable ProgramBuilder::ExitSubexpression( + const cel::Expr* expr) { + ABSL_DCHECK(expr == current_->self_); + ABSL_DCHECK(GetSubexpression(expr) == current_); + + MaybeReassignChildRecursiveProgram(current_); + + Subexpression* result = GetSubexpression(current_->parent_); + ABSL_DCHECK(result != nullptr || current_ == root_); + current_ = result; + return result; +} + +Subexpression* absl_nullable ProgramBuilder::GetSubexpression( + const cel::Expr* expr) { + auto it = subprogram_map_.find(expr); + if (it == subprogram_map_.end()) { + return nullptr; + } + + return it->second.get(); +} + +ExpressionStep* absl_nullable ProgramBuilder::AddStep( + std::unique_ptr step) { + if (current_ == nullptr) { + return nullptr; + } + auto* step_ptr = step.get(); + return current_->AddStep(std::move(step)) ? step_ptr : nullptr; +} + +int ProgramBuilder::ExtractSubexpression(const cel::Expr* expr) { + auto it = subprogram_map_.find(expr); + if (it == subprogram_map_.end()) { + return -1; + } + auto* subexpression = it->second.get(); + auto parent_it = subprogram_map_.find(subexpression->parent_); + if (parent_it == subprogram_map_.end()) { + return -1; + } + + auto* parent = parent_it->second.get(); + + auto* child = parent->ExtractChild(subexpression); + + if (child == nullptr) { + return -1; + } + + extracted_subexpressions_.push_back(child); + return extracted_subexpressions_.size() - 1; +} + +Subexpression* absl_nullable ProgramBuilder::MakeSubexpression( + const cel::Expr* expr) { + auto [it, inserted] = subprogram_map_.try_emplace( + expr, absl::WrapUnique(new Subexpression(expr, this))); + if (!inserted) { + return nullptr; + } + + return it->second.get(); +} + +bool PlannerContext::IsSubplanInspectable(const cel::Expr& node) const { + return program_builder_.GetSubexpression(&node) != nullptr; +} + +ExecutionPathView PlannerContext::GetSubplan(const cel::Expr& node) { + auto* subexpression = program_builder_.GetSubexpression(&node); + if (subexpression == nullptr) { + return ExecutionPathView(); + } + subexpression->Flatten(); + return subexpression->flattened_elements(); +} + +absl::StatusOr PlannerContext::ExtractSubplan( + const cel::Expr& node) { + auto* subexpression = program_builder_.GetSubexpression(&node); + if (subexpression == nullptr) { + return absl::InternalError( + "attempted to update program step for untracked expr node"); + } + + subexpression->Flatten(); + + ExecutionPath out; + subexpression->ExtractTo(out); + + return out; +} + +absl::Status PlannerContext::ReplaceSubplan(const cel::Expr& node, + ExecutionPath path) { + auto* subexpression = program_builder_.GetSubexpression(&node); + if (subexpression == nullptr) { + return absl::InternalError( + "attempted to update program step for untracked expr node"); + } + + // Make sure structure for descendents is erased. + if (!subexpression->IsFlattened()) { + subexpression->Flatten(); + } + + subexpression->flattened_elements() = std::move(path); + + return absl::OkStatus(); +} + +void ProgramBuilder::Reset() { + root_ = nullptr; + current_ = nullptr; + extracted_subexpressions_.clear(); + subprogram_map_.clear(); +} + +absl::Status PlannerContext::ReplaceSubplan( + const cel::Expr& node, std::unique_ptr step, + int depth) { + auto* subexpression = program_builder_.GetSubexpression(&node); + if (subexpression == nullptr) { + return absl::InternalError( + "attempted to update program step for untracked expr node"); + } + + subexpression->set_recursive_program(std::move(step), depth); + return absl::OkStatus(); +} + +absl::Status PlannerContext::AddSubplanStep( + const cel::Expr& node, std::unique_ptr step) { + auto* subexpression = program_builder_.GetSubexpression(&node); + + if (subexpression == nullptr) { + return absl::InternalError( + "attempted to update program step for untracked expr node"); + } + + subexpression->AddStep(std::move(step)); + + return absl::OkStatus(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h new file mode 100644 index 000000000..21e37b2a8 --- /dev/null +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -0,0 +1,481 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// API definitions for planner extensions. +// +// These are provided to indirect build dependencies for optional features and +// require detailed understanding of how the flat expression builder works and +// its assumptions. +// +// These interfaces should not be implemented directly by CEL users. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "base/ast.h" +#include "base/type_provider.h" +#include "common/expr.h" +#include "common/native_type.h" +#include "common/type_reflector.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/trace_step.h" +#include "internal/casts.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +// Class representing a CEL program being built. +// +// Maintains tree structure and mapping from the AST representation to +// subexpressions. Maintains an insertion point for new steps and +// subexpressions. +// +// This class is thread-hostile and not intended for direct access outside of +// the Expression builder. Extensions should interact with this through the +// the PlannerContext member functions. +class ProgramBuilder { + public: + class Subexpression; + + private: + using SubprogramMap = + absl::flat_hash_map>; + + public: + // Represents a subexpression. + // + // Steps apply operations on the stack machine for the C++ runtime. + // For most expression types, this maps to a post order traversal -- for all + // nodes, evaluate dependencies (pushing their results to stack) then evaluate + // self. + // + // Must be tied to a ProgramBuilder to coordinate relationships. + class Subexpression { + private: + using Element = absl::variant, + Subexpression* absl_nonnull>; + + using TreePlan = std::vector; + using FlattenedPlan = std::vector>; + + public: + struct RecursiveProgram { + std::unique_ptr step; + int depth; + }; + + ~Subexpression() = default; + + // Not copyable or movable. + Subexpression(const Subexpression&) = delete; + Subexpression& operator=(const Subexpression&) = delete; + Subexpression(Subexpression&&) = delete; + Subexpression& operator=(Subexpression&&) = delete; + + // Add a program step at the current end of the subexpression. + bool AddStep(std::unique_ptr step) { + if (IsRecursive()) { + return false; + } + + if (IsFlattened()) { + flattened_elements().push_back(std::move(step)); + return true; + } + + elements().push_back({std::move(step)}); + return true; + } + + void AddSubexpression(Subexpression* absl_nonnull expr) { + ABSL_DCHECK(absl::holds_alternative(program_)); + ABSL_DCHECK(owner_ == expr->owner_); + elements().push_back(expr); + } + + // Accessor for elements (either simple steps or subexpressions). + // + // Value is undefined if in the expression has already been flattened. + std::vector& elements() { + ABSL_DCHECK(absl::holds_alternative(program_)); + return absl::get(program_); + } + + const std::vector& elements() const { + ABSL_DCHECK(absl::holds_alternative(program_)); + return absl::get(program_); + } + + // Accessor for program steps. + // + // Value is undefined if in the expression has not yet been flattened. + std::vector>& flattened_elements() { + ABSL_DCHECK(IsFlattened()); + return absl::get(program_); + } + + const std::vector>& + flattened_elements() const { + ABSL_DCHECK(IsFlattened()); + return absl::get(program_); + } + + void set_recursive_program(std::unique_ptr step, + int depth) { + program_ = RecursiveProgram{std::move(step), depth}; + } + + const RecursiveProgram& recursive_program() const { + ABSL_DCHECK(IsRecursive()); + return absl::get(program_); + } + + absl::optional RecursiveDependencyDepth() const; + + std::vector> + ExtractRecursiveDependencies() const; + + RecursiveProgram ExtractRecursiveProgram(); + + bool IsRecursive() const { + return absl::holds_alternative(program_); + } + + // Compute the current number of program steps in this subexpression and + // its dependencies. + size_t ComputeSize() const; + + // Calculate the number of steps from the end of base to before target, + // (including negative offsets). + int CalculateOffset(int base, int target) const; + + // Extract a child subexpression. + // + // The expression is removed from the elements array. + // + // Returns nullptr if child is not an element of this subexpression. + Subexpression* absl_nullable ExtractChild(Subexpression* child); + + // Flatten the subexpression. + // + // This removes the structure tracking for subexpressions, but makes the + // subprogram evaluable on the runtime's stack machine. + void Flatten(); + + bool IsFlattened() const { + return absl::holds_alternative(program_); + } + + // Extract a flattened subexpression into the given vector. Transferring + // ownership of the given steps. + // + // Returns false if the subexpression is not currently flattened. + bool ExtractTo(std::vector>& out); + + private: + Subexpression(const cel::Expr* self, ProgramBuilder* owner); + + friend class ProgramBuilder; + + // Some extensions expect the program plan to be contiguous mid-planning. + // + // This adds complexity, but supports swapping to a flat representation as + // needed. + absl::variant program_; + + const cel::Expr* self_; + const cel::Expr* absl_nullable parent_; + ProgramBuilder* owner_; + }; + + ProgramBuilder(); + + // Flatten the main subexpression and return its value. + // + // This transfers ownership of the program, returning the builder to starting + // state. (See FlattenSubexpressions). + ExecutionPath FlattenMain(); + + // Flatten extracted subprograms. + // + // This transfers ownership of the subprograms, returning the extracted + // programs table to starting state. + std::vector FlattenSubexpressions(); + + // Returns the current subexpression where steps and new subexpressions are + // added. + // + // May return null if the builder is not currently planning an expression. + Subexpression* absl_nullable current() { return current_; } + + // Enter a subexpression context. + // + // Adds a subexpression at the current insertion point and move insertion + // to the subexpression. + // + // Returns the new current() value. + // + // May return nullptr if the expression is already indexed in the program + // builder. + Subexpression* absl_nullable EnterSubexpression(const cel::Expr* expr, + size_t size_hint = 0); + + // Exit a subexpression context. + // + // Sets insertion point to parent. + // + // Returns the new current() value or nullptr if called out of order. + Subexpression* absl_nullable ExitSubexpression(const cel::Expr* expr); + + // Return the subexpression mapped to the given expression. + // + // Returns nullptr if the mapping doesn't exist either due to the + // program being overwritten or not encountering the expression. + Subexpression* absl_nullable GetSubexpression(const cel::Expr* expr); + + // Return the extracted subexpression mapped to the given index. + // + // Returns nullptr if the mapping doesn't exist + Subexpression* absl_nullable GetExtractedSubexpression(size_t index) { + if (index >= extracted_subexpressions_.size()) { + return nullptr; + } + + return extracted_subexpressions_[index]; + } + + // Return index to the extracted subexpression. + // + // Returns -1 if the subexpression is not found. + int ExtractSubexpression(const cel::Expr* expr); + + // Add a program step to the current subexpression. + // If successful, returns the step pointer. + // + // Note: If successful, the pointer should remain valid until the parent + // expression is finalized. Optimizers may modify the program plan which may + // free the step at that point. + ExpressionStep* absl_nullable AddStep(std::unique_ptr step); + + void Reset(); + + private: + static std::vector> + FlattenSubexpression(Subexpression* absl_nonnull expr); + + Subexpression* absl_nullable MakeSubexpression(const cel::Expr* expr); + + Subexpression* absl_nullable root_; + std::vector extracted_subexpressions_; + Subexpression* absl_nullable current_; + SubprogramMap subprogram_map_; +}; + +// Attempt to downcast a specific type of recursive step. +template +const Subclass* TryDowncastDirectStep(const DirectExpressionStep* step) { + if (step == nullptr) { + return nullptr; + } + + auto type_id = step->GetNativeTypeId(); + if (type_id == cel::NativeTypeId::For()) { + const auto* trace_step = cel::internal::down_cast(step); + auto deps = trace_step->GetDependencies(); + if (!deps.has_value() || deps->size() != 1) { + return nullptr; + } + step = deps->at(0); + type_id = step->GetNativeTypeId(); + } + + if (type_id == cel::NativeTypeId::For()) { + return cel::internal::down_cast(step); + } + + return nullptr; +} + +// Class representing FlatExpr internals exposed to extensions. +class PlannerContext { + public: + PlannerContext( + std::shared_ptr environment, + const Resolver& resolver, const cel::RuntimeOptions& options, + const cel::TypeReflector& type_reflector, + cel::runtime_internal::IssueCollector& issue_collector, + ProgramBuilder& program_builder, + std::shared_ptr& arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::shared_ptr message_factory = nullptr) + : environment_(std::move(environment)), + resolver_(resolver), + type_reflector_(type_reflector), + options_(options), + issue_collector_(issue_collector), + program_builder_(program_builder), + arena_(arena), + explicit_arena_(arena_ != nullptr), + message_factory_(std::move(message_factory)) {} + + ProgramBuilder& program_builder() { return program_builder_; } + + // Returns true if the subplan is inspectable. + // + // If false, the node is not mapped to a subexpression in the program builder. + bool IsSubplanInspectable(const cel::Expr& node) const; + + // Return a view to the current subplan representing node. + // + // Note: this is invalidated after a sibling or parent is updated. + // + // This operation forces the subexpression to flatten which removes the + // expr->program mapping for any descendants. + ExecutionPathView GetSubplan(const cel::Expr& node); + + // Extract the plan steps for the given expr. + // + // After successful extraction, the subexpression is still inspectable, but + // empty. + absl::StatusOr ExtractSubplan(const cel::Expr& node); + + // Replace the subplan associated with node with a new subplan. + // + // This operation forces the subexpression to flatten which removes the + // expr->program mapping for any descendants. + absl::Status ReplaceSubplan(const cel::Expr& node, ExecutionPath path); + + // Replace the subplan associated with node with a new recursive subplan. + // + // This operation clears any existing plan to which removes the + // expr->program mapping for any descendants. + absl::Status ReplaceSubplan(const cel::Expr& node, + std::unique_ptr step, + int depth); + + // Extend the current subplan with the given expression step. + absl::Status AddSubplanStep(const cel::Expr& node, + std::unique_ptr step); + + const Resolver& resolver() const { return resolver_; } + const cel::TypeReflector& type_reflector() const { return type_reflector_; } + const cel::RuntimeOptions& options() const { return options_; } + cel::runtime_internal::IssueCollector& issue_collector() { + return issue_collector_; + } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { + return environment_->descriptor_pool.get(); + } + + // Returns `true` if an arena was explicitly provided during planning. + bool HasExplicitArena() const { return explicit_arena_; } + + google::protobuf::Arena* absl_nonnull MutableArena() { + if (!explicit_arena_ && arena_ == nullptr) { + arena_ = std::make_shared(); + } + ABSL_DCHECK(arena_ != nullptr); + return arena_.get(); + } + + // Returns `true` if a message factory was explicitly provided during + // planning. + bool HasExplicitMessageFactory() const { return message_factory_ != nullptr; } + + google::protobuf::MessageFactory* absl_nonnull MutableMessageFactory() { + return HasExplicitMessageFactory() ? message_factory_.get() + : environment_->MutableMessageFactory(); + } + + private: + const std::shared_ptr environment_; + const Resolver& resolver_; + const cel::TypeReflector& type_reflector_; + const cel::RuntimeOptions& options_; + cel::runtime_internal::IssueCollector& issue_collector_; + ProgramBuilder& program_builder_; + std::shared_ptr& arena_; + const bool explicit_arena_; + const std::shared_ptr message_factory_; +}; + +// Interface for Ast Transforms. +// If any are present, the FlatExprBuilder will apply the Ast Transforms in +// order on a copy of the relevant input expressions before planning the +// program. +class AstTransform { + public: + virtual ~AstTransform() = default; + + virtual absl::Status UpdateAst(PlannerContext& context, + cel::Ast& ast) const = 0; +}; + +// Interface for program optimizers. +// +// If any are present, the FlatExprBuilder will notify the implementations in +// order as it traverses the input ast. +// +// Note: implementations must correctly check that subprograms are available +// before accessing (i.e. they have not already been edited). +class ProgramOptimizer { + public: + virtual ~ProgramOptimizer() = default; + + // Called before planning the given expr node. + virtual absl::Status OnPreVisit(PlannerContext& context, + const cel::Expr& node) = 0; + + // Called after planning the given expr node. + virtual absl::Status OnPostVisit(PlannerContext& context, + const cel::Expr& node) = 0; +}; + +// Type definition for ProgramOptimizer factories. +// +// The expression builder must remain thread compatible, but ProgramOptimizers +// are often stateful for a given expression. To avoid requiring the optimizer +// implementation to handle concurrent planning, the builder creates a new +// instance per expression planned. +// +// The factory must be thread safe, but the returned instance may assume +// it is called from a synchronous context. +using ProgramOptimizerFactory = + absl::AnyInvocable>( + PlannerContext&, const cel::Ast&) const>; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ diff --git a/eval/compiler/flat_expr_builder_extensions_test.cc b/eval/compiler/flat_expr_builder_extensions_test.cc new file mode 100644 index 000000000..45913e61b --- /dev/null +++ b/eval/compiler/flat_expr_builder_extensions_test.cc @@ -0,0 +1,571 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/compiler/flat_expr_builder_extensions.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "common/expr.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/function_step.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/function_registry.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::Expr; +using ::cel::RuntimeIssue; +using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Optional; + +using Subexpression = ProgramBuilder::Subexpression; + +class PlannerContextTest : public testing::Test { + public: + PlannerContextTest() + : env_(NewTestingRuntimeEnv()), + type_registry_(env_->type_registry), + function_registry_(env_->function_registry), + resolver_("", function_registry_, type_registry_, + type_registry_.GetComposedTypeProvider()), + issue_collector_(RuntimeIssue::Severity::kError) {} + + protected: + absl_nonnull std::shared_ptr env_; + cel::TypeRegistry& type_registry_; + cel::FunctionRegistry& function_registry_; + cel::RuntimeOptions options_; + Resolver resolver_; + IssueCollector issue_collector_; +}; + +MATCHER_P(UniquePtrHolds, ptr, "") { + const auto& got = arg; + return ptr == got.get(); +} + +struct SimpleTreeSteps { + const ExpressionStep* a; + const ExpressionStep* b; + const ExpressionStep* c; +}; + +// simulate a program of: +// a +// / \ +// b c +absl::StatusOr InitSimpleTree( + const Expr& a, const Expr& b, const Expr& c, + ProgramBuilder& program_builder) { + CEL_ASSIGN_OR_RETURN(auto a_step, CreateConstValueStep(cel::NullValue(), -1)); + CEL_ASSIGN_OR_RETURN(auto b_step, CreateConstValueStep(cel::NullValue(), -1)); + CEL_ASSIGN_OR_RETURN(auto c_step, CreateConstValueStep(cel::NullValue(), -1)); + + SimpleTreeSteps result{a_step.get(), b_step.get(), c_step.get()}; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.AddStep(std::move(b_step)); + program_builder.ExitSubexpression(&b); + program_builder.EnterSubexpression(&c); + program_builder.AddStep(std::move(c_step)); + program_builder.ExitSubexpression(&c); + program_builder.AddStep(std::move(a_step)); + program_builder.ExitSubexpression(&a); + + return result; +} + +TEST_F(PlannerContextTest, GetPlan) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(auto step_ptrs, + InitSimpleTree(a, b, c, program_builder)); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(step_ptrs.b))); + + EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(step_ptrs.c))); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(step_ptrs.b), + UniquePtrHolds(step_ptrs.c), + UniquePtrHolds(step_ptrs.a))); + + Expr d; + EXPECT_FALSE(context.IsSubplanInspectable(d)); + EXPECT_THAT(context.GetSubplan(d), IsEmpty()); +} + +TEST_F(PlannerContextTest, ReplacePlan) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(auto step_ptrs, + InitSimpleTree(a, b, c, program_builder)); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(step_ptrs.b), + UniquePtrHolds(step_ptrs.c), + UniquePtrHolds(step_ptrs.a))); + + ExecutionPath new_a; + + ASSERT_OK_AND_ASSIGN(auto new_a_step, + CreateConstValueStep(cel::NullValue(), -1)); + const ExpressionStep* new_a_step_ptr = new_a_step.get(); + new_a.push_back(std::move(new_a_step)); + + ASSERT_THAT(context.ReplaceSubplan(a, std::move(new_a)), IsOk()); + + EXPECT_THAT(context.GetSubplan(a), + ElementsAre(UniquePtrHolds(new_a_step_ptr))); + EXPECT_THAT(context.GetSubplan(b), IsEmpty()); +} + +TEST_F(PlannerContextTest, ExtractPlan) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(auto plan_steps, + InitSimpleTree(a, b, c, program_builder)); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + EXPECT_TRUE(context.IsSubplanInspectable(a)); + EXPECT_TRUE(context.IsSubplanInspectable(b)); + + ASSERT_OK_AND_ASSIGN(ExecutionPath extracted, context.ExtractSubplan(b)); + + EXPECT_THAT(extracted, ElementsAre(UniquePtrHolds(plan_steps.b))); +} + +TEST_F(PlannerContextTest, ExtractFailsOnReplacedNode) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_THAT(InitSimpleTree(a, b, c, program_builder).status(), IsOk()); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ASSERT_THAT(context.ReplaceSubplan(a, {}), IsOk()); + + EXPECT_THAT(context.ExtractSubplan(b), IsOkAndHolds(IsEmpty())); +} + +TEST_F(PlannerContextTest, ReplacePlanUpdatesParent) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(auto plan_steps, + InitSimpleTree(a, b, c, program_builder)); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + EXPECT_TRUE(context.IsSubplanInspectable(a)); + + ASSERT_THAT(context.ReplaceSubplan(c, {}), IsOk()); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(plan_steps.b), + UniquePtrHolds(plan_steps.a))); + EXPECT_THAT(context.GetSubplan(c), IsEmpty()); +} + +TEST_F(PlannerContextTest, ReplacePlanUpdatesSibling) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(auto plan_steps, + InitSimpleTree(a, b, c, program_builder)); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ExecutionPath new_b; + + ASSERT_OK_AND_ASSIGN(auto b1_step, + CreateConstValueStep(cel::NullValue(), -1)); + const ExpressionStep* b1_step_ptr = b1_step.get(); + new_b.push_back(std::move(b1_step)); + ASSERT_OK_AND_ASSIGN(auto b2_step, + CreateConstValueStep(cel::NullValue(), -1)); + const ExpressionStep* b2_step_ptr = b2_step.get(); + new_b.push_back(std::move(b2_step)); + + ASSERT_THAT(context.ReplaceSubplan(b, std::move(new_b)), IsOk()); + + EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(plan_steps.c))); + EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(b1_step_ptr), + UniquePtrHolds(b2_step_ptr))); + EXPECT_THAT( + context.GetSubplan(a), + ElementsAre(UniquePtrHolds(b1_step_ptr), UniquePtrHolds(b2_step_ptr), + UniquePtrHolds(plan_steps.c), UniquePtrHolds(plan_steps.a))); +} + +TEST_F(PlannerContextTest, ReplacePlanFailsOnUpdatedNode) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(auto plan_steps, + InitSimpleTree(a, b, c, program_builder)); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(plan_steps.b), + UniquePtrHolds(plan_steps.c), + UniquePtrHolds(plan_steps.a))); + + ASSERT_THAT(context.ReplaceSubplan(a, {}), IsOk()); + EXPECT_THAT(context.ReplaceSubplan(b, {}), IsOk()); +} + +TEST_F(PlannerContextTest, AddSubplanStep) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(auto plan_steps, + InitSimpleTree(a, b, c, program_builder)); + + ASSERT_OK_AND_ASSIGN(auto b2_step, + CreateConstValueStep(cel::NullValue(), -1)); + + const ExpressionStep* b2_step_ptr = b2_step.get(); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + ASSERT_THAT(context.AddSubplanStep(b, std::move(b2_step)), IsOk()); + + EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(plan_steps.b), + UniquePtrHolds(b2_step_ptr))); + EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(plan_steps.c))); + EXPECT_THAT( + context.GetSubplan(a), + ElementsAre(UniquePtrHolds(plan_steps.b), UniquePtrHolds(b2_step_ptr), + UniquePtrHolds(plan_steps.c), UniquePtrHolds(plan_steps.a))); +} + +TEST_F(PlannerContextTest, AddSubplanStepFailsOnUnknownNode) { + Expr a; + Expr b; + Expr c; + Expr d; + ProgramBuilder program_builder; + + ASSERT_THAT(InitSimpleTree(a, b, c, program_builder).status(), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto b2_step, + CreateConstValueStep(cel::NullValue(), -1)); + + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector_, program_builder, arena); + + EXPECT_THAT(context.GetSubplan(d), IsEmpty()); + + EXPECT_THAT(context.AddSubplanStep(d, std::move(b2_step)), + StatusIs(absl::StatusCode::kInternal)); +} + +class ProgramBuilderTest : public testing::Test { + public: + ProgramBuilderTest() : type_registry_(), function_registry_() {} + + protected: + cel::TypeRegistry type_registry_; + cel::FunctionRegistry function_registry_; +}; + +TEST_F(ProgramBuilderTest, ExtractSubexpression) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(SimpleTreeSteps step_ptrs, + InitSimpleTree(a, b, c, program_builder)); + EXPECT_EQ(program_builder.ExtractSubexpression(&c), 0); + EXPECT_EQ(program_builder.ExtractSubexpression(&b), 1); + + EXPECT_THAT(program_builder.FlattenMain(), + ElementsAre(UniquePtrHolds(step_ptrs.a))); + EXPECT_THAT(program_builder.FlattenSubexpressions(), + ElementsAre(ElementsAre(UniquePtrHolds(step_ptrs.c)), + ElementsAre(UniquePtrHolds(step_ptrs.b)))); +} + +TEST_F(ProgramBuilderTest, FlattenRemovesChildrenReferences) { + Expr a; + Expr b; + Expr c; + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.EnterSubexpression(&c); + program_builder.ExitSubexpression(&c); + program_builder.ExitSubexpression(&b); + program_builder.ExitSubexpression(&a); + + auto subexpr_b = program_builder.GetSubexpression(&b); + ASSERT_TRUE(subexpr_b != nullptr); + subexpr_b->Flatten(); + + auto* subexpr_c = program_builder.GetSubexpression(&c); + EXPECT_EQ(subexpr_b->ExtractChild(subexpr_c), nullptr); +} + +TEST_F(ProgramBuilderTest, ExtractReturnsNullOnFlattendExpr) { + Expr a; + Expr b; + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.ExitSubexpression(&b); + program_builder.ExitSubexpression(&a); + + auto* subexpr_a = program_builder.GetSubexpression(&a); + auto* subexpr_b = program_builder.GetSubexpression(&b); + + ASSERT_TRUE(subexpr_a != nullptr); + ASSERT_TRUE(subexpr_b != nullptr); + + subexpr_a->Flatten(); + // subexpr_b is now freed. + + EXPECT_EQ(subexpr_a->ExtractChild(subexpr_b), nullptr); + EXPECT_EQ(program_builder.ExtractSubexpression(&b), -1); +} + +TEST_F(ProgramBuilderTest, ExtractReturnsNullOnNonChildren) { + Expr a; + Expr b; + Expr c; + + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.EnterSubexpression(&c); + program_builder.ExitSubexpression(&c); + program_builder.ExitSubexpression(&b); + program_builder.ExitSubexpression(&a); + + auto* subexpr_a = program_builder.GetSubexpression(&a); + auto* subexpr_c = program_builder.GetSubexpression(&c); + + ASSERT_TRUE(subexpr_a != nullptr); + ASSERT_TRUE(subexpr_c != nullptr); + + EXPECT_EQ(subexpr_a->ExtractChild(subexpr_c), nullptr); +} + +TEST_F(ProgramBuilderTest, ResetWorks) { + Expr a; + Expr b; + Expr c; + + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.EnterSubexpression(&c); + program_builder.ExitSubexpression(&c); + program_builder.ExitSubexpression(&b); + program_builder.ExitSubexpression(&a); + + auto* subexpr_a = program_builder.GetSubexpression(&a); + auto* subexpr_c = program_builder.GetSubexpression(&c); + + ASSERT_TRUE(subexpr_a != nullptr); + ASSERT_TRUE(subexpr_c != nullptr); + + program_builder.Reset(); + + subexpr_a = program_builder.GetSubexpression(&a); + subexpr_c = program_builder.GetSubexpression(&c); + + ASSERT_TRUE(subexpr_a == nullptr); + ASSERT_TRUE(subexpr_c == nullptr); +} + +TEST_F(ProgramBuilderTest, ExtractWorks) { + Expr a; + Expr b; + Expr c; + + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.ExitSubexpression(&b); + + ASSERT_OK_AND_ASSIGN(auto a_step, CreateConstValueStep(cel::NullValue(), -1)); + program_builder.AddStep(std::move(a_step)); + program_builder.EnterSubexpression(&c); + program_builder.ExitSubexpression(&c); + program_builder.ExitSubexpression(&a); + + auto* subexpr_a = program_builder.GetSubexpression(&a); + auto* subexpr_c = program_builder.GetSubexpression(&c); + + ASSERT_TRUE(subexpr_a != nullptr); + ASSERT_TRUE(subexpr_c != nullptr); + + EXPECT_EQ(subexpr_a->ExtractChild(subexpr_c), subexpr_c); +} + +TEST_F(ProgramBuilderTest, ExtractToRequiresFlatten) { + Expr a; + Expr b; + Expr c; + + ProgramBuilder program_builder; + + ASSERT_OK_AND_ASSIGN(SimpleTreeSteps step_ptrs, + InitSimpleTree(a, b, c, program_builder)); + + auto* subexpr_a = program_builder.GetSubexpression(&a); + ExecutionPath path; + + EXPECT_FALSE(subexpr_a->ExtractTo(path)); + + subexpr_a->Flatten(); + EXPECT_TRUE(subexpr_a->ExtractTo(path)); + + EXPECT_THAT(path, ElementsAre(UniquePtrHolds(step_ptrs.b), + UniquePtrHolds(step_ptrs.c), + UniquePtrHolds(step_ptrs.a))); +} + +TEST_F(ProgramBuilderTest, Recursive) { + Expr a; + Expr b; + Expr c; + + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.current()->set_recursive_program( + CreateConstValueDirectStep(cel::NullValue()), 1); + program_builder.ExitSubexpression(&b); + program_builder.EnterSubexpression(&c); + program_builder.current()->set_recursive_program( + CreateConstValueDirectStep(cel::NullValue()), 1); + program_builder.ExitSubexpression(&c); + + ASSERT_FALSE(program_builder.current()->IsFlattened()); + ASSERT_FALSE(program_builder.current()->IsRecursive()); + ASSERT_TRUE(program_builder.GetSubexpression(&b)->IsRecursive()); + ASSERT_TRUE(program_builder.GetSubexpression(&c)->IsRecursive()); + + EXPECT_EQ(program_builder.GetSubexpression(&b)->recursive_program().depth, 1); + EXPECT_EQ(program_builder.GetSubexpression(&c)->recursive_program().depth, 1); + + cel::CallExpr call_expr; + call_expr.set_function("_==_"); + call_expr.mutable_args().emplace_back(); + call_expr.mutable_args().emplace_back(); + + auto max_depth = program_builder.current()->RecursiveDependencyDepth(); + + EXPECT_THAT(max_depth, Optional(1)); + + auto deps = program_builder.current()->ExtractRecursiveDependencies(); + + program_builder.current()->set_recursive_program( + CreateDirectFunctionStep(-1, call_expr, std::move(deps), {}), + *max_depth + 1); + + program_builder.ExitSubexpression(&a); + + auto path = program_builder.FlattenMain(); + + ASSERT_THAT(path, testing::SizeIs(1)); + EXPECT_TRUE(path[0]->GetNativeTypeId() == + cel::NativeTypeId::For()); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index f00e200eb..afe7c5f9f 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -2,31 +2,31 @@ // produce expressions with the same outputs. #include -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" -#include "eval/compiler/flat_expr_builder.h" +#include "base/builtins.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" -#include "eval/public/cel_builtins.h" #include "eval/public/cel_expression.h" -#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" -#include "base/status_macros.h" +#include "internal/testing.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using testing::Eq; -using testing::SizeIs; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::expr::Expr; +using ::testing::Eq; +using ::testing::SizeIs; constexpr char kTwoLogicalOp[] = R"cel( id: 1 @@ -86,26 +86,27 @@ call_expr { void BuildAndEval(CelExpressionBuilder* builder, const Expr& expr, const Activation& activation, google::protobuf::Arena* arena, CelValue* result) { - auto expression_status = builder->CreateExpression(&expr, nullptr); - ASSERT_OK(expression_status.status()); + ASSERT_OK_AND_ASSIGN(auto expression, + builder->CreateExpression(&expr, nullptr)); - auto result_status = expression_status.value()->Evaluate(activation, arena); - ASSERT_OK(result_status.status()); + auto value = expression->Evaluate(activation, arena); + ASSERT_OK(value); - *result = result_status.value(); + *result = *value; } class ShortCircuitingTest : public testing::TestWithParam { public: - ShortCircuitingTest() {} std::unique_ptr GetBuilder( bool enable_unknowns = false) { - auto result = std::make_unique(); - result->set_shortcircuiting(GetParam()); + cel::RuntimeOptions options; + options.short_circuiting = GetParam(); if (enable_unknowns) { - result->set_enable_unknown_function_results(true); - result->set_enable_unknowns(true); + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; } + auto result = std::make_unique( + NewTestingRuntimeEnv(), options); return result; } }; @@ -115,7 +116,7 @@ TEST_P(ShortCircuitingTest, BasicAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(true)); @@ -143,7 +144,7 @@ TEST_P(ShortCircuitingTest, BasicOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(false)); @@ -171,7 +172,7 @@ TEST_P(ShortCircuitingTest, ErrorAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); @@ -201,7 +202,7 @@ TEST_P(ShortCircuitingTest, ErrorOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); @@ -231,7 +232,7 @@ TEST_P(ShortCircuitingTest, UnknownAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); @@ -254,9 +255,8 @@ TEST_P(ShortCircuitingTest, UnknownAnd) { ASSERT_TRUE(result.IsUnknownSet()); const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); - ASSERT_THAT(attrs.attributes(), testing::SizeIs(1)); - EXPECT_THAT(attrs.attributes()[0]->variable().ident_expr().name(), - testing::Eq("var1")); + ASSERT_THAT(attrs, testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, UnknownOr) { @@ -264,7 +264,7 @@ TEST_P(ShortCircuitingTest, UnknownOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); @@ -287,9 +287,8 @@ TEST_P(ShortCircuitingTest, UnknownOr) { ASSERT_TRUE(result.IsUnknownSet()); const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); - ASSERT_THAT(attrs.attributes(), testing::SizeIs(1)); - EXPECT_THAT(attrs.attributes()[0]->variable().ident_expr().name(), - testing::Eq("var1")); + ASSERT_THAT(attrs, testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, BasicTernary) { @@ -338,7 +337,7 @@ TEST_P(ShortCircuitingTest, TernaryErrorHandling) { BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsError()); - EXPECT_EQ(result.ErrorOrDie(), &error1); + EXPECT_EQ(*result.ErrorOrDie(), error1); ASSERT_TRUE(activation.RemoveValueEntry("cond")); activation.InsertValue("cond", CelValue::CreateBool(false)); @@ -369,10 +368,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownCondHandling) { BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, SizeIs(1)); - EXPECT_THAT(attrs[0]->variable().ident_expr().name(), Eq("cond")); + EXPECT_THAT(attrs.begin()->variable_name(), Eq("cond")); // Unknown branches are discarded if condition is unknown activation.set_unknown_attribute_patterns({CelAttributePattern("cond", {}), @@ -382,10 +380,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownCondHandling) { ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs2 = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs2 = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs2, SizeIs(1)); - EXPECT_THAT(attrs2[0]->variable().ident_expr().name(), Eq("cond")); + EXPECT_THAT(attrs2.begin()->variable_name(), Eq("cond")); } TEST_P(ShortCircuitingTest, TernaryUnknownArgsHandling) { @@ -418,10 +415,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownArgsHandling) { ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs3 = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs3 = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs3, SizeIs(1)); - EXPECT_EQ(attrs3[0]->variable().ident_expr().name(), "arg2"); + EXPECT_EQ(attrs3.begin()->variable_name(), "arg2"); } TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { @@ -446,7 +442,7 @@ TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsError()); - EXPECT_EQ(result.ErrorOrDie(), &error); + EXPECT_EQ(*result.ErrorOrDie(), error); // Error arg discarded if condition unknown activation.set_unknown_attribute_patterns({CelAttributePattern("cond", {})}); @@ -456,10 +452,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, SizeIs(1)); - EXPECT_EQ(attrs[0]->variable().ident_expr().name(), "cond"); + EXPECT_EQ(attrs.begin()->variable_name(), "cond"); } const char* TestName(testing::TestParamInfo info) { @@ -474,7 +469,5 @@ INSTANTIATE_TEST_SUITE_P(Test, ShortCircuitingTest, testing::Values(false, true), &TestName); } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index f263e57e0..105060282 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1,38 +1,103 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/compiler/flat_expr_builder.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "common/value.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" +#include "eval/compiler/constant_folding.h" +#include "eval/compiler/qualified_reference_resolver.h" +#include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_builtins.h" +#include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_adapter.h" +#include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/portable_cel_function_adapter.h" +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" -#include "base/status_macros.h" +#include "internal/proto_matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; - -using google::protobuf::FieldMask; -using testing::Eq; -using testing::Not; +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::BytesValue; +using ::cel::Value; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::internal::test::EqualsProto; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::testing::_; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::SizeIs; +using ::testing::Truly; class ConcatFunction : public CelFunction { public: @@ -46,8 +111,7 @@ class ConcatFunction : public CelFunction { absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 2) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Bad arguments number"); + return absl::InvalidArgumentError("Bad arguments number"); } std::string concat = std::string(args[0].StringOrDie().value()) + @@ -62,6 +126,25 @@ class ConcatFunction : public CelFunction { } }; +class RecorderFunction : public CelFunction { + public: + explicit RecorderFunction(const std::string& name, int* count) + : CelFunction(CelFunctionDescriptor{name, false, {}}), count_(count) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + if (!args.empty()) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Bad arguments number"); + } + (*count_)++; + *result = CelValue::CreateBool(true); + return absl::OkStatus(); + } + + int* count_; +}; + TEST(FlatExprBuilderTest, SimpleEndToEnd) { Expr expr; SourceInfo source_info; @@ -74,16 +157,13 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value"); - FlatExprBuilder builder; - - auto register_status = - builder.GetRegistry()->Register(absl::make_unique()); - ASSERT_OK(register_status); - - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - auto cel_expr = std::move(build_status.value()); + ASSERT_THAT( + builder.GetRegistry()->Register(std::make_unique()), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); std::string variable = "test"; @@ -92,14 +172,130 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { google::protobuf::Arena arena; - auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(eval_status); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsString()); + EXPECT_THAT(result.StringOrDie().value(), Eq("prefixtest")); +} + +TEST(FlatExprBuilderTest, ExprUnset) { + Expr expr; + SourceInfo source_info; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid empty expression"))); +} - CelValue result = eval_status.value(); +TEST(FlatExprBuilderTest, RuntimeExtensionsError) { + Expr expr; + SourceInfo source_info; + auto* ext = source_info.add_extensions(); + ext->set_id("ext1"); + ext->add_affected_components( + cel::expr::SourceInfo_Extension_Component_COMPONENT_RUNTIME); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unsupported CEL extension: ext1"))); +} - ASSERT_TRUE(result.IsString()); +TEST(FlatExprBuilderTest, ConstValueUnset) { + Expr expr; + SourceInfo source_info; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + // Create an empty constant expression to ensure that it triggers an error. + expr.mutable_const_expr(); - EXPECT_THAT(result.StringOrDie().value(), Eq("prefixtest")); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unspecified constant"))); +} + +TEST(FlatExprBuilderTest, MapKeyValueUnset) { + Expr expr; + SourceInfo source_info; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + + // Don't set either the key or the value for the map creation step. + auto* entry = expr.mutable_struct_expr()->add_entries(); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Map entry missing key"))); + + // Set the entry key, but not the value. + entry->mutable_map_key()->mutable_const_expr()->set_bool_value(true); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Map entry missing value"))); +} + +TEST(FlatExprBuilderTest, MessageFieldValueUnset) { + Expr expr; + SourceInfo source_info; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + + // Don't set either the field or the value for the message creation step. + auto* create_message = expr.mutable_struct_expr(); + create_message->set_message_name("google.protobuf.Value"); + auto* entry = create_message->add_entries(); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Struct field missing name"))); + + // Set the entry field, but not the value. + entry->set_field_key("bool_value"); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Struct field missing value"))); +} + +TEST(FlatExprBuilderTest, BinaryCallTooManyArguments) { + Expr expr; + SourceInfo source_info; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + + auto* call = expr.mutable_call_expr(); + call->set_function(builtin::kAnd); + call->mutable_target()->mutable_const_expr()->set_string_value("random"); + call->add_args()->mutable_const_expr()->set_bool_value(false); + call->add_args()->mutable_const_expr()->set_bool_value(true); + + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid argument count"))); +} + +TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { + Expr expr; + SourceInfo source_info; + auto* call = expr.mutable_call_expr(); + call->set_function(builtin::kTernary); + call->mutable_target()->mutable_const_expr()->set_string_value("random"); + call->add_args()->mutable_const_expr()->set_bool_value(false); + call->add_args()->mutable_const_expr()->set_int64_value(1); + call->add_args()->mutable_const_expr()->set_int64_value(2); + + { + cel::RuntimeOptions options; + options.short_circuiting = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid argument count"))); + } + + // Disable short-circuiting to ensure that a different visitor is used. + { + cel::RuntimeOptions options; + options.short_circuiting = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid argument count"))); + } } TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { @@ -114,33 +310,26 @@ TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value"); - FlatExprBuilder builder; - builder.set_fail_on_warnings(false); + cel::RuntimeOptions options; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; // Concat function not registered. - auto build_status = builder.CreateExpression(&expr, &source_info, &warnings); - ASSERT_OK(build_status); - - auto cel_expr = std::move(build_status.value()); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder.CreateExpression(&expr, &source_info, &warnings)); std::string variable = "test"; - Activation activation; activation.InsertValue("value", CelValue::CreateString(&variable)); google::protobuf::Arena arena; - auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(eval_status); - - CelValue result = eval_status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(result.ErrorOrDie()->message(), - Eq("No matching overloads found")); + Eq("No matching overloads found : concat(string, string)")); ASSERT_THAT(warnings, testing::SizeIs(1)); EXPECT_EQ(warnings[0].code(), absl::StatusCode::kInvalidArgument); @@ -148,25 +337,6 @@ TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { testing::HasSubstr("No overloads provided")); } -class RecorderFunction : public CelFunction { - public: - explicit RecorderFunction(const std::string& name, int* count) - : CelFunction(CelFunctionDescriptor{name, false, {}}), count_(count) {} - - absl::Status Evaluate(absl::Span args, CelValue* result, - google::protobuf::Arena* arena) const override { - if (!args.empty()) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Bad arguments number"); - } - (*count_)++; - *result = CelValue::CreateBool(true); - return absl::OkStatus(); - } - - int* count_; -}; - TEST(FlatExprBuilderTest, Shortcircuiting) { Expr expr; SourceInfo source_info; @@ -179,48 +349,58 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { auto arg2 = call_expr->add_args(); arg2->mutable_call_expr()->set_function("recorder2"); - FlatExprBuilder builder; - auto builtin_status = RegisterBuiltinFunctions(builder.GetRegistry()); - - int count1 = 0; - int count2 = 0; - - auto register_status1 = builder.GetRegistry()->Register( - absl::make_unique("recorder1", &count1)); - ASSERT_OK(register_status1); - auto register_status2 = builder.GetRegistry()->Register( - absl::make_unique("recorder2", &count2)); - ASSERT_OK(register_status2); - - // Shortcircuiting on. - auto build_status_on = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status_on); - - auto cel_expr_on = std::move(build_status_on.value()); - Activation activation; google::protobuf::Arena arena; - auto eval_status_on = cel_expr_on->Evaluate(activation, &arena); - ASSERT_OK(eval_status_on); - EXPECT_THAT(count1, Eq(1)); - EXPECT_THAT(count2, Eq(0)); + // Shortcircuiting on + { + cel::RuntimeOptions options; + options.short_circuiting = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + + int count1 = 0; + int count2 = 0; + + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder1", &count1)), + IsOk()); + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder2", &count2)), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr_on, + builder.CreateExpression(&expr, &source_info)); + ASSERT_THAT(cel_expr_on->Evaluate(activation, &arena), IsOk()); + + EXPECT_THAT(count1, Eq(1)); + EXPECT_THAT(count2, Eq(0)); + } // Shortcircuiting off. - builder.set_shortcircuiting(false); - auto build_status_off = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status_off); - - auto cel_expr_off = std::move(build_status_off.value()); - - count1 = 0; - count2 = 0; - - auto eval_status_off = cel_expr_off->Evaluate(activation, &arena); - ASSERT_OK(eval_status_off); - - EXPECT_THAT(count1, Eq(1)); - EXPECT_THAT(count2, Eq(1)); + { + cel::RuntimeOptions options; + options.short_circuiting = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + + int count1 = 0; + int count2 = 0; + + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder1", &count1)), + IsOk()); + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder2", &count2)), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr_off, + builder.CreateExpression(&expr, &source_info)); + + ASSERT_THAT(cel_expr_off->Evaluate(activation, &arena), IsOk()); + EXPECT_THAT(count1, Eq(1)); + EXPECT_THAT(count2, Eq(1)); + } } TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { @@ -240,50 +420,217 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { ->mutable_const_expr() ->set_bool_value(false); comprehension_expr->mutable_loop_step()->mutable_call_expr()->set_function( - "loop_step"); + "recorder_function1"); comprehension_expr->mutable_result()->mutable_const_expr()->set_bool_value( false); - FlatExprBuilder builder; - auto builtin_status = RegisterBuiltinFunctions(builder.GetRegistry()); + Activation activation; + google::protobuf::Arena arena; + + // shortcircuiting on + { + cel::RuntimeOptions options; + options.short_circuiting = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + + int count = 0; + ASSERT_THAT( + builder.GetRegistry()->Register( + std::make_unique("recorder_function1", &count)), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr_on, + builder.CreateExpression(&expr, &source_info)); + + ASSERT_THAT(cel_expr_on->Evaluate(activation, &arena), IsOk()); + EXPECT_THAT(count, Eq(0)); + } + + // shortcircuiting off + { + cel::RuntimeOptions options; + options.short_circuiting = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + + int count = 0; + ASSERT_THAT( + builder.GetRegistry()->Register( + std::make_unique("recorder_function1", &count)), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr_off, + builder.CreateExpression(&expr, &source_info)); + ASSERT_THAT(cel_expr_off->Evaluate(activation, &arena), IsOk()); + EXPECT_THAT(count, Eq(3)); + } +} - int count = 0; - auto register_status = builder.GetRegistry()->Register( - absl::make_unique("loop_step", &count)); - ASSERT_OK(register_status); +TEST(FlatExprBuilderTest, IdentExprUnsetName) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'name' must not be empty"))); +} - // Shortcircuiting on. - auto build_status_on = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status_on); +TEST(FlatExprBuilderTest, SelectExprUnsetField) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + operand{ ident_expr {name: 'var'} } + })", + &expr)); - auto cel_expr_on = std::move(build_status_on.value()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'field' must not be empty"))); +} - Activation activation; - google::protobuf::Arena arena; - auto eval_status_on = cel_expr_on->Evaluate(activation, &arena); - ASSERT_OK(eval_status_on); +TEST(FlatExprBuilderTest, SelectExprUnsetOperand) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + field: 'field' + operand { id: 1 } + })", + &expr)); - EXPECT_THAT(count, Eq(0)); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("must specify an operand"))); +} - // Shortcircuiting off. - builder.set_shortcircuiting(false); - auto build_status_off = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status_off); +TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr)); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'accu_var' must not be empty"))); +} + +TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + comprehension_expr{accu_var: "a"} + )", + &expr)); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'iter_var' must not be empty"))); +} - auto cel_expr_off = std::move(build_status_off.value()); +TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + comprehension_expr{ + accu_var: "a" + iter_var: "b"} + )", + &expr)); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'accu_init' must be set"))); +} - count = 0; +TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + comprehension_expr{ + accu_var: 'a' + iter_var: 'b' + accu_init { + const_expr {bool_value: true} + }} + )", + &expr)); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'loop_condition' must be set"))); +} - auto eval_status_off = cel_expr_off->Evaluate(activation, &arena); - ASSERT_OK(eval_status_off); +TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + comprehension_expr{ + accu_var: 'a' + iter_var: 'b' + accu_init { + const_expr {bool_value: true} + } + loop_condition { + const_expr {bool_value: true} + }} + )", + &expr)); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'loop_step' must be set"))); +} - EXPECT_THAT(count, Eq(3)); +TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { + Expr expr; + SourceInfo source_info; + // An empty ident without the name set should error. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + comprehension_expr{ + accu_var: 'a' + iter_var: 'b' + accu_init { + const_expr {bool_value: true} + } + loop_condition { + const_expr {bool_value: true} + } + loop_step { + const_expr {bool_value: false} + }} + )", + &expr)); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("'result' must be set"))); } TEST(FlatExprBuilderTest, MapComprehension) { Expr expr; + SourceInfo source_info; // {1: "", 2: ""}.all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" @@ -320,115 +667,637 @@ TEST(FlatExprBuilderTest, MapComprehension) { } } })", - &expr); + &expr)); - FlatExprBuilder builder; - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - SourceInfo source_info; - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); - - auto cel_expr = std::move(build_status.value()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; - auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(result_or); - CelValue result = result_or.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } -TEST(FlatExprBuilderTest, ComprehensionWorksForError) { +TEST(FlatExprBuilderTest, InvalidContainer) { Expr expr; - // {}[0].all(x, x) should evaluate OK but return an error value - google::protobuf::TextFormat::ParseFromString(R"( - id: 4 - comprehension_expr { - iter_var: "x" - iter_range { - id: 2 - call_expr { - function: "_[_]" - args { - id: 1 - struct_expr { - } - } - args { - id: 3 - const_expr { - int64_value: 0 - } - } + SourceInfo source_info; + // foo && bar + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + call_expr { + function: "_&&_" + args { + ident_expr { + name: "foo" } } - accu_var: "__result__" - accu_init { - id: 7 - const_expr { - bool_value: true + args { + ident_expr { + name: "bar" } } - loop_condition { - id: 8 - call_expr { - function: "__not_strictly_false__" - args { - id: 9 - ident_expr { - name: "__result__" - } + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + + builder.set_container(".bad"); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("container: '.bad'"))); + + builder.set_container("bad."); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("container: 'bad.'"))); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "ext.XOr", /*receiver_style=*/false, + [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, + builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("a", CelValue::CreateBool(false)); + act1.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); + + Activation act2; + act2.InsertValue("a", CelValue::CreateBool(true)); + act2.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(act2, &arena)); + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("XOr(a, b)")); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); + builder.set_container("ext"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "ext.XOr", /*receiver_style=*/false, + [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, + builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("a", CelValue::CreateBool(false)); + act1.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); + + Activation act2; + act2.InsertValue("a", CelValue::CreateBool(true)); + act2.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(act2, &arena)); + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.b.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return true; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, + ParsedNamespacedFunctionResolutionOrderParentContainer) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return true; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, + ParsedNamespacedFunctionResolutionOrderExplicitGlobal) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(".c.d.Get()")); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return true; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderReceiverCall) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("e.Get()")); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return true; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("e", CelValue::CreateBool(false)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); + cel::RuntimeOptions options; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + std::vector build_warnings; + builder.set_container("ext"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "ext.XOr", /*receiver_style=*/false, + [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, + builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder.CreateExpression(&expr.expr(), &expr.source_info(), + &build_warnings)); + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("a", CelValue::CreateBool(false)); + act1.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kUnknown, + HasSubstr("ext")))); +} + +TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { + CheckedExpr expr; + // foo && bar + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + expr { + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + ident_expr { + name: "foo" } } - } - loop_step { - id: 10 - call_expr { - function: "_&&_" - args { - id: 11 - ident_expr { - name: "__result__" - } - } - args { - id: 6 - ident_expr { - name: "x" - } + args { + id: 3 + ident_expr { + name: "bar" } } } - result { - id: 12 - ident_expr { - name: "__result__" - } - } })", - &expr); - - FlatExprBuilder builder; - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - SourceInfo source_info; - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + &expr)); - auto cel_expr = std::move(build_status.value()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; + activation.InsertValue("foo", CelValue::CreateBool(true)); + activation.InsertValue("bar", CelValue::CreateBool(true)); google::protobuf::Arena arena; - auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(result_or); - CelValue result = result_or.value(); - ASSERT_TRUE(result.IsError()); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + EXPECT_TRUE(result.BoolOrDie()); +} + +TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { + CheckedExpr expr; + // `foo.var1` && `bar.var2` + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + reference_map { + key: 2 + value { + name: "foo.var1" + } + } + reference_map { + key: 4 + value { + name: "bar.var2" + } + } + expr { + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + select_expr { + field: "var1" + operand { + id: 3 + ident_expr { + name: "foo" + } + } + } + } + args { + id: 4 + select_expr { + field: "var2" + operand { + ident_expr { + name: "bar" + } + } + } + } + } + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); + + Activation activation; + activation.InsertValue("foo.var1", CelValue::CreateBool(true)); + activation.InsertValue("bar.var2", CelValue::CreateBool(true)); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + EXPECT_TRUE(result.BoolOrDie()); +} + +TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { + CheckedExpr expr; + // ext.and(var1, bar.var2) + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + reference_map { + key: 1 + value { + overload_id: "com.foo.ext.and" + } + } + reference_map { + key: 3 + value { + name: "com.foo.var1" + } + } + reference_map { + key: 4 + value { + name: "bar.var2" + } + } + expr { + id: 1 + call_expr { + function: "and" + target { + id: 2 + ident_expr { + name: "ext" + } + } + args { + id: 3 + ident_expr { + name: "var1" + } + } + args { + id: 4 + select_expr { + field: "var2" + operand { + id: 5 + ident_expr { + name: "bar" + } + } + } + } + } + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); + builder.set_container("com.foo"); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "com.foo.ext.and", false, + [](google::protobuf::Arena*, bool lhs, bool rhs) { return lhs && rhs; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); + + Activation activation; + activation.InsertValue("com.foo.var1", CelValue::CreateBool(true)); + activation.InsertValue("bar.var2", CelValue::CreateBool(true)); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + EXPECT_TRUE(result.BoolOrDie()); +} + +TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { + CheckedExpr expr; + // && . + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + reference_map { + key: 2 + value { + name: "foo.var1" + } + } + reference_map { + key: 5 + value { + name: "bar" + } + } + expr { + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + select_expr { + field: "var1" + operand { + id: 3 + ident_expr { + name: "foo" + } + } + } + } + args { + id: 4 + select_expr { + field: "var2" + operand { + id: 5 + ident_expr { + name: "bar" + } + } + } + } + } + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); + + Activation activation; + activation.InsertValue("foo.var1", CelValue::CreateBool(true)); + // Activation tries to bind a namespaced variable but the reference map refers + // to the container 'bar'. + activation.InsertValue("bar.var2", CelValue::CreateBool(true)); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*(result.ErrorOrDie()), + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("No value with name \"bar\" found"))); + + // Re-run with the expected interpretation of `bar`.`var2` + std::vector> map_pairs{ + {CelValue::CreateStringView("var2"), CelValue::CreateBool(false)}}; + + std::unique_ptr map_value = + *CreateContainerBackedMap(absl::MakeSpan(map_pairs)); + activation.InsertValue("bar", CelValue::CreateMap(map_value.get())); + ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + EXPECT_FALSE(result.BoolOrDie()); +} + +TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { + CheckedExpr expr; + // {`var1`: 'hello'} + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + reference_map { + key: 3 + value { + name: "var1" + value { + int64_value: 1 + } + } + } + expr { + id: 1 + struct_expr { + entries { + id: 2 + map_key { + id: 3 + ident_expr { + name: "var1" + } + } + value { + id: 4 + const_expr { + string_value: "hello" + } + } + } + } + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.flat_expr_builder().AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); + google::protobuf::Arena arena; + builder.flat_expr_builder().AddProgramOptimizer( + cel::runtime_internal::CreateConstantFoldingOptimizer()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMap()); + auto m = result.MapOrDie(); + auto v = m->Get(&arena, CelValue::CreateInt64(1L)); + EXPECT_THAT(v->StringOrDie().value(), Eq("hello")); +} + +TEST(FlatExprBuilderTest, ComprehensionWorksForError) { + Expr expr; + SourceInfo source_info; + // {}[0].all(x, x) should evaluate OK but return an error value + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + id: 4 + comprehension_expr { + iter_var: "x" + iter_range { + id: 2 + call_expr { + function: "_[_]" + args { + id: 1 + struct_expr { + } + } + args { + id: 3 + const_expr { + int64_value: 0 + } + } + } + } + accu_var: "__result__" + accu_init { + id: 7 + const_expr { + bool_value: true + } + } + loop_condition { + id: 8 + call_expr { + function: "__not_strictly_false__" + args { + id: 9 + ident_expr { + name: "__result__" + } + } + } + } + loop_step { + id: 10 + call_expr { + function: "_&&_" + args { + id: 11 + ident_expr { + name: "__result__" + } + } + args { + id: 6 + ident_expr { + name: "x" + } + } + } + } + result { + id: 12 + ident_expr { + name: "__result__" + } + } + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsError()); } TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { Expr expr; + SourceInfo source_info; // 0.all(x, x) should evaluate OK but return an error value. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -482,30 +1351,26 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { } } })", - &expr); - - FlatExprBuilder builder; - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - SourceInfo source_info; - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + &expr)); - auto cel_expr = std::move(build_status.value()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; - auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(result_or); - CelValue result = result_or.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie()->message(), - Eq("No matching overloads found ")); + Eq("No matching overloads found : ")); } TEST(FlatExprBuilderTest, ComprehensionBudget) { Expr expr; + SourceInfo source_info; // [1, 2].all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" @@ -531,126 +1396,75 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { } iter_range { list_expr { - { const_expr { int64_value: 1 } } - { const_expr { int64_value: 2 } } + elements { const_expr { int64_value: 1 } } + elements { const_expr { int64_value: 2 } } } } })", - &expr); - - FlatExprBuilder builder; - builder.set_comprehension_max_iterations(1); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - SourceInfo source_info; - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + &expr)); - auto cel_expr = std::move(build_status.value()); + cel::RuntimeOptions options; + options.comprehension_max_iterations = 1; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; - auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_FALSE(result_or.ok()); - EXPECT_THAT(result_or.status().message(), Eq("Iteration budget exceeded")); + EXPECT_THAT(cel_expr->Evaluate(activation, &arena).status(), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Iteration budget exceeded"))); } -TEST(FlatExprBuilderTest, UnknownSupportTest) { +TEST(FlatExprBuilderTest, SimpleEnumTest) { TestMessage message; - Expr expr; SourceInfo source_info; + constexpr char enum_name[] = + "google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1"; - auto select_expr = expr.mutable_select_expr(); - select_expr->set_field("int32_value"); - - auto operand1 = select_expr->mutable_operand(); - auto select_expr1 = operand1->mutable_select_expr(); - - select_expr1->set_field("message_value"); - auto operand2 = select_expr1->mutable_operand(); - - operand2->mutable_ident_expr()->set_name("message"); - - FlatExprBuilder builder; + std::vector enum_name_parts = absl::StrSplit(enum_name, '.'); + Expr* cur_expr = &expr; - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + for (int i = enum_name_parts.size() - 1; i > 0; i--) { + auto select_expr = cur_expr->mutable_select_expr(); + select_expr->set_field(enum_name_parts[i]); + cur_expr = select_expr->mutable_operand(); + } - auto cel_expr = std::move(build_status.value()); + cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); - message.mutable_message_value()->set_int32_value(1); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; Activation activation; - activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); - - auto eval_status = cel_expr->Evaluate(activation, &arena); - - ASSERT_OK(eval_status); - CelValue result = eval_status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); - EXPECT_THAT(result.Int64OrDie(), Eq(1)); - - FieldMask mask; - mask.add_paths("message.message_value.int32_value"); - activation.set_unknown_paths(mask); - eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(eval_status); - result = eval_status.value(); - ASSERT_TRUE(result.IsError()); - ASSERT_TRUE(IsUnknownValueError(result)); - EXPECT_THAT(GetUnknownPathsSetOrDie(result), - Eq(std::set({"message.message_value.int32_value"}))); - - mask.clear_paths(); - mask.add_paths("message.message_value"); - activation.set_unknown_paths(mask); - eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(eval_status); - result = eval_status.value(); - ASSERT_TRUE(result.IsError()); - ASSERT_TRUE(IsUnknownValueError(result)); - EXPECT_THAT(GetUnknownPathsSetOrDie(result), - Eq(std::set({"message.message_value"}))); + EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } -TEST(FlatExprBuilderTest, SimpleEnumTest) { +TEST(FlatExprBuilderTest, SimpleEnumIdentTest) { TestMessage message; - Expr expr; SourceInfo source_info; - constexpr char enum_name[] = "google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1"; - std::vector enum_name_parts = absl::StrSplit(enum_name, '.'); Expr* cur_expr = &expr; + cur_expr->mutable_ident_expr()->set_name(enum_name); - for (int i = enum_name_parts.size() - 1; i > 0; i--) { - auto select_expr = cur_expr->mutable_select_expr(); - select_expr->set_field(enum_name_parts[i]); - cur_expr = select_expr->mutable_operand(); - } - - cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); - - FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); - - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); - - auto cel_expr = std::move(build_status.value()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; Activation activation; - auto eval_status = cel_expr->Evaluate(activation, &arena); - - ASSERT_OK(eval_status); - CelValue result = eval_status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } @@ -658,126 +1472,341 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { TEST(FlatExprBuilderTest, ContainerStringFormat) { Expr expr; SourceInfo source_info; - expr.mutable_ident_expr()->set_name("ident"); - FlatExprBuilder builder; - builder.set_container(""); { - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.set_container(""); + ASSERT_THAT(builder.CreateExpression(&expr, &source_info), IsOk()); } - builder.set_container("random.namespace"); { - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.set_container("random.namespace"); + ASSERT_THAT(builder.CreateExpression(&expr, &source_info), IsOk()); } - - // Leading '.' - builder.set_container(".random.namespace"); { - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_FALSE(build_status.status().ok()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + // Leading '.' + builder.set_container(".random.namespace"); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid expression container"))); } - - // Trailing '.' - builder.set_container("random.namespace."); { - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_FALSE(build_status.status().ok()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + // Trailing '.' + builder.set_container("random.namespace."); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid expression container"))); } } -void EvalExpressionWithEnum(absl::string_view enum_name, - absl::string_view container, CelValue* result) { - TestMessage message; +// Builder with google.api.expr.runtime.TestMessage and TestEnum types +// linked in and the standard functions registered. +CelExpressionBuilderFlatImpl BuilderForNameResolutionTest( + absl::string_view container) { + cel::RuntimeOptions options; + options.enable_qualified_type_identifiers = true; - Expr expr; - SourceInfo source_info; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestEnum_descriptor()); + builder.set_container(std::string(container)); + ABSL_CHECK_OK(cel::RegisterStandardFunctions( + builder.GetRegistry()->InternalGetRegistry(), options)); + return builder; +} - std::vector enum_name_parts = absl::StrSplit(enum_name, '.'); - Expr* cur_expr = &expr; +TEST(FlatExprBuilderTest, ShortEnumResolution) { + google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr.runtime.TestMessage"); - for (int i = enum_name_parts.size() - 1; i > 0; i--) { - auto select_expr = cur_expr->mutable_select_expr(); - select_expr->set_field(enum_name_parts[i]); - cur_expr = select_expr->mutable_operand(); - } + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + parser::Parse("TestMessage.TestEnum.TEST_ENUM_1")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); - cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); + Activation activation; - FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); - builder.AddResolvableEnum(TestEnum_descriptor()); - builder.set_container(std::string(container)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); - auto cel_expr = std::move(build_status.value()); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); +} +TEST(FlatExprBuilderTest, EnumResolutionHonorsLeadingDot) { google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr.runtime"); + + // Leading dot disables container resolution. + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + parser::Parse(".TestMessage.TestEnum.TEST_ENUM_1")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + Activation activation; - auto eval_status = cel_expr->Evaluate(activation, &arena); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsError()); + EXPECT_THAT( + result.ErrorOrDie()->message(), + HasSubstr("No value with name \"TestMessage\" found in Activation")); +} + +TEST(FlatExprBuilderTest, EnumResolutionComprehensionShadowing) { + google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr.runtime"); + + // Prefer the interpretation that it's a comprehension var if there's a + // collision. + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + parser::Parse("[{'TestEnum': {'TEST_ENUM_1': 42}}].map(TestMessage, " + "TestMessage.TestEnum.TEST_ENUM_1)[0] == 42")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); - ASSERT_OK(eval_status); - *result = eval_status.value(); + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + EXPECT_TRUE(result.BoolOrDie()); } -TEST(FlatExprBuilderTest, ShortEnumResolution) { - CelValue result; - // Test resolution of ".". - ASSERT_NO_FATAL_FAILURE(EvalExpressionWithEnum( - "TestEnum.TEST_ENUM_1", "google.api.expr.runtime.TestMessage", &result)); - ASSERT_TRUE(result.IsInt64()); - EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); +TEST(FlatExprBuilderTest, EnumResolutionComprehensionShadowingLeadingDot) { + google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr.runtime"); + + // Prefer the interpretation that it's a comprehension var if there's a + // collision. + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + parser::Parse("[0].map(google, " + ".google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1)" + "[0] == TestMessage.TestEnum.TEST_ENUM_1")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, FullEnumNameWithContainerResolution) { - CelValue result; + google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("very.random.Namespace"); + // Fully qualified name should work. - ASSERT_NO_FATAL_FAILURE(EvalExpressionWithEnum( - "google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1", - "very.random.Namespace", &result)); + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + parser::Parse( + "google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } TEST(FlatExprBuilderTest, SameShortNameEnumResolution) { - CelValue result; + google::protobuf::Arena arena; // This precondition validates that // TestMessage::TestEnum::TEST_ENUM1 and TestEnum::TEST_ENUM1 are compiled and // linked in and their values are different. ASSERT_TRUE(static_cast(TestEnum::TEST_ENUM_1) != static_cast(TestMessage::TEST_ENUM_1)); - ASSERT_NO_FATAL_FAILURE(EvalExpressionWithEnum( - "TestEnum.TEST_ENUM_1", "google.api.expr.runtime.TestMessage", &result)); - ASSERT_TRUE(result.IsInt64()); - EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); - // TEST_ENUM3 is present in google.api.expr.runtime.TestEnum, is absent in - // google.api.expr.runtime.TestMessage.TestEnum. - ASSERT_NO_FATAL_FAILURE(EvalExpressionWithEnum( - "TestEnum.TEST_ENUM_3", "google.api.expr.runtime.TestMessage", &result)); - ASSERT_TRUE(result.IsInt64()); - EXPECT_THAT(result.Int64OrDie(), Eq(TestEnum::TEST_ENUM_3)); + { + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr.runtime.TestMessage"); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + parser::Parse("TestEnum.TEST_ENUM_1")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); + } - ASSERT_NO_FATAL_FAILURE(EvalExpressionWithEnum( - "TestEnum.TEST_ENUM_1", "google.api.expr.runtime", &result)); - ASSERT_TRUE(result.IsInt64()); - EXPECT_THAT(result.Int64OrDie(), Eq(TestEnum::TEST_ENUM_1)); + // TEST_ENUM3 is present in google.api.expr.runtime.TestEnum, is absent in + // google.api.expr.runtime.TestMessage.TestEnum. + { + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr.runtime.TestMessage"); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + parser::Parse("TestEnum.TEST_ENUM_3")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(TestEnum::TEST_ENUM_3)); + } + + { + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr.runtime"); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + parser::Parse("TestEnum.TEST_ENUM_1")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(TestEnum::TEST_ENUM_1)); + } } TEST(FlatExprBuilderTest, PartialQualifiedEnumResolution) { - CelValue result; - ASSERT_NO_FATAL_FAILURE(EvalExpressionWithEnum( - "runtime.TestMessage.TestEnum.TEST_ENUM_1", "google.api.expr", &result)); + google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = + BuilderForNameResolutionTest("google.api.expr"); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + parser::Parse("runtime.TestMessage.TestEnum.TEST_ENUM_1")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } +TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVar) { + google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[0].map(x, x)[0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + + Activation activation; + activation.InsertValue("x", CelValue::CreateInt64(1)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(0)); +} + +TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVarLeadingDot) { + google::protobuf::Arena arena; + CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[0].map(x, .x)[0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + + Activation activation; + activation.InsertValue("x", CelValue::CreateInt64(1)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(1)); +} + +TEST(FlatExprBuilderTest, MapFieldPresence) { + Expr expr; + SourceInfo source_info; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + id: 1, + select_expr{ + operand { + id: 2 + ident_expr{ name: "msg" } + } + field: "string_int32_map" + test_only: true + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); + + google::protobuf::Arena arena; + { + TestMessage message; + auto strMap = message.mutable_string_int32_map(); + strMap->insert({"key", 1}); + Activation activation; + activation.InsertValue("msg", + CelProtoWrapper::CreateMessage(&message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } + { + TestMessage message; + Activation activation; + activation.InsertValue("msg", + CelProtoWrapper::CreateMessage(&message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + } +} + +TEST(FlatExprBuilderTest, RepeatedFieldPresence) { + Expr expr; + SourceInfo source_info; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + id: 1, + select_expr{ + operand { + id: 2 + ident_expr{ name: "msg" } + } + field: "int32_list" + test_only: true + })", + &expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); + + google::protobuf::Arena arena; + { + TestMessage message; + message.add_int32_list(1); + Activation activation; + activation.InsertValue("msg", + CelProtoWrapper::CreateMessage(&message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } + { + TestMessage message; + Activation activation; + activation.InsertValue("msg", + CelProtoWrapper::CreateMessage(&message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + } +} + absl::Status RunTernaryExpression(CelValue selector, CelValue value1, CelValue value2, google::protobuf::Arena* arena, CelValue* result) { @@ -793,13 +1822,9 @@ absl::Status RunTernaryExpression(CelValue selector, CelValue value1, auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value2"); - FlatExprBuilder builder; - auto build_status = builder.CreateExpression(&expr, &source_info); - if (!build_status.ok()) { - return build_status.status(); - } - - auto cel_expr = std::move(build_status.value()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + CEL_ASSIGN_OR_RETURN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); std::string variable = "test"; @@ -808,13 +1833,9 @@ absl::Status RunTernaryExpression(CelValue selector, CelValue value1, activation.InsertValue("value1", value1); activation.InsertValue("value2", value2); - auto eval_status = cel_expr->Evaluate(activation, arena); - if (!eval_status.ok()) { - return eval_status.status(); - } - - *result = eval_status.value(); - return eval_status.status(); + CEL_ASSIGN_OR_RETURN(auto eval, cel_expr->Evaluate(activation, arena)); + *result = eval; + return absl::OkStatus(); } TEST(FlatExprBuilderTest, Ternary) { @@ -830,34 +1851,34 @@ TEST(FlatExprBuilderTest, Ternary) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value1"); - FlatExprBuilder builder; - // builder.set_enable_unknowns(true); - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); - - auto cel_expr = std::move(build_status.value()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; // On True, value 1 { CelValue result; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(1)); // Unknown handling UnknownSet unknown_set; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), - CelValue::CreateUnknownSet(&unknown_set), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); - ASSERT_OK(RunTernaryExpression( - CelValue::CreateBool(true), CelValue::CreateInt64(1), - CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_THAT(RunTernaryExpression( + CelValue::CreateBool(true), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(1)); } @@ -865,73 +1886,70 @@ TEST(FlatExprBuilderTest, Ternary) { // On False, value 2 { CelValue result; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(2)); // Unknown handling UnknownSet unknown_set; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), - CelValue::CreateUnknownSet(&unknown_set), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(2)); - ASSERT_OK(RunTernaryExpression( - CelValue::CreateBool(false), CelValue::CreateInt64(1), - CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_THAT(RunTernaryExpression( + CelValue::CreateBool(false), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); } // On Error, surface error { CelValue result; - ASSERT_OK(RunTernaryExpression(CreateErrorValue(&arena, "error"), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CreateErrorValue(&arena, "error"), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsError()); } // On Unknown, surface Unknown { UnknownSet unknown_set; CelValue result; - ASSERT_OK(RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_set), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); - EXPECT_THAT(&unknown_set, Eq(result.UnknownSetOrDie())); + EXPECT_THAT(unknown_set, Eq(*result.UnknownSetOrDie())); } // We should not merge unknowns { - Expr selector; - selector.mutable_ident_expr()->set_name("selector"); - CelAttribute selector_attr(selector, {}); + CelAttribute selector_attr("selector", {}); - Expr value1; - value1.mutable_ident_expr()->set_name("value1"); - CelAttribute value1_attr(value1, {}); + CelAttribute value1_attr("value1", {}); - Expr value2; - value2.mutable_ident_expr()->set_name("value2"); - CelAttribute value2_attr(value2, {}); + CelAttribute value2_attr("value2", {}); - UnknownSet unknown_selector(UnknownAttributeSet({&selector_attr})); - UnknownSet unknown_value1(UnknownAttributeSet({&value1_attr})); - UnknownSet unknown_value2(UnknownAttributeSet({&value2_attr})); + UnknownSet unknown_selector(UnknownAttributeSet({selector_attr})); + UnknownSet unknown_value1(UnknownAttributeSet({value1_attr})); + UnknownSet unknown_value2(UnknownAttributeSet({value2_attr})); CelValue result; - ASSERT_OK(RunTernaryExpression( - CelValue::CreateUnknownSet(&unknown_selector), - CelValue::CreateUnknownSet(&unknown_value1), - CelValue::CreateUnknownSet(&unknown_value2), &arena, &result)); + ASSERT_THAT( + RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_selector), + CelValue::CreateUnknownSet(&unknown_value1), + CelValue::CreateUnknownSet(&unknown_value2), + &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); - EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(1)); - EXPECT_THAT(result_set->unknown_attributes() - .attributes()[0] - ->variable() - .ident_expr() - .name(), + EXPECT_THAT(result_set->unknown_attributes().size(), Eq(1)); + EXPECT_THAT(result_set->unknown_attributes().begin()->variable_name(), Eq("selector")); } } @@ -943,16 +1961,1193 @@ TEST(FlatExprBuilderTest, EmptyCallList) { SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function(op); - FlatExprBuilder builder; - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_FALSE(build_status.ok()); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + auto build = builder.CreateExpression(&expr, &source_info); + ASSERT_FALSE(build.ok()); } } +// Note: this should not be allowed by default, but updating is a breaking +// change. +TEST(FlatExprBuilderTest, HeterogeneousListsAllowed) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("[17, 'seventeen']")); + + cel::RuntimeOptions options; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + ASSERT_TRUE(result.IsList()) << result.DebugString(); + + const auto& list = *result.ListOrDie(); + ASSERT_EQ(list.size(), 2); + + CelValue elem0 = list.Get(&arena, 0); + CelValue elem1 = list.Get(&arena, 1); + + EXPECT_THAT(elem0, test::IsCelInt64(17)); + EXPECT_THAT(elem1, test::IsCelString("seventeen")); +} + +TEST(FlatExprBuilderTest, NullUnboxingEnabled) { + TestMessage message; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("message.int32_wrapper_value")); + cel::RuntimeOptions options; + options.enable_empty_wrapper_null_unboxing = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("message", + CelProtoWrapper::CreateMessage(&message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_TRUE(result.IsNull()); +} + +TEST(FlatExprBuilderTest, TypeResolve) { + TestMessage message; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("type(message) == runtime.TestMessage")); + cel::RuntimeOptions options; + options.enable_qualified_type_identifiers = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.set_container("google.api.expr"); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("message", + CelProtoWrapper::CreateMessage(&message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_TRUE(result.BoolOrDie()); +} + +TEST(FlatExprBuilderTest, FastEquality) { + TestMessage message; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("'foo' == 'bar'")); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_FALSE(result.BoolOrDie()); +} + +TEST(FlatExprBuilderTest, FastEqualityFiltersBadCalls) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("'foo' == 'bar'")); + parsed_expr.mutable_expr() + ->mutable_call_expr() + ->mutable_target() + ->mutable_const_expr() + ->set_string_value("foo"); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr( + "unexpected number of args for builtin equality operator"))); +} + +TEST(FlatExprBuilderTest, FastInequalityFiltersBadCalls) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("'foo' != 'bar'")); + parsed_expr.mutable_expr() + ->mutable_call_expr() + ->mutable_target() + ->mutable_const_expr() + ->set_string_value("foo"); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr( + "unexpected number of args for builtin equality operator"))); +} + +TEST(FlatExprBuilderTest, FastInFiltersBadCalls) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("a in b")); + parsed_expr.mutable_expr() + ->mutable_call_expr() + ->mutable_target() + ->mutable_const_expr() + ->set_string_value("foo"); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of args for builtin 'in' operator"))); +} + +TEST(FlatExprBuilderTest, IndexFiltersBadCalls) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("a[b]")); + parsed_expr.mutable_expr() + ->mutable_call_expr() + ->mutable_target() + ->mutable_const_expr() + ->set_string_value("foo"); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of args for builtin index operator"))); +} + +// TODO(uncreated-issue/79): temporarily allow index operator with a target. +TEST(FlatExprBuilderTest, IndexWithTarget) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("a[b]")); + parsed_expr.mutable_expr() + ->mutable_call_expr() + ->mutable_target() + ->mutable_ident_expr() + ->set_name("a"); + parsed_expr.mutable_expr() + ->mutable_call_expr() + ->mutable_args() + ->DeleteSubrange(0, 1); + + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + IsOk()); +} + +TEST(FlatExprBuilderTest, NotFiltersBadCalls) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("!a")); + parsed_expr.mutable_expr() + ->mutable_call_expr() + ->mutable_target() + ->mutable_const_expr() + ->set_string_value("foo"); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of args for builtin not operator"))); +} + +TEST(FlatExprBuilderTest, NotStrictlyFalseFiltersBadCalls) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("!a")); + auto* call = parsed_expr.mutable_expr()->mutable_call_expr(); + call->mutable_target()->mutable_const_expr()->set_string_value("foo"); + call->set_function("@not_strictly_false"); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + ASSERT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of args for builtin " + "not_strictly_false operator"))); +} + +TEST(FlatExprBuilderTest, FastEqualityDisabledWithCustomEquality) { + TestMessage message; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("1 == b'\001'")); + cel::RuntimeOptions options; + options.enable_fast_builtins = true; + InterpreterOptions legacy_options; + legacy_options.enable_fast_builtins = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), + IsOk()); + + auto& registry = builder.GetRegistry()->InternalGetRegistry(); + + auto status = cel::BinaryFunctionAdapter:: + RegisterGlobalOverload( + "_==_", + [](int64_t lhs, const cel::BytesValue& rhs) -> bool { return true; }, + registry); + ASSERT_THAT(status, IsOk()); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_TRUE(result.BoolOrDie()); +} + +TEST(FlatExprBuilderTest, AnyPackingList) { + google::protobuf::LinkMessageReflection(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("TestAllTypes{single_any: [1, 2, 3]}")); + + cel::RuntimeOptions options; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, + test::IsCelMessage(EqualsProto( + R"pb(single_any { + [type.googleapis.com/google.protobuf.ListValue] { + values { number_value: 1 } + values { number_value: 2 } + values { number_value: 3 } + } + })pb"))) + << result.DebugString(); +} + +TEST(FlatExprBuilderTest, AnyPackingNestedNumbers) { + google::protobuf::LinkMessageReflection(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("TestAllTypes{single_any: [1, 2.3]}")); + + cel::RuntimeOptions options; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, + test::IsCelMessage(EqualsProto( + R"pb(single_any { + [type.googleapis.com/google.protobuf.ListValue] { + values { number_value: 1 } + values { number_value: 2.3 } + } + })pb"))) + << result.DebugString(); +} + +TEST(FlatExprBuilderTest, AnyPackingInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("TestAllTypes{single_any: 1}")); + + cel::RuntimeOptions options; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT( + result, + test::IsCelMessage(EqualsProto( + R"pb(single_any { + [type.googleapis.com/google.protobuf.Int64Value] { value: 1 } + })pb"))) + << result.DebugString(); +} + +TEST(FlatExprBuilderTest, AnyPackingMap) { + ASSERT_OK_AND_ASSIGN( + ParsedExpr parsed_expr, + parser::Parse("TestAllTypes{single_any: {'key': 'value'}}")); + + cel::RuntimeOptions options; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + builder.set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, test::IsCelMessage(EqualsProto( + R"pb(single_any { + [type.googleapis.com/google.protobuf.Struct] { + fields { + key: "key" + value { string_value: "value" } + } + } + })pb"))) + << result.DebugString(); +} + +TEST(FlatExprBuilderTest, NullUnboxingDisabled) { + TestMessage message; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("message.int32_wrapper_value")); + cel::RuntimeOptions options; + options.enable_empty_wrapper_null_unboxing = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("message", + CelProtoWrapper::CreateMessage(&message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, test::IsCelInt64(0)); +} + +TEST(FlatExprBuilderTest, HeterogeneousEqualityEnabled) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("{1: 2, 2u: 3}[1.0]")); + cel::RuntimeOptions options; + options.enable_heterogeneous_equality = true; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST(FlatExprBuilderTest, HeterogeneousEqualityDisabled) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("{1: 2, 2u: 3}[1.0]")); + cel::RuntimeOptions options; + options.enable_heterogeneous_equality = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, + test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid map key type")))); +} + +std::pair CreateTestMessage( + const google::protobuf::DescriptorPool& descriptor_pool, + google::protobuf::MessageFactory& message_factory, absl::string_view name) { + const google::protobuf::Descriptor* desc = descriptor_pool.FindMessageTypeByName(name); + const google::protobuf::Message* message_prototype = message_factory.GetPrototype(desc); + google::protobuf::Message* message = message_prototype->New(); + const google::protobuf::Reflection* refl = message->GetReflection(); + return std::make_pair(message, refl); +} + +struct CustomDescriptorPoolTestParam final { + using SetterFunction = + std::function; + std::string message_type; + std::string field_name; + SetterFunction setter; + test::CelValueMatcher matcher; +}; + +class CustomDescriptorPoolTest + : public ::testing::TestWithParam {}; + +// This test in particular checks for conversion errors in cel_proto_wrapper.cc. +TEST_P(CustomDescriptorPoolTest, TestType) { + const CustomDescriptorPoolTestParam& p = GetParam(); + + google::protobuf::DescriptorPool descriptor_pool; + google::protobuf::Arena arena; + + // Setup descriptor pool and builder + ASSERT_THAT(AddStandardMessageTypesToDescriptorPool(descriptor_pool), IsOk()); + google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); + + // Create test subject, invoke custom setter for message + auto [message, reflection] = + CreateTestMessage(descriptor_pool, message_factory, p.message_type); + const google::protobuf::FieldDescriptor* field = + message->GetDescriptor()->FindFieldByName(p.field_name); + + p.setter(message, reflection, field); + ASSERT_OK_AND_ASSIGN(std::unique_ptr expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + // Evaluate expression, verify expectation with custom matcher + Activation activation; + activation.InsertValue("m", CelProtoWrapper::CreateMessage(message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + EXPECT_THAT(result, p.matcher); + + delete message; +} + +INSTANTIATE_TEST_SUITE_P( + ValueTypes, CustomDescriptorPoolTest, + ::testing::ValuesIn(std::vector{ + {"google.protobuf.Duration", "seconds", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetInt64(message, field, 10); + }, + test::IsCelDuration(absl::Seconds(10))}, + {"google.protobuf.DoubleValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetDouble(message, field, 1.2); + }, + test::IsCelDouble(1.2)}, + {"google.protobuf.Int64Value", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetInt64(message, field, -23); + }, + test::IsCelInt64(-23)}, + {"google.protobuf.UInt64Value", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetUInt64(message, field, 42); + }, + test::IsCelUint64(42)}, + {"google.protobuf.BoolValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetBool(message, field, true); + }, + test::IsCelBool(true)}, + {"google.protobuf.StringValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetString(message, field, "foo"); + }, + test::IsCelString("foo")}, + {"google.protobuf.BytesValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetString(message, field, "bar"); + }, + test::IsCelBytes("bar")}, + {"google.protobuf.Timestamp", "seconds", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetInt64(message, field, 20); + }, + test::IsCelTimestamp(absl::FromUnixSeconds(20))}})); + +struct ConstantFoldingTestCase { + std::string test_name; + std::string expr; + test::CelValueMatcher matcher; + absl::flat_hash_map values; +}; + +class UnknownFunctionImpl : public cel::Function { + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return cel::UnknownValue(); + } +}; + +absl::StatusOr> +CreateConstantFoldingConformanceTestExprBuilder( + const InterpreterOptions& options) { + auto builder = + google::api::expr::runtime::CreateCelExpressionBuilder(options); + CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(builder->GetRegistry()->RegisterLazyFunction( + cel::FunctionDescriptor("LazyFunction", false, {}))); + CEL_RETURN_IF_ERROR(builder->GetRegistry()->RegisterLazyFunction( + cel::FunctionDescriptor("LazyFunction", false, {cel::Kind::kBool}))); + CEL_RETURN_IF_ERROR(builder->GetRegistry()->Register( + cel::FunctionDescriptor("UnknownFunction", false, {}), + std::make_unique())); + return builder; +} + +class ConstantFoldingConformanceTest + : public ::testing::TestWithParam { + protected: + google::protobuf::Arena arena_; +}; + +TEST_P(ConstantFoldingConformanceTest, Updated) { + InterpreterOptions options; + options.constant_folding = true; + options.constant_arena = &arena_; + // Check interaction between const folding and list append optimizations. + options.enable_comprehension_list_append = true; + + const ConstantFoldingTestCase& p = GetParam(); + ASSERT_OK_AND_ASSIGN( + auto builder, CreateConstantFoldingConformanceTestExprBuilder(options)); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(p.expr)); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation activation; + ASSERT_OK(activation.InsertFunction( + PortableUnaryFunctionAdapter::Create( + "LazyFunction", false, + [](google::protobuf::Arena* arena, bool val) { return val; }))); + + for (auto iter = p.values.begin(); iter != p.values.end(); ++iter) { + activation.InsertValue(iter->first, CelValue::CreateInt64(iter->second)); + } + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena_)); + // Check that none of the memoized constants are being mutated. + ASSERT_OK_AND_ASSIGN(result, plan->Evaluate(activation, &arena_)); + EXPECT_THAT(result, p.matcher); +} + +INSTANTIATE_TEST_SUITE_P( + Exprs, ConstantFoldingConformanceTest, + ::testing::ValuesIn(std::vector{ + {"simple_add", "1 + 2 + 3", test::IsCelInt64(6)}, + {"add_with_var", + "1 + (2 + (3 + id))", + test::IsCelInt64(10), + {{"id", 4}}}, + {"const_list", "[1, 2, 3, 4]", test::IsCelList(_)}, + {"mixed_const_list", + "[1, 2, 3, 4] + [id]", + test::IsCelList(_), + {{"id", 5}}}, + {"create_struct", "{'abc': 'def', 'def': 'efg', 'efg': 'hij'}", + Truly([](const CelValue& v) { return v.IsMap(); })}, + {"field_selection", "{'abc': 123}.abc == 123", test::IsCelBool(true)}, + {"type_coverage", + // coverage for constant literals, type() is used to make the list + // homogenous. + R"cel( + [type(bool), + type(123), + type(123u), + type(12.3), + type(b'123'), + type('123'), + type(null), + type(timestamp(0)), + type(duration('1h')) + ])cel", + test::IsCelList(SizeIs(9))}, + {"lazy_function", "true || LazyFunction()", test::IsCelBool(true)}, + {"lazy_function_called", "LazyFunction(true) || false", + test::IsCelBool(true)}, + {"unknown_function", "UnknownFunction() && false", + test::IsCelBool(false)}, + {"nested_comprehension", + "[1, 2, 3, 4].all(x, [5, 6, 7, 8].all(y, x < y))", + test::IsCelBool(true)}, + // Implementation detail: map and filter use replace the accu_init + // expr with a special mutable list to avoid quadratic memory usage + // building the projected list. + {"map", "[1, 2, 3, 4].map(x, x * 2).size() == 4", + test::IsCelBool(true)}, + {"str_cat", + "'1234567890' + '1234567890' + '1234567890' + '1234567890' + " + "'1234567890'", + test::IsCelString( + "12345678901234567890123456789012345678901234567890")}})); + +// Check that list literals are pre-computed +TEST(UpdatedConstantFolding, FoldsLists) { + InterpreterOptions options; + google::protobuf::Arena arena; + options.constant_folding = true; + options.constant_arena = &arena; + + ASSERT_OK_AND_ASSIGN( + auto builder, CreateConstantFoldingConformanceTestExprBuilder(options)); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + parser::Parse("[1] + [2] + [3] + [4] + [5] + [6] + [7] " + "+ [8] + [9] + [10] + [11] + [12]")); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + Activation activation; + int before_size = arena.SpaceUsed(); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + // Some incidental allocations are expected related to interop. + // 128 is less than the expected allocations for allocating the list terms and + // any intermediates in the unoptimized case. + EXPECT_LE(arena.SpaceUsed() - before_size, 512); + EXPECT_THAT(result, test::IsCelList(SizeIs(12))); +} + +TEST(FlatExprBuilderTest, BlockBadIndex) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { + list_expr: { elements { const_expr: { string_value: "foo" } } } + } + args { ident_expr: { name: "@index-1" } } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("bad @index"))); +} + +TEST(FlatExprBuilderTest, OutOfRangeBlockIndex) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { + list_expr: { elements { const_expr: { string_value: "foo" } } } + } + args { ident_expr: { name: "@index1" } } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid @index greater than number of bindings:"))); +} + +TEST(FlatExprBuilderTest, EarlyBlockIndex) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { list_expr: { elements { ident_expr: { name: "@index0" } } } } + args { ident_expr: { name: "@index0" } } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("@index references current or future binding:"))); +} + +TEST(FlatExprBuilderTest, OutOfScopeCSE) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { ident_expr: { name: "@ac:0:0" } } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("out of scope reference to CSE generated " + "comprehension variable"))); +} + +TEST(FlatExprBuilderTest, BlockMissingBindings) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { call_expr: { function: "cel.@block" } } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr( + "malformed cel.@block: missing list of bound expressions"))); +} + +TEST(FlatExprBuilderTest, BlockMissingExpression) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { list_expr: {} } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("malformed cel.@block: missing bound expression"))); +} + +TEST(FlatExprBuilderTest, BlockNotListOfBoundExpressions) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { ident_expr: { name: "@index0" } } + args { ident_expr: { name: "@index0" } } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("malformed cel.@block: first argument is not a list " + "of bound expressions"))); +} + +TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { + ParsedExpr parsed_expr; + // Allowed, but degenerate case. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { list_expr: {} } + args { ident_expr: { name: "@index0" } } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid @index greater than number of bindings:"))); +} + +TEST(FlatExprBuilderTest, BlockOptionalListOfBoundExpressions) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { + list_expr: { + elements { const_expr: { string_value: "foo" } } + optional_indices: [ 0 ] + } + } + args { ident_expr: { name: "@index0" } } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("malformed cel.@block: list of bound expressions " + "contains an optional"))); +} + +TEST(FlatExprBuilderTest, BlockNested) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "cel.@block" + args { + list_expr: { elements { const_expr: { string_value: "foo" } } } + } + args { + call_expr: { + function: "cel.@block" + args { + list_expr: { + elements { const_expr: { string_value: "foo" } } + } + } + args { ident_expr: { name: "@index1" } } + } + } + } + } + )pb", + &parsed_expr)); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("multiple cel.@block are not allowed"))); +} + +struct VariadicLogicalEvalTestCase { + std::string label; + std::string expr; + std::string a_val; + std::string b_val; + std::string c_val; + std::string expected_type; // "bool", "error", "unknown" + bool expected_bool = false; +}; + +class FlatExprBuilderVariadicLogicalTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { + const auto& test_case = GetParam(); + parser::ParserOptions parser_options; + parser_options.enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN( + ParsedExpr parsed_expr, + parser::Parse(test_case.expr, test_case.label, parser_options)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + std::vector unknown_patterns; + + // Set up variables: + auto insert_value = [&](absl::string_view name, const std::string& val) { + if (val == "true") { + activation.InsertValue(name, CelValue::CreateBool(true)); + } else if (val == "false") { + activation.InsertValue(name, CelValue::CreateBool(false)); + } else if (val == "error") { + activation.InsertValue(name, CreateErrorValue(&arena, "test error")); + } else if (val == "unknown1" || val == "unknown2") { + activation.InsertValue(name, CelValue::CreateBool(true)); + unknown_patterns.push_back(CreateCelAttributePattern(name, {})); + } + }; + + insert_value("a", test_case.a_val); + insert_value("b", test_case.b_val); + insert_value("c", test_case.c_val); + + if (!unknown_patterns.empty()) { + activation.set_unknown_attribute_patterns(std::move(unknown_patterns)); + } + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + + if (test_case.expected_type == "bool") { + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_EQ(result.BoolOrDie(), test_case.expected_bool); + } else if (test_case.expected_type == "error") { + EXPECT_TRUE(result.IsError()) << result.DebugString(); + } else if (test_case.expected_type == "unknown") { + EXPECT_TRUE(result.IsUnknownSet()) << result.DebugString(); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderVariadicLogicalTest, FlatExprBuilderVariadicLogicalTest, + testing::Values( + VariadicLogicalEvalTestCase{"AND_AllTrue", "a && b && c", "true", + "true", "true", "bool", true}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFalse", "a && b && c", + "true", "false", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFirstFalse", "a && b && c", + "false", "unset", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"OR_AllFalse", "a || b || c", "false", + "false", "false", "bool", false}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitTrue", "a || b || c", + "false", "true", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitFirstTrue", "a || b || c", + "true", "unset", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Error", "a && b && c", "true", "error", + "true", "error"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeError", + "a && b && c", "false", "error", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Error", "a || b || c", "false", "error", + "false", "error"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeError", "a || b || c", + "true", "error", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Unknown", "a && b && c", "true", + "unknown1", "true", "unknown"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeUnknown", + "a && b && c", "false", "unknown1", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Unknown", "a || b || c", "false", + "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeUnknown", + "a || b || c", "true", "unknown1", "unset", + "bool", true}, + VariadicLogicalEvalTestCase{"AND_UnknownAggregation", "a && b && c", + "unknown1", "unknown2", "true", "unknown"}, + VariadicLogicalEvalTestCase{"OR_UnknownAggregation", "a || b || c", + "unknown1", "unknown2", "false", "unknown"}, + VariadicLogicalEvalTestCase{"Exists_True", "[a, b, c].exists(x, x)", + "false", "false", "true", "bool", true}, + VariadicLogicalEvalTestCase{"Exists_Unknown", "[a, b, c].exists(x, x)", + "false", "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"All_False", "[a, b, c].all(x, x)", "true", + "true", "false", "bool", false}, + VariadicLogicalEvalTestCase{"All_Unknown", "[a, b, c].all(x, x)", + "true", "unknown1", "true", "unknown"})); + +struct RecursionDepthTestCase { + std::string label; + std::string expr; + int max_recursion_depth; + absl::StatusCode expected_status_code; + std::string expected_error_msg; +}; + +class FlatExprBuilderRecursionDepthTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderRecursionDepthTest, CheckRecursionLimit) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = test_case.max_recursion_depth; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + auto result = + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()); + if (test_case.expected_status_code == absl::StatusCode::kOk) { + EXPECT_THAT(result, IsOk()); + } else { + EXPECT_THAT(result, StatusIs(test_case.expected_status_code, + HasSubstr(test_case.expected_error_msg))); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderRecursionDepthTest, FlatExprBuilderRecursionDepthTest, + testing::Values( + RecursionDepthTestCase{"AndChildLimitExceeded", "(1 + 1) && true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"AndParentLimitExceeded", "(1 + 1) && true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"AndLimitSuccess", "(1 + 1) && true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessGenerous", "(1 + 1) && true", 10, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessUnlimited", "(1 + 1) && true", + -1, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrChildLimitExceeded", "(1 + 1) || true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"OrParentLimitExceeded", "(1 + 1) || true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"OrLimitSuccess", "(1 + 1) || true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessGenerous", + "(1 + 1) || false || false || false || false || " + "(true && true && true && true && false)", + 10, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessUnlimited", "(1 + 1) || true", -1, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndDepthUpdateFromSubsequentArg", + "true && (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"}, + RecursionDepthTestCase{"OrDepthUpdateFromSubsequentArg", + "true || (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"})); + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockAndError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_&&_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockOrError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_||_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/compiler/instrumentation.cc b/eval/compiler/instrumentation.cc new file mode 100644 index 000000000..3e37bdb45 --- /dev/null +++ b/eval/compiler/instrumentation.cc @@ -0,0 +1,93 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/instrumentation.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/expr.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" + +namespace google::api::expr::runtime { + +namespace { + +class InstrumentStep : public ExpressionStepBase { + public: + explicit InstrumentStep(int64_t expr_id, Instrumentation instrumentation) + : ExpressionStepBase(/*expr_id=*/expr_id, /*comes_from_ast=*/false), + expr_id_(expr_id), + instrumentation_(std::move(instrumentation)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(1)) { + return absl::InternalError("stack underflow in instrument step."); + } + + return instrumentation_(expr_id_, frame->value_stack().Peek()); + + return absl::OkStatus(); + } + + private: + int64_t expr_id_; + Instrumentation instrumentation_; +}; + +class InstrumentOptimizer : public ProgramOptimizer { + public: + explicit InstrumentOptimizer(Instrumentation instrumentation) + : instrumentation_(std::move(instrumentation)) {} + + absl::Status OnPreVisit(PlannerContext& context, + const cel::Expr& node) override { + return absl::OkStatus(); + } + + absl::Status OnPostVisit(PlannerContext& context, + const cel::Expr& node) override { + if (context.GetSubplan(node).empty()) { + return absl::OkStatus(); + } + + return context.AddSubplanStep( + node, std::make_unique(node.id(), instrumentation_)); + } + + private: + Instrumentation instrumentation_; +}; + +} // namespace + +ProgramOptimizerFactory CreateInstrumentationExtension( + InstrumentationFactory factory) { + return [fac = std::move(factory)](PlannerContext&, const cel::Ast& ast) + -> absl::StatusOr> { + Instrumentation ins = fac(ast); + if (ins) { + return std::make_unique(std::move(ins)); + } + return nullptr; + }; +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/instrumentation.h b/eval/compiler/instrumentation.h new file mode 100644 index 000000000..9096830a0 --- /dev/null +++ b/eval/compiler/instrumentation.h @@ -0,0 +1,60 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Definitions for instrumenting a CEL expression at the planner level. +// +// CEL users should not use this directly. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_INSTRUMENTATION_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_INSTRUMENTATION_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "common/ast.h" +#include "common/value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" + +namespace google::api::expr::runtime { + +// Instrumentation inspects intermediate values after the evaluation of an +// expression node. +// +// Unlike traceable expressions, this callback is applied across all +// evaluations of an expression. Implementations must be thread safe if the +// expression is evaluated concurrently. +using Instrumentation = + std::function; + +// A factory for creating Instrumentation instances. +// +// This allows the extension implementations to map from a given ast to a +// specific instrumentation instance. +// +// An empty function object may be returned to skip instrumenting the given +// expression. +using InstrumentationFactory = + absl::AnyInvocable; + +// Create a new Instrumentation extension. +// +// These should typically be added last if any program optimizations are +// applied. +ProgramOptimizerFactory CreateInstrumentationExtension( + InstrumentationFactory factory); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_INSTRUMENTATION_H_ diff --git a/eval/compiler/instrumentation_test.cc b/eval/compiler/instrumentation_test.cc new file mode 100644 index 000000000..630f398d1 --- /dev/null +++ b/eval/compiler/instrumentation_test.cc @@ -0,0 +1,364 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/instrumentation.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "common/ast.h" +#include "common/value.h" +#include "eval/compiler/constant_folding.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/regex_precompilation_optimization.h" +#include "eval/eval/evaluator_core.h" +#include "extensions/protobuf/ast_converters.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::IntValue; +using ::cel::Value; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::ElementsAre; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class InstrumentationTest : public ::testing::Test { + public: + InstrumentationTest() + : env_(NewTestingRuntimeEnv()), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry) {} + void SetUp() override { + ASSERT_OK(cel::RegisterStandardFunctions(function_registry_, options_)); + } + + protected: + absl_nonnull std::shared_ptr env_; + cel::RuntimeOptions options_; + cel::FunctionRegistry& function_registry_; + cel::TypeRegistry& type_registry_; + google::protobuf::Arena arena_; +}; + +MATCHER_P(IsIntValue, expected, "") { + const Value& got = arg; + + return got.Is() && got.GetInt().NativeValue() == expected; +} + +TEST_F(InstrumentationTest, Basic) { + FlatExprBuilder builder(env_, options_); + + std::vector expr_ids; + Instrumentation expr_id_recorder = + [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { + expr_ids.push_back(expr_id); + return absl::OkStatus(); + }; + + builder.AddProgramOptimizer(CreateInstrumentationExtension( + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("1 + 2 + 3")); + ASSERT_OK_AND_ASSIGN(auto ast, + cel::extensions::CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(auto plan, + builder.CreateExpressionImpl(std::move(ast), + /*issues=*/nullptr)); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + cel::Activation activation; + + ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( + activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + // AST for the test expression: + // + <4> + // / \ + // +<2> 3<5> + // / \ + // 1<1> 2<3> + EXPECT_THAT(expr_ids, ElementsAre(1, 3, 2, 5, 4)); +} + +TEST_F(InstrumentationTest, BasicWithConstFolding) { + FlatExprBuilder builder(env_, options_); + + absl::flat_hash_map expr_id_to_value; + Instrumentation expr_id_recorder = [&expr_id_to_value]( + int64_t expr_id, + const cel::Value& v) -> absl::Status { + expr_id_to_value[expr_id] = v; + return absl::OkStatus(); + }; + builder.AddProgramOptimizer( + cel::runtime_internal::CreateConstantFoldingOptimizer()); + builder.AddProgramOptimizer(CreateInstrumentationExtension( + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("1 + 2 + 3")); + ASSERT_OK_AND_ASSIGN(auto ast, + cel::extensions::CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(auto plan, + builder.CreateExpressionImpl(std::move(ast), + /*issues=*/nullptr)); + + EXPECT_THAT( + expr_id_to_value, + UnorderedElementsAre(Pair(1, IsIntValue(1)), Pair(3, IsIntValue(2)), + Pair(2, IsIntValue(3)), Pair(5, IsIntValue(3)))); + expr_id_to_value.clear(); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + cel::Activation activation; + + ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( + activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + // AST for the test expression: + // + <4> + // / \ + // +<2> 3<5> + // / \ + // 1<1> 2<3> + EXPECT_THAT(expr_id_to_value, UnorderedElementsAre(Pair(4, IsIntValue(6)))); +} + +TEST_F(InstrumentationTest, AndShortCircuit) { + FlatExprBuilder builder(env_, options_); + + std::vector expr_ids; + Instrumentation expr_id_recorder = + [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { + expr_ids.push_back(expr_id); + return absl::OkStatus(); + }; + + builder.AddProgramOptimizer(CreateInstrumentationExtension( + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("a && b")); + ASSERT_OK_AND_ASSIGN(auto ast, + cel::extensions::CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(auto plan, + builder.CreateExpressionImpl(std::move(ast), + /*issues=*/nullptr)); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + cel::Activation activation; + + activation.InsertOrAssignValue("a", cel::BoolValue(true)); + activation.InsertOrAssignValue("b", cel::BoolValue(false)); + + ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( + activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + EXPECT_THAT(expr_ids, ElementsAre(1, 2, 3)); + + activation.InsertOrAssignValue("a", cel::BoolValue(false)); + + ASSERT_OK_AND_ASSIGN( + value, plan.EvaluateWithCallback(activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + EXPECT_THAT(expr_ids, ElementsAre(1, 2, 3, 1, 3)); +} + +TEST_F(InstrumentationTest, OrShortCircuit) { + FlatExprBuilder builder(env_, options_); + + std::vector expr_ids; + Instrumentation expr_id_recorder = + [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { + expr_ids.push_back(expr_id); + return absl::OkStatus(); + }; + + builder.AddProgramOptimizer(CreateInstrumentationExtension( + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("a || b")); + ASSERT_OK_AND_ASSIGN(auto ast, + cel::extensions::CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(auto plan, + builder.CreateExpressionImpl(std::move(ast), + /*issues=*/nullptr)); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + cel::Activation activation; + + activation.InsertOrAssignValue("a", cel::BoolValue(false)); + activation.InsertOrAssignValue("b", cel::BoolValue(true)); + + ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( + activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + EXPECT_THAT(expr_ids, ElementsAre(1, 2, 3)); + expr_ids.clear(); + activation.InsertOrAssignValue("a", cel::BoolValue(true)); + + ASSERT_OK_AND_ASSIGN( + value, plan.EvaluateWithCallback(activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + EXPECT_THAT(expr_ids, ElementsAre(1, 3)); +} + +TEST_F(InstrumentationTest, Ternary) { + FlatExprBuilder builder(env_, options_); + + std::vector expr_ids; + Instrumentation expr_id_recorder = + [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { + expr_ids.push_back(expr_id); + return absl::OkStatus(); + }; + + builder.AddProgramOptimizer(CreateInstrumentationExtension( + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("(c)? a : b")); + ASSERT_OK_AND_ASSIGN(auto ast, + cel::extensions::CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(auto plan, + builder.CreateExpressionImpl(std::move(ast), + /*issues=*/nullptr)); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + cel::Activation activation; + + activation.InsertOrAssignValue("c", cel::BoolValue(true)); + activation.InsertOrAssignValue("a", cel::IntValue(1)); + activation.InsertOrAssignValue("b", cel::IntValue(2)); + + ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( + activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + // AST + // ?:() <2> + // / | \ + // c <1> a <3> b <4> + EXPECT_THAT(expr_ids, ElementsAre(1, 3, 2)); + expr_ids.clear(); + + activation.InsertOrAssignValue("c", cel::BoolValue(false)); + + ASSERT_OK_AND_ASSIGN( + value, plan.EvaluateWithCallback(activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + EXPECT_THAT(expr_ids, ElementsAre(1, 4, 2)); + expr_ids.clear(); +} + +TEST_F(InstrumentationTest, OptimizedStepsNotEvaluated) { + FlatExprBuilder builder(env_, options_); + + builder.AddProgramOptimizer(CreateRegexPrecompilationExtension(0)); + + std::vector expr_ids; + Instrumentation expr_id_recorder = + [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { + expr_ids.push_back(expr_id); + return absl::OkStatus(); + }; + + builder.AddProgramOptimizer(CreateInstrumentationExtension( + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("r'test_string'.matches(r'[a-z_]+')")); + ASSERT_OK_AND_ASSIGN(auto ast, + cel::extensions::CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(auto plan, + builder.CreateExpressionImpl(std::move(ast), + /*issues=*/nullptr)); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + cel::Activation activation; + + ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( + activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + EXPECT_THAT(expr_ids, ElementsAre(1, 2)); + EXPECT_TRUE(value.Is() && value.GetBool().NativeValue()); +} + +TEST_F(InstrumentationTest, NoopSkipped) { + FlatExprBuilder builder(env_, options_); + + builder.AddProgramOptimizer(CreateInstrumentationExtension( + [=](const cel::Ast&) -> Instrumentation { return Instrumentation(); })); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("(c)? a : b")); + ASSERT_OK_AND_ASSIGN(auto ast, + cel::extensions::CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(auto plan, + builder.CreateExpressionImpl(std::move(ast), + /*issues=*/nullptr)); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + cel::Activation activation; + + activation.InsertOrAssignValue("c", cel::BoolValue(true)); + activation.InsertOrAssignValue("a", cel::IntValue(1)); + activation.InsertOrAssignValue("b", cel::IntValue(2)); + + ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( + activation, /*embedder_context=*/nullptr, + EvaluationListener(), state)); + + // AST + // ?:() <2> + // / | \ + // c <1> a <3> b <4> + EXPECT_THAT(value, IsIntValue(1)); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index da137df0d..158e492be 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -1,209 +1,359 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/compiler/qualified_reference_resolver.h" +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "eval/eval/const_value_step.h" -#include "eval/eval/expression_build_warning.h" -#include "eval/public/cel_builtins.h" -#include "eval/public/cel_function_registry.h" -#include "base/status_macros.h" +#include "base/ast.h" +#include "base/builtins.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/expr.h" +#include "common/kind.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/runtime_issue.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::Reference; +using ::cel::Expr; +using ::cel::Reference; +using ::cel::RuntimeIssue; +using ::cel::runtime_internal::IssueCollector; + +// Optional types are opt-in but require special handling in the evaluator. +constexpr absl::string_view kOptionalOr = "or"; +constexpr absl::string_view kOptionalOrValue = "orValue"; + +// Determines if function is implemented with custom evaluation step instead of +// registered. +bool IsSpecialFunction(absl::string_view function_name) { + return function_name == cel::builtin::kAnd || + function_name == cel::builtin::kOr || + function_name == cel::builtin::kIndex || + function_name == cel::builtin::kTernary || + function_name == kOptionalOr || function_name == kOptionalOrValue || + function_name == cel::builtin::kEqual || + function_name == cel::builtin::kInequal || + function_name == cel::builtin::kNot || + function_name == cel::builtin::kNotStrictlyFalse || + function_name == cel::builtin::kNotStrictlyFalseDeprecated || + function_name == cel::builtin::kIn || + function_name == cel::builtin::kInDeprecated || + function_name == cel::builtin::kInFunction || + function_name == "cel.@block"; +} + +bool OverloadExists(const Resolver& resolver, absl::string_view name, + const std::vector& arguments_matcher, + bool receiver_style = false) { + return !resolver.FindOverloads(name, receiver_style, arguments_matcher) + .empty() || + !resolver.FindLazyOverloads(name, receiver_style, arguments_matcher) + .empty(); +} + +// Return the qualified name of the most qualified matching overload, or +// nullopt if no matches are found. +std::optional BestOverloadMatch(const Resolver& resolver, + absl::string_view base_name, + int argument_count) { + if (IsSpecialFunction(base_name)) { + return std::string(base_name); + } + auto arguments_matcher = ArgumentsMatcher(argument_count); + // Check from most qualified to least qualified for a matching overload. + auto names = resolver.FullyQualifiedNames(base_name); + for (auto name = names.begin(); name != names.end(); ++name) { + if (OverloadExists(resolver, *name, arguments_matcher)) { + if (base_name[0] == '.') { + // Preserve leading '.' to prevent re-resolving at plan time. + return std::string(base_name); + } + return *name; + } + } + return std::nullopt; +} -class ReferenceResolver { +// Rewriter visitor for resolving references. +// +// On previsit pass, replace (possibly qualified) identifier branches with the +// canonical name in the reference map (most qualified references considered +// first). +// +// On post visit pass, update function calls to determine whether the function +// target is a namespace for the function or a receiver for the call. +class ReferenceResolver : public cel::AstRewriterBase { public: - ReferenceResolver(const google::protobuf::Map& reference_map, - BuilderWarnings* warnings) - : reference_map_(reference_map), warnings_(warnings) {} + ReferenceResolver( + const absl::flat_hash_map& reference_map, + const Resolver& resolver, IssueCollector& issue_collector) + : reference_map_(reference_map), + resolver_(resolver), + issues_(issue_collector), + progress_status_(absl::OkStatus()) {} // Attempt to resolve references in expr. Return true if part of the // expression was rewritten. - absl::StatusOr Rewrite(Expr* out) { - absl::StatusOr maybe_rewrite_result = MaybeRewriteReferences(out); - RETURN_IF_ERROR(maybe_rewrite_result.status()); + // TODO(issues/95): If possible, it would be nice to write a general utility + // for running the preprocess steps when traversing the AST instead of having + // one pass per transform. + bool PreVisitRewrite(Expr& expr) override { + const Reference* reference = GetReferenceForId(expr.id()); - if (maybe_rewrite_result.value()) { - return true; - } - // If we don't have a rewrite rule, continue traversing the AST. - switch (out->expr_kind_case()) { - case Expr::kConstExpr: { + // Fold compile time constant (e.g. enum values) + if (reference != nullptr && reference->has_value()) { + if (reference->value().has_int64_value()) { + // Replace enum idents with const reference value. + expr.mutable_const_expr().set_int64_value( + reference->value().int64_value()); + return true; + } else if (expr.has_ident_expr()) { + // "google.protobuf.NullValue.NULL_VALUE" is a special case: sometimes + // it is interpreted as null value and sometimes as an enum constant. + if (reference->value().has_null_value() && + expr.ident_expr().name() == + "google.protobuf.NullValue.NULL_VALUE") { + return false; + } + expr.set_const_expr(reference->value()); + return true; + } else { return false; } - case Expr::kIdentExpr: + } + + if (reference != nullptr) { + if (expr.has_ident_expr()) { + return MaybeUpdateIdentNode(&expr, *reference); + } else if (expr.has_select_expr()) { + return MaybeUpdateSelectNode(&expr, *reference); + } else { + // Call nodes are updated on post visit so they will see any select + // path rewrites. return false; - case Expr::kSelectExpr: { - return Rewrite(out->mutable_select_expr()->mutable_operand()); - } - case Expr::kCallExpr: { - auto* call_expr = out->mutable_call_expr(); - const bool receiver_style = call_expr->has_target(); - const int arg_num = call_expr->args_size(); - bool args_updated = false; - if (receiver_style) { - absl::StatusOr rewrite_result = - Rewrite(call_expr->mutable_target()); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); - } - for (int i = 0; i < arg_num; i++) { - absl::StatusOr rewrite_result = - Rewrite(call_expr->mutable_args(i)); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); - } - return args_updated; - } - case Expr::kListExpr: { - auto* list_expr = out->mutable_list_expr(); - int list_size = list_expr->elements_size(); - bool args_updated = false; - for (int i = 0; i < list_size; i++) { - absl::StatusOr rewrite_result = - Rewrite(list_expr->mutable_elements(i)); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); - } - return args_updated; - } - case Expr::kStructExpr: { - auto* struct_expr = out->mutable_struct_expr(); - int entries_size = struct_expr->entries_size(); - bool args_updated = false; - for (int i = 0; i < entries_size; i++) { - auto* new_entry = struct_expr->mutable_entries(i); - switch (new_entry->key_kind_case()) { - case Expr::CreateStruct::Entry::kFieldKey: - // Nothing to do. - break; - case Expr::CreateStruct::Entry::kMapKey: - args_updated = - Rewrite(new_entry->mutable_map_key()).ok() || args_updated; - break; - default: - GOOGLE_LOG(ERROR) << "Unsupported Entry kind: " - << new_entry->key_kind_case(); - break; - } - args_updated = - Rewrite(new_entry->mutable_value()).ok() || args_updated; - } - return args_updated; } - case Expr::kComprehensionExpr: { - auto* out_expr = out->mutable_comprehension_expr(); - bool args_updated = false; - absl::StatusOr rewrite_result; + } + return false; + } - rewrite_result = Rewrite(out_expr->mutable_accu_init()); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); + bool PostVisitRewrite(Expr& expr) override { + const Reference* reference = GetReferenceForId(expr.id()); + if (expr.has_call_expr()) { + return MaybeUpdateCallNode(&expr, reference); + } + return false; + } - rewrite_result = Rewrite(out_expr->mutable_iter_range()); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); + const absl::Status& GetProgressStatus() const { return progress_status_; } - rewrite_result = Rewrite(out_expr->mutable_loop_condition()); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); + private: + // Attempt to update a function call node. This disambiguates + // receiver call verses namespaced names in parse if possible. + // + // TODO(issues/95): This duplicates some of the overload matching behavior + // for parsed expressions. We should refactor to consolidate the code. + bool MaybeUpdateCallNode(Expr* out, const Reference* reference) { + auto& call_expr = out->mutable_call_expr(); + const std::string& function = call_expr.function(); + if (reference != nullptr && reference->overload_id().empty()) { + UpdateStatus(issues_.AddIssue( + RuntimeIssue::CreateWarning(absl::InvalidArgumentError( + absl::StrCat("Reference map doesn't provide overloads for ", + out->call_expr().function()))))); + } + bool receiver_style = call_expr.has_target(); + int arg_num = call_expr.args().size(); + if (receiver_style) { + auto maybe_namespace = ToNamespace(call_expr.target()); + if (maybe_namespace.has_value()) { + std::string resolved_name = + absl::StrCat(*maybe_namespace, ".", function); + auto resolved_function = + BestOverloadMatch(resolver_, resolved_name, arg_num); + if (resolved_function.has_value()) { + call_expr.set_function(*resolved_function); + call_expr.set_target(nullptr); + return true; + } + } + } else { + // Not a receiver style function call. Check to see if it is a namespaced + // function using a shorthand inside the expression container. + auto maybe_resolved_function = + BestOverloadMatch(resolver_, function, arg_num); + if (!maybe_resolved_function.has_value()) { + UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError(absl::StrCat( + "No overload found in reference resolve step for ", function)), + RuntimeIssue::ErrorCode::kNoMatchingOverload))); + } else if (maybe_resolved_function.value() != function) { + call_expr.set_function(maybe_resolved_function.value()); + return true; + } + } + // For parity, if we didn't rewrite the receiver call style function, + // check that an overload is provided in the builder. + if (call_expr.has_target() && !IsSpecialFunction(function) && + !OverloadExists(resolver_, function, ArgumentsMatcher(arg_num + 1), + /* receiver_style= */ true)) { + UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError(absl::StrCat( + "No overload found in reference resolve step for ", function)), + RuntimeIssue::ErrorCode::kNoMatchingOverload))); + } + return false; + } - rewrite_result = Rewrite(out_expr->mutable_loop_step()); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); + // Attempt to resolve a select node. If reference is valid, + // replace the select node with the fully qualified ident node. + bool MaybeUpdateSelectNode(Expr* out, const Reference& reference) { + if (out->select_expr().test_only()) { + UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError("Reference map points to a presence " + "test -- has(container.attr)")))); + } else if (!reference.name().empty()) { + out->mutable_ident_expr().set_name(reference.name()); + rewritten_reference_.insert(out->id()); + return true; + } + return false; + } - rewrite_result = Rewrite(out_expr->mutable_result()); - RETURN_IF_ERROR(rewrite_result.status()); - args_updated = args_updated || rewrite_result.value(); + // Attempt to resolve an ident node. If reference is valid, + // replace the node with the fully qualified ident node. + bool MaybeUpdateIdentNode(Expr* out, const Reference& reference) { + if (!reference.name().empty() && + reference.name() != out->ident_expr().name()) { + out->mutable_ident_expr().set_name(reference.name()); + rewritten_reference_.insert(out->id()); + return true; + } + return false; + } - return args_updated; + // Convert a select expr sub tree into a namespace name if possible. + // If any operand of the top element is a not a select or an ident node, + // return nullopt. + std::optional ToNamespace(const Expr& expr) { + std::optional maybe_parent_namespace; + if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { + // The target expr matches a reference (resolved to an ident decl). + // This should not be treated as a function qualifier. + return std::nullopt; + } + if (expr.has_ident_expr()) { + return expr.ident_expr().name(); + } else if (expr.has_select_expr()) { + if (expr.select_expr().test_only()) { + return std::nullopt; } - default: - GOOGLE_LOG(ERROR) << "Unsupported Expr kind: " << out->expr_kind_case(); - return false; + maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); + if (!maybe_parent_namespace.has_value()) { + return std::nullopt; + } + return absl::StrCat(*maybe_parent_namespace, ".", + expr.select_expr().field()); + } else { + return std::nullopt; } } - private: - // Attempts to apply rewrites for reference map. Returns true if rewrites - // occur. - absl::StatusOr MaybeRewriteReferences(Expr* expr) { - const auto iter = reference_map_.find(expr->id()); + // Find a reference for the given expr id. + // + // Returns nullptr if no reference is available. + const Reference* GetReferenceForId(int64_t expr_id) { + auto iter = reference_map_.find(expr_id); if (iter == reference_map_.end()) { - return false; + return nullptr; } - const Reference& reference = iter->second; - if (reference.has_value() || !reference.overload_id().empty()) { - // TODO(issues/71): Add support for functions and compile time - // constants. - return false; + if (expr_id == 0) { + UpdateStatus(issues_.AddIssue( + RuntimeIssue::CreateWarning(absl::InvalidArgumentError( + "reference map entries for expression id 0 are not supported")))); + return nullptr; } + return &iter->second; + } - switch (expr->expr_kind_case()) { - case Expr::ExprKindCase::kIdentExpr: - if (reference.name() != expr->ident_expr().name()) { - // Possibly shorthand for a namespaced name. - expr->clear_ident_expr(); - expr->mutable_ident_expr()->set_name(reference.name()); - return true; - } else { - return false; - } - case Expr::ExprKindCase::kStructExpr: - // reference to a create struct message type. nothing to do. - // TODO(issues/72): annotating the execution plan with this may help - // identify problems with the environment setup. This will probably - // also require the type map information from a checked expression. - return false; - case Expr::ExprKindCase::kSelectExpr: - if (expr->select_expr().test_only()) { - RETURN_IF_ERROR(warnings_->AddWarning( - absl::InvalidArgumentError("Reference map points to a presence " - "test -- has(container.attr)"))); - return false; - } - expr->clear_select_expr(); - expr->mutable_ident_expr()->set_name(reference.name()); - return true; - default: - RETURN_IF_ERROR( - warnings_->AddWarning(absl::InvalidArgumentError(absl::StrCat( - "Unsupported reference kind: ", expr->expr_kind_case())))); - return false; + void UpdateStatus(absl::Status status) { + if (progress_status_.ok() && !status.ok()) { + progress_status_ = std::move(status); + return; } + status.IgnoreError(); } - const google::protobuf::Map& reference_map_; - BuilderWarnings* warnings_; + const absl::flat_hash_map& reference_map_; + const Resolver& resolver_; + IssueCollector& issues_; + absl::Status progress_status_; + absl::flat_hash_set rewritten_reference_; +}; + +class ReferenceResolverExtension : public AstTransform { + public: + explicit ReferenceResolverExtension(ReferenceResolverOption opt) + : opt_(opt) {} + absl::Status UpdateAst(PlannerContext& context, + cel::Ast& ast) const override { + if (opt_ == ReferenceResolverOption::kCheckedOnly && + ast.reference_map().empty()) { + return absl::OkStatus(); + } + return ResolveReferences(context.resolver(), context.issue_collector(), ast) + .status(); + } + + private: + ReferenceResolverOption opt_; }; } // namespace -absl::StatusOr> ResolveReferences( - const Expr& expr, const google::protobuf::Map& reference_map, - BuilderWarnings* warnings) { - Expr out(expr); - ReferenceResolver resolver(reference_map, warnings); - absl::StatusOr rewrite_result = resolver.Rewrite(&out); - if (!rewrite_result.ok()) { - return rewrite_result.status(); - } else if (rewrite_result.value()) { - return absl::optional(out); - } else { - return absl::optional(); +absl::StatusOr ResolveReferences(const Resolver& resolver, + IssueCollector& issues, cel::Ast& ast) { + ReferenceResolver ref_resolver(ast.reference_map(), resolver, issues); + + // Rewriting interface doesn't support failing mid traverse propagate first + // error encountered if fail fast enabled. + bool was_rewritten = cel::AstRewrite(ast.mutable_root_expr(), ref_resolver); + if (!ref_resolver.GetProgressStatus().ok()) { + return ref_resolver.GetProgressStatus(); } + return was_rewritten; +} + +std::unique_ptr NewReferenceResolverExtension( + ReferenceResolverOption option) { + return std::make_unique(option); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index 98e1e630b..673273084 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -1,29 +1,53 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/map.h" -#include "absl/status/status.h" +#include + #include "absl/status/statusor.h" -#include "eval/eval/expression_build_warning.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// A transformation over input expression that produces a new expression with -// select subexpressions replaced by idents referring to the fully-qualified -// variable name. Returns modified expr if updates found. Otherwise, returns -// nullopt. -absl::StatusOr> ResolveReferences( - const google::api::expr::v1alpha1::Expr& expr, - const google::protobuf::Map& reference_map, - BuilderWarnings* warnings); - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +#include "common/ast.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" +#include "runtime/internal/issue_collector.h" + +namespace google::api::expr::runtime { + +// Resolves possibly qualified names in the provided expression, updating +// subexpressions with to use the fully qualified name, or a constant +// expressions in the case of enums. +// +// Returns true if updates were applied. +// +// Will warn or return a non-ok status if references can't be resolved (no +// function overload could match a call) or are inconsistent (reference map +// points to an expr node that isn't a reference). +absl::StatusOr ResolveReferences( + const Resolver& resolver, cel::runtime_internal::IssueCollector& issues, + cel::Ast& ast); + +enum class ReferenceResolverOption { + // Always attempt to resolve references based on runtime types and functions. + kAlways, + // Only attempt to resolve for checked expressions with reference metadata. + kCheckedOnly, +}; + +std::unique_ptr NewReferenceResolverExtension( + ReferenceResolverOption option); + +} // namespace google::api::expr::runtime + #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 19a96e38a..3fa7fca21 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -1,25 +1,65 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/compiler/qualified_reference_resolver.h" -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" -#include "absl/types/optional.h" -#include "testutil/util.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "absl/strings/str_cat.h" +#include "base/ast.h" +#include "base/builtins.h" +#include "common/ast.h" +#include "common/ast/expr_proto.h" +#include "common/expr.h" +#include "eval/compiler/resolver.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" +#include "extensions/protobuf/ast_converters.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/runtime_issue.h" +#include "runtime/type_registry.h" +#include "google/protobuf/text_format.h" + +namespace google::api::expr::runtime { + namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::Reference; -using testing::ElementsAre; -using testing::Eq; -using testing::Optional; -using testutil::EqualsProto; +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::Ast; +using ::cel::Expr; +using ::cel::RuntimeIssue; +using ::cel::SourceInfo; +using ::cel::ast_internal::ExprToProto; +using ::cel::internal::test::EqualsProto; +using ::cel::runtime_internal::IssueCollector; +using ::testing::Contains; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; // foo.bar.var1 && bar.foo.var2 constexpr char kExpr[] = R"( @@ -61,21 +101,45 @@ constexpr char kExpr[] = R"( } )"; -Expr ParseTestProto(const std::string& pb) { - Expr expr; +MATCHER_P(StatusCodeIs, x, "") { + const absl::Status& status = arg; + return status.code() == x; +} + +std::unique_ptr ParseTestProto(const std::string& pb) { + cel::expr::Expr expr; EXPECT_TRUE(google::protobuf::TextFormat::ParseFromString(pb, &expr)); - return expr; + return cel::extensions::CreateAstFromParsedExpr(expr).value(); +} + +std::vector ExtractIssuesStatus(const IssueCollector& issues) { + std::vector issues_status; + for (const auto& issue : issues.issues()) { + issues_status.push_back(issue.ToStatus()); + } + return issues_status; +} + +cel::expr::Expr ExprToProtoOrDie(const Expr& expr) { + cel::expr::Expr expr_proto; + ABSL_CHECK_OK(ExprToProto(expr, &expr_proto)); + return expr_proto; } TEST(ResolveReferences, Basic) { - Expr expr = ParseTestProto(kExpr); - google::protobuf::Map reference_map; - reference_map[2].set_name("foo.bar.var1"); - reference_map[5].set_name("bar.foo.var2"); - BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, &warnings); - ASSERT_OK(result); - EXPECT_THAT(result.value(), Optional(EqualsProto(R"( + std::unique_ptr expr_ast = ParseTestProto(kExpr); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.var2"); + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + + auto result = ResolveReferences(registry, issues, *expr_ast); + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_&&_" @@ -87,27 +151,43 @@ TEST(ResolveReferences, Basic) { id: 5 ident_expr { name: "bar.foo.var2" } } - })"))); + })pb")); } -TEST(ResolveReferences, ReturnsNulloptIfNoChanges) { - Expr expr = ParseTestProto(kExpr); - google::protobuf::Map reference_map; - BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, &warnings); - ASSERT_OK(result); - EXPECT_THAT(result.value(), Eq(absl::nullopt)); +TEST(ResolveReferences, ReturnsFalseIfNoChanges) { + std::unique_ptr expr_ast = ParseTestProto(kExpr); + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + + auto result = ResolveReferences(registry, issues, *expr_ast); + ASSERT_THAT(result, IsOkAndHolds(false)); + + // reference to the same name also doesn't count as a rewrite. + expr_ast->mutable_reference_map()[4].set_name("foo"); + expr_ast->mutable_reference_map()[7].set_name("bar"); + + result = ResolveReferences(registry, issues, *expr_ast); + ASSERT_THAT(result, IsOkAndHolds(false)); } TEST(ResolveReferences, NamespacedIdent) { - Expr expr = ParseTestProto(kExpr); - google::protobuf::Map reference_map; - BuilderWarnings warnings; - reference_map[2].set_name("foo.bar.var1"); - reference_map[7].set_name("namespace_x.bar"); - auto result = ResolveReferences(expr, reference_map, &warnings); - ASSERT_OK(result); - EXPECT_THAT(result.value(), Optional(EqualsProto(R"( + std::unique_ptr expr_ast = ParseTestProto(kExpr); + SourceInfo source_info; + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[7].set_name("namespace_x.bar"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_&&_" @@ -131,65 +211,11 @@ TEST(ResolveReferences, NamespacedIdent) { } } } - })"))); -} - -TEST(ResolveReferences, WarningOnUnsupportedExprKind) { - Expr expr = ParseTestProto( - R"( - id: 1 - comprehension_expr { - accu_init { - id: 2 - const_expr { int64_value: 1 } - } - accu_var: "__result__" - iter_var: "x" - iter_range { - id: 3 - list_expr { - elements { - id: 4 - const_expr { int64_value: 1 } - } - } - } - result { - id: 6 - ident_expr { name: "__result__" } - } - loop_condition { - id: 5 - const_expr { bool_value: true } - } - loop_step { - id: 7 - call_expr { - function: "_+_" - args { - id: 8 - ident_expr { name: "x" } - } - args { - id: 9 - ident_expr { name: "__result__" } - } - } - } - })"); - google::protobuf::Map reference_map; - BuilderWarnings warnings; - reference_map[1].set_name("foo"); - auto result = ResolveReferences(expr, reference_map, &warnings); - ASSERT_OK(result); - EXPECT_THAT(result.value(), Eq(absl::nullopt)); - EXPECT_THAT(warnings.warnings(), - ElementsAre(Eq(absl::Status(absl::StatusCode::kInvalidArgument, - "Unsupported reference kind: 9")))); + })pb")); } TEST(ResolveReferences, WarningOnPresenceTest) { - Expr expr = ParseTestProto(R"( + std::unique_ptr expr_ast = ParseTestProto(R"pb( id: 1 select_expr { field: "var1" @@ -204,30 +230,193 @@ TEST(ResolveReferences, WarningOnPresenceTest) { } } } - })"); - google::protobuf::Map reference_map; - BuilderWarnings warnings; - reference_map[1].set_name("foo.bar.var1"); - auto result = ResolveReferences(expr, reference_map, &warnings); - ASSERT_OK(result); - EXPECT_THAT(result.value(), Eq(absl::nullopt)); + })pb"); + SourceInfo source_info; + + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[1].set_name("foo.bar.var1"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT( - warnings.warnings(), + ExtractIssuesStatus(issues), testing::ElementsAre(Eq(absl::Status( absl::StatusCode::kInvalidArgument, "Reference map points to a presence test -- has(container.attr)")))); } +// foo.bar.var1 == bar.foo.Enum.ENUM_VAL1 +constexpr char kEnumExpr[] = R"( + id: 1 + call_expr { + function: "_==_" + args { + id: 2 + select_expr { + field: "var1" + operand { + id: 3 + select_expr { + field: "bar" + operand { + id: 4 + ident_expr { name: "foo" } + } + } + } + } + } + args { + id: 5 + ident_expr { name: "bar.foo.Enum.ENUM_VAL1" } + } + } +)"; + +TEST(ResolveReferences, EnumConstReferenceUsed) { + std::unique_ptr expr_ast = ParseTestProto(kEnumExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); + expr_ast->mutable_reference_map()[5].mutable_value().set_int64_value(9); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "_==_" + args { + id: 2 + ident_expr { name: "foo.bar.var1" } + } + args { + id: 5 + const_expr { int64_value: 9 } + } + })pb")); +} + +TEST(ResolveReferences, EnumConstReferenceUsedSelect) { + std::unique_ptr expr_ast = ParseTestProto(kEnumExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[2].mutable_value().set_int64_value(2); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); + expr_ast->mutable_reference_map()[5].mutable_value().set_int64_value(9); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "_==_" + args { + id: 2 + const_expr { int64_value: 2 } + } + args { + id: 5 + const_expr { int64_value: 9 } + } + })pb")); +} + +// foo && bar +constexpr char kConstReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + ident_expr { + name: "foo" + } + } + args { + id: 5 + ident_expr { + name: "bar" + } + } + } +)"; + +TEST(ResolveReferences, ConstReferenceFolded) { + std::unique_ptr expr_ast = ParseTestProto(kConstReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo"); + expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); + expr_ast->mutable_reference_map()[5].set_name("bar"); + expr_ast->mutable_reference_map()[5].mutable_value().set_bool_value(false); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + const_expr { bool_value: true } + } + args { + id: 5 + const_expr { bool_value: false } + } + })pb")); +} + TEST(ResolveReferences, ConstReferenceSkipped) { - Expr expr = ParseTestProto(kExpr); - google::protobuf::Map reference_map; - reference_map[2].set_name("foo.bar.var1"); - reference_map[2].mutable_value()->set_bool_value(true); - reference_map[5].set_name("bar.foo.var2"); - BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, &warnings); - ASSERT_OK(result); - EXPECT_THAT(result.value(), Optional(EqualsProto(R"( + std::unique_ptr expr_ast = ParseTestProto(kExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.var2"); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_&&_" @@ -251,23 +440,574 @@ TEST(ResolveReferences, ConstReferenceSkipped) { id: 5 ident_expr { name: "bar.foo.var2" } } - })"))); + })pb")); +} + +constexpr char kNullValueReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_+_" + args { + id: 2 + ident_expr { + name: "google.protobuf.NullValue.NULL_VALUE" + } + } + args { + id: 5 + const_expr { int64_value: 1 } + } + } +)"; + +TEST(ResolveReferences, NullValueReferenceSkipped) { + std::unique_ptr expr_ast = ParseTestProto(kNullValueReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name( + "google.protobuf.NullValue.NULL_VALUE"); + expr_ast->mutable_reference_map()[2].mutable_value().set_null_value(nullptr); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(/*was_rewritten=*/false)); +} + +constexpr char kExtensionAndExpr[] = R"( +id: 1 +call_expr { + function: "boolean_and" + args { + id: 2 + const_expr { + bool_value: true + } + } + args { + id: 3 + const_expr { + bool_value: false + } + } +})"; + +TEST(ResolveReferences, FunctionReferenceBasic) { + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction( + CelFunctionDescriptor("boolean_and", false, + { + CelValue::Type::kBool, + CelValue::Type::kBool, + }))); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + IssueCollector issues(RuntimeIssue::Severity::kError); + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); +} + +TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + IssueCollector issues(RuntimeIssue::Severity::kError); + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT(ExtractIssuesStatus(issues), + ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); +} + +TEST(ResolveReferences, SpecialBuiltinsNotWarned) { + std::unique_ptr expr_ast = ParseTestProto(R"pb( + id: 1 + call_expr { + function: "*" + args { + id: 2 + const_expr { bool_value: true } + } + args { + id: 3 + const_expr { bool_value: false } + } + })pb"); + SourceInfo source_info; + + std::vector special_builtins{ + cel::builtin::kAnd, cel::builtin::kOr, cel::builtin::kTernary, + cel::builtin::kIndex}; + for (const char* builtin_fn : special_builtins) { + // Builtins aren't in the function registry. + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + IssueCollector issues(RuntimeIssue::Severity::kError); + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + absl::StrCat("builtin.", builtin_fn)); + expr_ast->mutable_root_expr().mutable_call_expr().set_function(builtin_fn); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); + } +} + +TEST(ResolveReferences, + FunctionReferenceMissingOverloadDetectedAndMissingReference) { + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + IssueCollector issues(RuntimeIssue::Severity::kError); + expr_ast->mutable_reference_map()[1].set_name("udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT( + ExtractIssuesStatus(issues), + UnorderedElementsAre( + Eq(absl::InvalidArgumentError( + "No overload found in reference resolve step for boolean_and")), + Eq(absl::InvalidArgumentError( + "Reference map doesn't provide overloads for boolean_and")))); +} + +TEST(ResolveReferences, EmulatesEagerFailing) { + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + IssueCollector issues(RuntimeIssue::Severity::kWarning); + expr_ast->mutable_reference_map()[1].set_name("udf_boolean_and"); + + EXPECT_THAT( + ResolveReferences(registry, issues, *expr_ast), + StatusIs(absl::StatusCode::kInvalidArgument, + "Reference map doesn't provide overloads for boolean_and")); +} + +TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].mutable_overload_id().push_back( + "udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT(ExtractIssuesStatus(issues), + ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); +} + +constexpr char kReceiverCallExtensionAndExpr[] = R"( +id: 1 +call_expr { + function: "boolean_and" + target { + id: 2 + ident_expr { + name: "ext" + } + } + args { + id: 3 + const_expr { + bool_value: false + } + } +})"; + +TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { + std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( + "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); +} + +TEST(ResolveReferences, + FunctionReferenceWithTargetNoChangeMissingOverloadDetected) { + std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT(ExtractIssuesStatus(issues), + ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); +} + +TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { + std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( + "ext.boolean_and", false, {CelValue::Type::kBool}))); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "ext.boolean_and" + args { + id: 3 + const_expr { bool_value: false } + } + } + )pb")); + EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); } -TEST(ResolveReferences, FunctionReferenceSkipped) { - Expr expr = ParseTestProto(kExpr); - google::protobuf::Map reference_map; - BuilderWarnings warnings; - reference_map[1].set_name("@user_defined_boolean_and"); - reference_map[1].add_overload_id("@user_defined_boolean_and_overload1"); - auto result = ResolveReferences(expr, reference_map, &warnings); - ASSERT_OK(result); - EXPECT_THAT(result.value(), Eq(absl::nullopt)); +TEST(ResolveReferences, + FunctionReferenceWithTargetToNamespacedFunctionInContainer) { + std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( + "com.google.ext.boolean_and", false, {CelValue::Type::kBool}))); + cel::TypeRegistry type_registry; + std::vector namespace_prefixes{"com.google.", "google.", ""}; + Resolver registry("com.google", func_registry.InternalGetRegistry(), + type_registry, type_registry.GetComposedTypeProvider()); + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "com.google.ext.boolean_and" + args { + id: 3 + const_expr { bool_value: false } + } + } + )pb")); + EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); } +// has(ext.option).boolean_and(false) +constexpr char kReceiverCallHasExtensionAndExpr[] = R"( +id: 1 +call_expr { + function: "boolean_and" + target { + id: 2 + select_expr { + test_only: true + field: "option" + operand { + id: 3 + ident_expr { + name: "ext" + } + } + } + } + args { + id: 4 + const_expr { + bool_value: false + } + } +})"; + +TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { + std::unique_ptr expr_ast = + ParseTestProto(kReceiverCallHasExtensionAndExpr); + SourceInfo source_info; + + IssueCollector issues(RuntimeIssue::Severity::kError); + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( + "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( + "ext.option.boolean_and", true, {CelValue::Type::kBool}))); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + // The target is unchanged because it is a test_only select. + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), + EqualsProto(kReceiverCallHasExtensionAndExpr)); + EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); +} + +constexpr char kComprehensionExpr[] = R"( +id:17 +comprehension_expr: { + iter_var:"i" + iter_range:{ + id:1 + list_expr:{ + elements:{ + id:2 + const_expr:{int64_value:1} + } + elements:{ + id:3 + ident_expr:{name:"ENUM"} + } + elements:{ + id:4 + const_expr:{int64_value:3} + } + } + } + accu_var:"__result__" + accu_init: { + id:10 + const_expr:{bool_value:false} + } + loop_condition:{ + id:13 + call_expr:{ + function:"@not_strictly_false" + args:{ + id:12 + call_expr:{ + function:"!_" + args:{ + id:11 + ident_expr:{name:"__result__"} + } + } + } + } + } + loop_step:{ + id:15 + call_expr: { + function:"_||_" + args:{ + id:14 + ident_expr: {name:"__result__"} + } + args:{ + id:8 + call_expr:{ + function:"_==_" + args:{ + id:7 ident_expr:{name:"ENUM"} + } + args:{ + id:9 ident_expr:{name:"i"} + } + } + } + } + } + result:{id:16 ident_expr:{name:"__result__"}} +} +)"; +TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { + std::unique_ptr expr_ast = ParseTestProto(kComprehensionExpr); + + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[3].set_name("ENUM"); + expr_ast->mutable_reference_map()[3].mutable_value().set_int64_value(2); + expr_ast->mutable_reference_map()[7].set_name("ENUM"); + expr_ast->mutable_reference_map()[7].mutable_value().set_int64_value(2); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 17 + comprehension_expr { + iter_var: "i" + iter_range { + id: 1 + list_expr { + elements { + id: 2 + const_expr { int64_value: 1 } + } + elements { + id: 3 + const_expr { int64_value: 2 } + } + elements { + id: 4 + const_expr { int64_value: 3 } + } + } + } + accu_var: "__result__" + accu_init { + id: 10 + const_expr { bool_value: false } + } + loop_condition { + id: 13 + call_expr { + function: "@not_strictly_false" + args { + id: 12 + call_expr { + function: "!_" + args { + id: 11 + ident_expr { name: "__result__" } + } + } + } + } + } + loop_step { + id: 15 + call_expr { + function: "_||_" + args { + id: 14 + ident_expr { name: "__result__" } + } + args { + id: 8 + call_expr { + function: "_==_" + args { + id: 7 + const_expr { int64_value: 2 } + } + args { + id: 9 + ident_expr { name: "i" } + } + } + } + } + } + result { + id: 16 + ident_expr { name: "__result__" } + } + })pb")); +} + +TEST(ResolveReferences, ReferenceToId0Warns) { + // ID 0 is unsupported since it is not normally used by parsers and is + // ambiguous as an intentional ID or default for unset field. + std::unique_ptr expr_ast = ParseTestProto(R"pb( + id: 0 + select_expr { + operand { + id: 1 + ident_expr { name: "pkg" } + } + field: "var" + })pb"); + + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[0].set_name("pkg.var"); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 0 + select_expr { + operand { + id: 1 + ident_expr { name: "pkg" } + } + field: "var" + })pb")); + EXPECT_THAT( + ExtractIssuesStatus(issues), + Contains(StatusIs( + absl::StatusCode::kInvalidArgument, + "reference map entries for expression id 0 are not supported"))); +} } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc new file mode 100644 index 000000000..38ef842b9 --- /dev/null +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -0,0 +1,274 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/regex_precompilation_optimization.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/builtins.h" +#include "common/ast.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/eval/compiler_constant_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/regex_match_step.h" +#include "internal/casts.h" +#include "internal/re2_options.h" +#include "internal/status_macros.h" +#include "re2/re2.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::Ast; +using ::cel::CallExpr; +using ::cel::Cast; +using ::cel::Expr; +using ::cel::InstanceOf; +using ::cel::NativeTypeId; +using ::cel::Reference; +using ::cel::StringValue; +using ::cel::Value; +using ::cel::internal::down_cast; + +using ReferenceMap = absl::flat_hash_map; + +bool IsFunctionOverload(const Expr& expr, absl::string_view function, + absl::string_view overload, size_t arity, + const ReferenceMap& reference_map) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call_expr = expr.call_expr(); + if (call_expr.function() != function) { + return false; + } + if (call_expr.args().size() + (call_expr.has_target() ? 1 : 0) != arity) { + return false; + } + + // If parse-only and opted in to the optimization, assume this is the intended + // overload. This will still only change the evaluation plan if the second arg + // is a constant string. + if (reference_map.empty()) { + return true; + } + + auto reference = reference_map.find(expr.id()); + if (reference != reference_map.end() && + reference->second.overload_id().size() == 1 && + reference->second.overload_id().front() == overload) { + return true; + } + return false; +} + +// Abstraction for deduplicating regular expressions over the course of a single +// create expression call. Should not be used during evaluation. Uses +// std::shared_ptr and std::weak_ptr. +class RegexProgramBuilder final { + public: + explicit RegexProgramBuilder(int max_program_size) + : max_program_size_(max_program_size) {} + + absl::StatusOr> BuildRegexProgram( + std::string pattern) { + auto existing = programs_.find(pattern); + if (existing != programs_.end()) { + if (auto program = existing->second.lock(); program) { + return program; + } + programs_.erase(existing); + } + auto program = + std::make_shared(pattern, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(*program, max_program_size_)); + programs_.insert({std::move(pattern), program}); + return program; + } + + private: + const int max_program_size_; + absl::flat_hash_map> programs_; +}; + +class RegexPrecompilationOptimization : public ProgramOptimizer { + public: + explicit RegexPrecompilationOptimization(const ReferenceMap& reference_map, + int regex_max_program_size) + : reference_map_(reference_map), + regex_program_builder_(regex_max_program_size) {} + + absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { + return absl::OkStatus(); + } + + absl::Status OnPostVisit(PlannerContext& context, const Expr& node) override { + // Check that this is the correct matches overload instead of a user defined + // overload. + if (!IsFunctionOverload(node, cel::builtin::kRegexMatch, "matches_string", + 2, reference_map_)) { + return absl::OkStatus(); + } + + ProgramBuilder::Subexpression* subexpression = + context.program_builder().GetSubexpression(&node); + + const CallExpr& call_expr = node.call_expr(); + const Expr& pattern_expr = call_expr.args().back(); + + // Try to check if the regex is valid, whether or not we can actually update + // the plan. + std::optional pattern = + GetConstantString(context, subexpression, node, pattern_expr); + if (!pattern.has_value()) { + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN( + std::shared_ptr regex_program, + regex_program_builder_.BuildRegexProgram(std::move(pattern).value())); + + if (subexpression == nullptr || subexpression->IsFlattened()) { + // Already modified, can't update further. + return absl::OkStatus(); + } + + const Expr& subject_expr = + call_expr.has_target() ? call_expr.target() : call_expr.args().front(); + + return RewritePlan(context, subexpression, node, subject_expr, + std::move(regex_program)); + } + + private: + std::optional GetConstantString( + PlannerContext& context, + ProgramBuilder::Subexpression* absl_nullable subexpression, + const Expr& call_expr, const Expr& re_expr) const { + if (re_expr.has_const_expr() && re_expr.const_expr().has_string_value()) { + return re_expr.const_expr().string_value(); + } + + if (subexpression == nullptr || subexpression->IsFlattened()) { + // Already modified, can't recover the input pattern. + return std::nullopt; + } + std::optional constant; + if (subexpression->IsRecursive()) { + const auto& program = subexpression->recursive_program(); + auto deps = program.step->GetDependencies(); + if (deps.has_value() && deps->size() == 2) { + const auto* re_plan = + TryDowncastDirectStep(deps->at(1)); + if (re_plan != nullptr) { + constant = re_plan->value(); + } + } + } else { + // otherwise stack-machine program. + ExecutionPathView re_plan = context.GetSubplan(re_expr); + if (re_plan.size() == 1 && + re_plan[0]->GetNativeTypeId() == + NativeTypeId::For()) { + constant = + down_cast(re_plan[0].get())->value(); + } + } + + if (constant.has_value() && InstanceOf(*constant)) { + return Cast(*constant).ToString(); + } + + return std::nullopt; + } + + absl::Status RewritePlan( + PlannerContext& context, + ProgramBuilder::Subexpression* absl_nonnull subexpression, + const Expr& call, const Expr& subject, + std::shared_ptr regex_program) { + if (subexpression->IsRecursive()) { + return RewriteRecursivePlan(subexpression, call, subject, + std::move(regex_program)); + } + return RewriteStackMachinePlan(context, call, subject, + std::move(regex_program)); + } + + absl::Status RewriteRecursivePlan( + ProgramBuilder::Subexpression* absl_nonnull subexpression, + const Expr& call, const Expr& subject, + std::shared_ptr regex_program) { + auto program = subexpression->ExtractRecursiveProgram(); + auto deps = program.step->ExtractDependencies(); + if (!deps.has_value() || deps->size() != 2) { + // Possibly already const-folded, put the plan back. + subexpression->set_recursive_program(std::move(program.step), + program.depth); + return absl::OkStatus(); + } + subexpression->set_recursive_program( + CreateDirectRegexMatchStep(call.id(), std::move(deps->at(0)), + std::move(regex_program)), + program.depth); + return absl::OkStatus(); + } + + absl::Status RewriteStackMachinePlan( + PlannerContext& context, const Expr& call, const Expr& subject, + std::shared_ptr regex_program) { + if (context.GetSubplan(subject).empty()) { + // This subexpression was already optimized, nothing to do. + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(ExecutionPath new_plan, + context.ExtractSubplan(subject)); + CEL_ASSIGN_OR_RETURN( + new_plan.emplace_back(), + CreateRegexMatchStep(std::move(regex_program), call.id())); + + return context.ReplaceSubplan(call, std::move(new_plan)); + } + + const ReferenceMap& reference_map_; + RegexProgramBuilder regex_program_builder_; +}; + +} // namespace + +ProgramOptimizerFactory CreateRegexPrecompilationExtension( + int regex_max_program_size) { + return [=](PlannerContext& context, const Ast& ast) { + return std::make_unique( + ast.reference_map(), regex_max_program_size); + }; +} +} // namespace google::api::expr::runtime diff --git a/eval/compiler/regex_precompilation_optimization.h b/eval/compiler/regex_precompilation_optimization.h new file mode 100644 index 000000000..7b15d9aae --- /dev/null +++ b/eval/compiler/regex_precompilation_optimization.h @@ -0,0 +1,29 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ + +#include "eval/compiler/flat_expr_builder_extensions.h" + +namespace google::api::expr::runtime { + +// Create a new extension for the FlatExprBuilder that precompiles constant +// regular expressions used in the standard 'Match' function. +ProgramOptimizerFactory CreateRegexPrecompilationExtension( + int regex_max_program_size); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc new file mode 100644 index 000000000..9666144b2 --- /dev/null +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -0,0 +1,285 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/regex_precompilation_optimization.h" + +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "common/ast.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" +#include "eval/compiler/constant_folding.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_type_registry.h" +#include "eval/public/cel_value.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::RuntimeIssue; +using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::google::api::expr::parser::Parse; +using ::testing::ElementsAre; + +namespace exprpb = cel::expr; + +class RegexPrecompilationExtensionTest : public testing::TestWithParam { + public: + RegexPrecompilationExtensionTest() + : env_(NewTestingRuntimeEnv()), + builder_(env_), + type_registry_(*builder_.GetTypeRegistry()), + function_registry_(*builder_.GetRegistry()), + resolver_("", function_registry_.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()), + issue_collector_(RuntimeIssue::Severity::kError) { + if (EnableRecursivePlanning()) { + options_.max_recursion_depth = -1; + options_.enable_recursive_tracing = true; + } + options_.enable_regex = true; + options_.regex_max_program_size = 100; + options_.enable_regex_precompilation = true; + runtime_options_ = ConvertToRuntimeOptions(options_); + } + + void SetUp() override { + ASSERT_OK(RegisterBuiltinFunctions(&function_registry_, options_)); + } + + bool EnableRecursivePlanning() { return GetParam(); } + + protected: + CelEvaluationListener RecordStringValues() { + return [this](int64_t, const CelValue& value, google::protobuf::Arena*) { + if (value.IsString()) { + string_values_.push_back(std::string(value.StringOrDie().value())); + } + return absl::OkStatus(); + }; + } + + absl_nonnull std::shared_ptr env_; + CelExpressionBuilderFlatImpl builder_; + CelTypeRegistry& type_registry_; + CelFunctionRegistry& function_registry_; + InterpreterOptions options_; + cel::RuntimeOptions runtime_options_; + Resolver resolver_; + IssueCollector issue_collector_; + std::vector string_values_; +}; + +TEST_P(RegexPrecompilationExtensionTest, SmokeTest) { + ProgramOptimizerFactory factory = + CreateRegexPrecompilationExtension(options_.regex_max_program_size); + ExecutionPath path; + ProgramBuilder program_builder; + cel::Ast ast_impl; + ast_impl.set_is_checked(true); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, runtime_options_, + type_registry_.GetTypeProvider(), issue_collector_, + program_builder, arena); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr optimizer, + factory(context, ast_impl)); +} + +TEST_P(RegexPrecompilationExtensionTest, OptimizeableExpression) { + builder_.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches(r'[a-zA-Z]+[0-9]*')")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); +} + +TEST_P(RegexPrecompilationExtensionTest, OptimizeParsedExpr) { + builder_.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr expr, + Parse("input.matches(r'[a-zA-Z]+[0-9]*')")); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr plan, + builder_.CreateExpression(&expr.expr(), &expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); +} + +TEST_P(RegexPrecompilationExtensionTest, DoesNotOptimizeNonConstRegex) { + builder_.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches(input_re)")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + activation.InsertValue("input_re", CelValue::CreateStringView("input_re")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123", "input_re")); +} + +TEST_P(RegexPrecompilationExtensionTest, DoesNotOptimizeCompoundExpr) { + builder_.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches('abc' + 'def')")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123", "abc", "def", "abcdef")); +} + +class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { + public: + RegexConstFoldInteropTest() : RegexPrecompilationExtensionTest() { + builder_.flat_expr_builder().AddProgramOptimizer( + cel::runtime_internal::CreateConstantFoldingOptimizer()); + } + + protected: + google::protobuf::Arena arena_; +}; + +TEST_P(RegexConstFoldInteropTest, StringConstantOptimizeable) { + builder_.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches('abc' + 'def')")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); +} + +TEST_P(RegexConstFoldInteropTest, WrongTypeNotOptimized) { + builder_.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches(123 + 456)")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK_AND_ASSIGN(CelValue result, + plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); + EXPECT_TRUE(result.IsError()); + EXPECT_TRUE(CheckNoMatchingOverloadError(result)); +} + +INSTANTIATE_TEST_SUITE_P(RegexPrecompilationExtensionTest, + RegexPrecompilationExtensionTest, testing::Bool()); + +INSTANTIATE_TEST_SUITE_P(RegexConstFoldInteropTest, RegexConstFoldInteropTest, + testing::Bool()); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc new file mode 100644 index 000000000..cca72964a --- /dev/null +++ b/eval/compiler/resolver.cc @@ -0,0 +1,222 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/resolver.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/kind.h" +#include "common/type.h" +#include "common/type_reflector.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_registry.h" +#include "runtime/type_registry.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::TypeValue; +using ::cel::Value; +using ::cel::runtime_internal::GetEnumValueTable; + +std::vector MakeNamespaceCandidates(absl::string_view container) { + std::vector namespace_prefixes; + std::string prefix = ""; + namespace_prefixes.push_back(prefix); + auto container_elements = absl::StrSplit(container, '.'); + for (const auto& elem : container_elements) { + // Tolerate trailing / leading '.'. + if (elem.empty()) { + continue; + } + absl::StrAppend(&prefix, elem, "."); + // longest prefix first. + namespace_prefixes.insert(namespace_prefixes.begin(), prefix); + } + return namespace_prefixes; +} + +} // namespace + +Resolver::Resolver(absl::string_view container, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::TypeReflector& type_reflector, + bool resolve_qualified_type_identifiers) + : namespace_prefixes_(MakeNamespaceCandidates(container)), + enum_value_map_(GetEnumValueTable(type_registry)), + function_registry_(function_registry), + type_reflector_(type_reflector), + resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) {} + +std::vector Resolver::FullyQualifiedNames(absl::string_view name, + int64_t expr_id) const { + // TODO(issues/105): refactor the reference resolution into this method. + // and handle the case where this id is in the reference map as either a + // function name or identifier name. + std::vector names; + + auto prefixes = GetPrefixesFor(name); + names.reserve(prefixes.size()); + for (const auto& prefix : prefixes) { + std::string fully_qualified_name = absl::StrCat(prefix, name); + names.push_back(fully_qualified_name); + } + return names; +} + +absl::Span Resolver::GetPrefixesFor( + absl::string_view& name) const { + static const absl::NoDestructor kEmptyPrefix(""); + if (absl::StartsWith(name, ".")) { + name = name.substr(1); + return absl::MakeConstSpan(kEmptyPrefix.get(), 1); + } + return namespace_prefixes_; +} + +std::optional Resolver::FindConstant(absl::string_view name, + int64_t expr_id) const { + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + // Attempt to resolve the fully qualified name to a known enum. + auto enum_entry = enum_value_map_->find(qualified_name); + if (enum_entry != enum_value_map_->end()) { + return enum_entry->second; + } + // Attempt to resolve the fully qualified name to a known type. + if (resolve_qualified_type_identifiers_) { + auto type_value = type_reflector_.FindType(qualified_name); + if (type_value.ok() && type_value->has_value()) { + return TypeValue(**type_value); + } + } + } + + if (!resolve_qualified_type_identifiers_ && !absl::StrContains(name, '.')) { + auto type_value = type_reflector_.FindType(name); + + if (type_value.ok() && type_value->has_value()) { + return TypeValue(**type_value); + } + } + return std::nullopt; +} + +std::vector Resolver::FindOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id) const { + // Resolve the fully qualified names and then search the function registry + // for possible matches. + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (auto it = names.begin(); it != names.end(); it++) { + // Only one set of overloads is returned along the namespace hierarchy as + // the function name resolution follows the same behavior as variable name + // resolution, meaning the most specific definition wins. This is different + // from how C++ namespaces work, as they will accumulate the overload set + // over the namespace hierarchy. + funcs = function_registry_.FindStaticOverloads(*it, receiver_style, types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +std::vector Resolver::FindOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id) const { + std::vector funcs; + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + // Only one set of overloads is returned along the namespace hierarchy as + // the function name resolution follows the same behavior as variable name + // resolution, meaning the most specific definition wins. This is different + // from how C++ namespaces work, as they will accumulate the overload set + // over the namespace hierarchy. + funcs = function_registry_.FindStaticOverloadsByArity( + qualified_name, receiver_style, arity); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +std::vector Resolver::FindLazyOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id) const { + // Resolve the fully qualified names and then search the function registry + // for possible matches. + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + funcs = function_registry_.FindLazyOverloads(name, receiver_style, types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +std::vector Resolver::FindLazyOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id) const { + std::vector funcs; + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + funcs = function_registry_.FindLazyOverloadsByArity(name, receiver_style, + arity); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +absl::StatusOr>> +Resolver::FindType(absl::string_view name, int64_t expr_id) const { + auto prefixes = GetPrefixesFor(name); + for (auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + CEL_ASSIGN_OR_RETURN(auto maybe_type, + type_reflector_.FindType(qualified_name)); + if (maybe_type.has_value()) { + return std::make_pair(std::move(qualified_name), std::move(*maybe_type)); + } + } + return std::nullopt; +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h new file mode 100644 index 000000000..de7b22f26 --- /dev/null +++ b/eval/compiler/resolver.h @@ -0,0 +1,127 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/kind.h" +#include "common/type.h" +#include "common/type_reflector.h" +#include "common/value.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_registry.h" +#include "runtime/type_registry.h" + +namespace google::api::expr::runtime { + +// Resolver assists with finding functions and types from the associated +// registries within a container. +// +// container is used to construct the namespace lookup candidates. +// e.g. for "cel.dev" -> {"cel.dev.", "cel.", ""} +class Resolver { + public: + Resolver(absl::string_view container, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::TypeReflector& type_reflector, + bool resolve_qualified_type_identifiers = true); + + Resolver(const Resolver&) = delete; + Resolver& operator=(const Resolver&) = delete; + Resolver(Resolver&&) = delete; + Resolver& operator=(Resolver&&) = delete; + + ~Resolver() = default; + + // FindConstant will return an enum constant value or a type value if one + // exists for the given name. An empty handle will be returned if none exists. + // + // Since enums and type identifiers are specified as (potentially) qualified + // names within an expression, there is the chance that the name provided + // is a variable name which happens to collide with an existing enum or proto + // based type name. For this reason, within parsed only expressions, the + // constant should be treated as a value that can be shadowed by a runtime + // provided value. + absl::optional FindConstant(absl::string_view name, + int64_t expr_id) const; + + absl::StatusOr>> FindType( + absl::string_view name, int64_t expr_id) const; + + // FindLazyOverloads returns the set, possibly empty, of lazy overloads + // matching the given function signature. + std::vector FindLazyOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + + std::vector FindLazyOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id = -1) const; + + // FindOverloads returns the set, possibly empty, of eager function overloads + // matching the given function signature. + std::vector FindOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + + std::vector FindOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id = -1) const; + + // FullyQualifiedNames returns the set of fully qualified names which may be + // derived from the base_name within the specified expression container. + std::vector FullyQualifiedNames(absl::string_view base_name, + int64_t expr_id = -1) const; + + private: + absl::Span GetPrefixesFor(absl::string_view& name) const; + + std::vector namespace_prefixes_; + std::shared_ptr> + enum_value_map_; + const cel::FunctionRegistry& function_registry_; + const cel::TypeReflector& type_reflector_; + + bool resolve_qualified_type_identifiers_; +}; + +// ArgumentMatcher generates a function signature matcher for CelFunctions. +// TODO(issues/91): this is the same behavior as parsed exprs in the CPP +// evaluator (just check the right call style and number of arguments), but we +// should have enough type information in a checked expr to find a more +// specific candidate list. +inline std::vector ArgumentsMatcher(int argument_count) { + std::vector argument_matcher(argument_count); + for (int i = 0; i < argument_count; i++) { + argument_matcher[i] = cel::Kind::kAny; + } + return argument_matcher; +} + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc new file mode 100644 index 000000000..212790b22 --- /dev/null +++ b/eval/compiler/resolver_test.cc @@ -0,0 +1,239 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/resolver.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "common/value.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "eval/public/cel_value.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::IntValue; +using ::cel::TypeValue; +using ::testing::Eq; + +class FakeFunction : public CelFunction { + public: + explicit FakeFunction(const std::string& name) + : CelFunction(CelFunctionDescriptor{name, false, {}}) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + return absl::OkStatus(); + } +}; + +class ResolverTest : public testing::Test { + public: + ResolverTest() = default; + + protected: + CelTypeRegistry type_registry_; +}; + +TEST_F(ResolverTest, TestFullyQualifiedNames) { + CelFunctionRegistry func_registry; + Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto names = resolver.FullyQualifiedNames("simple_name"); + std::vector expected_names( + {"google.api.expr.simple_name", "google.api.simple_name", + "google.simple_name", "simple_name"}); + EXPECT_THAT(names, Eq(expected_names)); +} + +TEST_F(ResolverTest, TestFullyQualifiedNamesPartiallyQualifiedName) { + CelFunctionRegistry func_registry; + Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto names = resolver.FullyQualifiedNames("expr.simple_name"); + std::vector expected_names( + {"google.api.expr.expr.simple_name", "google.api.expr.simple_name", + "google.expr.simple_name", "expr.simple_name"}); + EXPECT_THAT(names, Eq(expected_names)); +} + +TEST_F(ResolverTest, TestFullyQualifiedNamesAbsoluteName) { + CelFunctionRegistry func_registry; + Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto names = resolver.FullyQualifiedNames(".google.api.expr.absolute_name"); + EXPECT_THAT(names.size(), Eq(1)); + EXPECT_THAT(names[0], Eq("google.api.expr.absolute_name")); +} + +TEST_F(ResolverTest, TestFindConstantEnum) { + CelFunctionRegistry func_registry; + type_registry_.Register(TestMessage::TestEnum_descriptor()); + + Resolver resolver("google.api.expr.runtime.TestMessage", + func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto enum_value = resolver.FindConstant("TestEnum.TEST_ENUM_1", -1); + ASSERT_TRUE(enum_value); + ASSERT_TRUE(enum_value->Is()); + EXPECT_THAT(enum_value->GetInt().NativeValue(), Eq(1L)); + + enum_value = resolver.FindConstant( + ".google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_2", -1); + ASSERT_TRUE(enum_value); + ASSERT_TRUE(enum_value->Is()); + EXPECT_THAT(enum_value->GetInt().NativeValue(), Eq(2L)); +} + +TEST_F(ResolverTest, TestFindConstantUnqualifiedType) { + CelFunctionRegistry func_registry; + Resolver resolver("cel", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto type_value = resolver.FindConstant("int", -1); + EXPECT_TRUE(type_value); + EXPECT_TRUE(type_value->Is()); + EXPECT_THAT(type_value->GetType().name(), Eq("int")); +} + +TEST_F(ResolverTest, TestFindConstantFullyQualifiedType) { + google::protobuf::LinkMessageReflection(); + CelFunctionRegistry func_registry; + Resolver resolver("cel", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto type_value = + resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); + ASSERT_TRUE(type_value); + ASSERT_TRUE(type_value->Is()); + EXPECT_THAT(type_value->GetType().name(), + Eq("google.api.expr.runtime.TestMessage")); +} + +TEST_F(ResolverTest, TestFindConstantQualifiedTypeDisabled) { + CelFunctionRegistry func_registry; + Resolver resolver("", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider(), false); + auto type_value = + resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); + EXPECT_FALSE(type_value); +} + +TEST_F(ResolverTest, FindTypeBySimpleName) { + CelFunctionRegistry func_registry; + Resolver resolver("google.api.expr.runtime", + func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + ASSERT_OK_AND_ASSIGN(auto type, resolver.FindType("TestMessage", -1)); + EXPECT_TRUE(type.has_value()); + EXPECT_EQ(type->second.name(), "google.api.expr.runtime.TestMessage"); +} + +TEST_F(ResolverTest, FindTypeByQualifiedName) { + CelFunctionRegistry func_registry; + Resolver resolver("google.api.expr.runtime", + func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + ASSERT_OK_AND_ASSIGN( + auto type, resolver.FindType(".google.api.expr.runtime.TestMessage", -1)); + ASSERT_TRUE(type.has_value()); + EXPECT_EQ(type->second.name(), "google.api.expr.runtime.TestMessage"); +} + +TEST_F(ResolverTest, TestFindDescriptorNotFound) { + CelFunctionRegistry func_registry; + Resolver resolver("google.api.expr.runtime", + func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + ASSERT_OK_AND_ASSIGN(auto type, resolver.FindType("UndefinedMessage", -1)); + EXPECT_FALSE(type.has_value()) << type->second; +} + +TEST_F(ResolverTest, TestFindOverloads) { + CelFunctionRegistry func_registry; + auto status = + func_registry.Register(std::make_unique("fake_func")); + ASSERT_OK(status); + status = func_registry.Register( + std::make_unique("cel.fake_ns_func")); + ASSERT_OK(status); + + Resolver resolver("cel", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto overloads = + resolver.FindOverloads("fake_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + EXPECT_THAT(overloads[0].descriptor.name(), Eq("fake_func")); + + overloads = + resolver.FindOverloads("fake_ns_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + EXPECT_THAT(overloads[0].descriptor.name(), Eq("cel.fake_ns_func")); +} + +TEST_F(ResolverTest, TestFindLazyOverloads) { + CelFunctionRegistry func_registry; + auto status = func_registry.RegisterLazyFunction( + CelFunctionDescriptor{"fake_lazy_func", false, {}}); + ASSERT_OK(status); + status = func_registry.RegisterLazyFunction( + CelFunctionDescriptor{"cel.fake_lazy_ns_func", false, {}}); + ASSERT_OK(status); + + Resolver resolver("cel", func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); + + auto overloads = + resolver.FindLazyOverloads("fake_lazy_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + + overloads = resolver.FindLazyOverloads("fake_lazy_ns_func", false, + ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 409a52f00..44c7ded79 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -1,11 +1,37 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + # This package contains implementation of expression evaluator # internals. package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) exports_files(["LICENSE"]) +package_group( + name = "internal_eval_visibility", + packages = [ + "//eval/...", + "//extensions", + "//runtime/internal", + ], +) + cc_library( name = "evaluator_core", srcs = [ @@ -15,59 +41,150 @@ cc_library( "evaluator_core.h", ], deps = [ - ":attribute_trail", ":attribute_utility", - "//base:status_macros", - "//eval/public:activation", - "//eval/public:cel_attribute", - "//eval/public:cel_expression", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", + ":comprehension_slots", + ":evaluator_stack", + ":iterator_stack", + "//base:data", + "//common:native_type", + "//common:value", + "//runtime", + "//runtime:activation_interface", + "//runtime:runtime_options", + "//runtime/internal:activation_attribute_matcher_access", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "expression_step_base", + name = "cel_expression_flat_impl", + srcs = [ + "cel_expression_flat_impl.cc", + ], hdrs = [ - "expression_step_base.h", + "cel_expression_flat_impl.h", ], deps = [ + ":attribute_trail", + ":comprehension_slots", + ":direct_expression_step", ":evaluator_core", - "//eval/public:activation", + "//common:native_type", + "//common:value", + "//eval/internal:adapter_activation_impl", + "//eval/internal:interop", + "//eval/public:base_activation", "//eval/public:cel_expression", "//eval/public:cel_value", + "//internal:casts", + "//internal:status_macros", + "//runtime/internal:runtime_env", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "const_value_step", + name = "comprehension_slots", + hdrs = [ + "comprehension_slots.h", + ], + deps = [ + ":attribute_trail", + "//common:value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "comprehension_slots_test", srcs = [ - "const_value_step.cc", + "comprehension_slots_test.cc", + ], + deps = [ + ":attribute_trail", + ":comprehension_slots", + "//base:attributes", + "//base:data", + "//common:memory", + "//common:value", + "//internal:testing", + ], +) + +cc_library( + name = "evaluator_stack", + srcs = [ + "evaluator_stack.cc", ], + hdrs = [ + "evaluator_stack.h", + ], + deps = [ + ":attribute_trail", + "//common:value", + "//internal:align", + "//internal:new", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "evaluator_stack_test", + srcs = [ + "evaluator_stack_test.cc", + ], + deps = [ + ":evaluator_stack", + "//base:attributes", + "//common:value", + "//internal:testing", + ], +) + +cc_library( + name = "expression_step_base", + hdrs = [ + "expression_step_base.h", + ], + deps = [":evaluator_core"], +) + +cc_library( + name = "const_value_step", hdrs = [ "const_value_step.h", ], deps = [ + ":compiler_constant_step", + ":direct_expression_step", ":evaluator_core", - ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_expression", - "//eval/public:cel_value", - "//eval/public/structs:cel_proto_wrapper", + "//common:value", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", ], ) @@ -80,15 +197,46 @@ cc_library( "container_access_step.h", ], deps = [ + ":attribute_trail", + ":attribute_utility", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", + "//base:attributes", + "//common:casting", + "//common:expr", + "//common:kind", + "//common:value", + "//common:value_kind", + "//eval/internal:errors", + "//internal:number", + "//internal:status_macros", + "//runtime/internal:errors", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_protobuf//:protobuf", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "regex_match_step", + srcs = ["regex_match_step.cc"], + hdrs = ["regex_match_step.h"], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + ":expression_step_base", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_googlesource_code_re2//:re2", ], ) @@ -102,14 +250,18 @@ cc_library( ], deps = [ ":attribute_trail", + ":comprehension_slots", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", + "//common:expr", + "//common:value", + "//eval/internal:errors", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_protobuf//:protobuf", ], ) @@ -123,25 +275,29 @@ cc_library( ], deps = [ ":attribute_trail", + ":direct_expression_step", ":evaluator_core", - ":expression_build_warning", ":expression_step_base", - "//base:status_macros", - "//eval/public:activation", - "//eval/public:cel_builtins", - "//eval/public:cel_function", - "//eval/public:cel_function_provider", - "//eval/public:cel_function_registry", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", - "//eval/public:unknown_function_result_set", - "//eval/public:unknown_set", + "//common:casting", + "//common:expr", + "//common:function_descriptor", + "//common:kind", + "//common:value", + "//common:value_kind", + "//eval/internal:errors", + "//internal:status_macros", + "//runtime:activation_interface", + "//runtime:function", + "//runtime:function_overload_reference", + "//runtime:function_provider", + "//runtime:function_registry", + "//runtime/internal:errors", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", ], ) @@ -154,15 +310,24 @@ cc_library( "select_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_value", - "//eval/public/containers:field_access", - "//eval/public/containers:field_backed_list_impl", - "//eval/public/containers:field_backed_map_impl", + "//common:expr", + "//common:value", + "//common:value_kind", + "//eval/internal:errors", + "//internal:status_macros", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", ], ) @@ -175,15 +340,19 @@ cc_library( "create_list_step.h", ], deps = [ + ":attribute_trail", + ":attribute_utility", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_value", - "//eval/public/containers:container_backed_list_impl", + "//common:casting", + "//common:expr", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/types:optional", ], ) @@ -196,15 +365,42 @@ cc_library( "create_struct_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//eval/public/containers:container_backed_map_impl", - "//eval/public/containers:field_access", - "//eval/public/structs:cel_proto_wrapper", + "//common:casting", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "create_map_step", + srcs = [ + "create_map_step.cc", + ], + hdrs = [ + "create_map_step.h", + ], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + ":expression_step_base", + "//common:casting", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -219,11 +415,12 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_value", + "//common:value", + "//eval/internal:errors", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -236,16 +433,74 @@ cc_library( "logic_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_builtins", - "//eval/public:cel_function", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", + "//base:builtins", + "//common:casting", + "//common:value", + "//common:value_kind", + "//eval/internal:errors", + "//internal:status_macros", + "//runtime/internal:errors", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "equality_steps", + srcs = [ + "equality_steps.cc", + ], + hdrs = [ + "equality_steps.h", + ], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + ":expression_step_base", + "//base:builtins", + "//common:value", + "//common:value_kind", + "//internal:number", + "//internal:status_macros", + "//runtime/internal:errors", + "//runtime/standard:equality_functions", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "equality_steps_test", + srcs = [ + "equality_steps_test.cc", + ], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":equality_steps", + ":evaluator_core", + "//base:attributes", + "//common:value", + "//common:value_kind", + "//common:value_testing", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", ], ) @@ -259,16 +514,22 @@ cc_library( ], deps = [ ":attribute_trail", + ":comprehension_slots", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//base:status_macros", - "//eval/public:activation", - "//eval/public:cel_attribute", - "//eval/public:cel_function", - "//eval/public:cel_value", + "//base:attributes", + "//common:casting", + "//common:value", + "//common:value_kind", + "//eval/internal:errors", + "//internal:status_macros", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/status:statusor", ], ) @@ -279,19 +540,38 @@ cc_test( "comprehension_step_test.cc", ], deps = [ + ":attribute_trail", + ":cel_expression_flat_impl", + ":comprehension_slots", ":comprehension_step", + ":const_value_step", + ":direct_expression_step", ":evaluator_core", + ":expression_step_base", ":ident_step", - "//base:status_macros", + "//base:data", + "//common:expr", + "//common:value", + "//common:value_testing", + "//eval/public:activation", "//eval/public:cel_attribute", - "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", ], ) @@ -302,31 +582,25 @@ cc_test( "evaluator_core_test.cc", ], deps = [ - ":attribute_trail", + ":cel_expression_flat_impl", ":evaluator_core", - "//base:status_macros", - "//eval/compiler:flat_expr_builder", + "//base:data", + "//common:value", + "//eval/compiler:cel_expression_builder_flat_impl", + "//eval/internal:interop", + "//eval/public:activation", "//eval/public:builtin_func_registrar", - "//eval/public:cel_attribute", "//eval/public:cel_value", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "const_value_step_test", - size = "small", - srcs = [ - "const_value_step_test.cc", - ], - deps = [ - ":const_value_step", - ":evaluator_core", - "//base:status_macros", - "@com_google_absl//absl/status:statusor", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -337,16 +611,56 @@ cc_test( "container_access_step_test.cc", ], deps = [ + ":cel_expression_flat_impl", ":container_access_step", + ":direct_expression_step", + ":evaluator_core", ":ident_step", - "//base:status_macros", + "//base:builtins", + "//base:data", + "//common:ast", + "//common:expr", + "//eval/public:activation", "//eval/public:cel_attribute", - "//eval/public:cel_builtins", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", "//eval/public:cel_value", + "//eval/public:unknown_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", - "@com_google_googletest//:gtest_main", + "//eval/public/testing:matchers", + "//internal:testing", + "//parser", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_test( + name = "regex_match_step_test", + size = "small", + srcs = [ + "regex_match_step_test.cc", + ], + deps = [ + ":regex_match_step", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_options", + "//internal:testing", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -358,11 +672,26 @@ cc_test( "ident_step_test.cc", ], deps = [ + ":attribute_trail", + ":cel_expression_flat_impl", ":evaluator_core", ":ident_step", - "//base:status_macros", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "//base:data", + "//common:casting", + "//common:memory", + "//common:value", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:cel_value", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", ], ) @@ -373,22 +702,41 @@ cc_test( "function_step_test.cc", ], deps = [ + ":cel_expression_flat_impl", + ":const_value_step", + ":direct_expression_step", ":evaluator_core", - ":expression_build_warning", ":function_step", ":ident_step", - "//base:status_macros", + "//base:builtins", + "//base:data", + "//common:constant", + "//common:expr", + "//common:kind", + "//common:value", + "//eval/internal:interop", + "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_function", + "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_value", - "//eval/public:unknown_function_result_set", + "//eval/public:portable_cel_function_adapter", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", - "@com_google_absl//absl/memory", + "//internal:testing", + "//runtime:function_overload_reference", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime:standard_functions", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", ], ) @@ -399,12 +747,39 @@ cc_test( "logic_step_test.cc", ], deps = [ + ":attribute_trail", + ":cel_expression_flat_impl", + ":const_value_step", + ":direct_expression_step", + ":evaluator_core", ":ident_step", ":logic_step", - "//base:status_macros", + "//base:attributes", + "//base:data", + "//common:casting", + "//common:expr", + "//common:unknown", + "//common:value", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", - "@com_google_googletest//:gtest_main", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) @@ -415,18 +790,49 @@ cc_test( "select_step_test.cc", ], deps = [ + ":attribute_trail", + ":cel_expression_flat_impl", + ":const_value_step", + ":evaluator_core", ":ident_step", ":select_step", - "//base:status_macros", + "//base:attributes", + "//base:data", + "//common:casting", + "//common:expr", + "//common:legacy_value", + "//common:value", + "//common:value_testing", + "//eval/public:activation", "//eval/public:cel_attribute", - "//eval/public:unknown_attribute_set", + "//eval/public:cel_value", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:trivial_legacy_type_info", + "//eval/public/testing:matchers", + "//eval/testutil:test_extensions_cc_proto", "//eval/testutil:test_message_cc_proto", - "//testutil:util", + "//extensions/protobuf:value", + "//internal:proto_matchers", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", ], ) @@ -437,16 +843,40 @@ cc_test( "create_list_step_test.cc", ], deps = [ + ":attribute_trail", + ":cel_expression_flat_impl", ":const_value_step", ":create_list_step", + ":direct_expression_step", + ":evaluator_core", ":ident_step", - "//base:status_macros", + "//base:attributes", + "//base:data", + "//common:casting", + "//common:expr", + "//common:value", + "//common:value_testing", + "//eval/internal:interop", "//eval/public:activation", "//eval/public:cel_attribute", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", + "//eval/public/testing:matchers", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", ], ) @@ -457,44 +887,68 @@ cc_test( "create_struct_step_test.cc", ], deps = [ + ":cel_expression_flat_impl", ":create_struct_step", + ":direct_expression_step", + ":evaluator_core", ":ident_step", - "//base:status_macros", + "//base:data", + "//common:expr", + "//eval/public:activation", + "//eval/public:cel_type_registry", + "//eval/public:cel_value", + "//eval/public:unknown_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", - "//testutil:util", + "//internal:proto_matchers", + "//internal:status_macros", + "//internal:testing", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "expression_build_warning", - srcs = [ - "expression_build_warning.cc", - ], - hdrs = [ - "expression_build_warning.h", - ], - deps = [ - "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) cc_test( - name = "expression_build_warning_test", + name = "create_map_step_test", size = "small", srcs = [ - "expression_build_warning_test.cc", + "create_map_step_test.cc", ], deps = [ - ":expression_build_warning", + ":cel_expression_flat_impl", + ":create_map_step", + ":direct_expression_step", + ":evaluator_core", + ":ident_step", + "//base:data", + "//common:expr", + "//eval/public:activation", + "//eval/public:cel_value", + "//eval/public:unknown_set", + "//eval/testutil:test_message_cc_proto", + "//internal:status_macros", + "//internal:testing", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -503,16 +957,9 @@ cc_library( srcs = ["attribute_trail.cc"], hdrs = ["attribute_trail.h"], deps = [ - "//eval/public:activation", - "//eval/public:cel_attribute", - "//eval/public:cel_expression", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "//base:attributes", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", + "@com_google_absl//absl/utility", ], ) @@ -526,8 +973,8 @@ cc_test( ":attribute_trail", "//eval/public:cel_attribute", "//eval/public:cel_value", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -537,17 +984,21 @@ cc_library( hdrs = ["attribute_utility.h"], deps = [ ":attribute_trail", - "//eval/public:activation", - "//eval/public:cel_attribute", - "//eval/public:cel_expression", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", - "//eval/public:unknown_set", - "@com_google_absl//absl/status", + "//base:attributes", + "//base:function_result", + "//base:function_result_set", + "//base/internal:unknown_set", + "//common:casting", + "//common:function_descriptor", + "//common:unknown", + "//common:value", + "//eval/internal:errors", + "//internal:status_macros", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_protobuf//:protobuf", ], ) @@ -558,13 +1009,19 @@ cc_test( "attribute_utility_test.cc", ], deps = [ + ":attribute_trail", ":attribute_utility", + "//base:attributes", + "//common:unknown", + "//common:value", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "//internal:testing", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", ], ) @@ -577,16 +1034,16 @@ cc_library( "ternary_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_builtins", - "//eval/public:cel_function", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", + "//base:builtins", + "//common:value", + "//eval/internal:errors", + "//internal:status_macros", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) @@ -597,11 +1054,232 @@ cc_test( "ternary_step_test.cc", ], deps = [ + ":attribute_trail", + ":cel_expression_flat_impl", + ":const_value_step", + ":direct_expression_step", + ":evaluator_core", ":ident_step", ":ternary_step", - "//base:status_macros", + "//base:attributes", + "//base:data", + "//common:casting", + "//common:expr", + "//common:value", + "//eval/public:activation", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", - "@com_google_googletest//:gtest_main", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "shadowable_value_step", + srcs = ["shadowable_value_step.cc"], + hdrs = ["shadowable_value_step.h"], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + ":expression_step_base", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "shadowable_value_step_test", + size = "small", + srcs = ["shadowable_value_step_test.cc"], + deps = [ + ":cel_expression_flat_impl", + ":evaluator_core", + ":shadowable_value_step", + "//base:data", + "//common:value", + "//eval/internal:interop", + "//eval/public:activation", + "//eval/public:cel_value", + "//internal:status_macros", + "//internal:testing", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "compiler_constant_step", + srcs = ["compiler_constant_step.cc"], + hdrs = ["compiler_constant_step.h"], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + ":expression_step_base", + "//common:native_type", + "//common:value", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "compiler_constant_step_test", + srcs = ["compiler_constant_step_test.cc"], + deps = [ + ":compiler_constant_step", + ":evaluator_core", + "//common:native_type", + "//common:value", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_type_provider", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "lazy_init_step", + srcs = ["lazy_init_step.cc"], + hdrs = ["lazy_init_step.h"], + deps = [ + ":attribute_trail", + ":comprehension_slots", + ":direct_expression_step", + ":evaluator_core", + ":expression_step_base", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + ], +) + +cc_test( + name = "lazy_init_step_test", + srcs = ["lazy_init_step_test.cc"], + deps = [ + ":const_value_step", + ":evaluator_core", + ":lazy_init_step", + "//base:data", + "//common:value", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:runtime_type_provider", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "direct_expression_step", + srcs = ["direct_expression_step.cc"], + hdrs = ["direct_expression_step.h"], + deps = [ + ":attribute_trail", + ":evaluator_core", + "//common:native_type", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "trace_step", + hdrs = ["trace_step.h"], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + "//common:native_type", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "optional_or_step", + srcs = ["optional_or_step.cc"], + hdrs = ["optional_or_step.h"], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + ":expression_step_base", + ":jump_step", + "//common:casting", + "//common:value", + "//internal:status_macros", + "//runtime/internal:errors", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "optional_or_step_test", + srcs = ["optional_or_step_test.cc"], + deps = [ + ":attribute_trail", + ":const_value_step", + ":direct_expression_step", + ":evaluator_core", + ":optional_or_step", + "//common:casting", + "//common:value", + "//common:value_kind", + "//common:value_testing", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime/internal:errors", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "iterator_stack", + hdrs = ["iterator_stack.h"], + deps = [ + "//common:value", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", ], ) diff --git a/eval/eval/attribute_trail.cc b/eval/eval/attribute_trail.cc index b019f83d2..6b5db896e 100644 --- a/eval/eval/attribute_trail.cc +++ b/eval/eval/attribute_trail.cc @@ -1,25 +1,28 @@ #include "eval/eval/attribute_trail.h" -#include "absl/status/status.h" -#include "eval/public/cel_value.h" -#include "absl/status/statusor.h" +#include +#include +#include +#include +#include + +#include "base/attribute.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { // Creates AttributeTrail with attribute path incremented by "qualifier". -AttributeTrail AttributeTrail::Step(CelAttributeQualifier qualifier, - google::protobuf::Arena* arena) const { +AttributeTrail AttributeTrail::Step(cel::AttributeQualifier qualifier) const { // Cannot continue void trail if (empty()) return AttributeTrail(); - std::vector qualifiers = attribute_->qualifier_path(); - qualifiers.push_back(qualifier); - return AttributeTrail(google::protobuf::Arena::Create( - arena, attribute_->variable(), std::move(qualifiers))); + std::vector qualifiers; + qualifiers.reserve(attribute_->qualifier_path().size() + 1); + std::copy_n(attribute_->qualifier_path().begin(), + attribute_->qualifier_path().size(), + std::back_inserter(qualifiers)); + qualifiers.push_back(std::move(qualifier)); + return AttributeTrail(cel::Attribute(std::string(attribute_->variable_name()), + std::move(qualifiers))); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +} // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h index c2aefc6cb..576d0be34 100644 --- a/eval/eval/attribute_trail.h +++ b/eval/eval/attribute_trail.h @@ -2,62 +2,63 @@ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ #include -#include +#include -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/arena.h" #include "absl/types/optional.h" -#include "eval/public/activation.h" -#include "eval/public/cel_attribute.h" -#include "eval/public/cel_expression.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" +#include "absl/utility/utility.h" +#include "base/attribute.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // AttributeTrail reflects current attribute path. -// It is functionally similar to CelAttribute, yet intended to have better +// It is functionally similar to cel::Attribute, yet intended to have better // complexity on attribute path increment operations. // TODO(issues/41) Current AttributeTrail implementation is equivalent to -// CelAttribute - improve it. -// Intended to be used in conjunction with CelValue, describing the attribute +// cel::Attribute - improve it. +// Intended to be used in conjunction with cel::Value, describing the attribute // value originated from. // Empty AttributeTrail denotes object with attribute path not defined // or supported. class AttributeTrail { public: - AttributeTrail() : attribute_(nullptr) {} - AttributeTrail(google::api::expr::v1alpha1::Expr root, google::protobuf::Arena* arena) - : AttributeTrail(google::protobuf::Arena::Create( - arena, root, std::vector())) {} + AttributeTrail() : attribute_(absl::nullopt) {} + + explicit AttributeTrail(std::string variable_name) + : attribute_(absl::in_place, std::move(variable_name)) {} + + explicit AttributeTrail(cel::Attribute attribute) + : attribute_(std::move(attribute)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + AttributeTrail(absl::nullopt_t) : AttributeTrail() {} + + AttributeTrail(const AttributeTrail&) = default; + AttributeTrail& operator=(const AttributeTrail&) = default; + AttributeTrail(AttributeTrail&&) = default; + AttributeTrail& operator=(AttributeTrail&&) = default; + + AttributeTrail& operator=(absl::nullopt_t) { + attribute_.reset(); + return *this; + } // Creates AttributeTrail with attribute path incremented by "qualifier". - AttributeTrail Step(CelAttributeQualifier qualifier, - google::protobuf::Arena* arena) const; + AttributeTrail Step(cel::AttributeQualifier qualifier) const; // Creates AttributeTrail with attribute path incremented by "qualifier". - AttributeTrail Step(const std::string* qualifier, - google::protobuf::Arena* arena) const { - return Step( - CelAttributeQualifier::Create(CelValue::CreateString(qualifier)), - arena); + AttributeTrail Step(const std::string* qualifier) const { + return Step(cel::AttributeQualifier::OfString(*qualifier)); } // Returns CelAttribute that corresponds to content of AttributeTrail. - const CelAttribute* attribute() const { return attribute_; } + const cel::Attribute& attribute() const { return attribute_.value(); } - bool empty() const { return !attribute_; } + bool empty() const { return !attribute_.has_value(); } private: - AttributeTrail(const CelAttribute* attribute) : attribute_(attribute) {} - const CelAttribute* attribute_; + absl::optional attribute_; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ diff --git a/eval/eval/attribute_trail_test.cc b/eval/eval/attribute_trail_test.cc index d8a03d53e..3143b9ed4 100644 --- a/eval/eval/attribute_trail_test.cc +++ b/eval/eval/attribute_trail_test.cc @@ -1,43 +1,31 @@ #include "eval/eval/attribute_trail.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include + +#include "cel/expr/syntax.pb.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" +#include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { - -using google::api::expr::v1alpha1::Expr; +namespace google::api::expr::runtime { // Attribute Trail behavior TEST(AttributeTrailTest, AttributeTrailEmptyStep) { - google::protobuf::Arena arena; std::string step = "step"; CelValue step_value = CelValue::CreateString(&step); AttributeTrail trail; - ASSERT_TRUE(trail.Step(&step, &arena).empty()); - ASSERT_TRUE( - trail.Step(CelAttributeQualifier::Create(step_value), &arena).empty()); + ASSERT_TRUE(trail.Step(&step).empty()); + ASSERT_TRUE(trail.Step(CreateCelAttributeQualifier(step_value)).empty()); } TEST(AttributeTrailTest, AttributeTrailStep) { - google::protobuf::Arena arena; std::string step = "step"; CelValue step_value = CelValue::CreateString(&step); - Expr root; - root.mutable_ident_expr()->set_name("ident"); - AttributeTrail trail = AttributeTrail(root, &arena).Step(&step, &arena); - ASSERT_TRUE(trail.attribute() != nullptr); - ASSERT_EQ(*trail.attribute(), - CelAttribute(root, {CelAttributeQualifier::Create(step_value)})); + AttributeTrail trail = AttributeTrail("ident").Step(&step); + + ASSERT_EQ(trail.attribute(), + CelAttribute("ident", {CreateCelAttributeQualifier(step_value)})); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index 70fea5398..1e044627e 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -1,34 +1,88 @@ #include "eval/eval/attribute_utility.h" -#include "absl/status/status.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" -#include "eval/public/unknown_set.h" +#include +#include +#include + #include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "base/function_result.h" +#include "base/function_result_set.h" +#include "base/internal/unknown_set.h" +#include "common/casting.h" +#include "common/function_descriptor.h" +#include "common/unknown.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/internal/errors.h" +#include "internal/status_macros.h" +#include "runtime/internal/attribute_matcher.h" + +namespace google::api::expr::runtime { + +using ::cel::Attribute; +using ::cel::AttributePattern; +using ::cel::AttributeSet; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::FunctionResult; +using ::cel::FunctionResultSet; +using ::cel::InstanceOf; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::base_internal::UnknownSet; +using ::cel::runtime_internal::AttributeMatcher; + +using Accumulator = AttributeUtility::Accumulator; +using MatchResult = AttributeMatcher::MatchResult; -namespace google { -namespace api { -namespace expr { -namespace runtime { +DefaultAttributeMatcher::DefaultAttributeMatcher( + absl::Span unknown_patterns, + absl::Span missing_patterns) + : unknown_patterns_(unknown_patterns), + missing_patterns_(missing_patterns) {} -using google::protobuf::Arena; +DefaultAttributeMatcher::DefaultAttributeMatcher() = default; + +AttributeMatcher::MatchResult MatchAgainstPatterns( + absl::Span patterns, const Attribute& attr) { + MatchResult result = MatchResult::NONE; + for (const auto& pattern : patterns) { + auto current_match = pattern.IsMatch(attr); + if (current_match == cel::AttributePattern::MatchType::FULL) { + return MatchResult::FULL; + } + if (current_match == cel::AttributePattern::MatchType::PARTIAL) { + result = MatchResult::PARTIAL; + } + } + return result; +} + +DefaultAttributeMatcher::MatchResult DefaultAttributeMatcher::CheckForUnknown( + const Attribute& attr) const { + return MatchAgainstPatterns(unknown_patterns_, attr); +} + +DefaultAttributeMatcher::MatchResult DefaultAttributeMatcher::CheckForMissing( + const Attribute& attr) const { + return MatchAgainstPatterns(missing_patterns_, attr); +} bool AttributeUtility::CheckForMissingAttribute( const AttributeTrail& trail) const { if (trail.empty()) { return false; } - - for (const auto& pattern : *missing_attribute_patterns_) { - // (b/161297249) Preserving existing behavior for now, will add a streamz - // for partial match, follow up with tightening up which fields are exposed - // to the condition (w/ ajay and jim) - if (pattern.IsMatch(*trail.attribute()) == - CelAttributePattern::MatchType::FULL) { - return true; - } - } - return false; + // Missing attributes are only treated as errors if the attribute exactly + // matches (so no guard against passing partial state to a function as with + // unknowns). This was initially a design oversight, but is difficult to + // change now. + return matcher_->CheckForMissing(trail.attribute()) == + AttributeMatcher::MatchResult::FULL; } // Checks whether particular corresponds to any patterns that define unknowns. @@ -37,13 +91,11 @@ bool AttributeUtility::CheckForUnknown(const AttributeTrail& trail, if (trail.empty()) { return false; } - for (const auto& pattern : *unknown_patterns_) { - auto current_match = pattern.IsMatch(*trail.attribute()); - if (current_match == CelAttributePattern::MatchType::FULL || - (use_partial && - current_match == CelAttributePattern::MatchType::PARTIAL)) { - return true; - } + MatchResult result = matcher_->CheckForUnknown(trail.attribute()); + + if (result == MatchResult::FULL || + (use_partial && result == MatchResult::PARTIAL)) { + return true; } return false; } @@ -52,22 +104,45 @@ bool AttributeUtility::CheckForUnknown(const AttributeTrail& trail, // Scans over the args collection, merges any UnknownSets found in // it together with initial_set (if initial_set is not null). // Returns pointer to merged set or nullptr, if there were no sets to merge. -const UnknownSet* AttributeUtility::MergeUnknowns( - absl::Span args, const UnknownSet* initial_set) const { - const UnknownSet* result = initial_set; +absl::optional AttributeUtility::MergeUnknowns( + absl::Span args) const { + // Empty unknown value may be used as a sentinel in some tests so need to + // distinguish unset (nullopt) and empty(engaged empty value). + absl::optional result_set; for (const auto& value : args) { - if (!value.IsUnknownSet()) continue; - - auto current_set = value.UnknownSetOrDie(); - if (result == nullptr) { - result = current_set; - } else { - result = Arena::Create(arena_, *result, *current_set); + if (!value->Is()) continue; + if (!result_set.has_value()) { + result_set.emplace(); } + const auto& current_set = value.GetUnknown(); + + cel::base_internal::UnknownSetAccess::Add( + *result_set, UnknownSet(current_set.attribute_set(), + current_set.function_result_set())); } - return result; + if (!result_set.has_value()) { + return std::nullopt; + } + + return UnknownValue(cel::Unknown(result_set->unknown_attributes(), + result_set->unknown_function_results())); +} + +UnknownValue AttributeUtility::MergeUnknownValues( + const UnknownValue& left, const UnknownValue& right) const { + // Empty unknown value may be used as a sentinel in some tests so need to + // distinguish unset (nullopt) and empty(engaged empty value). + AttributeSet attributes; + FunctionResultSet function_results; + attributes.Add(left.attribute_set()); + function_results.Add(left.function_result_set()); + attributes.Add(right.attribute_set()); + function_results.Add(right.function_result_set()); + + return UnknownValue( + cel::Unknown(std::move(attributes), std::move(function_results))); } // Creates merged UnknownAttributeSet. @@ -75,17 +150,17 @@ const UnknownSet* AttributeUtility::MergeUnknowns( // patterns, merges attributes together with those from initial_set // (if initial_set is not null). // Returns pointer to merged set or nullptr, if there were no sets to merge. -UnknownAttributeSet AttributeUtility::CheckForUnknowns( +AttributeSet AttributeUtility::CheckForUnknowns( absl::Span args, bool use_partial) const { - std::vector unknown_attrs; + AttributeSet attribute_set; - for (auto trail : args) { + for (const auto& trail : args) { if (CheckForUnknown(trail, use_partial)) { - unknown_attrs.push_back(trail.attribute()); + attribute_set.Add(trail.attribute()); } } - return UnknownAttributeSet(unknown_attrs); + return attribute_set; } // Creates merged UnknownAttributeSet. @@ -94,21 +169,92 @@ UnknownAttributeSet AttributeUtility::CheckForUnknowns( // patterns, and attributes from initial_set // (if initial_set is not null). // Returns pointer to merged set or nullptr, if there were no sets to merge. -const UnknownSet* AttributeUtility::MergeUnknowns( - absl::Span args, absl::Span attrs, - const UnknownSet* initial_set, bool use_partial) const { - UnknownAttributeSet attr_set = CheckForUnknowns(attrs, use_partial); - if (!attr_set.attributes().empty()) { - if (initial_set != nullptr) { - initial_set = - Arena::Create(arena_, *initial_set, UnknownSet(attr_set)); - } else { - initial_set = Arena::Create(arena_, attr_set); - } +absl::optional AttributeUtility::IdentifyAndMergeUnknowns( + absl::Span args, absl::Span attrs, + bool use_partial) const { + absl::optional result_set; + + // Identify new unknowns by attribute patterns. + cel::AttributeSet attr_set = CheckForUnknowns(attrs, use_partial); + if (!attr_set.empty()) { + result_set.emplace(std::move(attr_set)); + } + + // merge down existing unknown sets + absl::optional arg_unknowns = MergeUnknowns(args); + + if (!result_set.has_value()) { + // No new unknowns so no need to check for presence of existing unknowns -- + // just forward. + return arg_unknowns; + } + + if (arg_unknowns.has_value()) { + cel::base_internal::UnknownSetAccess::Add( + *result_set, UnknownSet((*arg_unknowns).attribute_set(), + (*arg_unknowns).function_result_set())); + } + + return UnknownValue(cel::Unknown(result_set->unknown_attributes(), + result_set->unknown_function_results())); +} + +UnknownValue AttributeUtility::CreateUnknownSet(cel::Attribute attr) const { + return UnknownValue(cel::Unknown(AttributeSet({std::move(attr)}))); +} + +absl::StatusOr AttributeUtility::CreateMissingAttributeError( + const cel::Attribute& attr) const { + CEL_ASSIGN_OR_RETURN(std::string message, attr.AsString()); + return cel::ErrorValue( + cel::runtime_internal::CreateMissingAttributeError(message)); +} + +UnknownValue AttributeUtility::CreateUnknownSet( + const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, + absl::Span args) const { + return UnknownValue( + cel::Unknown(FunctionResultSet(FunctionResult(fn_descriptor, expr_id)))); +} + +void AttributeUtility::Add(Accumulator& a, const cel::UnknownValue& v) const { + a.attribute_set_.Add(v.attribute_set()); + a.function_result_set_.Add(v.function_result_set()); +} + +void AttributeUtility::Add(Accumulator& a, const AttributeTrail& attr) const { + a.attribute_set_.Add(attr.attribute()); +} + +void Accumulator::Add(const UnknownValue& value) { + unknown_present_ = true; + parent_.Add(*this, value); +} + +void Accumulator::Add(const AttributeTrail& attr) { parent_.Add(*this, attr); } + +void Accumulator::MaybeAdd(const Value& v) { + if (v.IsUnknown()) { + Add(v.GetUnknown()); } - return MergeUnknowns(args, initial_set); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +void Accumulator::MaybeAdd(const Value& v, const AttributeTrail& attr) { + if (v.IsUnknown()) { + Add(v.GetUnknown()); + } else if (parent_.CheckForUnknown(attr, /*use_partial=*/true)) { + Add(attr); + } +} + +bool Accumulator::IsEmpty() const { + return !unknown_present_ && attribute_set_.empty() && + function_result_set_.empty(); +} + +cel::UnknownValue Accumulator::Build() && { + return cel::UnknownValue( + cel::Unknown(std::move(attribute_set_), std::move(function_result_set_))); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index a49b282d9..94a5158f0 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -1,77 +1,179 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ -#include +#include -#include "google/protobuf/arena.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "base/function_result_set.h" +#include "common/function_descriptor.h" +#include "common/value.h" #include "eval/eval/attribute_trail.h" -#include "eval/public/activation.h" -#include "eval/public/cel_attribute.h" -#include "eval/public/cel_expression.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" -#include "eval/public/unknown_set.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "runtime/internal/attribute_matcher.h" + +namespace google::api::expr::runtime { + +// Default implementation of the attribute matcher. +// Scans the attribute trail against a list of unknown or missing patterns. +class DefaultAttributeMatcher : public cel::runtime_internal::AttributeMatcher { + private: + using MatchResult = cel::runtime_internal::AttributeMatcher::MatchResult; + + public: + DefaultAttributeMatcher( + absl::Span unknown_patterns, + absl::Span missing_patterns); + + DefaultAttributeMatcher(); + + MatchResult CheckForUnknown(const cel::Attribute& attr) const override; + MatchResult CheckForMissing(const cel::Attribute& attr) const override; + + private: + absl::Span unknown_patterns_; + absl::Span missing_patterns_; +}; // Helper class for handling unknowns and missing attribute logic. Provides // helpers for merging unknown sets from arguments on the stack and for // identifying unknown/missing attributes based on the patterns for a given // Evaluation. +// Neither moveable nor copyable. class AttributeUtility { public: - AttributeUtility( - const std::vector* unknown_patterns, - const std::vector* missing_attribute_patterns, - google::protobuf::Arena* arena) - : unknown_patterns_(unknown_patterns), - missing_attribute_patterns_(missing_attribute_patterns), - arena_(arena) {} + class Accumulator { + public: + Accumulator(const Accumulator&) = delete; + Accumulator& operator=(const Accumulator&) = delete; + Accumulator(Accumulator&&) = delete; + Accumulator& operator=(Accumulator&&) = delete; + + // Add to the accumulated unknown attributes and functions. + void Add(const cel::UnknownValue& v); + void Add(const AttributeTrail& attr); + + // Add to the accumulated set of unknowns if value is UnknownValue. + void MaybeAdd(const cel::Value& v); + + // Add to the accumulated set of unknowns if value is UnknownValue or + // the attribute trail is (partially) unknown. This version prefers + // preserving an already present unknown value over a new one matching the + // attribute trail. + // + // Uses partial matching (a pattern matches the attribute or any + // sub-attribute). + void MaybeAdd(const cel::Value& v, const AttributeTrail& attr); + + bool IsEmpty() const; + + cel::UnknownValue Build() &&; + + private: + explicit Accumulator(const AttributeUtility& parent) + : parent_(parent), unknown_present_(false) {} + + friend class AttributeUtility; + const AttributeUtility& parent_; + + cel::AttributeSet attribute_set_; + cel::FunctionResultSet function_result_set_; + + // Some tests will use an empty unknown set as a sentinel. + // Preserve forwarding behavior. + bool unknown_present_; + }; + + AttributeUtility(absl::Span unknown_patterns, + absl::Span missing_patterns) + : default_matcher_(unknown_patterns, missing_patterns), + matcher_(&default_matcher_) {} + + explicit AttributeUtility( + const cel::runtime_internal::AttributeMatcher* absl_nonnull matcher) + : matcher_(matcher) {} + + AttributeUtility(const AttributeUtility&) = delete; + AttributeUtility& operator=(const AttributeUtility&) = delete; + AttributeUtility(AttributeUtility&&) = delete; + AttributeUtility& operator=(AttributeUtility&&) = delete; // Checks whether particular corresponds to any patterns that define missing // attribute. bool CheckForMissingAttribute(const AttributeTrail& trail) const; - // Checks whether particular corresponds to any patterns that define unknowns. + // Checks whether trail corresponds to any patterns that define unknowns. bool CheckForUnknown(const AttributeTrail& trail, bool use_partial) const; + // Checks whether trail corresponds to any patterns that identify + // unknowns. Only matches exactly (exact attribute match for self or parent). + bool CheckForUnknownExact(const AttributeTrail& trail) const { + return CheckForUnknown(trail, false); + } + + // Checks whether trail corresponds to any patterns that define unknowns. + // Matches if a parent or any descendant (select or index of) the attribute. + bool CheckForUnknownPartial(const AttributeTrail& trail) const { + return CheckForUnknown(trail, true); + } + // Creates merged UnknownAttributeSet. // Scans over the args collection, determines if there matches to unknown // patterns and returns the (possibly empty) collection. - UnknownAttributeSet CheckForUnknowns(absl::Span args, - bool use_partial) const; - - // Creates merged UnknownSet. - // Scans over the args collection, merges any UnknownAttributeSets found in - // it together with initial_set (if initial_set is not null). - // Returns pointer to merged set or nullptr, if there were no sets to merge. - const UnknownSet* MergeUnknowns(absl::Span args, - const UnknownSet* initial_set) const; - - // Creates merged UnknownSet. - // Merges together attributes from UnknownSets found in the args - // collection, attributes from attr that match unknown pattern - // patterns, and attributes from initial_set - // (if initial_set is not null). - // Returns pointer to merged set or nullptr, if there were no sets to merge. - const UnknownSet* MergeUnknowns(absl::Span args, - absl::Span attrs, - const UnknownSet* initial_set, - bool use_partial) const; + cel::AttributeSet CheckForUnknowns(absl::Span args, + bool use_partial) const; + + // Creates merged UnknownValue. + // Scans over the args collection, merges any UnknownValues found. + // Returns the merged UnknownValue or nullopt if not found. + absl::optional MergeUnknowns( + absl::Span args) const; + + // Creates a merged UnknownValue from two unknown values. + cel::UnknownValue MergeUnknownValues(const cel::UnknownValue& left, + const cel::UnknownValue& right) const; + + // Creates merged UnknownValue. + // Merges together UnknownValues found in the args + // along with attributes from attr that match the configured unknown patterns + // Returns returns the merged UnknownValue if available or nullopt. + absl::optional IdentifyAndMergeUnknowns( + absl::Span args, absl::Span attrs, + bool use_partial) const; + + // Create an initial UnknownSet from a single attribute. + cel::UnknownValue CreateUnknownSet(cel::Attribute attr) const; + + // Factory function for missing attribute errors. + absl::StatusOr CreateMissingAttributeError( + const cel::Attribute& attr) const; + + // Create an initial UnknownSet from a single missing function call. + cel::UnknownValue CreateUnknownSet( + const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, + absl::Span args) const; + + Accumulator CreateAccumulator() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Accumulator(*this); + } + + void set_matcher( + const cel::runtime_internal::AttributeMatcher* absl_nonnull matcher) { + matcher_ = matcher; + } private: - const std::vector* unknown_patterns_; - const std::vector* missing_attribute_patterns_; - google::protobuf::Arena* arena_; + // Workaround friend visibility. + void Add(Accumulator& a, const cel::UnknownValue& v) const; + void Add(Accumulator& a, const AttributeTrail& attr) const; + + DefaultAttributeMatcher default_matcher_; + const cel::runtime_internal::AttributeMatcher* absl_nonnull matcher_; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ diff --git a/eval/eval/attribute_utility_test.cc b/eval/eval/attribute_utility_test.cc index b7a09d4a8..f3dbc0d06 100644 --- a/eval/eval/attribute_utility_test.cc +++ b/eval/eval/attribute_utility_test.cc @@ -1,31 +1,47 @@ #include "eval/eval/attribute_utility.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include + +#include "absl/types/span.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "common/unknown.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" +#include "internal/testing.h" +#include "runtime/internal/attribute_matcher.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { -using google::api::expr::v1alpha1::Expr; -using testing::Eq; -using testing::IsNull; -using testing::NotNull; -using testing::SizeIs; -using testing::UnorderedPointwise; +using ::cel::AttributeSet; -TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { - google::protobuf::Arena arena; +using ::cel::UnknownValue; +using ::cel::Value; +using ::testing::Eq; +using ::testing::SizeIs; +using ::testing::UnorderedPointwise; + +class AttributeUtilityTest : public ::testing::Test { + public: + AttributeUtilityTest() = default; + + protected: + google::protobuf::Arena arena_; +}; + +absl::Span NoPatterns() { return {}; } + +TEST_F(AttributeUtilityTest, UnknownsUtilityCheckUnknowns) { std::vector unknown_patterns = { - CelAttributePattern("unknown0", {CelAttributeQualifierPattern::Create( + CelAttributePattern("unknown0", {CreateCelAttributeQualifierPattern( CelValue::CreateInt64(1))}), - CelAttributePattern("unknown0", {CelAttributeQualifierPattern::Create( + CelAttributePattern("unknown0", {CreateCelAttributeQualifierPattern( CelValue::CreateInt64(2))}), CelAttributePattern("unknown1", {}), CelAttributePattern("unknown2", {}), @@ -33,16 +49,12 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { std::vector missing_attribute_patterns; - AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, - &arena); + AttributeUtility utility(unknown_patterns, missing_attribute_patterns); // no match for void trail ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), true)); ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), false)); - google::api::expr::v1alpha1::Expr unknown_expr0; - unknown_expr0.mutable_ident_expr()->set_name("unknown0"); - - AttributeTrail unknown_trail0(unknown_expr0, &arena); + AttributeTrail unknown_trail0("unknown0"); { ASSERT_FALSE(utility.CheckForUnknown(unknown_trail0, false)); } @@ -51,68 +63,48 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { { ASSERT_TRUE(utility.CheckForUnknown( unknown_trail0.Step( - CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + CreateCelAttributeQualifier(CelValue::CreateInt64(1))), false)); } { ASSERT_TRUE(utility.CheckForUnknown( unknown_trail0.Step( - CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + CreateCelAttributeQualifier(CelValue::CreateInt64(1))), true)); } } -TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { - google::protobuf::Arena arena; +TEST_F(AttributeUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { + std::vector unknown_patterns; - google::api::expr::v1alpha1::Expr unknown_expr0; - unknown_expr0.mutable_ident_expr()->set_name("unknown0"); + std::vector missing_attribute_patterns; - google::api::expr::v1alpha1::Expr unknown_expr1; - unknown_expr1.mutable_ident_expr()->set_name("unknown1"); + CelAttribute attribute0("unknown0", {}); + CelAttribute attribute1("unknown1", {}); - google::api::expr::v1alpha1::Expr unknown_expr2; - unknown_expr2.mutable_ident_expr()->set_name("unknown2"); + AttributeUtility utility(unknown_patterns, missing_attribute_patterns); - std::vector unknown_patterns; + UnknownValue unknown_set0 = + cel::UnknownValue(cel::Unknown(AttributeSet({attribute0}))); + UnknownValue unknown_set1 = + cel::UnknownValue(cel::Unknown(AttributeSet({attribute1}))); - std::vector missing_attribute_patterns; - - CelAttribute attribute0(unknown_expr0, {}); - CelAttribute attribute1(unknown_expr1, {}); - CelAttribute attribute2(unknown_expr2, {}); - - AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, - &arena); - - UnknownSet unknown_set0(UnknownAttributeSet({&attribute0})); - UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); - UnknownSet unknown_set2(UnknownAttributeSet({&attribute1, &attribute2})); - std::vector values = { - CelValue::CreateUnknownSet(&unknown_set0), - CelValue::CreateUnknownSet(&unknown_set1), - CelValue::CreateBool(true), - CelValue::CreateInt64(1), + std::vector values = { + unknown_set0, + unknown_set1, + cel::BoolValue(true), + cel::IntValue(1), }; - const UnknownSet* unknown_set = utility.MergeUnknowns(values, nullptr); - ASSERT_THAT(unknown_set, NotNull()); - ASSERT_THAT(unknown_set->unknown_attributes().attributes(), - UnorderedPointwise(Eq(), std::vector{ - &attribute0, &attribute1})); - - unknown_set = utility.MergeUnknowns(values, &unknown_set2); - ASSERT_THAT(unknown_set, NotNull()); - ASSERT_THAT( - unknown_set->unknown_attributes().attributes(), - UnorderedPointwise(Eq(), std::vector{ - &attribute0, &attribute1, &attribute2})); + absl::optional unknown_set = utility.MergeUnknowns(values); + ASSERT_TRUE(unknown_set.has_value()); + EXPECT_THAT((*unknown_set).attribute_set(), + UnorderedPointwise( + Eq(), std::vector{attribute0, attribute1})); } -TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { - google::protobuf::Arena arena; - +TEST_F(AttributeUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { std::vector unknown_patterns = { CelAttributePattern("unknown0", {CelAttributeQualifierPattern::CreateWildcard()}), @@ -120,68 +112,105 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { std::vector missing_attribute_patterns; - google::api::expr::v1alpha1::Expr unknown_expr0; - unknown_expr0.mutable_ident_expr()->set_name("unknown0"); - - google::api::expr::v1alpha1::Expr unknown_expr1; - unknown_expr1.mutable_ident_expr()->set_name("unknown1"); + AttributeTrail trail0("unknown0"); + AttributeTrail trail1("unknown1"); - AttributeTrail trail0(unknown_expr0, &arena); - AttributeTrail trail1(unknown_expr1, &arena); + CelAttribute attribute1("unknown1", {}); + UnknownSet unknown_set1(UnknownAttributeSet({attribute1})); - CelAttribute attribute1(unknown_expr1, {}); - UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); - - AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, - &arena); + AttributeUtility utility(unknown_patterns, missing_attribute_patterns); UnknownSet unknown_attr_set(utility.CheckForUnknowns( { AttributeTrail(), // To make sure we handle empty trail gracefully. - trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - &arena), - trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(2)), - &arena), + trail0.Step(CreateCelAttributeQualifier(CelValue::CreateInt64(1))), + trail0.Step(CreateCelAttributeQualifier(CelValue::CreateInt64(2))), }, false)); UnknownSet unknown_set(unknown_set1, unknown_attr_set); - ASSERT_THAT(unknown_set.unknown_attributes().attributes(), SizeIs(3)); + ASSERT_THAT(unknown_set.unknown_attributes(), SizeIs(3)); } -TEST(UnknownsUtilityTest, UnknownsUtilityCheckForMissingAttributes) { - google::protobuf::Arena arena; - +TEST_F(AttributeUtilityTest, UnknownsUtilityCheckForMissingAttributes) { std::vector unknown_patterns; std::vector missing_attribute_patterns; - Expr expr; - auto* select_expr = expr.mutable_select_expr(); - select_expr->set_field("ip"); - - Expr* ident_expr = select_expr->mutable_operand(); - ident_expr->mutable_ident_expr()->set_name("destination"); + AttributeTrail trail("destination"); + trail = + trail.Step(CreateCelAttributeQualifier(CelValue::CreateStringView("ip"))); - AttributeTrail trail(*ident_expr, &arena); - trail = trail.Step( - CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), &arena); - - AttributeUtility utility0(&unknown_patterns, &missing_attribute_patterns, - &arena); + AttributeUtility utility0(unknown_patterns, missing_attribute_patterns); EXPECT_FALSE(utility0.CheckForMissingAttribute(trail)); missing_attribute_patterns.push_back(CelAttributePattern( - "destination", {CelAttributeQualifierPattern::Create( - CelValue::CreateStringView("ip"))})); + "destination", + {CreateCelAttributeQualifierPattern(CelValue::CreateStringView("ip"))})); - AttributeUtility utility1(&unknown_patterns, &missing_attribute_patterns, - &arena); + AttributeUtility utility1(unknown_patterns, missing_attribute_patterns); EXPECT_TRUE(utility1.CheckForMissingAttribute(trail)); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +TEST_F(AttributeUtilityTest, CreateUnknownSet) { + AttributeTrail trail("destination"); + trail = + trail.Step(CreateCelAttributeQualifier(CelValue::CreateStringView("ip"))); + + std::vector empty_patterns; + AttributeUtility utility(empty_patterns, empty_patterns); + + UnknownValue set = utility.CreateUnknownSet(trail.attribute()); + ASSERT_THAT(set.attribute_set(), SizeIs(1)); + ASSERT_OK_AND_ASSIGN(auto elem, set.attribute_set().begin()->AsString()); + EXPECT_EQ(elem, "destination.ip"); +} + +class FakeMatcher : public cel::runtime_internal::AttributeMatcher { + private: + using MatchResult = cel::runtime_internal::AttributeMatcher::MatchResult; + + public: + MatchResult CheckForUnknown(const cel::Attribute& attr) const override { + std::string attr_str = attr.AsString().value_or(""); + if (attr_str == "device.foo") { + return MatchResult::FULL; + } else if (attr_str == "device") { + return MatchResult::PARTIAL; + } + return MatchResult::NONE; + } + + MatchResult CheckForMissing(const cel::Attribute& attr) const override { + std::string attr_str = attr.AsString().value_or(""); + + if (attr_str == "device2.foo") { + return MatchResult::FULL; + } else if (attr_str == "device2") { + return MatchResult::PARTIAL; + } + return MatchResult::NONE; + } +}; + +TEST_F(AttributeUtilityTest, CustomMatcher) { + AttributeTrail trail("device"); + + AttributeUtility utility(NoPatterns(), NoPatterns()); + FakeMatcher matcher; + utility.set_matcher(&matcher); + EXPECT_TRUE(utility.CheckForUnknownPartial(trail)); + EXPECT_FALSE(utility.CheckForUnknownExact(trail)); + + trail = trail.Step(cel::AttributeQualifier::OfString("foo")); + EXPECT_TRUE(utility.CheckForUnknownExact(trail)); + EXPECT_TRUE(utility.CheckForUnknownPartial(trail)); + + trail = AttributeTrail("device2"); + EXPECT_FALSE(utility.CheckForMissingAttribute(trail)); + trail = trail.Step(cel::AttributeQualifier::OfString("foo")); + EXPECT_TRUE(utility.CheckForMissingAttribute(trail)); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/cel_expression_flat_impl.cc b/eval/eval/cel_expression_flat_impl.cc new file mode 100644 index 000000000..9e35b41ad --- /dev/null +++ b/eval/eval/cel_expression_flat_impl.cc @@ -0,0 +1,147 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/cel_expression_flat_impl.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/internal/adapter_activation_impl.h" +#include "eval/internal/interop.h" +#include "eval/public/base_activation.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "runtime/internal/runtime_env.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::Value; +using ::cel::runtime_internal::RuntimeEnv; + +EvaluationListener AdaptListener(const CelEvaluationListener& listener) { + if (!listener) return nullptr; + return [&](int64_t expr_id, const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) -> absl::Status { + if (value->Is()) { + // Opaque types are used to implement some optimized operations. + // These aren't representable as legacy values and shouldn't be + // inspectable by clients. + return absl::OkStatus(); + } + CelValue legacy_value = + cel::interop_internal::ModernValueToLegacyValueOrDie(arena, value); + return listener(expr_id, legacy_value, arena); + }; +} +} // namespace + +CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + const FlatExpression& expression) + : state_(expression.MakeEvaluatorState(descriptor_pool, message_factory, + arena)) {} + +absl::StatusOr CelExpressionFlatImpl::Trace( + const BaseActivation& activation, CelEvaluationState* _state, + CelEvaluationListener callback) const { + auto state = + ::cel::internal::down_cast(_state); + state->state().Reset(); + cel::interop_internal::AdapterActivationImpl modern_activation(activation); + + CEL_ASSIGN_OR_RETURN(cel::Value value, + flat_expression_.EvaluateWithCallback( + modern_activation, + /*embedder_context=*/nullptr, + AdaptListener(callback), state->state())); + + return cel::interop_internal::ModernValueToLegacyValueOrDie(state->arena(), + value); +} + +std::unique_ptr CelExpressionFlatImpl::InitializeState( + google::protobuf::Arena* arena) const { + return std::make_unique( + arena, env_->descriptor_pool.get(), env_->MutableMessageFactory(), + flat_expression_); +} + +absl::StatusOr CelExpressionFlatImpl::Evaluate( + const BaseActivation& activation, CelEvaluationState* state) const { + return Trace(activation, state, CelEvaluationListener()); +} + +absl::StatusOr> +CelExpressionRecursiveImpl::Create( + absl_nonnull std::shared_ptr env, + FlatExpression flat_expr) { + if (flat_expr.path().empty() || + flat_expr.path().front()->GetNativeTypeId() != + cel::NativeTypeId::For()) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected a recursive program step", flat_expr.path().size())); + } + + auto* instance = + new CelExpressionRecursiveImpl(std::move(env), std::move(flat_expr)); + + return absl::WrapUnique(instance); +} + +absl::StatusOr CelExpressionRecursiveImpl::Trace( + const BaseActivation& activation, google::protobuf::Arena* arena, + CelEvaluationListener callback) const { + cel::interop_internal::AdapterActivationImpl modern_activation(activation); + ComprehensionSlots slots(flat_expression_.comprehension_slots_size()); + ExecutionFrameBase execution_frame( + modern_activation, AdaptListener(callback), flat_expression_.options(), + flat_expression_.type_provider(), env_->descriptor_pool.get(), + env_->MutableMessageFactory(), arena, + /*embedder_context=*/nullptr, slots); + + cel::Value result; + AttributeTrail trail; + CEL_RETURN_IF_ERROR(root_->Evaluate(execution_frame, result, trail)); + + return cel::interop_internal::ModernValueToLegacyValueOrDie(arena, result); +} + +absl::StatusOr CelExpressionRecursiveImpl::Evaluate( + const BaseActivation& activation, google::protobuf::Arena* arena) const { + return Trace(activation, arena, /*callback=*/nullptr); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/cel_expression_flat_impl.h b/eval/eval/cel_expression_flat_impl.h new file mode 100644 index 000000000..7faf6856a --- /dev/null +++ b/eval/eval/cel_expression_flat_impl.h @@ -0,0 +1,175 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CEL_EXPRESSION_FLAT_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CEL_EXPRESSION_FLAT_IMPL_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/base_activation.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "internal/casts.h" +#include "runtime/internal/runtime_env.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +// Wrapper for FlatExpressionEvaluationState used to implement CelExpression. +class CelExpressionFlatEvaluationState : public CelEvaluationState { + public: + CelExpressionFlatEvaluationState( + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + const FlatExpression& expr); + + google::protobuf::Arena* arena() { return state_.arena(); } + FlatExpressionEvaluatorState& state() { return state_; } + + private: + FlatExpressionEvaluatorState state_; +}; + +// Implementation of the CelExpression that evaluates a flattened representation +// of the AST. +// +// This class adapts FlatExpression to implement the CelExpression interface. +class CelExpressionFlatImpl : public CelExpression { + public: + CelExpressionFlatImpl( + absl_nonnull std::shared_ptr env, + FlatExpression flat_expression) + : env_(std::move(env)), flat_expression_(std::move(flat_expression)) {} + + // Move-only + CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; + CelExpressionFlatImpl& operator=(const CelExpressionFlatImpl&) = delete; + CelExpressionFlatImpl(CelExpressionFlatImpl&&) = default; + CelExpressionFlatImpl& operator=(CelExpressionFlatImpl&&) = delete; + + // Implement CelExpression. + std::unique_ptr InitializeState( + google::protobuf::Arena* arena) const override; + + absl::StatusOr Evaluate(const BaseActivation& activation, + google::protobuf::Arena* arena) const override { + return Evaluate(activation, InitializeState(arena).get()); + } + + absl::StatusOr Evaluate(const BaseActivation& activation, + CelEvaluationState* state) const override; + absl::StatusOr Trace( + const BaseActivation& activation, google::protobuf::Arena* arena, + CelEvaluationListener callback) const override { + return Trace(activation, InitializeState(arena).get(), callback); + } + + absl::StatusOr Trace(const BaseActivation& activation, + CelEvaluationState* state, + CelEvaluationListener callback) const override; + + // Exposed for inspection in tests. + const FlatExpression& flat_expression() const { return flat_expression_; } + + private: + absl_nonnull std::shared_ptr env_; + FlatExpression flat_expression_; +}; + +// Implementation of the CelExpression that evaluates a recursive representation +// of the AST. +// +// This class adapts FlatExpression to implement the CelExpression interface. +// +// Assumes that the flat expression is wrapping a simple recursive program. +class CelExpressionRecursiveImpl : public CelExpression { + private: + class EvaluationState : public CelEvaluationState { + public: + explicit EvaluationState(google::protobuf::Arena* arena) : arena_(arena) {} + google::protobuf::Arena* arena() { return arena_; } + + private: + google::protobuf::Arena* arena_; + }; + + public: + static absl::StatusOr> Create( + absl_nonnull std::shared_ptr env, + FlatExpression flat_expression); + + // Move-only + CelExpressionRecursiveImpl(const CelExpressionRecursiveImpl&) = delete; + CelExpressionRecursiveImpl& operator=(const CelExpressionRecursiveImpl&) = + delete; + CelExpressionRecursiveImpl(CelExpressionRecursiveImpl&&) = default; + CelExpressionRecursiveImpl& operator=(CelExpressionRecursiveImpl&&) = delete; + + // Implement CelExpression. + std::unique_ptr InitializeState( + google::protobuf::Arena* arena) const override { + return std::make_unique(arena); + } + + absl::StatusOr Evaluate(const BaseActivation& activation, + google::protobuf::Arena* arena) const override; + + absl::StatusOr Evaluate(const BaseActivation& activation, + CelEvaluationState* state) const override { + auto* state_impl = cel::internal::down_cast(state); + return Evaluate(activation, state_impl->arena()); + } + + absl::StatusOr Trace(const BaseActivation& activation, + google::protobuf::Arena* arena, + CelEvaluationListener callback) const override; + + absl::StatusOr Trace( + const BaseActivation& activation, CelEvaluationState* state, + CelEvaluationListener callback) const override { + auto* state_impl = cel::internal::down_cast(state); + return Trace(activation, state_impl->arena(), callback); + } + + // Exposed for inspection in tests. + const FlatExpression& flat_expression() const { return flat_expression_; } + + const DirectExpressionStep* root() const { return root_; } + + private: + explicit CelExpressionRecursiveImpl( + absl_nonnull std::shared_ptr env, + FlatExpression flat_expression) + : env_(std::move(env)), + flat_expression_(std::move(flat_expression)), + root_(cel::internal::down_cast( + flat_expression_.path()[0].get()) + ->wrapped()) {} + + absl_nonnull std::shared_ptr env_; + FlatExpression flat_expression_; + const DirectExpressionStep* root_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CEL_EXPRESSION_FLAT_IMPL_H_ diff --git a/eval/eval/compiler_constant_step.cc b/eval/eval/compiler_constant_step.cc new file mode 100644 index 000000000..44a03cecd --- /dev/null +++ b/eval/eval/compiler_constant_step.cc @@ -0,0 +1,37 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/eval/compiler_constant_step.h" + +#include "absl/status/status.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +using ::cel::Value; + +absl::Status DirectCompilerConstantStep::Evaluate( + ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { + result = value_; + return absl::OkStatus(); +} + +absl::Status CompilerConstantStep::Evaluate(ExecutionFrame* frame) const { + frame->value_stack().Push(value_); + + return absl::OkStatus(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/compiler_constant_step.h b/eval/eval/compiler_constant_step.h new file mode 100644 index 000000000..bd514a036 --- /dev/null +++ b/eval/eval/compiler_constant_step.h @@ -0,0 +1,76 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ + +#include +#include + +#include "absl/status/status.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" + +namespace google::api::expr::runtime { + +// DirectExpressionStep implementation that simply assigns a constant value. +// +// Overrides NativeTypeId() allow the FlatExprBuilder and extensions to +// inspect the underlying value. +class DirectCompilerConstantStep : public DirectExpressionStep { + public: + DirectCompilerConstantStep(cel::Value value, int64_t expr_id) + : DirectExpressionStep(expr_id), value_(std::move(value)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const override; + + cel::NativeTypeId GetNativeTypeId() const override { + return cel::NativeTypeId::For(); + } + + const cel::Value& value() const { return value_; } + + private: + cel::Value value_; +}; + +// ExpressionStep implementation that simply pushes a constant value on the +// stack. +// +// Overrides NativeTypeId ()o allow the FlatExprBuilder and extensions to +// inspect the underlying value. +class CompilerConstantStep : public ExpressionStepBase { + public: + CompilerConstantStep(cel::Value value, int64_t expr_id, bool comes_from_ast) + : ExpressionStepBase(expr_id, comes_from_ast), value_(std::move(value)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + cel::NativeTypeId GetNativeTypeId() const override { + return cel::NativeTypeId::For(); + } + + const cel::Value& value() const { return value_; } + + private: + cel::Value value_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ diff --git a/eval/eval/compiler_constant_step_test.cc b/eval/eval/compiler_constant_step_test.cc new file mode 100644 index 000000000..856ca30e0 --- /dev/null +++ b/eval/eval/compiler_constant_step_test.cc @@ -0,0 +1,75 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/eval/compiler_constant_step.h" + +#include + +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/evaluator_core.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { + +namespace { + +class CompilerConstantStepTest : public testing::Test { + public: + CompilerConstantStepTest() + : type_provider_(cel::internal::GetTestingDescriptorPool()), + state_(2, 0, type_provider_, cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_) {} + + protected: + google::protobuf::Arena arena_; + cel::runtime_internal::RuntimeTypeProvider type_provider_; + FlatExpressionEvaluatorState state_; + cel::Activation empty_activation_; + cel::RuntimeOptions options_; +}; + +TEST_F(CompilerConstantStepTest, Evaluate) { + ExecutionPath path; + path.push_back( + std::make_unique(cel::IntValue(42), -1, false)); + + ExecutionFrame frame(path, empty_activation_, options_, state_); + + ASSERT_OK_AND_ASSIGN(cel::Value result, frame.Evaluate()); + + EXPECT_EQ(result.GetInt().NativeValue(), 42); +} + +TEST_F(CompilerConstantStepTest, TypeId) { + CompilerConstantStep step(cel::IntValue(42), -1, false); + + ExpressionStep& abstract_step = step; + EXPECT_EQ(abstract_step.GetNativeTypeId(), + cel::NativeTypeId::For()); +} + +TEST_F(CompilerConstantStepTest, Value) { + CompilerConstantStep step(cel::IntValue(42), -1, false); + + EXPECT_EQ(step.value().GetInt().NativeValue(), 42); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/comprehension_slots.h b/eval/eval/comprehension_slots.h new file mode 100644 index 000000000..795cca7f7 --- /dev/null +++ b/eval/eval/comprehension_slots.h @@ -0,0 +1,153 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_SLOTS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_SLOTS_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/container/fixed_array.h" +#include "absl/log/absl_check.h" +#include "absl/types/optional.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" + +namespace google::api::expr::runtime { + +class ComprehensionSlot final { + public: + ComprehensionSlot() = default; + ComprehensionSlot(const ComprehensionSlot&) = delete; + ComprehensionSlot(ComprehensionSlot&&) = delete; + ComprehensionSlot& operator=(const ComprehensionSlot&) = delete; + ComprehensionSlot& operator=(ComprehensionSlot&&) = delete; + + const cel::Value& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Has()); + + return value_; + } + + cel::Value* absl_nonnull mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Has()); + + return &value_; + } + + const AttributeTrail& attribute() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Has()); + + return attribute_; + } + + AttributeTrail* absl_nonnull mutable_attribute() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(Has()); + + return &attribute_; + } + + bool Has() const { return has_; } + + void Set() { Set(cel::NullValue(), absl::nullopt); } + + template + void Set(V&& value) { + Set(std::forward(value), absl::nullopt); + } + + template + void Set(V&& value, A&& attribute) { + value_ = std::forward(value); + attribute_ = std::forward(attribute); + has_ = true; + } + + void Clear() { + if (has_) { + value_ = cel::NullValue(); + attribute_ = absl::nullopt; + has_ = false; + } + } + + private: + cel::Value value_; + AttributeTrail attribute_; + bool has_ = false; +}; + +// Simple manager for comprehension variables. +// +// At plan time, each comprehension variable is assigned a slot by index. +// This is used instead of looking up the variable identifier by name in a +// runtime stack. +// +// Callers must handle range checking. +class ComprehensionSlots final { + public: + using Slot = ComprehensionSlot; + + // Trivial instance if no slots are needed. + // Trivially thread safe since no effective state. + static ComprehensionSlots& GetEmptyInstance() { + static absl::NoDestructor instance(0); + return *instance; + } + + explicit ComprehensionSlots(size_t size) : slots_(size) {} + + ComprehensionSlots(const ComprehensionSlots&) = delete; + ComprehensionSlots& operator=(const ComprehensionSlots&) = delete; + + ComprehensionSlots(ComprehensionSlots&&) = delete; + ComprehensionSlots& operator=(ComprehensionSlots&&) = delete; + + Slot* absl_nonnull Get(size_t index) ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK_LT(index, size()); + + return &slots_[index]; + } + + void Reset() { + for (Slot& slot : slots_) { + slot.Clear(); + } + } + + void ClearSlot(size_t index) { Get(index)->Clear(); } + + template + void Set(size_t index, V&& value) { + Set(index, std::forward(value), absl::nullopt); + } + + template + void Set(size_t index, V&& value, A&& attribute) { + Get(index)->Set(std::forward(value), std::forward(attribute)); + } + + size_t size() const { return slots_.size(); } + + private: + absl::FixedArray slots_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_SLOTS_H_ diff --git a/eval/eval/comprehension_slots_test.cc b/eval/eval/comprehension_slots_test.cc new file mode 100644 index 000000000..5f869d7cb --- /dev/null +++ b/eval/eval/comprehension_slots_test.cc @@ -0,0 +1,75 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/comprehension_slots.h" + +#include "base/attribute.h" +#include "base/type_provider.h" +#include "common/memory.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { + +using ::cel::Attribute; + +using ::absl_testing::IsOkAndHolds; +using ::cel::MemoryManagerRef; +using ::cel::StringValue; +using ::cel::TypeProvider; +using ::cel::Value; +using ::testing::Truly; + +TEST(ComprehensionSlots, Basic) { + ComprehensionSlots slots(4); + + ComprehensionSlots::Slot* slot0 = slots.Get(0); + EXPECT_FALSE(slot0->Has()); + + slots.Set(0, cel::StringValue("abcd"), + AttributeTrail(Attribute("fake_attr"))); + + ASSERT_TRUE(slot0->Has()); + + EXPECT_THAT(slot0->value(), Truly([](const Value& v) { + return v.Is() && + v.GetString().ToString() == "abcd"; + })) + << "value is 'abcd'"; + + EXPECT_THAT(slot0->attribute().attribute().AsString(), + IsOkAndHolds("fake_attr")); + + slots.ClearSlot(0); + EXPECT_FALSE(slot0->Has()); + + slots.Set(3, cel::StringValue("abcd"), + AttributeTrail(Attribute("fake_attr"))); + + auto* slot3 = slots.Get(3); + + ASSERT_TRUE(slot3->Has()); + EXPECT_THAT(slot3->value(), Truly([](const Value& v) { + return v.Is() && + v.GetString().ToString() == "abcd"; + })) + << "value is 'abcd'"; + + slots.Reset(); + EXPECT_FALSE(slot0->Has()); + EXPECT_FALSE(slot3->Has()); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index a42cf822a..5e741d805 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -1,257 +1,685 @@ #include "eval/eval/comprehension_step.h" +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" +#include "absl/status/statusor.h" +#include "base/attribute.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_kind.h" #include "eval/eval/attribute_trail.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/cel_attribute.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// Stack variables during comprehension evaluation: -// 0. accu_init, then loop_step (any), available through accu_var -// 1. iter_range (list) -// 2. current index in iter_range (int64_t) -// 3. current_value from iter_range (any), available through iter_var -// 4. loop_condition (bool) OR loop_step (any) - -// What to put on ExecutionPath: stack size -// 0. (dummy) 1 -// 1. iter_range (dep) 2 -// 2. -1 3 -// 3. (dummy) 4 -// 4. accu_init (dep) 5 -// 5. ComprehensionNextStep 4 -// 6. loop_condition (dep) 5 -// 7. ComprehensionCondStep 4 -// 8. loop_step (dep) 5 -// 9. goto 5. 5 -// 10. result (dep) 2 -// 11. ComprehensionFinish 1 - -ComprehensionNextStep::ComprehensionNextStep(const std::string& accu_var, - const std::string& iter_var, - int64_t expr_id) - : ExpressionStepBase(expr_id, false), - accu_var_(accu_var), - iter_var_(iter_var) {} - -void ComprehensionNextStep::set_jump_offset(int offset) { - jump_offset_ = offset; -} +#include "eval/eval/expression_step_base.h" +#include "eval/internal/errors.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { +namespace { + +enum class IterableKind { + kList = 1, + kMap, +}; -void ComprehensionNextStep::set_error_jump_offset(int offset) { - error_jump_offset_ = offset; +using ::cel::AttributeQualifier; +using ::cel::Cast; +using ::cel::InstanceOf; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::ValueIterator; +using ::cel::ValueIteratorPtr; +using ::cel::ValueKind; +using ::cel::runtime_internal::CreateNoMatchingOverloadError; + +AttributeQualifier AttributeQualifierFromValue(const Value& v) { + switch (v.kind()) { + case ValueKind::kString: + return AttributeQualifier::OfString(v.GetString().ToString()); + case ValueKind::kInt64: + return AttributeQualifier::OfInt(v.GetInt().NativeValue()); + case ValueKind::kUint64: + return AttributeQualifier::OfUint(v.GetUint().NativeValue()); + case ValueKind::kBool: + return AttributeQualifier::OfBool(v.GetBool().NativeValue()); + default: + // Non-matching qualifier. + return AttributeQualifier(); + } } -// Stack changes of ComprehensionNextStep. -// -// Stack before: -// 0. previous accu_init or "" on the first iteration -// 1. iter_range (list) -// 2. old current_index in iter_range (int64_t) -// 3. old current_value or "" on the first iteration -// 4. loop_step or accu_init (any) -// -// Stack after: -// 0. loop_step or accu_init (any) -// 1. iter_range (list) -// 2. new current_index in iter_range (int64_t) -// 3. new current_value -// -// Stack on break: -// 0. loop_step or accu_init (any) -// -// When iter_range is not a list, this step jumps to error_jump_offset_ that is -// controlled by set_error_jump_offset. In that case the stack is cleared -// from values related to this comprehension and an error is put on the stack. -// -// Stack on error: -// 0. error -absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { - enum { - POS_PREVIOUS_LOOP_STEP, - POS_ITER_RANGE, - POS_CURRENT_INDEX, - POS_CURRENT_VALUE, - POS_LOOP_STEP, - }; - if (!frame->value_stack().HasEnough(5)) { - return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); +class ComprehensionFinishStep final : public ExpressionStepBase { + public: + ComprehensionFinishStep(size_t accu_slot, int64_t expr_id) + : ExpressionStepBase(expr_id), accu_slot_(accu_slot) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(2)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + frame->value_stack().SwapAndPop(2, 1); + frame->comprehension_slots().ClearSlot(accu_slot_); + frame->iterator_stack().Pop(); + return absl::OkStatus(); } - auto state = frame->value_stack().GetSpan(5); - auto attr = frame->value_stack().GetAttributeSpan(5); - // Get range from the stack. - CelValue iter_range = state[POS_ITER_RANGE]; - if (!iter_range.IsList()) { - frame->value_stack().Pop(5); - if (iter_range.IsError() || iter_range.IsUnknownSet()) { - frame->value_stack().Push(iter_range); - return frame->JumpTo(error_jump_offset_); + private: + const size_t accu_slot_; +}; + +class ComprehensionDirectStep final : public DirectExpressionStep { + public: + explicit ComprehensionDirectStep( + size_t iter_slot, size_t iter2_slot, size_t accu_slot, + std::unique_ptr range, + std::unique_ptr accu_init, + std::unique_ptr loop_step, + std::unique_ptr condition_step, + std::unique_ptr result_step, bool shortcircuiting, + int64_t expr_id) + : DirectExpressionStep(expr_id), + iter_slot_(iter_slot), + iter2_slot_(iter2_slot), + accu_slot_(accu_slot), + range_(std::move(range)), + accu_init_(std::move(accu_init)), + loop_step_(std::move(loop_step)), + condition_(std::move(condition_step)), + result_step_(std::move(result_step)), + shortcircuiting_(shortcircuiting) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const override { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame, result, trail) + : Evaluate2(frame, result, trail); + } + + private: + absl::Status Evaluate1(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const; + + absl::StatusOr Evaluate1Unknown( + ExecutionFrameBase& frame, IterableKind range_iter_kind, + const AttributeTrail& range_iter_attr, + ValueIterator* absl_nonnull range_iter, + ComprehensionSlots::Slot* absl_nonnull accu_slot, + ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, + AttributeTrail& trail) const; + + absl::StatusOr Evaluate1Known( + ExecutionFrameBase& frame, ValueIterator* absl_nonnull range_iter, + ComprehensionSlots::Slot* absl_nonnull accu_slot, + ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, + AttributeTrail& trail) const; + + absl::Status Evaluate2(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const; + + const size_t iter_slot_; + const size_t iter2_slot_; + const size_t accu_slot_; + const std::unique_ptr range_; + const std::unique_ptr accu_init_; + const std::unique_ptr loop_step_; + const std::unique_ptr condition_; + const std::unique_ptr result_step_; + const bool shortcircuiting_; +}; + +absl::Status ComprehensionDirectStep::Evaluate1(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + Value range; + AttributeTrail range_attr; + CEL_RETURN_IF_ERROR(range_->Evaluate(frame, range, range_attr)); + + if (frame.unknown_processing_enabled() && range.IsMap()) { + if (frame.attribute_utility().CheckForUnknownPartial(range_attr)) { + result = + frame.attribute_utility().CreateUnknownSet(range_attr.attribute()); + return absl::OkStatus(); } - frame->value_stack().Push( - CreateNoMatchingOverloadError(frame->arena(), "")); - return frame->JumpTo(error_jump_offset_); } - const CelList* cel_list = iter_range.ListOrDie(); - const AttributeTrail iter_range_attr = attr[POS_ITER_RANGE]; - - // Get the current index off the stack. - CelValue current_index_value = state[POS_CURRENT_INDEX]; - if (!current_index_value.IsInt64()) { - auto message = absl::StrCat( - "ComprehensionNextStep: want int64_t, got ", - CelValue::TypeName(current_index_value.type()) - ); - return absl::Status(absl::StatusCode::kInternal, message); - } - auto increment_status = frame->IncrementIterations(); - if (!increment_status.ok()) { - return increment_status; - } - int64_t current_index = current_index_value.Int64OrDie(); - if (current_index == -1) { - RETURN_IF_ERROR(frame->PushIterFrame()); - } - - // Update stack for breaking out of loop or next round. - CelValue loop_step = state[POS_LOOP_STEP]; - frame->value_stack().Pop(5); - frame->value_stack().Push(loop_step); - RETURN_IF_ERROR(frame->SetIterVar(accu_var_, loop_step)); - if (current_index >= cel_list->size() - 1) { - RETURN_IF_ERROR(frame->ClearIterVar(iter_var_)); - return frame->JumpTo(jump_offset_); + + absl_nullability_unknown ValueIteratorPtr range_iter; + IterableKind iterable_kind; + switch (range.kind()) { + case ValueKind::kList: { + CEL_ASSIGN_OR_RETURN(range_iter, range.GetList().NewIterator()); + iterable_kind = IterableKind::kList; + } break; + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN(range_iter, range.GetMap().NewIterator()); + iterable_kind = IterableKind::kMap; + } break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + result = std::move(range); + return absl::OkStatus(); + default: + result = cel::ErrorValue(CreateNoMatchingOverloadError("")); + return absl::OkStatus(); } - frame->value_stack().Push(iter_range, iter_range_attr); - current_index += 1; - CelValue current_value = (*cel_list)[current_index]; - frame->value_stack().Push(CelValue::CreateInt64(current_index)); - auto iter_trail = iter_range_attr.Step( - CelAttributeQualifier::Create(CelValue::CreateInt64(current_index)), - frame->arena()); - frame->value_stack().Push(current_value, iter_trail); - RETURN_IF_ERROR(frame->SetIterVar(iter_var_, current_value, iter_trail)); + ABSL_DCHECK(range_iter != nullptr); + + ComprehensionSlots::Slot* accu_slot = + frame.comprehension_slots().Get(accu_slot_); + ABSL_DCHECK(accu_slot != nullptr); + + { + Value accu_init; + AttributeTrail accu_init_attr; + CEL_RETURN_IF_ERROR(accu_init_->Evaluate(frame, accu_init, accu_init_attr)); + accu_slot->Set(std::move(accu_init), std::move(accu_init_attr)); + } + + ComprehensionSlots::Slot* iter_slot = + frame.comprehension_slots().Get(iter_slot_); + ABSL_DCHECK(iter_slot != nullptr); + iter_slot->Set(); + + bool should_skip_result; + if (frame.unknown_processing_enabled()) { + CEL_ASSIGN_OR_RETURN( + should_skip_result, + Evaluate1Unknown(frame, iterable_kind, range_attr, range_iter.get(), + accu_slot, iter_slot, result, trail)); + } else { + CEL_ASSIGN_OR_RETURN(should_skip_result, + Evaluate1Known(frame, range_iter.get(), accu_slot, + iter_slot, result, trail)); + } + + frame.comprehension_slots().ClearSlot(iter_slot_); + if (!should_skip_result) { + CEL_RETURN_IF_ERROR(result_step_->Evaluate(frame, result, trail)); + } + frame.comprehension_slots().ClearSlot(accu_slot_); return absl::OkStatus(); } -ComprehensionCondStep::ComprehensionCondStep(const std::string&, - const std::string& iter_var, - bool shortcircuiting, - int64_t expr_id) - : ExpressionStepBase(expr_id, false), - iter_var_(iter_var), - shortcircuiting_(shortcircuiting) {} +absl::StatusOr ComprehensionDirectStep::Evaluate1Unknown( + ExecutionFrameBase& frame, IterableKind range_iter_kind, + const AttributeTrail& range_iter_attr, + ValueIterator* absl_nonnull range_iter, + ComprehensionSlots::Slot* absl_nonnull accu_slot, + ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, + AttributeTrail& trail) const { + Value condition; + AttributeTrail condition_attr; + Value key_or_value; + Value* key; + Value* value; + + switch (range_iter_kind) { + case IterableKind::kList: + key = &key_or_value; + value = iter_slot->mutable_value(); + break; + case IterableKind::kMap: + key = iter_slot->mutable_value(); + value = nullptr; + break; + default: + ABSL_UNREACHABLE(); + } + while (true) { + CEL_ASSIGN_OR_RETURN(bool ok, range_iter->Next2(frame.descriptor_pool(), + frame.message_factory(), + frame.arena(), key, value)); + if (!ok) { + break; + } + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + *iter_slot->mutable_attribute() = + range_iter_attr.Step(AttributeQualifierFromValue(*key)); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute())) { + *iter_slot->mutable_value() = frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute().attribute()); + } + + // Evaluate the loop condition. + CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); + + switch (condition.kind()) { + case ValueKind::kBool: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + result = std::move(condition); + return true; + default: + result = + cel::ErrorValue(CreateNoMatchingOverloadError("")); + return true; + } -void ComprehensionCondStep::set_jump_offset(int offset) { - jump_offset_ = offset; + if (shortcircuiting_ && !absl::implicit_cast(condition.GetBool())) { + break; + } + + // Evaluate the loop step. + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, *accu_slot->mutable_value(), + *accu_slot->mutable_attribute())); + } + return false; } -void ComprehensionCondStep::set_error_jump_offset(int offset) { - error_jump_offset_ = offset; +absl::StatusOr ComprehensionDirectStep::Evaluate1Known( + ExecutionFrameBase& frame, ValueIterator* absl_nonnull range_iter, + ComprehensionSlots::Slot* absl_nonnull accu_slot, + ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, + AttributeTrail& trail) const { + Value condition; + AttributeTrail condition_attr; + + while (true) { + CEL_ASSIGN_OR_RETURN( + bool ok, + range_iter->Next1(frame.descriptor_pool(), frame.message_factory(), + frame.arena(), iter_slot->mutable_value())); + if (!ok) { + break; + } + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + + // Evaluate the loop condition. + CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); + + switch (condition.kind()) { + case ValueKind::kBool: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + result = std::move(condition); + return true; + default: + result = + cel::ErrorValue(CreateNoMatchingOverloadError("")); + return true; + } + + if (shortcircuiting_ && !absl::implicit_cast(condition.GetBool())) { + break; + } + + // Evaluate the loop step. + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, *accu_slot->mutable_value(), + *accu_slot->mutable_attribute())); + } + return false; } -// Stack changes by ComprehensionCondStep. -// -// Stack size before: 5. -// Stack size after: 4. -// Stack size on break: 1. -absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { - if (!frame->value_stack().HasEnough(5)) { - return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); +absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + Value range; + AttributeTrail range_attr; + CEL_RETURN_IF_ERROR(range_->Evaluate(frame, range, range_attr)); + + if (frame.unknown_processing_enabled() && range.IsMap()) { + if (frame.attribute_utility().CheckForUnknownPartial(range_attr)) { + result = + frame.attribute_utility().CreateUnknownSet(range_attr.attribute()); + return absl::OkStatus(); + } + } + + absl_nullability_unknown ValueIteratorPtr range_iter; + switch (range.kind()) { + case ValueKind::kList: { + CEL_ASSIGN_OR_RETURN(range_iter, range.GetList().NewIterator()); + } break; + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN(range_iter, range.GetMap().NewIterator()); + } break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + result = std::move(range); + return absl::OkStatus(); + default: + result = cel::ErrorValue(CreateNoMatchingOverloadError("")); + return absl::OkStatus(); + } + ABSL_DCHECK(range_iter != nullptr); + + ComprehensionSlots::Slot* accu_slot = + frame.comprehension_slots().Get(accu_slot_); + ABSL_DCHECK(accu_slot != nullptr); + + { + Value accu_init; + AttributeTrail accu_init_attr; + CEL_RETURN_IF_ERROR(accu_init_->Evaluate(frame, accu_init, accu_init_attr)); + accu_slot->Set(std::move(accu_init), std::move(accu_init_attr)); } - CelValue loop_condition_value = frame->value_stack().Peek(); - if (!loop_condition_value.IsBool()) { - frame->value_stack().Pop(5); - if (loop_condition_value.IsError() || loop_condition_value.IsUnknownSet()) { - frame->value_stack().Push(loop_condition_value); - } else { - frame->value_stack().Push( - CreateNoMatchingOverloadError(frame->arena(), "")); + + ComprehensionSlots::Slot* iter_slot = + frame.comprehension_slots().Get(iter_slot_); + ABSL_DCHECK(iter_slot != nullptr); + iter_slot->Set(); + + ComprehensionSlots::Slot* iter2_slot = + frame.comprehension_slots().Get(iter2_slot_); + ABSL_DCHECK(iter2_slot != nullptr); + iter2_slot->Set(); + + Value condition; + AttributeTrail condition_attr; + bool should_skip_result = false; + + while (true) { + CEL_ASSIGN_OR_RETURN( + bool ok, + range_iter->Next2(frame.descriptor_pool(), frame.message_factory(), + frame.arena(), iter_slot->mutable_value(), + iter2_slot->mutable_value())); + if (!ok) { + break; } - // The error jump skips the ComprehensionFinish clean-up step, so we - // need to update the iteration variable stack here. - RETURN_IF_ERROR(frame->PopIterFrame()); - return frame->JumpTo(error_jump_offset_); + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + if (frame.unknown_processing_enabled()) { + *iter_slot->mutable_attribute() = *iter2_slot->mutable_attribute() = + range_attr.Step(AttributeQualifierFromValue(iter_slot->value())); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute())) { + *iter2_slot->mutable_value() = + frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute().attribute()); + } + } + + // Evaluate the loop condition. + CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); + + switch (condition.kind()) { + case ValueKind::kBool: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + result = std::move(condition); + should_skip_result = true; + goto finish; + default: + result = + cel::ErrorValue(CreateNoMatchingOverloadError("")); + should_skip_result = true; + goto finish; + } + + if (shortcircuiting_ && !absl::implicit_cast(condition.GetBool())) { + break; + } + + // Evaluate the loop step. + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, *accu_slot->mutable_value(), + *accu_slot->mutable_attribute())); } - bool loop_condition = loop_condition_value.BoolOrDie(); - frame->value_stack().Pop(1); // loop_condition - if (!loop_condition && shortcircuiting_) { - frame->value_stack().Pop(3); // current_value, current_index, iter_range - return frame->JumpTo(jump_offset_); + +finish: + iter_slot->Clear(); + iter2_slot->Clear(); + if (!should_skip_result) { + CEL_RETURN_IF_ERROR(result_step_->Evaluate(frame, result, trail)); } + accu_slot->Clear(); return absl::OkStatus(); } -ComprehensionFinish::ComprehensionFinish(const std::string& accu_var, - const std::string&, int64_t expr_id) - : ExpressionStepBase(expr_id), accu_var_(accu_var) {} +} // namespace -// Stack changes of ComprehensionFinish. -// -// Stack size before: 2. -// Stack size after: 1. -absl::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { - if (!frame->value_stack().HasEnough(2)) { +absl::Status ComprehensionInitStep::Evaluate(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(1)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } - CelValue result = frame->value_stack().Peek(); - frame->value_stack().Pop(1); // result - frame->value_stack().PopAndPush(result); - RETURN_IF_ERROR(frame->PopIterFrame()); + + const Value& top = frame->value_stack().Peek(); + if (top.IsError() || top.IsUnknown()) { + return frame->JumpTo(error_jump_offset_); + } + + if (frame->enable_unknowns() && top.IsMap()) { + const AttributeTrail& top_attr = frame->value_stack().PeekAttribute(); + if (frame->attribute_utility().CheckForUnknownPartial(top_attr)) { + frame->value_stack().PopAndPush( + frame->attribute_utility().CreateUnknownSet(top_attr.attribute())); + return frame->JumpTo(error_jump_offset_); + } + } + + switch (top.kind()) { + case ValueKind::kList: { + CEL_ASSIGN_OR_RETURN(auto iterator, top.GetList().NewIterator()); + frame->iterator_stack().Push(std::move(iterator)); + } break; + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto iterator, top.GetMap().NewIterator()); + frame->iterator_stack().Push(std::move(iterator)); + } break; + default: + // Replace with an error and jump past + // ComprehensionFinishStep. + frame->value_stack().PopAndPush( + cel::ErrorValue(CreateNoMatchingOverloadError(""))); + return frame->JumpTo(error_jump_offset_); + } + return absl::OkStatus(); } -class ListKeysStep : public ExpressionStepBase { - public: - ListKeysStep(int64_t expr_id) : ExpressionStepBase(expr_id, false) {} - absl::Status Evaluate(ExecutionFrame* frame) const override; +absl::Status ComprehensionNextStep::Evaluate1(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(2)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } - private: - absl::Status ProjectKeys(ExecutionFrame* frame) const; -}; + { + Value& accu_var = frame->value_stack().Peek(); + AttributeTrail& accu_var_attr = frame->value_stack().PeekAttribute(); + frame->comprehension_slots().Set(accu_slot_, std::move(accu_var), + std::move(accu_var_attr)); + frame->value_stack().Pop(1); + } -std::unique_ptr CreateListKeysStep(int64_t expr_id) { - return absl::make_unique(expr_id); + ComprehensionSlots::Slot* iter_slot = + frame->comprehension_slots().Get(iter_slot_); + ABSL_DCHECK(iter_slot != nullptr); + iter_slot->Set(); + + if (frame->enable_unknowns()) { + Value key_or_value; + Value* key; + Value* value; + switch (frame->value_stack().Peek().kind()) { + case ValueKind::kList: + key = &key_or_value; + value = iter_slot->mutable_value(); + break; + case ValueKind::kMap: + key = iter_slot->mutable_value(); + value = nullptr; + break; + default: + ABSL_UNREACHABLE(); + } + CEL_ASSIGN_OR_RETURN(bool ok, + frame->iterator_stack().Peek()->Next2( + frame->descriptor_pool(), frame->message_factory(), + frame->arena(), key, value)); + if (!ok) { + iter_slot->Clear(); + return frame->JumpTo(jump_offset_); + } + CEL_RETURN_IF_ERROR(frame->IncrementIterations()); + *iter_slot->mutable_attribute() = frame->value_stack().PeekAttribute().Step( + AttributeQualifierFromValue(*key)); + if (frame->attribute_utility().CheckForUnknownExact( + iter_slot->attribute())) { + *iter_slot->mutable_value() = frame->attribute_utility().CreateUnknownSet( + iter_slot->attribute().attribute()); + } + } else { + CEL_ASSIGN_OR_RETURN(bool ok, + frame->iterator_stack().Peek()->Next1( + frame->descriptor_pool(), frame->message_factory(), + frame->arena(), iter_slot->mutable_value())); + if (!ok) { + iter_slot->Clear(); + return frame->JumpTo(jump_offset_); + } + CEL_RETURN_IF_ERROR(frame->IncrementIterations()); + } + return absl::OkStatus(); } -absl::Status ListKeysStep::ProjectKeys(ExecutionFrame* frame) const { - // Top of stack is map, but could be partially unknown. To tolerate cases when - // keys are not set for declared unknown values, convert to an unknown set. +absl::Status ComprehensionNextStep::Evaluate2(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(2)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + + { + Value& accu_var = frame->value_stack().Peek(); + AttributeTrail& accu_var_attr = frame->value_stack().PeekAttribute(); + frame->comprehension_slots().Set(accu_slot_, std::move(accu_var), + std::move(accu_var_attr)); + frame->value_stack().Pop(1); + } + + ComprehensionSlots::Slot* iter_slot = + frame->comprehension_slots().Get(iter_slot_); + ABSL_DCHECK(iter_slot != nullptr); + iter_slot->Set(); + + ComprehensionSlots::Slot* iter2_slot = + frame->comprehension_slots().Get(iter2_slot_); + ABSL_DCHECK(iter2_slot != nullptr); + iter2_slot->Set(); + + CEL_ASSIGN_OR_RETURN( + bool ok, + frame->iterator_stack().Peek()->Next2( + frame->descriptor_pool(), frame->message_factory(), frame->arena(), + iter_slot->mutable_value(), iter2_slot->mutable_value())); + if (!ok) { + iter_slot->Clear(); + iter2_slot->Clear(); + return frame->JumpTo(jump_offset_); + } + CEL_RETURN_IF_ERROR(frame->IncrementIterations()); if (frame->enable_unknowns()) { - const UnknownSet* unknown = frame->attribute_utility().MergeUnknowns( - frame->value_stack().GetSpan(1), - frame->value_stack().GetAttributeSpan(1), nullptr, - /*use_partial=*/true); - if (unknown) { - frame->value_stack().PopAndPush(CelValue::CreateUnknownSet(unknown)); - return absl::OkStatus(); + *iter_slot->mutable_attribute() = *iter2_slot->mutable_attribute() = + frame->value_stack().PeekAttribute().Step( + AttributeQualifierFromValue(iter_slot->value())); + if (frame->attribute_utility().CheckForUnknownExact( + iter2_slot->attribute())) { + *iter2_slot->mutable_value() = + frame->attribute_utility().CreateUnknownSet( + iter2_slot->attribute().attribute()); } } + return absl::OkStatus(); +} - const CelValue& map = frame->value_stack().Peek(); - frame->value_stack().PopAndPush( - CelValue::CreateList(map.MapOrDie()->ListKeys())); +absl::Status ComprehensionCondStep::Evaluate1(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(2)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + const Value& top = frame->value_stack().Peek(); + switch (top.kind()) { + case ValueKind::kBool: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + frame->value_stack().SwapAndPop(2, 1); + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(accu_slot_); + frame->iterator_stack().Pop(); + return frame->JumpTo(error_jump_offset_); + default: + frame->value_stack().PopAndPush( + 2, + cel::ErrorValue(CreateNoMatchingOverloadError(""))); + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(accu_slot_); + frame->iterator_stack().Pop(); + return frame->JumpTo(error_jump_offset_); + } + const bool loop_condition = absl::implicit_cast(top.GetBool()); + frame->value_stack().Pop(1); // loop_condition + if (!loop_condition && shortcircuiting_) { + return frame->JumpTo(jump_offset_); + } return absl::OkStatus(); } -absl::Status ListKeysStep::Evaluate(ExecutionFrame* frame) const { - if (!frame->value_stack().HasEnough(1)) { +absl::Status ComprehensionCondStep::Evaluate2(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } - const CelValue& map_value = frame->value_stack().Peek(); - if (map_value.IsMap()) { - return ProjectKeys(frame); + const Value& top = frame->value_stack().Peek(); + switch (top.kind()) { + case ValueKind::kBool: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + frame->value_stack().SwapAndPop(2, 1); + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(iter2_slot_); + frame->comprehension_slots().ClearSlot(accu_slot_); + frame->iterator_stack().Pop(); + return frame->JumpTo(error_jump_offset_); + default: + frame->value_stack().PopAndPush( + 2, + cel::ErrorValue(CreateNoMatchingOverloadError(""))); + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(iter2_slot_); + frame->comprehension_slots().ClearSlot(accu_slot_); + frame->iterator_stack().Pop(); + return frame->JumpTo(error_jump_offset_); + } + const bool loop_condition = absl::implicit_cast(top.GetBool()); + frame->value_stack().Pop(1); // loop_condition + if (!loop_condition && shortcircuiting_) { + return frame->JumpTo(jump_offset_); } return absl::OkStatus(); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +std::unique_ptr CreateDirectComprehensionStep( + size_t iter_slot, size_t iter2_slot, size_t accu_slot, + std::unique_ptr range, + std::unique_ptr accu_init, + std::unique_ptr loop_step, + std::unique_ptr condition_step, + std::unique_ptr result_step, bool shortcircuiting, + int64_t expr_id) { + return std::make_unique( + iter_slot, iter2_slot, accu_slot, std::move(range), std::move(accu_init), + std::move(loop_step), std::move(condition_step), std::move(result_step), + shortcircuiting, expr_id); +} + +std::unique_ptr CreateComprehensionFinishStep(size_t accu_slot, + int64_t expr_id) { + return std::make_unique(accu_slot, expr_id); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/comprehension_step.h b/eval/eval/comprehension_step.h index 93ebcb091..34a6afc19 100644 --- a/eval/eval/comprehension_step.h +++ b/eval/eval/comprehension_step.h @@ -1,71 +1,119 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/activation.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_value.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -class ComprehensionNextStep : public ExpressionStepBase { +namespace google::api::expr::runtime { + +// Comprehension Evaluation +// +// 0: 1 -> 1 +// 1: ComprehensionInitStep 1 -> 1 +// 2: 1 -> 2 +// 3: ComprehensionNextStep 2 -> 1 +// 4: 1 -> 2 +// 5: ComprehensionCondStep 2 -> 1 +// 6: 1 -> 2 +// 8: 1 -> 2 +// 9: ComprehensionFinishStep 2 -> 1 + +class ComprehensionInitStep final : public ExpressionStepBase { public: - ComprehensionNextStep(const std::string& accu_var, - const std::string& iter_var, int64_t expr_id); + explicit ComprehensionInitStep(int64_t expr_id) + : ExpressionStepBase(expr_id, /*comes_from_ast=*/false) {} - void set_jump_offset(int offset); - void set_error_jump_offset(int offset); + void set_error_jump_offset(int offset) { error_jump_offset_ = offset; } absl::Status Evaluate(ExecutionFrame* frame) const override; private: - std::string accu_var_; - std::string iter_var_; - int jump_offset_; - int error_jump_offset_; + int error_jump_offset_ = std::numeric_limits::max(); }; -class ComprehensionCondStep : public ExpressionStepBase { +class ComprehensionNextStep final : public ExpressionStepBase { public: - ComprehensionCondStep(const std::string& accu_var, - const std::string& iter_var, bool shortcircuiting, - int64_t expr_id); + ComprehensionNextStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, + int64_t expr_id) + : ExpressionStepBase(expr_id, /*comes_from_ast=*/false), + iter_slot_(iter_slot), + iter2_slot_(iter2_slot), + accu_slot_(accu_slot) {} - void set_jump_offset(int offset); - void set_error_jump_offset(int offset); + void set_jump_offset(int offset) { jump_offset_ = offset; } - absl::Status Evaluate(ExecutionFrame* frame) const override; + void set_error_jump_offset(int offset) { error_jump_offset_ = offset; } + + absl::Status Evaluate(ExecutionFrame* frame) const override { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); + } private: - std::string iter_var_; - int jump_offset_; - int error_jump_offset_; - bool shortcircuiting_; + absl::Status Evaluate1(ExecutionFrame* frame) const; + + absl::Status Evaluate2(ExecutionFrame* frame) const; + + const size_t iter_slot_; + const size_t iter2_slot_; + const size_t accu_slot_; + int jump_offset_ = std::numeric_limits::max(); + int error_jump_offset_ = std::numeric_limits::max(); }; -class ComprehensionFinish : public ExpressionStepBase { +class ComprehensionCondStep final : public ExpressionStepBase { public: - ComprehensionFinish(const std::string& accu_var, const std::string& iter_var, - int64_t expr_id); + ComprehensionCondStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, + bool shortcircuiting, int64_t expr_id) + : ExpressionStepBase(expr_id, /*comes_from_ast=*/false), + iter_slot_(iter_slot), + iter2_slot_(iter2_slot), + accu_slot_(accu_slot), + shortcircuiting_(shortcircuiting) {} - absl::Status Evaluate(ExecutionFrame* frame) const override; + void set_jump_offset(int offset) { jump_offset_ = offset; } + + void set_error_jump_offset(int offset) { error_jump_offset_ = offset; } + + absl::Status Evaluate(ExecutionFrame* frame) const override { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); + } private: - std::string accu_var_; -}; + absl::Status Evaluate1(ExecutionFrame* frame) const; + + absl::Status Evaluate2(ExecutionFrame* frame) const; -// Creates a step that lists the map keys if the top of the stack is a map, -// otherwise it's a no-op. -std::unique_ptr CreateListKeysStep(int64_t expr_id); + const size_t iter_slot_; + const size_t iter2_slot_; + const size_t accu_slot_; + int jump_offset_ = std::numeric_limits::max(); + int error_jump_offset_ = std::numeric_limits::max(); + const bool shortcircuiting_; +}; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +// Creates a step for executing a comprehension. +std::unique_ptr CreateDirectComprehensionStep( + size_t iter_slot, size_t iter2_slot, size_t accu_slot, + std::unique_ptr range, + std::unique_ptr accu_init, + std::unique_ptr loop_step, + std::unique_ptr condition_step, + std::unique_ptr result_step, bool shortcircuiting, + int64_t expr_id); + +// Creates a cleanup step for the comprehension. +// Removes the comprehension context then pushes the 'result' sub expression to +// the top of the stack. +std::unique_ptr CreateComprehensionFinishStep(size_t accu_slot, + int64_t expr_id); + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 5c6f6769a..681f8af4f 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -1,36 +1,60 @@ #include "eval/eval/comprehension_step.h" -#include +#include +#include +#include +#include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" -#include "google/protobuf/wrappers.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "base/type_provider.h" +#include "common/expr.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" #include "eval/eval/ident_step.h" +#include "eval/public/activation.h" #include "eval/public/cel_attribute.h" -#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { namespace { -using ::google::protobuf::ListValue; +using ::absl_testing::StatusIs; +using ::cel::BoolValue; +using ::cel::Expr; +using ::cel::IdentExpr; +using ::cel::IntValue; +using ::cel::TypeProvider; +using ::cel::Value; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::test::BoolValueIs; using ::google::protobuf::Struct; using ::google::protobuf::Arena; -using testing::Eq; -using testing::SizeIs; - -using IdentExpr = google::api::expr::v1alpha1::Expr::Ident; -using Expr = google::api::expr::v1alpha1::Expr; +using ::testing::_; +using ::testing::Eq; +using ::testing::Return; +using ::testing::SizeIs; IdentExpr CreateIdent(const std::string& var) { IdentExpr expr; @@ -40,97 +64,51 @@ IdentExpr CreateIdent(const std::string& var) { class ListKeysStepTest : public testing::Test { public: - ListKeysStepTest() {} + ListKeysStepTest() = default; std::unique_ptr MakeExpression( ExecutionPath&& path, bool unknown_attributes = false) { + cel::RuntimeOptions options; + if (unknown_attributes) { + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + } + auto env = NewTestingRuntimeEnv(); return std::make_unique( - &dummy_expr_, std::move(path), 0, std::set(), - unknown_attributes, unknown_attributes); + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); } private: Expr dummy_expr_; }; +class GetListKeysResultStep : public ExpressionStepBase { + public: + GetListKeysResultStep() : ExpressionStepBase(-1, false) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + frame->value_stack().Pop(1); + return absl::OkStatus(); + } +}; + MATCHER_P(CelStringValue, val, "") { const CelValue& to_match = arg; absl::string_view value = val; return to_match.IsString() && to_match.StringOrDie().value() == value; } -TEST_F(ListKeysStepTest, ListPassedThrough) { - ExecutionPath path; - IdentExpr ident = CreateIdent("var"); - auto result = CreateIdentStep(&ident, 0); - ASSERT_OK(result); - path.push_back(std::move(result.value())); - result = CreateListKeysStep(1); - ASSERT_OK(result); - path.push_back(std::move(result.value())); - - auto expression = MakeExpression(std::move(path)); - - Activation activation; - Arena arena; - ListValue value; - value.add_values()->set_number_value(1.0); - value.add_values()->set_number_value(2.0); - value.add_values()->set_number_value(3.0); - activation.InsertValue("var", CelProtoWrapper::CreateMessage(&value, &arena)); - - auto eval_result = expression->Evaluate(activation, &arena); - - ASSERT_OK(eval_result); - ASSERT_TRUE(eval_result->IsList()); - EXPECT_THAT(*eval_result->ListOrDie(), SizeIs(3)); -} - -TEST_F(ListKeysStepTest, MapToKeyList) { - ExecutionPath path; - IdentExpr ident = CreateIdent("var"); - auto result = CreateIdentStep(&ident, 0); - ASSERT_OK(result); - path.push_back(std::move(result.value())); - result = CreateListKeysStep(1); - ASSERT_OK(result); - path.push_back(std::move(result.value())); - - auto expression = MakeExpression(std::move(path)); - - Activation activation; - Arena arena; - Struct value; - (*value.mutable_fields())["key1"].set_number_value(1.0); - (*value.mutable_fields())["key2"].set_number_value(2.0); - (*value.mutable_fields())["key3"].set_number_value(3.0); - - activation.InsertValue("var", CelProtoWrapper::CreateMessage(&value, &arena)); - - auto eval_result = expression->Evaluate(activation, &arena); - - ASSERT_OK(eval_result); - ASSERT_TRUE(eval_result->IsList()); - EXPECT_THAT(*eval_result->ListOrDie(), SizeIs(3)); - std::vector keys; - keys.reserve(eval_result->ListOrDie()->size()); - for (int i = 0; i < eval_result->ListOrDie()->size(); i++) { - keys.push_back(eval_result->ListOrDie()->operator[](i)); - } - EXPECT_THAT(keys, testing::UnorderedElementsAre(CelStringValue("key1"), - CelStringValue("key2"), - CelStringValue("key3"))); -} - TEST_F(ListKeysStepTest, MapPartiallyUnknown) { ExecutionPath path; - IdentExpr ident = CreateIdent("var"); - auto result = CreateIdentStep(&ident, 0); + auto result = CreateIdentStep("var", 0); ASSERT_OK(result); - path.push_back(std::move(result.value())); - result = CreateListKeysStep(1); - ASSERT_OK(result); - path.push_back(std::move(result.value())); + path.push_back(*std::move(result)); + ComprehensionInitStep* init_step = new ComprehensionInitStep(1); + init_step->set_error_jump_offset(1); + path.push_back(absl::WrapUnique(init_step)); + path.push_back(std::make_unique()); auto expression = MakeExpression(std::move(path), /*unknown_attributes=*/true); @@ -145,31 +123,30 @@ TEST_F(ListKeysStepTest, MapPartiallyUnknown) { activation.InsertValue("var", CelProtoWrapper::CreateMessage(&value, &arena)); activation.set_unknown_attribute_patterns({CelAttributePattern( "var", - {CelAttributeQualifierPattern::Create(CelValue::CreateStringView("key2")), - CelAttributeQualifierPattern::Create(CelValue::CreateStringView("foo")), + {CreateCelAttributeQualifierPattern(CelValue::CreateStringView("key2")), + CreateCelAttributeQualifierPattern(CelValue::CreateStringView("foo")), CelAttributeQualifierPattern::CreateWildcard()})}); auto eval_result = expression->Evaluate(activation, &arena); ASSERT_OK(eval_result); ASSERT_TRUE(eval_result->IsUnknownSet()); - const auto& attrs = - eval_result->UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = eval_result->UnknownSetOrDie()->unknown_attributes(); EXPECT_THAT(attrs, SizeIs(1)); - EXPECT_THAT(attrs.at(0)->variable().ident_expr().name(), Eq("var")); - EXPECT_THAT(attrs.at(0)->qualifier_path(), SizeIs(0)); + EXPECT_THAT(attrs.begin()->variable_name(), Eq("var")); + EXPECT_THAT(attrs.begin()->qualifier_path(), SizeIs(0)); } TEST_F(ListKeysStepTest, ErrorPassedThrough) { ExecutionPath path; - IdentExpr ident = CreateIdent("var"); - auto result = CreateIdentStep(&ident, 0); - ASSERT_OK(result); - path.push_back(std::move(result.value())); - result = CreateListKeysStep(1); + auto result = CreateIdentStep("var", 0); ASSERT_OK(result); - path.push_back(std::move(result.value())); + path.push_back(*std::move(result)); + ComprehensionInitStep* init_step = new ComprehensionInitStep(1); + init_step->set_error_jump_offset(1); + path.push_back(absl::WrapUnique(init_step)); + path.push_back(std::make_unique()); auto expression = MakeExpression(std::move(path)); @@ -188,13 +165,13 @@ TEST_F(ListKeysStepTest, ErrorPassedThrough) { TEST_F(ListKeysStepTest, UnknownSetPassedThrough) { ExecutionPath path; - IdentExpr ident = CreateIdent("var"); - auto result = CreateIdentStep(&ident, 0); - ASSERT_OK(result); - path.push_back(std::move(result.value())); - result = CreateListKeysStep(1); + auto result = CreateIdentStep("var", 0); ASSERT_OK(result); - path.push_back(std::move(result.value())); + path.push_back(*std::move(result)); + ComprehensionInitStep* init_step = new ComprehensionInitStep(1); + init_step->set_error_jump_offset(1); + path.push_back(absl::WrapUnique(init_step)); + path.push_back(std::make_unique()); auto expression = MakeExpression(std::move(path), /*unknown_attributes=*/true); @@ -208,12 +185,308 @@ TEST_F(ListKeysStepTest, UnknownSetPassedThrough) { ASSERT_OK(eval_result); ASSERT_TRUE(eval_result->IsUnknownSet()); - EXPECT_THAT(eval_result->UnknownSetOrDie()->unknown_attributes().attributes(), - SizeIs(1)); + EXPECT_THAT(eval_result->UnknownSetOrDie()->unknown_attributes(), SizeIs(1)); +} + +class MockDirectStep : public DirectExpressionStep { + public: + MockDirectStep() : DirectExpressionStep(-1) {} + + MOCK_METHOD(absl::Status, Evaluate, + (ExecutionFrameBase&, Value&, AttributeTrail&), + (const, override)); +}; + +// Test fixture for comprehensions. +// +// Comprehensions are quite involved so tests here focus on edge cases that are +// hard to exercise normally in functional-style tests for the planner. +class DirectComprehensionTest : public testing::Test { + public: + DirectComprehensionTest() + : type_provider_(cel::internal::GetTestingDescriptorPool()), slots_(2) {} + + // returns a two element list for testing [1, 2]. + absl::StatusOr MakeList() { + auto builder = cel::NewListValueBuilder(&arena_); + + CEL_RETURN_IF_ERROR(builder->Add(IntValue(1))); + CEL_RETURN_IF_ERROR(builder->Add(IntValue(2))); + return std::move(*builder).Build(); + } + + protected: + google::protobuf::Arena arena_; + cel::runtime_internal::RuntimeTypeProvider type_provider_; + ComprehensionSlots slots_; + cel::Activation empty_activation_; +}; + +TEST_F(DirectComprehensionTest, PropagateRangeNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto range_step = std::make_unique(); + MockDirectStep* mock = range_step.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test range error"))); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/std::move(range_step), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test range error")); +} + +TEST_F(DirectComprehensionTest, PropagateAccuInitNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto accu_init = std::make_unique(); + MockDirectStep* mock = accu_init.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test accu init error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/std::move(accu_init), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test accu init error")); +} + +TEST_F(DirectComprehensionTest, PropagateLoopNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test loop error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test loop error")); +} + +TEST_F(DirectComprehensionTest, PropagateConditionNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto condition = std::make_unique(); + MockDirectStep* mock = condition.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test condition error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/std::move(condition), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test condition error")); +} + +TEST_F(DirectComprehensionTest, PropagateResultNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto result_step = std::make_unique(); + MockDirectStep* mock = result_step.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test result error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/std::move(result_step), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test result error")); +} + +TEST_F(DirectComprehensionTest, Shortcircuit) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + EXPECT_CALL(*mock, Evaluate(_, _, _)) + .Times(0) + .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { + result = BoolValue(false); + return absl::OkStatus(); + }); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + ASSERT_OK(compre_step->Evaluate(frame, result, trail)); + EXPECT_THAT(result, BoolValueIs(false)); +} + +TEST_F(DirectComprehensionTest, IterationLimit) { + cel::RuntimeOptions options; + options.comprehension_max_iterations = 2; + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + EXPECT_CALL(*mock, Evaluate(_, _, _)) + .Times(1) + .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { + result = BoolValue(false); + return absl::OkStatus(); + }); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST_F(DirectComprehensionTest, Exhaustive) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_, + /*embedder_context=*/nullptr, slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + EXPECT_CALL(*mock, Evaluate(_, _, _)) + .Times(2) + .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { + result = BoolValue(false); + return absl::OkStatus(); + }); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/false, -1); + + Value result; + AttributeTrail trail; + ASSERT_OK(compre_step->Evaluate(frame, result, trail)); + EXPECT_THAT(result, BoolValueIs(false)); } } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc deleted file mode 100644 index 5cdc216c6..000000000 --- a/eval/eval/const_value_step.cc +++ /dev/null @@ -1,96 +0,0 @@ -#include "eval/eval/const_value_step.h" - -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "absl/status/statusor.h" -#include "eval/eval/expression_step_base.h" -#include "eval/public/structs/cel_proto_wrapper.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::Expr; - -namespace { - -class ConstValueStep : public ExpressionStepBase { - public: - ConstValueStep(const CelValue& value, int64_t expr_id, bool comes_from_ast) - : ExpressionStepBase(expr_id, comes_from_ast), value_(value) {} - - absl::Status Evaluate(ExecutionFrame* frame) const override; - - private: - CelValue value_; -}; - -absl::Status ConstValueStep::Evaluate(ExecutionFrame* frame) const { - frame->value_stack().Push(value_); - - return absl::OkStatus(); -} - -} // namespace - -absl::optional ConvertConstant(const Constant* const_expr) { - CelValue value = CelValue::CreateNull(); - switch (const_expr->constant_kind_case()) { - case Constant::kNullValue: - value = CelValue::CreateNull(); - break; - case Constant::kBoolValue: - value = CelValue::CreateBool(const_expr->bool_value()); - break; - case Constant::kInt64Value: - value = CelValue::CreateInt64(const_expr->int64_value()); - break; - case Constant::kUint64Value: - value = CelValue::CreateUint64(const_expr->uint64_value()); - break; - case Constant::kDoubleValue: - value = CelValue::CreateDouble(const_expr->double_value()); - break; - case Constant::kStringValue: - value = CelValue::CreateString(&const_expr->string_value()); - break; - case Constant::kBytesValue: - value = CelValue::CreateBytes(&const_expr->bytes_value()); - break; - case Constant::kDurationValue: - value = CelProtoWrapper::CreateDuration(&const_expr->duration_value()); - break; - case Constant::kTimestampValue: - value = CelProtoWrapper::CreateTimestamp(&const_expr->timestamp_value()); - break; - default: - // constant with no kind specified - return {}; - break; - } - return value; -} - -absl::StatusOr> CreateConstValueStep( - CelValue value, int64_t expr_id, bool comes_from_ast) { - std::unique_ptr step = - absl::make_unique(value, expr_id, comes_from_ast); - return std::move(step); -} - -// Factory method for Constant(Enum value) - based Execution step -absl::StatusOr> CreateConstValueStep( - const google::protobuf::EnumValueDescriptor* value_descriptor, int64_t expr_id) { - CelValue value = CelValue::CreateInt64(value_descriptor->number()); - - std::unique_ptr step = - absl::make_unique(value, expr_id, false); - return std::move(step); -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/eval/const_value_step.h b/eval/eval/const_value_step.h index 267dea7b9..c3cf6a424 100644 --- a/eval/eval/const_value_step.h +++ b/eval/eval/const_value_step.h @@ -1,29 +1,31 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONST_VALUE_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONST_VALUE_STEP_H_ -#include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" -#include "eval/public/cel_value.h" +#include +#include +#include -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "absl/status/statusor.h" +#include "common/value.h" +#include "eval/eval/compiler_constant_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" -absl::optional ConvertConstant( - const google::api::expr::v1alpha1::Constant* const_expr); +namespace google::api::expr::runtime { -// Factory method for Constant - based Execution step -absl::StatusOr> CreateConstValueStep( - CelValue value, int64_t expr_id, bool comes_from_ast = true); +// Factory method for Constant AST node expression recursive step. +inline std::unique_ptr CreateConstValueDirectStep( + cel::Value value, int64_t id = -1) { + return std::make_unique(std::move(value), id); +} -// Factory method for Constant(Enum value) - based Execution step -absl::StatusOr> CreateConstValueStep( - const google::protobuf::EnumValueDescriptor* value_descriptor, int64_t expr_id); +// Factory method for Constant AST node expression step. +inline absl::StatusOr> CreateConstValueStep( + cel::Value value, int64_t expr_id, bool comes_from_ast = true) { + return std::make_unique(std::move(value), expr_id, + comes_from_ast); +} -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONST_VALUE_STEP_H_ diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc deleted file mode 100644 index 58d71f42d..000000000 --- a/eval/eval/const_value_step_test.cc +++ /dev/null @@ -1,168 +0,0 @@ -#include "eval/eval/const_value_step.h" - -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/status/statusor.h" -#include "eval/eval/evaluator_core.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -namespace { - -using testing::Eq; - -using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::Expr; - -using google::protobuf::Arena; - -absl::StatusOr RunConstantExpression(const Expr* expr, - const Constant* const_expr, - Arena* arena) { - auto step_status = - CreateConstValueStep(ConvertConstant(const_expr).value(), expr->id()); - if (!step_status.ok()) return step_status.status(); - - ExecutionPath path; - path.push_back(std::move(step_status.value())); - - google::api::expr::v1alpha1::Expr dummy_expr; - - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}); - - Activation activation; - - return impl.Evaluate(activation, arena); -} - -TEST(ConstValueStepTest, TestEvaluationConstInt64) { - Expr expr; - auto const_expr = expr.mutable_const_expr(); - const_expr->set_int64_value(1); - - google::protobuf::Arena arena; - - auto status = RunConstantExpression(&expr, const_expr, &arena); - - ASSERT_OK(status); - - auto value = status.value(); - - ASSERT_TRUE(value.IsInt64()); - EXPECT_THAT(value.Int64OrDie(), Eq(1)); -} - -TEST(ConstValueStepTest, TestEvaluationConstUint64) { - Expr expr; - auto const_expr = expr.mutable_const_expr(); - const_expr->set_uint64_value(1); - - google::protobuf::Arena arena; - - auto status = RunConstantExpression(&expr, const_expr, &arena); - - ASSERT_OK(status); - - auto value = status.value(); - - ASSERT_TRUE(value.IsUint64()); - EXPECT_THAT(value.Uint64OrDie(), Eq(1)); -} - -TEST(ConstValueStepTest, TestEvaluationConstBool) { - Expr expr; - auto const_expr = expr.mutable_const_expr(); - const_expr->set_bool_value(true); - - google::protobuf::Arena arena; - - auto status = RunConstantExpression(&expr, const_expr, &arena); - - ASSERT_OK(status); - - auto value = status.value(); - - ASSERT_TRUE(value.IsBool()); - EXPECT_THAT(value.BoolOrDie(), Eq(true)); -} - -TEST(ConstValueStepTest, TestEvaluationConstNull) { - Expr expr; - auto const_expr = expr.mutable_const_expr(); - const_expr->set_null_value(google::protobuf::NullValue(0)); - - google::protobuf::Arena arena; - - auto status = RunConstantExpression(&expr, const_expr, &arena); - - ASSERT_OK(status); - - auto value = status.value(); - - EXPECT_TRUE(value.IsNull()); -} - -TEST(ConstValueStepTest, TestEvaluationConstString) { - Expr expr; - auto const_expr = expr.mutable_const_expr(); - const_expr->set_string_value("test"); - - google::protobuf::Arena arena; - - auto status = RunConstantExpression(&expr, const_expr, &arena); - - ASSERT_OK(status); - - auto value = status.value(); - - ASSERT_TRUE(value.IsString()); - EXPECT_THAT(value.StringOrDie().value(), Eq("test")); -} - -TEST(ConstValueStepTest, TestEvaluationConstDouble) { - Expr expr; - auto const_expr = expr.mutable_const_expr(); - const_expr->set_double_value(1.0); - - google::protobuf::Arena arena; - - auto status = RunConstantExpression(&expr, const_expr, &arena); - - ASSERT_OK(status); - - auto value = status.value(); - - ASSERT_TRUE(value.IsDouble()); - EXPECT_THAT(value.DoubleOrDie(), testing::DoubleEq(1.0)); -} - -// Test Bytes constant -// For now, bytes are equivalent to string. -TEST(ConstValueStepTest, TestEvaluationConstBytes) { - Expr expr; - auto const_expr = expr.mutable_const_expr(); - const_expr->set_bytes_value("test"); - - google::protobuf::Arena arena; - - auto status = RunConstantExpression(&expr, const_expr, &arena); - - ASSERT_OK(status); - - auto value = status.value(); - - ASSERT_TRUE(value.IsBytes()); - EXPECT_THAT(value.BytesOrDie().value(), Eq("test")); -} - -} // namespace - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index aeb2499f9..4cf4ebf4d 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -1,167 +1,370 @@ #include "eval/eval/container_access_step.h" -#include "google/protobuf/arena.h" +#include +#include +#include + +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/kind.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/attribute_utility.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" +#include "eval/internal/errors.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "runtime/internal/errors.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -constexpr int NUM_CONTAINER_ACCESS_ARGUMENTS = 2; +using ::cel::AttributeQualifier; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::ListValue; +using ::cel::MapValue; +using ::cel::UintValue; +using ::cel::Value; +using ::cel::ValueKind; +using ::cel::ValueKindToString; +using ::cel::internal::Number; +using ::cel::runtime_internal::CreateNoSuchKeyError; -// ContainerAccessStep performs message field access specified by Expr::Select -// message. -class ContainerAccessStep : public ExpressionStepBase { - public: - ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} +inline constexpr int kNumContainerAccessArguments = 2; - absl::Status Evaluate(ExecutionFrame* frame) const override; +absl::optional CelNumberFromValue(const Value& value) { + switch (value->kind()) { + case ValueKind::kInt64: + return Number::FromInt64(value.GetInt().NativeValue()); + case ValueKind::kUint64: + return Number::FromUint64(value.GetUint().NativeValue()); + case ValueKind::kDouble: + return Number::FromDouble(value.GetDouble().NativeValue()); + default: + return std::nullopt; + } +} - private: - using ValueAttributePair = std::pair; +absl::Status CheckMapKeyType(const Value& key) { + ValueKind kind = key->kind(); + switch (kind) { + case ValueKind::kString: + case ValueKind::kInt64: + case ValueKind::kUint64: + case ValueKind::kBool: + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Invalid map key type: '", ValueKindToString(kind), "'")); + } +} - ValueAttributePair PerformLookup(ExecutionFrame* frame) const; - CelValue LookupInMap(const CelMap* cel_map, const CelValue& key, - google::protobuf::Arena* arena) const; - CelValue LookupInList(const CelList* cel_list, const CelValue& key, - google::protobuf::Arena* arena) const; -}; +AttributeQualifier AttributeQualifierFromValue(const Value& v) { + switch (v->kind()) { + case ValueKind::kString: + return AttributeQualifier::OfString(v.GetString().ToString()); + case ValueKind::kInt64: + return AttributeQualifier::OfInt(v.GetInt().NativeValue()); + case ValueKind::kUint64: + return AttributeQualifier::OfUint(v.GetUint().NativeValue()); + case ValueKind::kBool: + return AttributeQualifier::OfBool(v.GetBool().NativeValue()); + default: + // Non-matching qualifier. + return AttributeQualifier(); + } +} -inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, - const CelValue& key, - google::protobuf::Arena* arena) const { - switch (key.type()) { - case CelValue::Type::kBool: - case CelValue::Type::kInt64: - case CelValue::Type::kUint64: - case CelValue::Type::kString: { - absl::optional maybe_value = (*cel_map)[key]; - if (maybe_value.has_value()) { - return maybe_value.value(); +void LookupInMap(const MapValue& cel_map, const Value& key, + ExecutionFrameBase& frame, Value& result) { + if (frame.options().enable_heterogeneous_equality) { + // Double isn't a supported key type but may be convertible to an integer. + absl::optional number = CelNumberFromValue(key); + if (number.has_value()) { + // Consider uint as uint first then try coercion (prefer matching the + // original type of the key value). + if (key->Is()) { + auto lookup = + cel_map.Find(key, frame.descriptor_pool(), frame.message_factory(), + frame.arena(), &result); + if (!lookup.ok()) { + result = cel::ErrorValue(std::move(lookup).status()); + return; + } + if (*lookup) { + ABSL_DCHECK(!result.IsUnknown()); + return; + } } - break; - } - default: { - break; + // double / int / uint -> int + if (number->LosslessConvertibleToInt()) { + auto lookup = + cel_map.Find(IntValue(number->AsInt()), frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result); + if (!lookup.ok()) { + result = cel::ErrorValue(std::move(lookup).status()); + return; + } + if (*lookup) { + ABSL_DCHECK(!result.IsUnknown()); + return; + } + } + // double / int -> uint + if (number->LosslessConvertibleToUint()) { + auto lookup = + cel_map.Find(UintValue(number->AsUint()), frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result); + if (!lookup.ok()) { + result = cel::ErrorValue(std::move(lookup).status()); + return; + } + if (*lookup) { + ABSL_DCHECK(!result.IsUnknown()); + return; + } + } + result = cel::ErrorValue(CreateNoSuchKeyError(key->DebugString())); + return; } } - return CreateNoSuchKeyError(arena, absl::StrCat("Key not found in map")); + + absl::Status status = CheckMapKeyType(key); + if (!status.ok()) { + result = cel::ErrorValue(std::move(status)); + return; + } + + absl::Status lookup = + cel_map.Get(key, frame.descriptor_pool(), frame.message_factory(), + frame.arena(), &result); + if (!lookup.ok()) { + result = cel::ErrorValue(std::move(lookup)); + } + ABSL_DCHECK(!result.IsUnknown()); } -inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, - const CelValue& key, - google::protobuf::Arena* arena) const { - switch (key.type()) { - case CelValue::Type::kInt64: { - int64_t idx = key.Int64OrDie(); - if (idx < 0 || idx >= cel_list->size()) { - return CreateErrorValue(arena, - absl::StrCat("Index error: index=", idx, - " size=", cel_list->size())); - } - return (*cel_list)[idx]; - } - default: { - return CreateErrorValue( - arena, absl::StrCat("Index error: expected integer type, got ", - CelValue::TypeName(key.type()))); +void LookupInList(const ListValue& cel_list, const Value& key, + ExecutionFrameBase& frame, Value& result) { + absl::optional maybe_idx; + if (frame.options().enable_heterogeneous_equality) { + auto number = CelNumberFromValue(key); + if (number.has_value() && number->LosslessConvertibleToInt()) { + maybe_idx = number->AsInt(); } + } else if (InstanceOf(key)) { + maybe_idx = key.GetInt().NativeValue(); } -} -ContainerAccessStep::ValueAttributePair ContainerAccessStep::PerformLookup( - ExecutionFrame* frame) const { - auto input_args = - frame->value_stack().GetSpan(NUM_CONTAINER_ACCESS_ARGUMENTS); - AttributeTrail trail; + if (!maybe_idx.has_value()) { + result = cel::ErrorValue(absl::UnknownError( + absl::StrCat("Index error: expected integer type, got ", + cel::KindToString(ValueKindToKind(key->kind()))))); + return; + } - const CelValue& container = input_args[0]; - const CelValue& key = input_args[1]; + int64_t idx = *maybe_idx; + auto size = cel_list.Size(); + if (!size.ok()) { + result = cel::ErrorValue(size.status()); + return; + } + if (idx < 0 || idx >= *size) { + result = cel::ErrorValue(absl::UnknownError( + absl::StrCat("Index error: index=", idx, " size=", *size))); + return; + } + + absl::Status lookup = + cel_list.Get(idx, frame.descriptor_pool(), frame.message_factory(), + frame.arena(), &result); - if (frame->enable_unknowns()) { - auto unknown_set = - frame->attribute_utility().MergeUnknowns(input_args, nullptr); + if (!lookup.ok()) { + result = cel::ErrorValue(std::move(lookup)); + } + ABSL_DCHECK(!result.IsUnknown()); +} - if (unknown_set) { - return {CelValue::CreateUnknownSet(unknown_set), trail}; +void LookupInContainer(const Value& container, const Value& key, + ExecutionFrameBase& frame, Value& result) { + // Select steps can be applied to either maps or messages + switch (container.kind()) { + case ValueKind::kMap: { + LookupInMap(Cast(container), key, frame, result); + return; + } + case ValueKind::kList: { + LookupInList(Cast(container), key, frame, result); + return; } + default: + result = cel::ErrorValue(absl::InvalidArgumentError( + absl::StrCat("Invalid container type: '", + ValueKindToString(container->kind()), "'"))); + return; + } +} - // We guarantee that GetAttributeSpan can aquire this number of arguments - // by calling HasEnough() at the beginning of Execute() method. - auto input_attrs = - frame->value_stack().GetAttributeSpan(NUM_CONTAINER_ACCESS_ARGUMENTS); - auto container_trail = input_attrs[0]; - trail = container_trail.Step(CelAttributeQualifier::Create(key), - frame->arena()); +void PerformLookup(ExecutionFrameBase& frame, const Value& container, + const Value& key, const AttributeTrail& container_trail, + bool enable_optional_types, Value& result, + AttributeTrail& trail) { + if (frame.unknown_processing_enabled()) { + AttributeUtility::Accumulator unknowns = + frame.attribute_utility().CreateAccumulator(); + unknowns.MaybeAdd(container); + unknowns.MaybeAdd(key); - if (frame->attribute_utility().CheckForUnknown(trail, - /*use_partial=*/false)) { - auto unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownAttributeSet({trail.attribute()})); + if (!unknowns.IsEmpty()) { + result = std::move(unknowns).Build(); + return; + } + + trail = container_trail.Step(AttributeQualifierFromValue(key)); - return {CelValue::CreateUnknownSet(unknown_set), trail}; + if (frame.attribute_utility().CheckForUnknownExact(trail)) { + result = frame.attribute_utility().CreateUnknownSet(trail.attribute()); + return; } } - for (const auto& value : input_args) { - if (value.IsError()) { - return {value, trail}; - } + if (InstanceOf(container)) { + result = container; + return; + } + if (InstanceOf(key)) { + result = key; + return; } - // Select steps can be applied to either maps or messages - switch (container.type()) { - case CelValue::Type::kMap: { - const CelMap* cel_map = container.MapOrDie(); - return {LookupInMap(cel_map, key, frame->arena()), trail}; + if (enable_optional_types && container.IsOptional()) { + const auto& optional_value = container.GetOptional(); + if (!optional_value.HasValue()) { + result = cel::OptionalValue::None(); + return; } - case CelValue::Type::kList: { - const CelList* cel_list = container.ListOrDie(); - return {LookupInList(cel_list, key, frame->arena()), trail}; - } - default: { - auto error = CreateErrorValue( - frame->arena(), - absl::StrCat("Unexpected container type for [] operation: ", - CelValue::TypeName(key.type()))); - return {error, trail}; + Value value; + optional_value.Value(&value); + LookupInContainer(value, key, frame, result); + if (auto error_value = cel::As(result); + error_value && cel::IsNoSuchKey(*error_value)) { + result = cel::OptionalValue::None(); + return; } + result = cel::OptionalValue::Of(std::move(result), frame.arena()); + return; } + + LookupInContainer(container, key, frame, result); } +// ContainerAccessStep performs message field access specified by Expr::Select +// message. +class ContainerAccessStep : public ExpressionStepBase { + public: + ContainerAccessStep(int64_t expr_id, bool enable_optional_types) + : ExpressionStepBase(expr_id), + enable_optional_types_(enable_optional_types) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + bool enable_optional_types_; +}; + absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { - if (!frame->value_stack().HasEnough(NUM_CONTAINER_ACCESS_ARGUMENTS)) { + if (!frame->value_stack().HasEnough(kNumContainerAccessArguments)) { return absl::Status( absl::StatusCode::kInternal, "Insufficient arguments supplied for ContainerAccess-type expression"); } - auto result = PerformLookup(frame); - frame->value_stack().Pop(NUM_CONTAINER_ACCESS_ARGUMENTS); - frame->value_stack().Push(result.first, result.second); + Value result; + AttributeTrail result_trail; + auto args = frame->value_stack().GetSpan(kNumContainerAccessArguments); + const AttributeTrail& container_trail = + frame->value_stack().GetAttributeSpan(kNumContainerAccessArguments)[0]; + + PerformLookup(*frame, args[0], args[1], container_trail, + enable_optional_types_, result, result_trail); + frame->value_stack().PopAndPush(kNumContainerAccessArguments, + std::move(result), std::move(result_trail)); return absl::OkStatus(); } + +class DirectContainerAccessStep : public DirectExpressionStep { + public: + DirectContainerAccessStep( + std::unique_ptr container_step, + std::unique_ptr key_step, + bool enable_optional_types, int64_t expr_id) + : DirectExpressionStep(expr_id), + container_step_(std::move(container_step)), + key_step_(std::move(key_step)), + enable_optional_types_(enable_optional_types) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const override; + + private: + std::unique_ptr container_step_; + std::unique_ptr key_step_; + bool enable_optional_types_; +}; + +absl::Status DirectContainerAccessStep::Evaluate(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + Value container; + Value key; + AttributeTrail container_trail; + AttributeTrail key_trail; + + CEL_RETURN_IF_ERROR( + container_step_->Evaluate(frame, container, container_trail)); + CEL_RETURN_IF_ERROR(key_step_->Evaluate(frame, key, key_trail)); + + PerformLookup(frame, container, key, container_trail, enable_optional_types_, + result, trail); + + return absl::OkStatus(); +} + } // namespace +std::unique_ptr CreateDirectContainerAccessStep( + std::unique_ptr container_step, + std::unique_ptr key_step, bool enable_optional_types, + int64_t expr_id) { + return std::make_unique( + std::move(container_step), std::move(key_step), enable_optional_types, + expr_id); +} + // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( - const google::api::expr::v1alpha1::Expr::Call*, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(expr_id); - return std::move(step); + const cel::CallExpr& call, int64_t expr_id, bool enable_optional_types) { + int arg_count = call.args().size() + (call.has_target() ? 1 : 0); + if (arg_count != kNumContainerAccessArguments) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid argument count for index operation: ", arg_count)); + } + return std::make_unique(expr_id, enable_optional_types); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/container_access_step.h b/eval/eval/container_access_step.h index 72a285e5c..b7af5e895 100644 --- a/eval/eval/container_access_step.h +++ b/eval/eval/container_access_step.h @@ -1,22 +1,26 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ +#include +#include + +#include "absl/status/statusor.h" +#include "common/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" -#include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { + +std::unique_ptr CreateDirectContainerAccessStep( + std::unique_ptr container_step, + std::unique_ptr key_step, bool enable_optional_types, + int64_t expr_id); // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( - const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id); + const cel::CallExpr& call, int64_t expr_id, + bool enable_optional_types = false); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 6f73510da..25bf72223 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -1,104 +1,158 @@ #include "eval/eval/container_access_step.h" +#include #include +#include #include #include +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "base/builtins.h" +#include "base/type_provider.h" +#include "common/ast.h" +#include "common/expr.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" +#include "eval/public/activation.h" #include "eval/public/cel_attribute.h" -#include "eval/public/cel_builtins.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "base/status_macros.h" +#include "eval/public/testing/matchers.h" +#include "eval/public/unknown_set.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { +using ::absl_testing::StatusIs; +using ::cel::Expr; +using ::cel::SourceInfo; +using ::cel::TypeProvider; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::cel::expr::ParsedExpr; using ::google::protobuf::Struct; +using ::testing::_; +using ::testing::AllOf; +using ::testing::HasSubstr; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using TestParamType = std::tuple; -using TestParamType = std::tuple; - -// Helper method. Looks up in registry and tests comparison operation. CelValue EvaluateAttributeHelper( - google::protobuf::Arena* arena, CelValue container, CelValue key, bool receiver_style, - bool enable_unknown, const std::vector& patterns) { + const absl_nonnull std::shared_ptr& env, + google::protobuf::Arena* arena, CelValue container, CelValue key, + bool use_recursive_impl, bool receiver_style, bool enable_unknown, + const std::vector& patterns) { ExecutionPath path; Expr expr; SourceInfo source_info; - auto call = expr.mutable_call_expr(); - - call->set_function(builtin::kIndex); - - Expr* container_expr = - (receiver_style) ? call->mutable_target() : call->add_args(); - Expr* key_expr = call->add_args(); - - container_expr->mutable_ident_expr()->set_name("container"); - key_expr->mutable_ident_expr()->set_name("key"); - - path.push_back( - std::move(CreateIdentStep(&container_expr->ident_expr(), 1).value())); - path.push_back( - std::move(CreateIdentStep(&key_expr->ident_expr(), 2).value())); - path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); + auto& call = expr.mutable_call_expr(); + + call.set_function(cel::builtin::kIndex); + + call.mutable_args().reserve(2); + Expr& container_expr = (receiver_style) ? call.mutable_target() + : call.mutable_args().emplace_back(); + Expr& key_expr = call.mutable_args().emplace_back(); + + container_expr.mutable_ident_expr().set_name("container"); + key_expr.mutable_ident_expr().set_name("key"); + + if (use_recursive_impl) { + path.push_back(std::make_unique( + CreateDirectContainerAccessStep(CreateDirectIdentStep("container", 1), + CreateDirectIdentStep("key", 2), + /*enable_optional_types=*/false, 3), + 3)); + } else { + path.push_back(std::move(CreateIdentStep("container", 1).value())); + path.push_back(std::move(CreateIdentStep("key", 2).value())); + path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); + } - CelExpressionFlatImpl cel_expr(&expr, std::move(path), 0, {}, enable_unknown); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + options.enable_heterogeneous_equality = false; + CelExpressionFlatImpl cel_expr( + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("container", container); activation.InsertValue("key", key); activation.set_unknown_attribute_patterns(patterns); - auto eval_status = cel_expr.Evaluate(activation, arena); - - EXPECT_OK(eval_status); - return eval_status.value(); + auto result = cel_expr.Evaluate(activation, arena); + return *result; } class ContainerAccessStepTest : public ::testing::Test { protected: - ContainerAccessStepTest() {} + ContainerAccessStepTest() = default; - void SetUp() override {} + void SetUp() override { env_ = NewTestingRuntimeEnv(); } CelValue EvaluateAttribute( CelValue container, CelValue key, bool receiver_style, - bool enable_unknown, + bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { - return EvaluateAttributeHelper(&arena_, container, key, receiver_style, - enable_unknown, patterns); + return EvaluateAttributeHelper(env_, &arena_, container, key, + receiver_style, enable_unknown, + use_recursive_impl, patterns); } + absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; class ContainerAccessStepUniformityTest : public ::testing::TestWithParam { protected: - ContainerAccessStepUniformityTest() {} + ContainerAccessStepUniformityTest() = default; + + void SetUp() override { env_ = NewTestingRuntimeEnv(); } + + bool receiver_style() { + TestParamType params = GetParam(); + return std::get<0>(params); + } - void SetUp() override {} + bool enable_unknown() { + TestParamType params = GetParam(); + return std::get<1>(params); + } + + bool use_recursive_impl() { + TestParamType params = GetParam(); + return std::get<2>(params); + } // Helper method. Looks up in registry and tests comparison operation. CelValue EvaluateAttribute( CelValue container, CelValue key, bool receiver_style, - bool enable_unknown, + bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { - return EvaluateAttributeHelper(&arena_, container, key, receiver_style, - enable_unknown, patterns); + return EvaluateAttributeHelper(env_, &arena_, container, key, + receiver_style, enable_unknown, + use_recursive_impl, patterns); } + absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; @@ -107,10 +161,9 @@ TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccess) { CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - TestParamType param = GetParam(); CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(1), - std::get<0>(param), std::get<1>(param)); + receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), 2); @@ -121,26 +174,24 @@ TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessOutOfBounds) { CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - TestParamType param = GetParam(); - CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(0), - std::get<0>(param), std::get<1>(param)); + receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsInt64()); result = EvaluateAttribute(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(2), std::get<0>(param), - std::get<1>(param)); + CelValue::CreateInt64(2), receiver_style(), + enable_unknown()); ASSERT_TRUE(result.IsInt64()); result = EvaluateAttribute(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(-1), std::get<0>(param), - std::get<1>(param)); + CelValue::CreateInt64(-1), receiver_style(), + enable_unknown()); ASSERT_TRUE(result.IsError()); result = EvaluateAttribute(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(3), std::get<0>(param), - std::get<1>(param)); + CelValue::CreateInt64(3), receiver_style(), + enable_unknown()); ASSERT_TRUE(result.IsError()); } @@ -150,18 +201,14 @@ TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessNotAnInt) { CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - TestParamType param = GetParam(); - CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateUint64(1), - std::get<0>(param), std::get<1>(param)); + receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsError()); } TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccess) { - TestParamType param = GetParam(); - const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; const std::string kKey2 = "testkey2"; @@ -172,15 +219,25 @@ TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccess) { CelValue result = EvaluateAttribute( CelProtoWrapper::CreateMessage(&cel_struct, &arena_), - CelValue::CreateString(&kKey0), std::get<0>(param), std::get<1>(param)); + CelValue::CreateString(&kKey0), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsString()); ASSERT_EQ(result.StringOrDie().value(), "value0"); } -TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { - TestParamType param = GetParam(); +TEST_P(ContainerAccessStepUniformityTest, TestBoolKeyType) { + CelMapBuilder cel_map; + ASSERT_OK(cel_map.Add(CelValue::CreateBool(true), + CelValue::CreateStringView("value_true"))); + + CelValue result = EvaluateAttribute(CelValue::CreateMap(&cel_map), + CelValue::CreateBool(true), + receiver_style(), enable_unknown()); + ASSERT_THAT(result, test::IsCelString("value_true")); +} + +TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; Struct cel_struct; @@ -188,9 +245,49 @@ TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { CelValue result = EvaluateAttribute( CelProtoWrapper::CreateMessage(&cel_struct, &arena_), - CelValue::CreateString(&kKey1), std::get<0>(param), std::get<1>(param)); + CelValue::CreateString(&kKey1), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), + StatusIs(absl::StatusCode::kNotFound, + AllOf(HasSubstr("Key not found in map : "), + HasSubstr("testkey1")))); +} + +TEST_F(ContainerAccessStepTest, TestInvalidReceiverCreateContainerAccessStep) { + Expr expr; + auto& call = expr.mutable_call_expr(); + call.set_function(cel::builtin::kIndex); + Expr& container_expr = call.mutable_target(); + container_expr.mutable_ident_expr().set_name("container"); + + call.mutable_args().reserve(2); + Expr& key_expr = call.mutable_args().emplace_back(); + key_expr.mutable_ident_expr().set_name("key"); + + Expr& extra_arg = call.mutable_args().emplace_back(); + extra_arg.mutable_const_expr().set_bool_value(true); + EXPECT_THAT(CreateContainerAccessStep(call, 0).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid argument count"))); +} + +TEST_F(ContainerAccessStepTest, TestInvalidGlobalCreateContainerAccessStep) { + Expr expr; + auto& call = expr.mutable_call_expr(); + call.set_function(cel::builtin::kIndex); + call.mutable_args().reserve(3); + Expr& container_expr = call.mutable_args().emplace_back(); + container_expr.mutable_ident_expr().set_name("container"); + + Expr& key_expr = call.mutable_args().emplace_back(); + key_expr.mutable_ident_expr().set_name("key"); + + Expr& extra_arg = call.mutable_args().emplace_back(); + extra_arg.mutable_const_expr().set_bool_value(true); + EXPECT_THAT(CreateContainerAccessStep(call, 0).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid argument count"))); } TEST_F(ContainerAccessStepTest, TestListIndexAccessUnknown) { @@ -206,10 +303,11 @@ TEST_F(ContainerAccessStepTest, TestListIndexAccessUnknown) { std::vector patterns = {CelAttributePattern( "container", - {CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1))})}; + {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1))})}; - result = EvaluateAttribute(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(1), true, true, patterns); + result = + EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), true, true, false, patterns); ASSERT_TRUE(result.IsUnknownSet()); } @@ -227,6 +325,25 @@ TEST_F(ContainerAccessStepTest, TestListUnknownKey) { ASSERT_TRUE(result.IsUnknownSet()); } +TEST_F(ContainerAccessStepTest, TestMapInvalidKey) { + const std::string kKey0 = "testkey0"; + const std::string kKey1 = "testkey1"; + const std::string kKey2 = "testkey2"; + Struct cel_struct; + (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); + (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); + (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); + + CelValue result = + EvaluateAttribute(CelProtoWrapper::CreateMessage(&cel_struct, &arena_), + CelValue::CreateDouble(1.0), true, true); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid map key type: 'double'"))); +} + TEST_F(ContainerAccessStepTest, TestMapUnknownKey) { const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; @@ -252,13 +369,280 @@ TEST_F(ContainerAccessStepTest, TestUnknownContainer) { ASSERT_TRUE(result.IsUnknownSet()); } -INSTANTIATE_TEST_SUITE_P(CombinedContainerTest, - ContainerAccessStepUniformityTest, - testing::Combine(testing::Bool(), testing::Bool())); +TEST_F(ContainerAccessStepTest, TestInvalidContainerType) { + CelValue result = EvaluateAttribute(CelValue::CreateInt64(1), + CelValue::CreateInt64(0), true, true); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid container type: 'int"))); +} + +INSTANTIATE_TEST_SUITE_P( + CombinedContainerTest, ContainerAccessStepUniformityTest, + testing::Combine(/*receiver_style*/ testing::Bool(), + /*unknown_enabled*/ testing::Bool(), + /*use_recursive_impl*/ testing::Bool())); + +class ContainerAccessHeterogeneousLookupsTest : public testing::Test { + public: + ContainerAccessHeterogeneousLookupsTest() { + options_.enable_heterogeneous_equality = true; + builder_ = CreateCelExpressionBuilder(options_); + } + + protected: + InterpreterOptions options_; + std::unique_ptr builder_; + google::protobuf::Arena arena_; + Activation activation_; +}; + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.0]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleListIndexNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.1]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +// treat uint as uint before trying coercion to signed int. +TEST_F(ContainerAccessHeterogeneousLookupsTest, UintKeyAsUint) { + // TODO(uncreated-issue/4): Map creation should error here instead of permitting + // mixed key types with equivalent values. + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, UintKeyAsInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, IntKeyAsUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, UintListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][2u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(3)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, StringKeyUnaffected) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2, '1': 3}['1']")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(3)); +} + +class ContainerAccessHeterogeneousLookupsDisabledTest : public testing::Test { + public: + ContainerAccessHeterogeneousLookupsDisabledTest() { + options_.enable_heterogeneous_equality = false; + builder_ = CreateCelExpressionBuilder(options_); + } + + protected: + InterpreterOptions options_; + std::unique_ptr builder_; + google::protobuf::Arena arena_; + Activation activation_; +}; + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.0]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, + DoubleListIndexNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.1]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintKeyAsUint) { + // TODO(uncreated-issue/4): Map creation should error here instead of permitting + // mixed key types with equivalent values. + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintKeyAsInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, IntKeyAsUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][2u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, StringKeyUnaffected) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2, '1': 3}['1']")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(3)); +} } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index f9da18357..bb977ce94 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -1,24 +1,53 @@ #include "eval/eval/create_list_step.h" +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "eval/public/containers/container_backed_list_impl.h" +#include "absl/types/optional.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/value.h" +#include "common/values/list_value_builder.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/attribute_utility.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::ListValueBuilderPtr; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::common_internal::NewListValueBuilder; + class CreateListStep : public ExpressionStepBase { public: - CreateListStep(int64_t expr_id, int list_size) - : ExpressionStepBase(expr_id), list_size_(list_size) {} + CreateListStep(int64_t expr_id, int list_size, + absl::flat_hash_set optional_indices) + : ExpressionStepBase(expr_id), + list_size_(list_size), + optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: + absl::Status DoEvaluate(ExecutionFrame* frame, Value* result) const; + int list_size_; + absl::flat_hash_set optional_indices_; }; absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { @@ -32,45 +61,223 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { "CreateListStep: stack underflow"); } + Value result; + CEL_RETURN_IF_ERROR(DoEvaluate(frame, &result)); + + frame->value_stack().PopAndPush(list_size_, std::move(result)); + return absl::OkStatus(); +} + +absl::Status CreateListStep::DoEvaluate(ExecutionFrame* frame, + Value* result) const { auto args = frame->value_stack().GetSpan(list_size_); - CelValue result; + for (const auto& arg : args) { + if (arg.IsError()) { + *result = arg; + return absl::OkStatus(); + } + } - const UnknownSet* unknown_set = nullptr; if (frame->enable_unknowns()) { - unknown_set = frame->attribute_utility().MergeUnknowns( - args, frame->value_stack().GetAttributeSpan(list_size_), - /*initial_set=*/nullptr, - /*use_partial=*/true); - if (unknown_set != nullptr) { - result = CelValue::CreateUnknownSet(unknown_set); + absl::optional unknown_set = + frame->attribute_utility().IdentifyAndMergeUnknowns( + args, frame->value_stack().GetAttributeSpan(list_size_), + /*use_partial=*/true); + if (unknown_set.has_value()) { + *result = std::move(*unknown_set); + return absl::OkStatus(); } } - if (unknown_set == nullptr) { - CelList* cel_list = google::protobuf::Arena::Create( - frame->arena(), std::vector(args.begin(), args.end())); - result = CelValue::CreateList(cel_list); + ListValueBuilderPtr builder = NewListValueBuilder(frame->arena()); + builder->Reserve(args.size()); + + for (size_t i = 0; i < args.size(); ++i) { + const auto& arg = args[i]; + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_arg = arg.AsOptional(); optional_arg) { + if (!optional_arg->HasValue()) { + continue; + } + Value optional_arg_value; + optional_arg->Value(&optional_arg_value); + if (optional_arg_value.IsError()) { + // Error should never be in optional, but better safe than sorry. + *result = std::move(optional_arg_value); + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(builder->Add(std::move(optional_arg_value))); + } else { + *result = cel::TypeConversionError(arg.GetTypeName(), "optional_type"); + return absl::OkStatus(); + } + } else { + CEL_RETURN_IF_ERROR(builder->Add(arg)); + } } - frame->value_stack().Pop(list_size_); - frame->value_stack().Push(result); + *result = std::move(*builder).Build(); + return absl::OkStatus(); +} +absl::flat_hash_set MakeOptionalIndicesSet( + const cel::ListExpr& create_list_expr) { + absl::flat_hash_set optional_indices; + for (size_t i = 0; i < create_list_expr.elements().size(); ++i) { + if (create_list_expr.elements()[i].optional()) { + optional_indices.insert(static_cast(i)); + } + } + return optional_indices; +} + +class CreateListDirectStep : public DirectExpressionStep { + public: + CreateListDirectStep( + std::vector> elements, + absl::flat_hash_set optional_indices, int64_t expr_id) + : DirectExpressionStep(expr_id), + elements_(std::move(elements)), + optional_indices_(std::move(optional_indices)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override { + ListValueBuilderPtr builder = NewListValueBuilder(frame.arena()); + builder->Reserve(elements_.size()); + + AttributeUtility::Accumulator unknowns = + frame.attribute_utility().CreateAccumulator(); + AttributeTrail tmp_attr; + + for (size_t i = 0; i < elements_.size(); ++i) { + const auto& element = elements_[i]; + CEL_RETURN_IF_ERROR(element->Evaluate(frame, result, tmp_attr)); + + if (result.IsError()) { + return absl::OkStatus(); + } + + if (frame.attribute_tracking_enabled()) { + if (frame.missing_attribute_errors_enabled()) { + if (frame.attribute_utility().CheckForMissingAttribute(tmp_attr)) { + CEL_ASSIGN_OR_RETURN( + result, frame.attribute_utility().CreateMissingAttributeError( + tmp_attr.attribute())); + return absl::OkStatus(); + } + } + if (frame.unknown_processing_enabled()) { + if (result.IsUnknown()) { + unknowns.Add(result.GetUnknown()); + } + if (frame.attribute_utility().CheckForUnknown(tmp_attr, + /*use_partial=*/true)) { + unknowns.Add(tmp_attr); + } + } + } + + if (!unknowns.IsEmpty()) { + // We found an unknown, there is no point in attempting to create a + // list. Instead iterate through the remaining elements and look for + // more unknowns. + continue; + } + + // Conditionally add if optional. + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_arg = result.AsOptional(); optional_arg) { + if (!optional_arg->HasValue()) { + continue; + } + Value optional_arg_value; + optional_arg->Value(&optional_arg_value); + if (optional_arg_value.IsError()) { + // Error should never be in optional, but better safe than sorry. + result = std::move(optional_arg_value); + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(builder->Add(std::move(optional_arg_value))); + continue; + } + result = + cel::TypeConversionError(result.GetTypeName(), "optional_type"); + return absl::OkStatus(); + } + + // Otherwise just add. + CEL_RETURN_IF_ERROR(builder->Add(std::move(result))); + } + + if (!unknowns.IsEmpty()) { + result = std::move(unknowns).Build(); + return absl::OkStatus(); + } + result = std::move(*builder).Build(); + + return absl::OkStatus(); + } + + private: + std::vector> elements_; + absl::flat_hash_set optional_indices_; +}; + +class MutableListStep : public ExpressionStepBase { + public: + explicit MutableListStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; +}; + +absl::Status MutableListStep::Evaluate(ExecutionFrame* frame) const { + frame->value_stack().Push(cel::CustomListValue( + cel::common_internal::NewMutableListValue(frame->arena()), + frame->arena())); + return absl::OkStatus(); +} + +class DirectMutableListStep : public DirectExpressionStep { + public: + explicit DirectMutableListStep(int64_t expr_id) + : DirectExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; +}; + +absl::Status DirectMutableListStep::Evaluate( + ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + result = cel::CustomListValue( + cel::common_internal::NewMutableListValue(frame.arena()), frame.arena()); return absl::OkStatus(); } } // namespace -// Factory method for CreateList - based Execution step +std::unique_ptr CreateDirectListStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id) { + return std::make_unique( + std::move(deps), std::move(optional_indices), expr_id); +} + absl::StatusOr> CreateCreateListStep( - const google::api::expr::v1alpha1::Expr::CreateList* create_list_expr, + const cel::ListExpr& create_list_expr, int64_t expr_id) { + return std::make_unique( + expr_id, create_list_expr.elements().size(), + MakeOptionalIndicesSet(create_list_expr)); +} + +std::unique_ptr CreateMutableListStep(int64_t expr_id) { + return std::make_unique(expr_id); +} + +std::unique_ptr CreateDirectMutableListStep( int64_t expr_id) { - std::unique_ptr step = absl::make_unique( - expr_id, create_list_expr->elements_size()); - return std::move(step); + return std::make_unique(expr_id); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/create_list_step.h b/eval/eval/create_list_step.h index 11bd38eb3..b60a5e9c8 100644 --- a/eval/eval/create_list_step.h +++ b/eval/eval/create_list_step.h @@ -1,23 +1,40 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_LIST_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_LIST_STEP_H_ +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "common/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_step_base.h" -#include "absl/types/span.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { + +// Factory method for CreateList that evaluates recursively. +std::unique_ptr CreateDirectListStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id); -// Factory method for CreateList - based Execution step +// Factory method for CreateList which constructs an immutable list. absl::StatusOr> CreateCreateListStep( - const google::api::expr::v1alpha1::Expr::CreateList* create_list_expr, + const cel::ListExpr& create_list_expr, int64_t expr_id); + +// Factory method for CreateList which constructs a mutable list. +// +// This is intended for the list construction step is generated for a +// list-building comprehension (rather than a user authored expression). +std::unique_ptr CreateMutableListStep(int64_t expr_id); + +// Factory method for CreateList which constructs a mutable list. +// +// This is intended for the list construction step is generated for a +// list-building comprehension (rather than a user authored expression). +std::unique_ptr CreateDirectMutableListStep( int64_t expr_id); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_LIST_STEP_H_ diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 7cdd569b9..990003823 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -1,57 +1,105 @@ #include "eval/eval/create_list_step.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "base/type_provider.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" +#include "eval/internal/interop.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" +#include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" -#include "base/status_macros.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { namespace { -using testing::Eq; -using testing::Not; - -using google::api::expr::v1alpha1::Expr; +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::Attribute; +using ::cel::AttributeQualifier; +using ::cel::AttributeSet; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::Expr; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::ListValue; +using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::cel::test::IntValueIs; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Not; +using ::testing::UnorderedElementsAre; // Helper method. Creates simple pipeline containing Select step and runs it. -absl::StatusOr RunExpression(const std::vector& values, - google::protobuf::Arena* arena, - bool enable_unknowns) { +absl::StatusOr RunExpression( + const absl_nonnull std::shared_ptr& env, + const std::vector& values, google::protobuf::Arena* arena, + bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; - auto create_list = dummy_expr.mutable_list_expr(); + auto& create_list = dummy_expr.mutable_list_expr(); for (auto value : values) { - auto expr0 = create_list->add_elements(); - expr0->mutable_const_expr()->set_int64_value(value); - auto const_step_status = CreateConstValueStep( - ConvertConstant(&expr0->const_expr()).value(), expr0->id()); - if (!const_step_status.ok()) { - return const_step_status.status(); - } - - path.push_back(std::move(const_step_status.value())); + auto& expr0 = create_list.mutable_elements().emplace_back().mutable_expr(); + expr0.mutable_const_expr().set_int64_value(value); + CEL_ASSIGN_OR_RETURN( + auto const_step, + CreateConstValueStep(cel::interop_internal::CreateIntValue(value), + /*expr_id=*/-1)); + path.push_back(std::move(const_step)); } - auto step0_status = CreateCreateListStep(create_list, dummy_expr.id()); - - if (!step0_status.ok()) { - return step0_status.status(); + CEL_ASSIGN_OR_RETURN(auto step, + CreateCreateListStep(create_list, dummy_expr.id())); + path.push_back(std::move(step)); + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } + CelExpressionFlatImpl cel_expr( + env, - path.push_back(std::move(step0_status.value())); - - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - enable_unknowns); + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; return cel_expr.Evaluate(activation, arena); @@ -59,142 +107,444 @@ absl::StatusOr RunExpression(const std::vector& values, // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpressionWithCelValues( + const absl_nonnull std::shared_ptr& env, const std::vector& values, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; Activation activation; - auto create_list = dummy_expr.mutable_list_expr(); + auto& create_list = dummy_expr.mutable_list_expr(); int ind = 0; for (auto value : values) { std::string var_name = absl::StrCat("name_", ind++); - auto expr0 = create_list->add_elements(); - expr0->set_id(ind); - expr0->mutable_ident_expr()->set_name(var_name); + auto& expr0 = create_list.mutable_elements().emplace_back().mutable_expr(); + expr0.set_id(ind); + expr0.mutable_ident_expr().set_name(var_name); - auto ident_step_status = CreateIdentStep(&expr0->ident_expr(), expr0->id()); - if (!ident_step_status.ok()) { - return ident_step_status.status(); - } - - path.push_back(std::move(ident_step_status.value())); + CEL_ASSIGN_OR_RETURN(auto ident_step, + CreateIdentStep(var_name, /*expr_id=*/-1)); + path.push_back(std::move(ident_step)); activation.InsertValue(var_name, value); } - auto step0_status = CreateCreateListStep(create_list, dummy_expr.id()); + CEL_ASSIGN_OR_RETURN(auto step0, + CreateCreateListStep(create_list, dummy_expr.id())); + path.push_back(std::move(step0)); - if (!step0_status.ok()) { - return step0_status.status(); + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - path.push_back(std::move(step0_status.value())); - - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - enable_unknowns); + CelExpressionFlatImpl cel_expr( + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); return cel_expr.Evaluate(activation, arena); } -class CreateListStepTest : public testing::TestWithParam {}; +class CreateListStepTest : public testing::TestWithParam { + public: + CreateListStepTest() : env_(NewTestingRuntimeEnv()) {} + + protected: + absl_nonnull std::shared_ptr env_; + google::protobuf::Arena arena_; +}; // Tests error when not enough list elements are on the stack during list // creation. -TEST(CreateListStepTest, TestCreateListStackUndeflow) { +TEST(CreateListStepTest, TestCreateListStackUnderflow) { ExecutionPath path; Expr dummy_expr; - auto create_list = dummy_expr.mutable_list_expr(); - auto expr0 = create_list->add_elements(); - expr0->mutable_const_expr()->set_int64_value(1); + auto& create_list = dummy_expr.mutable_list_expr(); + auto& expr0 = create_list.mutable_elements().emplace_back().mutable_expr(); + expr0.mutable_const_expr().set_int64_value(1); - auto step0_status = CreateCreateListStep(create_list, dummy_expr.id()); + ASSERT_OK_AND_ASSIGN(auto step0, + CreateCreateListStep(create_list, dummy_expr.id())); + path.push_back(std::move(step0)); - ASSERT_OK(step0_status); - - path.push_back(std::move(step0_status.value())); - - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl cel_expr( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; google::protobuf::Arena arena; auto status = cel_expr.Evaluate(activation, &arena); - ASSERT_FALSE(status.ok()); + ASSERT_THAT(status, Not(IsOk())); } TEST_P(CreateListStepTest, CreateListEmpty) { - google::protobuf::Arena arena; - auto eval_result = RunExpression({}, &arena, GetParam()); - - ASSERT_OK(eval_result); - const CelValue result_value = eval_result.value(); - ASSERT_TRUE(result_value.IsList()); - EXPECT_THAT(result_value.ListOrDie()->size(), Eq(0)); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(env_, {}, &arena_, GetParam())); + ASSERT_TRUE(result.IsList()); + EXPECT_THAT(result.ListOrDie()->size(), Eq(0)); } TEST_P(CreateListStepTest, CreateListOne) { - google::protobuf::Arena arena; - auto eval_result = RunExpression({100}, &arena, GetParam()); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(env_, {100}, &arena_, GetParam())); + ASSERT_TRUE(result.IsList()); + const auto& list = *result.ListOrDie(); + ASSERT_THAT(list.size(), Eq(1)); + const CelValue& value = list.Get(&arena_, 0); + EXPECT_THAT(value, test::IsCelInt64(100)); +} - ASSERT_OK(eval_result); - const CelValue result_value = eval_result.value(); - ASSERT_TRUE(result_value.IsList()); - EXPECT_THAT(result_value.ListOrDie()->size(), Eq(1)); - EXPECT_THAT((*result_value.ListOrDie())[0].Int64OrDie(), Eq(100)); +TEST_P(CreateListStepTest, CreateListWithError) { + std::vector values; + CelError error = absl::InvalidArgumentError("bad arg"); + values.push_back(CelValue::CreateError(&error)); + ASSERT_OK_AND_ASSIGN(CelValue result, RunExpressionWithCelValues( + env_, values, &arena_, GetParam())); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), Eq(absl::InvalidArgumentError("bad arg"))); +} + +TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { + // list composition is: {unknown, error} + std::vector values; + Expr expr0; + expr0.mutable_ident_expr().set_name("name0"); + CelAttribute attr0(expr0.ident_expr().name(), {}); + UnknownSet unknown_set0(UnknownAttributeSet({attr0})); + values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); + CelError error = absl::InvalidArgumentError("bad arg"); + values.push_back(CelValue::CreateError(&error)); + + ASSERT_OK_AND_ASSIGN(CelValue result, RunExpressionWithCelValues( + env_, values, &arena_, GetParam())); + + // The bad arg should win. + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), Eq(absl::InvalidArgumentError("bad arg"))); } TEST_P(CreateListStepTest, CreateListHundred) { - google::protobuf::Arena arena; std::vector values; for (size_t i = 0; i < 100; i++) { values.push_back(i); } - auto eval_result = RunExpression(values, &arena, GetParam()); - - ASSERT_OK(eval_result); - const CelValue result_value = eval_result.value(); - ASSERT_TRUE(result_value.IsList()); - EXPECT_THAT(result_value.ListOrDie()->size(), - Eq(static_cast(values.size()))); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(env_, values, &arena_, GetParam())); + ASSERT_TRUE(result.IsList()); + const auto& list = *result.ListOrDie(); + EXPECT_THAT(list.size(), Eq(static_cast(values.size()))); for (size_t i = 0; i < values.size(); i++) { - EXPECT_THAT((*result_value.ListOrDie())[i].Int64OrDie(), Eq(values[i])); + EXPECT_THAT(list.Get(&arena_, i), test::IsCelInt64(values[i])); } } +INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, + testing::Bool()); + TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { google::protobuf::Arena arena; std::vector values; Expr expr0; - expr0.mutable_ident_expr()->set_name("name0"); - CelAttribute attr0(expr0, {}); + expr0.mutable_ident_expr().set_name("name0"); + CelAttribute attr0(expr0.ident_expr().name(), {}); Expr expr1; - expr1.mutable_ident_expr()->set_name("name1"); - CelAttribute attr1(expr1, {}); - UnknownSet unknown_set0(UnknownAttributeSet({&attr0})); - UnknownSet unknown_set1(UnknownAttributeSet({&attr1})); + expr1.mutable_ident_expr().set_name("name1"); + CelAttribute attr1(expr1.ident_expr().name(), {}); + UnknownSet unknown_set0(UnknownAttributeSet({attr0})); + UnknownSet unknown_set1(UnknownAttributeSet({attr1})); for (size_t i = 0; i < 100; i++) { values.push_back(CelValue::CreateInt64(i)); } values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); values.push_back(CelValue::CreateUnknownSet(&unknown_set1)); - auto eval_result = RunExpressionWithCelValues(values, &arena, true); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpressionWithCelValues(NewTestingRuntimeEnv(), values, &arena, true)); + ASSERT_TRUE(result.IsUnknownSet()); + const UnknownSet* result_set = result.UnknownSetOrDie(); + EXPECT_THAT(result_set->unknown_attributes().size(), Eq(2)); +} + +TEST(CreateDirectListStep, Basic) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); + deps.push_back(CreateConstValueDirectStep(IntValue(2), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); - ASSERT_OK(eval_result); - const CelValue result_value = eval_result.value(); - ASSERT_TRUE(result_value.IsUnknownSet()); - const UnknownSet* result_set = result_value.UnknownSetOrDie(); - EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(2)); + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).Size(), IsOkAndHolds(2)); } -INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, - testing::Bool()); +TEST(CreateDirectListStep, ForwardFirstError) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep( + cel::ErrorValue(absl::InternalError("test1")), -1)); + deps.push_back(CreateConstValueDirectStep( + cel::ErrorValue(absl::InternalError("test2")), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, "test1")); +} + +std::vector UnknownAttrNames(const UnknownValue& v) { + std::vector names; + names.reserve(v.attribute_set().size()); + + for (const auto& attr : v.attribute_set()) { + EXPECT_OK(attr.AsString().status()); + names.push_back(attr.AsString().value_or("")); + } + return names; +} + +TEST(CreateDirectListStep, MergeUnknowns) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + AttributeSet attr_set1({Attribute("var1")}); + AttributeSet attr_set2({Attribute("var2")}); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep( + cel::UnknownValue(cel::Unknown(std::move(attr_set1))), -1)); + deps.push_back(CreateConstValueDirectStep( + cel::UnknownValue(cel::Unknown(std::move(attr_set2))), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(UnknownAttrNames(Cast(result)), + UnorderedElementsAre("var1", "var2")); +} + +TEST(CreateDirectListStep, ErrorBeforeUnknown) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + AttributeSet attr_set1({Attribute("var1")}); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep( + cel::ErrorValue(absl::InternalError("test1")), -1)); + deps.push_back(CreateConstValueDirectStep( + cel::ErrorValue(absl::InternalError("test2")), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, "test1")); +} + +class SetAttrDirectStep : public DirectExpressionStep { + public: + explicit SetAttrDirectStep(Attribute attr) + : DirectExpressionStep(-1), attr_(std::move(attr)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attr) const override { + result = cel::NullValue(); + attr = AttributeTrail(attr_); + return absl::OkStatus(); + } + + private: + cel::Attribute attr_; +}; + +TEST(CreateDirectListStep, MissingAttribute) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.enable_missing_attribute_errors = true; + + activation.SetMissingPatterns({cel::AttributePattern( + "var1", {cel::AttributeQualifierPattern::OfString("field1")})}); + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(cel::NullValue(), -1)); + deps.push_back(std::make_unique( + Attribute("var1", {AttributeQualifier::OfString("field1")}))); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT( + Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("var1.field1"))); +} + +TEST(CreateDirectListStep, OptionalPresentSet) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); + deps.push_back(CreateConstValueDirectStep( + cel::OptionalValue::Of(IntValue(2), &arena), -1)); + auto step = CreateDirectListStep(std::move(deps), {1}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + auto list = Cast(result); + EXPECT_THAT(list.Size(), IsOkAndHolds(2)); + EXPECT_THAT(list.Get(0, cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena), + IsOkAndHolds(IntValueIs(1))); + EXPECT_THAT(list.Get(1, cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena), + IsOkAndHolds(IntValueIs(2))); +} + +TEST(CreateDirectListStep, OptionalAbsentNotSet) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); + deps.push_back(CreateConstValueDirectStep(cel::OptionalValue::None(), -1)); + auto step = CreateDirectListStep(std::move(deps), {1}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + auto list = Cast(result); + EXPECT_THAT(list.Size(), IsOkAndHolds(1)); + EXPECT_THAT(list.Get(0, cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena), + IsOkAndHolds(IntValueIs(1))); +} + +TEST(CreateDirectListStep, PartialUnknown) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + activation.SetUnknownPatterns({cel::AttributePattern( + "var1", {cel::AttributeQualifierPattern::OfString("field1")})}); + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(cel::IntValue(1), -1)); + deps.push_back(std::make_unique(Attribute("var1", {}))); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(UnknownAttrNames(Cast(result)), + UnorderedElementsAre("var1")); +} } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/create_map_step.cc b/eval/eval/create_map_step.cc new file mode 100644 index 000000000..451181e75 --- /dev/null +++ b/eval/eval/create_map_step.cc @@ -0,0 +1,289 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/create_map_step.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/values/map_value_builder.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::ErrorValueAssign; +using ::cel::ErrorValueReturn; +using ::cel::InstanceOf; +using ::cel::MapValueBuilderPtr; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::common_internal::NewMapValueBuilder; +using ::cel::common_internal::NewMutableMapValue; + +// `CreateStruct` implementation for map. +class CreateStructStepForMap final : public ExpressionStepBase { + public: + CreateStructStepForMap(int64_t expr_id, size_t entry_count, + absl::flat_hash_set optional_indices) + : ExpressionStepBase(expr_id), + entry_count_(entry_count), + optional_indices_(std::move(optional_indices)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + absl::StatusOr DoEvaluate(ExecutionFrame* frame) const; + + size_t entry_count_; + absl::flat_hash_set optional_indices_; +}; + +absl::StatusOr CreateStructStepForMap::DoEvaluate( + ExecutionFrame* frame) const { + auto args = frame->value_stack().GetSpan(2 * entry_count_); + + for (const auto& arg : args) { + if (arg.IsError()) { + return arg; + } + } + + if (frame->enable_unknowns()) { + absl::optional unknown_set = + frame->attribute_utility().IdentifyAndMergeUnknowns( + args, frame->value_stack().GetAttributeSpan(args.size()), true); + if (unknown_set.has_value()) { + return *unknown_set; + } + } + + MapValueBuilderPtr builder = NewMapValueBuilder(frame->arena()); + builder->Reserve(entry_count_); + + for (size_t i = 0; i < entry_count_; i += 1) { + const auto& map_key = args[2 * i]; + CEL_RETURN_IF_ERROR(cel::CheckMapKey(map_key)).With(ErrorValueReturn()); + const auto& map_value = args[(2 * i) + 1]; + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_map_value = map_value.AsOptional(); + optional_map_value) { + if (!optional_map_value->HasValue()) { + continue; + } + Value optional_map_value_value; + optional_map_value->Value(&optional_map_value_value); + if (optional_map_value_value.IsError()) { + // Error should never be in optional, but better safe than sorry. + return optional_map_value_value; + } + CEL_RETURN_IF_ERROR( + builder->Put(map_key, std::move(optional_map_value_value))); + } else { + return cel::TypeConversionError(map_value.DebugString(), + "optional_type"); + } + } else { + CEL_RETURN_IF_ERROR(builder->Put(map_key, map_value)); + } + } + + return std::move(*builder).Build(); +} + +absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { + if (frame->value_stack().size() < 2 * entry_count_) { + return absl::InternalError("CreateStructStepForMap: stack underflow"); + } + + CEL_ASSIGN_OR_RETURN(auto result, DoEvaluate(frame)); + + frame->value_stack().PopAndPush(2 * entry_count_, std::move(result)); + + return absl::OkStatus(); +} + +class DirectCreateMapStep : public DirectExpressionStep { + public: + DirectCreateMapStep(std::vector> deps, + absl::flat_hash_set optional_indices, + int64_t expr_id) + : DirectExpressionStep(expr_id), + deps_(std::move(deps)), + optional_indices_(std::move(optional_indices)), + entry_count_(deps_.size() / 2) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::vector> deps_; + absl::flat_hash_set optional_indices_; + size_t entry_count_; +}; + +absl::Status DirectCreateMapStep::Evaluate( + ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + auto unknowns = frame.attribute_utility().CreateAccumulator(); + + MapValueBuilderPtr builder = NewMapValueBuilder(frame.arena()); + builder->Reserve(entry_count_); + + for (size_t i = 0; i < entry_count_; i += 1) { + Value key; + Value value; + AttributeTrail tmp_attr; + int map_key_index = 2 * i; + int map_value_index = map_key_index + 1; + CEL_RETURN_IF_ERROR(deps_[map_key_index]->Evaluate(frame, key, tmp_attr)); + + if (key.IsError()) { + result = std::move(key); + return absl::OkStatus(); + } + + if (frame.unknown_processing_enabled()) { + if (key.IsUnknown()) { + unknowns.Add(key.GetUnknown()); + } else if (frame.attribute_utility().CheckForUnknownPartial(tmp_attr)) { + unknowns.Add(tmp_attr); + } + } + + CEL_RETURN_IF_ERROR(cel::CheckMapKey(key)).With(ErrorValueAssign(result)); + + CEL_RETURN_IF_ERROR( + deps_[map_value_index]->Evaluate(frame, value, tmp_attr)); + + if (value.IsError()) { + result = std::move(value); + return absl::OkStatus(); + } + + if (frame.unknown_processing_enabled()) { + if (value.IsUnknown()) { + unknowns.Add(value.GetUnknown()); + } else if (frame.attribute_utility().CheckForUnknownPartial(tmp_attr)) { + unknowns.Add(tmp_attr); + } + } + + // Preserve the stack machine behavior of forwarding unknowns before + // errors. + if (!unknowns.IsEmpty()) { + continue; + } + + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_map_value = value.AsOptional(); optional_map_value) { + if (!optional_map_value->HasValue()) { + continue; + } + Value optional_map_value_value; + optional_map_value->Value(&optional_map_value_value); + if (optional_map_value_value.IsError()) { + // Error should never be in optional, but better safe than sorry. + result = optional_map_value_value; + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( + builder->Put(std::move(key), std::move(optional_map_value_value))); + continue; + } + result = cel::TypeConversionError(value.DebugString(), "optional_type"); + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(builder->Put(std::move(key), std::move(value))); + } + + if (!unknowns.IsEmpty()) { + result = std::move(unknowns).Build(); + return absl::OkStatus(); + } + + result = std::move(*builder).Build(); + return absl::OkStatus(); +} + +class MutableMapStep final : public ExpressionStep { + public: + explicit MutableMapStep(int64_t expr_id) : ExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + frame->value_stack().Push(cel::CustomMapValue( + NewMutableMapValue(frame->arena()), frame->arena())); + return absl::OkStatus(); + } +}; + +class DirectMutableMapStep final : public DirectExpressionStep { + public: + explicit DirectMutableMapStep(int64_t expr_id) + : DirectExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + result = + cel::CustomMapValue(NewMutableMapValue(frame.arena()), frame.arena()); + return absl::OkStatus(); + } +}; + +} // namespace + +std::unique_ptr CreateDirectCreateMapStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id) { + return std::make_unique( + std::move(deps), std::move(optional_indices), expr_id); +} + +absl::StatusOr> CreateCreateStructStepForMap( + size_t entry_count, absl::flat_hash_set optional_indices, + int64_t expr_id) { + // Make map-creating step. + return std::make_unique(expr_id, entry_count, + std::move(optional_indices)); +} + +absl::StatusOr> CreateMutableMapStep( + int64_t expr_id) { + return std::make_unique(expr_id); +} + +std::unique_ptr CreateDirectMutableMapStep( + int64_t expr_id) { + return std::make_unique(expr_id); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/create_map_step.h b/eval/eval/create_map_step.h new file mode 100644 index 000000000..cf5e94644 --- /dev/null +++ b/eval/eval/create_map_step.h @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +// Creates an expression step that evaluates a create map expression. +// +// Deps must have an even number of elements, that alternate key, value pairs. +// (key1, value1, key2, value2...). +std::unique_ptr CreateDirectCreateMapStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id); + +// Creates an `ExpressionStep` which performs `CreateStruct` for a map. +absl::StatusOr> CreateCreateStructStepForMap( + size_t entry_count, absl::flat_hash_set optional_indices, + int64_t expr_id); + +// Factory method for CreateMap which constructs a mutable map. +// +// This is intended for the map construction step is generated for a +// map-building comprehension (rather than a user authored expression). +absl::StatusOr> CreateMutableMapStep( + int64_t expr_id); + +// Factory method for CreateMap which constructs a mutable map. +// +// This is intended for the map construction step is generated for a +// map-building comprehension (rather than a user authored expression). +std::unique_ptr CreateDirectMutableMapStep( + int64_t expr_id); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ diff --git a/eval/eval/create_map_step_test.cc b/eval/eval/create_map_step_test.cc new file mode 100644 index 000000000..dbc9adb5a --- /dev/null +++ b/eval/eval/create_map_step_test.cc @@ -0,0 +1,283 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/create_map_step.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "base/type_provider.h" +#include "common/expr.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/ident_step.h" +#include "eval/public/activation.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_set.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::absl_testing::StatusIs; +using ::cel::Expr; +using ::cel::TypeProvider; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::google::protobuf::Arena; + +absl::StatusOr CreateStackMachineProgram( + const std::vector>& values, + Activation& activation) { + ExecutionPath path; + + Expr expr1; + Expr expr0; + + std::vector exprs; + exprs.reserve(values.size() * 2); + int index = 0; + + auto& create_struct = expr1.mutable_struct_expr(); + for (const auto& item : values) { + std::string key_name = absl::StrCat("key", index); + std::string value_name = absl::StrCat("value", index); + + CEL_ASSIGN_OR_RETURN(auto step_key, + CreateIdentStep(key_name, /*expr_id=*/-1)); + + CEL_ASSIGN_OR_RETURN(auto step_value, + CreateIdentStep(value_name, /*expr _id=*/-1)); + + path.push_back(std::move(step_key)); + path.push_back(std::move(step_value)); + + activation.InsertValue(key_name, item.first); + activation.InsertValue(value_name, item.second); + + create_struct.mutable_fields().emplace_back(); + index++; + } + + CEL_ASSIGN_OR_RETURN( + auto step1, CreateCreateStructStepForMap(values.size(), {}, expr1.id())); + path.push_back(std::move(step1)); + return path; +} + +absl::StatusOr CreateRecursiveProgram( + const std::vector>& values, + Activation& activation) { + ExecutionPath path; + + int index = 0; + std::vector> deps; + for (const auto& item : values) { + std::string key_name = absl::StrCat("key", index); + std::string value_name = absl::StrCat("value", index); + + deps.push_back(CreateDirectIdentStep(key_name, -1)); + + deps.push_back(CreateDirectIdentStep(value_name, -1)); + + activation.InsertValue(key_name, item.first); + activation.InsertValue(value_name, item.second); + + index++; + } + path.push_back(std::make_unique( + CreateDirectCreateMapStep(std::move(deps), {}, -1), -1)); + + return path; +} + +// Helper method. Creates simple pipeline containing CreateStruct step that +// builds Map and runs it. +// Equivalent to {key0: value0, ...} +absl::StatusOr RunCreateMapExpression( + const absl_nonnull std::shared_ptr& env, + const std::vector>& values, + google::protobuf::Arena* arena, bool enable_unknowns, bool enable_recursive_program) { + Activation activation; + + ExecutionPath path; + if (enable_recursive_program) { + CEL_ASSIGN_OR_RETURN(path, CreateRecursiveProgram(values, activation)); + } else { + CEL_ASSIGN_OR_RETURN(path, CreateStackMachineProgram(values, activation)); + } + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + + CelExpressionFlatImpl cel_expr( + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); + return cel_expr.Evaluate(activation, arena); +} + +class CreateMapStepTest + : public testing::TestWithParam> { + public: + CreateMapStepTest() : env_(NewTestingRuntimeEnv()) {} + + bool enable_unknowns() { return std::get<0>(GetParam()); } + bool enable_recursive_program() { return std::get<1>(GetParam()); } + + absl::StatusOr RunMapExpression( + const std::vector>& values) { + return RunCreateMapExpression(env_, values, &arena_, enable_unknowns(), + enable_recursive_program()); + } + + protected: + absl_nonnull std::shared_ptr env_; + google::protobuf::Arena arena_; +}; + +// Test that Empty Map is created successfully. +TEST_P(CreateMapStepTest, TestCreateEmptyMap) { + ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression({})); + ASSERT_TRUE(result.IsMap()); + + const CelMap* cel_map = result.MapOrDie(); + ASSERT_EQ(cel_map->size(), 0); +} + +// Test message creation if unknown argument is passed +TEST(CreateMapStepTest, TestMapCreateWithUnknown) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + Arena arena; + UnknownSet unknown_set; + std::vector> entries; + + std::vector kKeys = {"test2", "test1"}; + + entries.push_back( + {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); + entries.push_back({CelValue::CreateString(&kKeys[1]), + CelValue::CreateUnknownSet(&unknown_set)}); + + ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( + env, entries, &arena, true, false)); + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST(CreateMapStepTest, TestMapCreateWithError) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + Arena arena; + UnknownSet unknown_set; + absl::Status error = absl::CancelledError(); + std::vector> entries; + entries.push_back({CelValue::CreateStringView("foo"), + CelValue::CreateUnknownSet(&unknown_set)}); + entries.push_back( + {CelValue::CreateStringView("bar"), CelValue::CreateError(&error)}); + + ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( + env, entries, &arena, true, false)); + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kCancelled)); +} + +TEST(CreateMapStepTest, TestMapCreateWithErrorRecursiveProgram) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + Arena arena; + UnknownSet unknown_set; + absl::Status error = absl::CancelledError(); + std::vector> entries; + entries.push_back({CelValue::CreateStringView("foo"), + CelValue::CreateUnknownSet(&unknown_set)}); + entries.push_back( + {CelValue::CreateStringView("bar"), CelValue::CreateError(&error)}); + + ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( + env, entries, &arena, true, true)); + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kCancelled)); +} + +TEST(CreateMapStepTest, TestMapCreateWithUnknownRecursiveProgram) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + Arena arena; + UnknownSet unknown_set; + std::vector> entries; + + std::vector kKeys = {"test2", "test1"}; + + entries.push_back( + {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); + entries.push_back({CelValue::CreateString(&kKeys[1]), + CelValue::CreateUnknownSet(&unknown_set)}); + + ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( + env, entries, &arena, true, true)); + ASSERT_TRUE(result.IsUnknownSet()); +} + +// Test that String Map is created successfully. +TEST_P(CreateMapStepTest, TestCreateStringMap) { + Arena arena; + + std::vector> entries; + + std::vector kKeys = {"test2", "test1"}; + + entries.push_back( + {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); + entries.push_back( + {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); + + ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression(entries)); + ASSERT_TRUE(result.IsMap()); + + const CelMap* cel_map = result.MapOrDie(); + ASSERT_EQ(cel_map->size(), 2); + + auto lookup0 = cel_map->Get(&arena, CelValue::CreateString(&kKeys[0])); + ASSERT_TRUE(lookup0.has_value()); + ASSERT_TRUE(lookup0->IsInt64()) << lookup0->DebugString(); + EXPECT_EQ(lookup0->Int64OrDie(), 2); + + auto lookup1 = cel_map->Get(&arena, CelValue::CreateString(&kKeys[1])); + ASSERT_TRUE(lookup1.has_value()); + ASSERT_TRUE(lookup1->IsInt64()); + EXPECT_EQ(lookup1->Int64OrDie(), 1); +} + +INSTANTIATE_TEST_SUITE_P(CreateMapStep, CreateMapStepTest, + testing::Combine(testing::Bool(), testing::Bool())); + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 987ec77cf..5d042baf5 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -1,316 +1,270 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/eval/create_struct_step.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/substitute.h" -#include "eval/public/containers/container_backed_map_impl.h" -#include "eval/public/containers/field_access.h" -#include "eval/public/structs/cel_proto_wrapper.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/casting.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { namespace { -using ::google::protobuf::Arena; -using ::google::protobuf::Descriptor; -using ::google::protobuf::DescriptorPool; -using ::google::protobuf::FieldDescriptor; -using ::google::protobuf::Message; -using ::google::protobuf::MessageFactory; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::StructValueBuilderInterface; +using ::cel::UnknownValue; +using ::cel::Value; -class CreateStructStepForMessage : public ExpressionStepBase { +// `CreateStruct` implementation for message/struct. +class CreateStructStepForStruct final : public ExpressionStepBase { public: - struct FieldEntry { - const FieldDescriptor* field; - }; - - CreateStructStepForMessage(int64_t expr_id, const Descriptor* descriptor, - std::vector entries) + CreateStructStepForStruct(int64_t expr_id, std::string name, + std::vector entries, + absl::flat_hash_set optional_indices) : ExpressionStepBase(expr_id), - descriptor_(descriptor), - entries_(std::move(entries)) {} - - absl::Status Evaluate(ExecutionFrame* frame) const override; - - private: - absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; - - const Descriptor* descriptor_; - std::vector entries_; -}; - -class CreateStructStepForMap : public ExpressionStepBase { - public: - CreateStructStepForMap(int64_t expr_id, size_t entry_count) - : ExpressionStepBase(expr_id), entry_count_(entry_count) {} + name_(std::move(name)), + entries_(std::move(entries)), + optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: - absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + absl::StatusOr DoEvaluate(ExecutionFrame* frame) const; - size_t entry_count_; + std::string name_; + std::vector entries_; + absl::flat_hash_set optional_indices_; }; -absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, - CelValue* result) const { +absl::StatusOr CreateStructStepForStruct::DoEvaluate( + ExecutionFrame* frame) const { int entries_size = entries_.size(); - absl::Span args = frame->value_stack().GetSpan(entries_size); + auto args = frame->value_stack().GetSpan(entries_size); - if (frame->enable_unknowns()) { - auto unknown_set = frame->attribute_utility().MergeUnknowns( - args, frame->value_stack().GetAttributeSpan(entries_size), - /*initial_set=*/nullptr, - /*use_partial=*/true); - if (unknown_set != nullptr) { - *result = CelValue::CreateUnknownSet(unknown_set); - return absl::OkStatus(); + for (const auto& arg : args) { + if (arg.IsError()) { + return arg; } } - const Message* prototype = - MessageFactory::generated_factory()->GetPrototype(descriptor_); - - Message* msg = - (prototype != nullptr) ? prototype->New(frame->arena()) : nullptr; - - if (msg == nullptr) { - *result = CreateErrorValue( - frame->arena(), - absl::Substitute("Failed to create message $0", descriptor_->name())); - return absl::OkStatus(); + if (frame->enable_unknowns()) { + absl::optional unknown_set = + frame->attribute_utility().IdentifyAndMergeUnknowns( + args, frame->value_stack().GetAttributeSpan(entries_size), + /*use_partial=*/true); + if (unknown_set.has_value()) { + return *unknown_set; + } } - int index = 0; - for (const auto& entry : entries_) { - const CelValue& arg = args[index++]; - - absl::Status status = absl::OkStatus(); - - if (entry.field->is_map()) { - constexpr int kKeyField = 1; - constexpr int kValueField = 2; - - const CelMap* cel_map; - if (!arg.GetValue(&cel_map) || cel_map == nullptr) { - status = absl::InvalidArgumentError(absl::Substitute( - "Failed to create message $0, field $1: value is not CelMap", - descriptor_->name(), entry.field->name())); - break; - } - - auto entry_descriptor = entry.field->message_type(); - - if (entry_descriptor == nullptr) { - status = absl::InvalidArgumentError( - absl::Substitute("Failed to create message $0, field $1: failed to " - "find map entry descriptor", - descriptor_->name(), entry.field->name())); - break; - } - - auto key_field_descriptor = - entry_descriptor->FindFieldByNumber(kKeyField); - auto value_field_descriptor = - entry_descriptor->FindFieldByNumber(kValueField); - - if (key_field_descriptor == nullptr) { - status = absl::InvalidArgumentError( - absl::Substitute("Failed to create message $0, field $1: failed to " - "find key field descriptor", - descriptor_->name(), entry.field->name())); - break; - } - if (value_field_descriptor == nullptr) { - status = absl::InvalidArgumentError( - absl::Substitute("Failed to create message $0, field $1: failed to " - "find value field descriptor", - descriptor_->name(), entry.field->name())); - break; - } - - const CelList* key_list = cel_map->ListKeys(); - for (int i = 0; i < key_list->size(); i++) { - CelValue key = (*key_list)[i]; + CEL_ASSIGN_OR_RETURN(auto builder, + frame->type_provider().NewValueBuilder( + name_, frame->message_factory(), frame->arena())); + if (builder == nullptr) { + return ErrorValue( + absl::NotFoundError(absl::StrCat("Unable to find builder: ", name_))); + } - auto value = (*cel_map)[key]; - if (!value.has_value()) { - status = absl::InvalidArgumentError(absl::Substitute( - "Failed to create message $0, field $1: Error serializing CelMap", - descriptor_->name(), entry.field->name())); - break; + for (int i = 0; i < entries_size; ++i) { + const auto& entry = entries_[i]; + const auto& arg = args[i]; + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_arg = arg.AsOptional(); optional_arg) { + if (!optional_arg->HasValue()) { + continue; } - - Message* entry_msg = msg->GetReflection()->AddMessage(msg, entry.field); - status = SetValueToSingleField(key, key_field_descriptor, entry_msg); - if (!status.ok()) { - break; + Value optional_arg_value; + optional_arg->Value(&optional_arg_value); + if (optional_arg_value.IsError()) { + // Error should never be in optional, but better safe than sorry. + return optional_arg_value; } - status = SetValueToSingleField(value.value(), value_field_descriptor, - entry_msg); - if (!status.ok()) { - break; + CEL_ASSIGN_OR_RETURN( + absl::optional error_value, + builder->SetFieldByName(entry, std::move(optional_arg_value))); + if (error_value) { + return std::move(*error_value); } - } - - } else if (entry.field->is_repeated()) { - const CelList* cel_list; - if (!arg.GetValue(&cel_list) || cel_list == nullptr) { - *result = CreateErrorValue( - frame->arena(), - absl::Substitute( - "Failed to create message $0: value $1 is not CelList", - descriptor_->name(), entry.field->name())); - return absl::OkStatus(); - } - - for (int i = 0; i < cel_list->size(); i++) { - status = AddValueToRepeatedField((*cel_list)[i], entry.field, msg); - if (!status.ok()) break; + } else { + return cel::TypeConversionError(arg.DebugString(), "optional_type"); } } else { - status = SetValueToSingleField(arg, entry.field, msg); - } - - if (!status.ok()) { - *result = CreateErrorValue( - frame->arena(), - absl::Substitute("Failed to create message $0: reason $1", - descriptor_->name(), status.ToString())); - return absl::OkStatus(); + CEL_ASSIGN_OR_RETURN(absl::optional error_value, + builder->SetFieldByName(entry, arg)); + if (error_value) { + return std::move(*error_value); + } } } - *result = CelProtoWrapper::CreateMessage(msg, frame->arena()); - - return absl::OkStatus(); + return std::move(*builder).Build(); } -absl::Status CreateStructStepForMessage::Evaluate(ExecutionFrame* frame) const { +absl::Status CreateStructStepForStruct::Evaluate(ExecutionFrame* frame) const { if (frame->value_stack().size() < entries_.size()) { - return absl::Status(absl::StatusCode::kInternal, - "CreateStructStepForMessage: stack undeflow"); + return absl::InternalError("CreateStructStepForStruct: stack underflow"); } - - CelValue result; - - absl::Status status = DoEvaluate(frame, &result); - if (!status.ok()) { - return status; - } - - frame->value_stack().Pop(entries_.size()); - frame->value_stack().Push(result); + CEL_ASSIGN_OR_RETURN(Value result, DoEvaluate(frame)); + frame->value_stack().PopAndPush(entries_.size(), std::move(result)); return absl::OkStatus(); } -absl::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, - CelValue* result) const { - absl::Span args = - frame->value_stack().GetSpan(2 * entry_count_); - - if (frame->enable_unknowns()) { - const UnknownSet* unknown_set = frame->attribute_utility().MergeUnknowns( - args, frame->value_stack().GetAttributeSpan(args.size()), - /*initial_set=*/nullptr, true); - if (unknown_set != nullptr) { - *result = CelValue::CreateUnknownSet(unknown_set); - return absl::OkStatus(); - } - } - - std::vector> map_entries; - map_entries.reserve(entry_count_); - for (size_t i = 0; i < entry_count_; i += 1) { - map_entries.push_back({args[2 * i], args[2 * i + 1]}); - } - - auto cel_map = - CreateContainerBackedMap(absl::Span>( - map_entries.data(), map_entries.size())); +class DirectCreateStructStep : public DirectExpressionStep { + public: + DirectCreateStructStep( + int64_t expr_id, std::string name, std::vector field_keys, + std::vector> deps, + absl::flat_hash_set optional_indices) + : DirectExpressionStep(expr_id), + name_(std::move(name)), + field_keys_(std::move(field_keys)), + deps_(std::move(deps)), + optional_indices_(std::move(optional_indices)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const override; - if (cel_map == nullptr) { - *result = CreateErrorValue(frame->arena(), "Failed to create map"); + private: + std::string name_; + std::vector field_keys_; + std::vector> deps_; + absl::flat_hash_set optional_indices_; +}; +absl::Status DirectCreateStructStep::Evaluate(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + Value field_value; + AttributeTrail field_attr; + auto unknowns = frame.attribute_utility().CreateAccumulator(); + + CEL_ASSIGN_OR_RETURN(auto builder, + frame.type_provider().NewValueBuilder( + name_, frame.message_factory(), frame.arena())); + if (builder == nullptr) { + result = cel::ErrorValue( + absl::NotFoundError(absl::StrCat("Unable to find builder: ", name_))); return absl::OkStatus(); } - *result = CelValue::CreateMap(cel_map.get()); + for (int i = 0; i < field_keys_.size(); i++) { + CEL_RETURN_IF_ERROR(deps_[i]->Evaluate(frame, field_value, field_attr)); - // Pass object ownership to Arena. - frame->arena()->Own(cel_map.release()); + // TODO(uncreated-issue/67): if the value is an error, we should be able to return + // early, however some client tests depend on the error message the struct + // impl returns in the stack machine version. + if (field_value.IsError()) { + result = std::move(field_value); + return absl::OkStatus(); + } - return absl::OkStatus(); -} + if (frame.unknown_processing_enabled()) { + if (field_value.IsUnknown()) { + unknowns.Add(field_value.GetUnknown()); + } else if (frame.attribute_utility().CheckForUnknownPartial(field_attr)) { + unknowns.Add(field_attr); + } + } -absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { - if (frame->value_stack().size() < 2 * entry_count_) { - return absl::Status(absl::StatusCode::kInternal, - "CreateStructStepForMap: stack undeflow"); - } + if (!unknowns.IsEmpty()) { + continue; + } - CelValue result; + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_arg = field_value.AsOptional(); optional_arg) { + if (!optional_arg->HasValue()) { + continue; + } + Value optional_arg_value; + optional_arg->Value(&optional_arg_value); + if (optional_arg_value.IsError()) { + // Error should never be in optional, but better safe than sorry. + result = std::move(optional_arg_value); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + absl::optional error_value, + builder->SetFieldByName(field_keys_[i], + std::move(optional_arg_value))); + if (error_value) { + result = std::move(*error_value); + return absl::OkStatus(); + } + continue; + } else { + result = cel::TypeConversionError(field_value.DebugString(), + "optional_type"); + return absl::OkStatus(); + } + } - absl::Status status = DoEvaluate(frame, &result); - if (!status.ok()) { - return status; + CEL_ASSIGN_OR_RETURN( + absl::optional error_value, + builder->SetFieldByName(field_keys_[i], std::move(field_value))); + if (error_value) { + result = std::move(*error_value); + return absl::OkStatus(); + } } - frame->value_stack().Pop(2 * entry_count_); - frame->value_stack().Push(result); + if (!unknowns.IsEmpty()) { + result = std::move(unknowns).Build(); + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(result, std::move(*builder).Build()); return absl::OkStatus(); } } // namespace -absl::StatusOr> CreateCreateStructStep( - const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - int64_t expr_id) { - if (!create_struct_expr->message_name().empty()) { - // Make message-creating step. - std::vector entries; - - const Descriptor* desc = - DescriptorPool::generated_pool()->FindMessageTypeByName( - create_struct_expr->message_name()); - - if (desc == nullptr) { - return absl::InvalidArgumentError( - "Error configuring message creation: message descriptor not found"); - } - - for (const auto& entry : create_struct_expr->entries()) { - if (entry.field_key().empty()) { - return absl::InvalidArgumentError( - "Error configuring message creation: field name missing"); - } - - const FieldDescriptor* field_desc = - desc->FindFieldByName(entry.field_key()); - if (field_desc == nullptr) { - return absl::InvalidArgumentError( - "Error configuring message creation: field name not found"); - } - entries.push_back({field_desc}); - } - - return absl::WrapUnique( - new CreateStructStepForMessage(expr_id, desc, std::move(entries))); - } else { - // Make map-creating step. - return absl::WrapUnique(new CreateStructStepForMap( - expr_id, create_struct_expr->entries_size())); - } +std::unique_ptr CreateDirectCreateStructStep( + std::string resolved_name, std::vector field_keys, + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id) { + return std::make_unique( + expr_id, std::move(resolved_name), std::move(field_keys), std::move(deps), + std::move(optional_indices)); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +std::unique_ptr CreateCreateStructStep( + std::string name, std::vector field_keys, + absl::flat_hash_set optional_indices, int64_t expr_id) { + // MakeOptionalIndicesSet(create_struct_expr) + return std::make_unique( + expr_id, std::move(name), std::move(field_keys), + std::move(optional_indices)); +} +} // namespace google::api::expr::runtime diff --git a/eval/eval/create_struct_step.h b/eval/eval/create_struct_step.h index 93ce3c9b3..eb80634f8 100644 --- a/eval/eval/create_struct_step.h +++ b/eval/eval/create_struct_step.h @@ -1,22 +1,44 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_step_base.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// Factory method for CreateList - based Execution step -absl::StatusOr> CreateCreateStructStep( - const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - int64_t expr_id); - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +namespace google::api::expr::runtime { + +// Creates an `ExpressionStep` which performs `CreateStruct` for a +// message/struct. +std::unique_ptr CreateDirectCreateStructStep( + std::string name, std::vector field_keys, + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id); + +// Creates an `ExpressionStep` which performs `CreateStruct` for a +// message/struct. +std::unique_ptr CreateCreateStructStep( + std::string name, std::vector field_keys, + absl::flat_hash_set optional_indices, int64_t expr_id); + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index d7eac46d8..cd9db9bd9 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -1,296 +1,375 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/eval/create_struct_step.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/type_provider.h" +#include "common/expr.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" +#include "eval/public/activation.h" +#include "eval/public/cel_type_registry.h" +#include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" -#include "testutil/util.h" -#include "base/status_macros.h" +#include "internal/proto_matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { namespace { +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::Expr; +using ::cel::TypeProvider; +using ::cel::internal::test::EqualsProto; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::google::protobuf::Message; +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::Not; +using ::testing::Pointwise; -using testing::Eq; -using testing::IsNull; -using testing::Not; -using testing::Pointwise; +absl::StatusOr MakeStackMachinePath(absl::string_view field) { + ExecutionPath path; -using testutil::EqualsProto; + CEL_ASSIGN_OR_RETURN(auto step0, CreateIdentStep("message", /*expr_id=*/-1)); -using google::api::expr::v1alpha1::Expr; + auto step1 = CreateCreateStructStep("google.api.expr.runtime.TestMessage", + {std::string(field)}, + /*optional_indices=*/{}, -// Helper method. Creates simple pipeline containing CreateStruct step that -// builds message and runs it. -absl::StatusOr RunExpression(absl::string_view field, - const CelValue& value, - google::protobuf::Arena* arena, - bool enable_unknowns) { + /*id=*/-1); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + + return path; +} + +absl::StatusOr MakeRecursivePath(absl::string_view field) { ExecutionPath path; - Expr expr0; - Expr expr1; + std::vector> deps; + deps.push_back(CreateDirectIdentStep("message", -1)); - auto ident = expr0.mutable_ident_expr(); - ident->set_name("message"); - auto step0_status = CreateIdentStep(ident, expr0.id()); + auto step1 = + CreateDirectCreateStructStep("google.api.expr.runtime.TestMessage", + {std::string(field)}, std::move(deps), + /*optional_indices=*/{}, - auto create_struct = expr1.mutable_struct_expr(); - create_struct->set_message_name("google.api.expr.runtime.TestMessage"); + /*id=*/-1); - auto entry = create_struct->add_entries(); - entry->set_field_key(field.data()); + path.push_back(std::make_unique(std::move(step1), -1)); - if (!step0_status.ok()) { - return step0_status.status(); - } + return path; +} - auto step1_status = CreateCreateStructStep(create_struct, expr1.id()); +// Helper method. Creates simple pipeline containing CreateStruct step that +// builds message and runs it. +absl::StatusOr RunExpression( + const absl_nonnull std::shared_ptr& env, + absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, + bool enable_unknowns, bool enable_recursive_planning) { + google::protobuf::LinkMessageReflection(); + CEL_ASSIGN_OR_RETURN(auto maybe_type, + env->type_registry.GetComposedTypeProvider().FindType( + "google.api.expr.runtime.TestMessage")); + if (!maybe_type.has_value()) { + return absl::Status(absl::StatusCode::kFailedPrecondition, + "missing proto message type"); + } - if (!step1_status.ok()) { - return step1_status.status(); + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } + ExecutionPath path; - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); + if (enable_recursive_planning) { + CEL_ASSIGN_OR_RETURN(path, MakeRecursivePath(field)); + } else { + CEL_ASSIGN_OR_RETURN(path, MakeStackMachinePath(field)); + } - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, - enable_unknowns); + CelExpressionFlatImpl cel_expr( + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", value); return cel_expr.Evaluate(activation, arena); } -void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, - google::protobuf::Arena* arena, TestMessage* test_msg, - bool enable_unknowns) { - auto result_status = RunExpression(field, value, arena, enable_unknowns); - ASSERT_OK(result_status); - auto result = result_status.value(); - - ASSERT_TRUE(result.IsMessage()); +void RunExpressionAndGetMessage( + const absl_nonnull std::shared_ptr& env, + absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, + TestMessage* test_msg, bool enable_unknowns, + bool enable_recursive_planning) { + ASSERT_OK_AND_ASSIGN(auto result, + RunExpression(env, field, value, arena, enable_unknowns, + enable_recursive_planning)); + ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); - ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); - test_msg->MergeFrom(*msg); + ASSERT_EQ(msg->GetDescriptor()->full_name(), + "google.api.expr.runtime.TestMessage"); + test_msg->MergePartialFromString(msg->SerializePartialAsCord()); } -void RunExpressionAndGetMessage(absl::string_view field, - std::vector values, - google::protobuf::Arena* arena, TestMessage* test_msg, - bool enable_unknowns) { +void RunExpressionAndGetMessage( + const absl_nonnull std::shared_ptr& env, + absl::string_view field, std::vector values, google::protobuf::Arena* arena, + TestMessage* test_msg, bool enable_unknowns, + bool enable_recursive_planning) { ContainerBackedListImpl cel_list(std::move(values)); CelValue value = CelValue::CreateList(&cel_list); - auto result_status = RunExpression(field, value, arena, enable_unknowns); - ASSERT_OK(result_status); - auto result = result_status.value(); - - ASSERT_TRUE(result.IsMessage()); + ASSERT_OK_AND_ASSIGN(auto result, + RunExpression(env, field, value, arena, enable_unknowns, + enable_recursive_planning)); + ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); - ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); - test_msg->MergeFrom(*msg); + ASSERT_EQ(msg->GetDescriptor()->full_name(), + "google.api.expr.runtime.TestMessage"); + test_msg->MergePartialFromString(msg->SerializePartialAsCord()); } -// Helper method. Creates simple pipeline containing CreateStruct step that -// builds Map and runs it. -absl::StatusOr RunCreateMapExpression( - const std::vector> values, - google::protobuf::Arena* arena, bool enable_unknowns) { - ExecutionPath path; - Activation activation; - - Expr expr0; - Expr expr1; - - std::vector exprs; - int index = 0; - - auto create_struct = expr1.mutable_struct_expr(); - for (const auto& item : values) { - Expr expr; - std::string key_name = absl::StrCat("key", index); - std::string value_name = absl::StrCat("value", index); - - auto key_ident = expr.mutable_ident_expr(); - key_ident->set_name(key_name); - exprs.push_back(expr); - auto step_key_status = CreateIdentStep(key_ident, exprs.back().id()); - if (!step_key_status.ok()) { - return step_key_status.status(); - } - - expr.Clear(); - auto value_ident = expr.mutable_ident_expr(); - value_ident->set_name(value_name); - exprs.push_back(expr); - auto step_value_status = CreateIdentStep(value_ident, exprs.back().id()); - if (!step_value_status.ok()) { - return step_value_status.status(); - } - - path.push_back(std::move(step_key_status.value())); - path.push_back(std::move(step_value_status.value())); - - activation.InsertValue(key_name, item.first); - activation.InsertValue(value_name, item.second); - - create_struct->add_entries(); - index++; - } - - auto step1_status = CreateCreateStructStep(create_struct, expr1.id()); - - if (!step1_status.ok()) { - return step1_status.status(); - } - - path.push_back(std::move(step1_status.value())); +class CreateCreateStructStepTest + : public testing::TestWithParam> { + public: + CreateCreateStructStepTest() : env_(NewTestingRuntimeEnv()) {} - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, - enable_unknowns); - return cel_expr.Evaluate(activation, arena); -} + bool enable_unknowns() { return std::get<0>(GetParam()); } + bool enable_recursive_planning() { return std::get<1>(GetParam()); } -class CreateCreateStructStepTest : public testing::TestWithParam {}; + protected: + absl_nonnull std::shared_ptr env_; + google::protobuf::Arena arena_; +}; TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; - Expr expr1; - - auto create_struct = expr1.mutable_struct_expr(); - create_struct->set_message_name("google.api.expr.runtime.TestMessage"); - - auto step_status = CreateCreateStructStep(create_struct, expr1.id()); - - ASSERT_OK(step_status); - - path.push_back(std::move(step_status.value())); + auto adapter = env_->legacy_type_registry.FindTypeAdapter( + "google.api.expr.runtime.TestMessage"); + ASSERT_TRUE(adapter.has_value() && adapter->mutation_apis() != nullptr); + + ASSERT_OK_AND_ASSIGN(auto maybe_type, + env_->type_registry.GetComposedTypeProvider().FindType( + "google.api.expr.runtime.TestMessage")); + ASSERT_TRUE(maybe_type.has_value()); + if (enable_recursive_planning()) { + auto step = + CreateDirectCreateStructStep("google.api.expr.runtime.TestMessage", + /*fields=*/{}, + /*deps=*/{}, + /*optional_indices=*/{}, + /*id=*/-1); + path.push_back( + std::make_unique(std::move(step), /*id=*/-1)); + } else { + auto step = CreateCreateStructStep("google.api.expr.runtime.TestMessage", + /*fields=*/{}, + /*optional_indices=*/{}, + /*id=*/-1); + path.push_back(std::move(step)); + } - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, GetParam()); + cel::RuntimeOptions options; + if (enable_unknowns(), enable_recursive_planning()) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl cel_expr( + env_, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; - google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); + ASSERT_TRUE(result.IsMessage()) << result.DebugString(); + const Message* msg = result.MessageOrDie(); + ASSERT_THAT(msg, Not(IsNull())); - auto status = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(status); + ASSERT_EQ(msg->GetDescriptor()->full_name(), + "google.api.expr.runtime.TestMessage"); +} - CelValue result = status.value(); - ASSERT_TRUE(result.IsMessage()); +TEST(CreateCreateStructStepTest, TestMessageCreateError) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + Arena arena; + TestMessage test_msg; + absl::Status error = absl::CancelledError(); - const Message* msg = result.MessageOrDie(); - ASSERT_THAT(msg, Not(IsNull())); + auto eval_status = + RunExpression(env, "bool_value", CelValue::CreateError(&error), &arena, + true, /*enable_recursive_planning=*/false); + ASSERT_THAT(eval_status, IsOk()); + EXPECT_THAT(*eval_status->ErrorOrDie(), + StatusIs(absl::StatusCode::kCancelled)); +} + +TEST(CreateCreateStructStepTest, TestMessageCreateErrorRecursive) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + Arena arena; + TestMessage test_msg; + absl::Status error = absl::CancelledError(); - ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); + auto eval_status = + RunExpression(env, "bool_value", CelValue::CreateError(&error), &arena, + true, /*enable_recursive_planning=*/true); + ASSERT_THAT(eval_status, IsOk()); + EXPECT_THAT(*eval_status->ErrorOrDie(), + StatusIs(absl::StatusCode::kCancelled)); } // Test message creation if unknown argument is passed TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; TestMessage test_msg; UnknownSet unknown_set; - auto eval_status = RunExpression( - "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true); + auto eval_status = + RunExpression(env, "bool_value", CelValue::CreateUnknownSet(&unknown_set), + &arena, true, /*enable_recursive_planning=*/false); ASSERT_OK(eval_status); ASSERT_TRUE(eval_status->IsUnknownSet()); } +// Test message creation if unknown argument is passed +TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknownRecursive) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + Arena arena; + TestMessage test_msg; + UnknownSet unknown_set; + + auto eval_status = + RunExpression(env, "bool_value", CelValue::CreateUnknownSet(&unknown_set), + &arena, true, /*enable_recursive_planning=*/true); + ASSERT_OK(eval_status); + ASSERT_TRUE(eval_status->IsUnknownSet()) << eval_status->DebugString(); +} + // Test that fields of type bool are set correctly TEST_P(CreateCreateStructStepTest, TestSetBoolField) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_value", CelValue::CreateBool(true), &arena, &test_msg, GetParam())); + env_, "bool_value", CelValue::CreateBool(true), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.bool_value(), true); } -// Test that fields of type int32_t are set correctly +// Test that fields of type int32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetInt32Field) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); + env_, "int32_value", CelValue::CreateInt64(1), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int32_value(), 1); } -// Test that fields of type uint32_t are set correctly. +// Test that fields of type uint32 are set correctly. TEST_P(CreateCreateStructStepTest, TestSetUInt32Field) { - Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint32_value", CelValue::CreateUint64(1), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + env_, "uint32_value", CelValue::CreateUint64(1), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint32_value(), 1); } -// Test that fields of type int64_t are set correctly. +// Test that fields of type int64 are set correctly. TEST_P(CreateCreateStructStepTest, TestSetInt64Field) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); + env_, "int64_value", CelValue::CreateInt64(1), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.int64_value(), 1); } -// Test that fields of type uint64_t are set correctly. +// Test that fields of type uint64 are set correctly. TEST_P(CreateCreateStructStepTest, TestSetUInt64Field) { - Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint64_value", CelValue::CreateUint64(1), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + env_, "uint64_value", CelValue::CreateUint64(1), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.uint64_value(), 1); } // Test that fields of type float are set correctly TEST_P(CreateCreateStructStepTest, TestSetFloatField) { - Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("float_value", CelValue::CreateDouble(2.0), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + env_, "float_value", CelValue::CreateDouble(2.0), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.float_value(), 2.0); } // Test that fields of type double are set correctly TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { - Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("double_value", CelValue::CreateDouble(2.0), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + env_, "double_value", CelValue::CreateDouble(2.0), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.double_value(), 2.0); } @@ -298,63 +377,54 @@ TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { TEST_P(CreateCreateStructStepTest, TestSetStringField) { const std::string kTestStr = "test"; - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_value", CelValue::CreateString(&kTestStr), &arena, &test_msg, - GetParam())); + env_, "string_value", CelValue::CreateString(&kTestStr), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.string_value(), kTestStr); } - // Test that fields of type bytes are set correctly. TEST_P(CreateCreateStructStepTest, TestSetBytesField) { - Arena arena; - const std::string kTestStr = "test"; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bytes_value", CelValue::CreateBytes(&kTestStr), &arena, &test_msg, - GetParam())); + env_, "bytes_value", CelValue::CreateBytes(&kTestStr), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.bytes_value(), kTestStr); } // Test that fields of type duration are set correctly. TEST_P(CreateCreateStructStepTest, TestSetDurationField) { - Arena arena; - google::protobuf::Duration test_duration; test_duration.set_seconds(2); test_duration.set_nanos(3); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "duration_value", CelProtoWrapper::CreateDuration(&test_duration), &arena, - &test_msg, GetParam())); + env_, "duration_value", CelProtoWrapper::CreateDuration(&test_duration), + &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.duration_value(), EqualsProto(test_duration)); } // Test that fields of type timestamp are set correctly. TEST_P(CreateCreateStructStepTest, TestSetTimestampField) { - Arena arena; - google::protobuf::Timestamp test_timestamp; test_timestamp.set_seconds(2); test_timestamp.set_nanos(3); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "timestamp_value", CelProtoWrapper::CreateTimestamp(&test_timestamp), - &arena, &test_msg, GetParam())); + env_, "timestamp_value", + CelProtoWrapper::CreateTimestamp(&test_timestamp), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.timestamp_value(), EqualsProto(test_timestamp)); } // Test that fields of type Message are set correctly. TEST_P(CreateCreateStructStepTest, TestSetMessageField) { - Arena arena; - // Create payload message and set some fields. TestMessage orig_msg; orig_msg.set_bool_value(true); @@ -363,15 +433,13 @@ TEST_P(CreateCreateStructStepTest, TestSetMessageField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "message_value", CelProtoWrapper::CreateMessage(&orig_msg, &arena), - &arena, &test_msg, GetParam())); + env_, "message_value", CelProtoWrapper::CreateMessage(&orig_msg, &arena_), + &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.message_value(), EqualsProto(orig_msg)); } // Test that fields of type Any are set correctly. TEST_P(CreateCreateStructStepTest, TestSetAnyField) { - Arena arena; - // Create payload message and set some fields. TestMessage orig_embedded_msg; orig_embedded_msg.set_bool_value(true); @@ -383,8 +451,9 @@ TEST_P(CreateCreateStructStepTest, TestSetAnyField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "any_value", CelProtoWrapper::CreateMessage(&orig_embedded_msg, &arena), - &arena, &test_msg, GetParam())); + env_, "any_value", + CelProtoWrapper::CreateMessage(&orig_embedded_msg, &arena_), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg, EqualsProto(orig_msg)); TestMessage test_embedded_msg; @@ -394,18 +463,16 @@ TEST_P(CreateCreateStructStepTest, TestSetAnyField) { // Test that fields of type Message are set correctly. TEST_P(CreateCreateStructStepTest, TestSetEnumField) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), &arena, - &test_msg, GetParam())); + env_, "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), + &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.enum_value(), TestMessage::TEST_ENUM_2); } // Test that fields of type bool are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { - Arena arena; TestMessage test_msg; std::vector kValues = {true, false}; @@ -415,13 +482,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_list", values, &arena, &test_msg, GetParam())); + env_, "bool_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.bool_list(), Pointwise(Eq(), kValues)); } -// Test that repeated fields of type int32_t are set correctly +// Test that repeated fields of type int32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -431,13 +498,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_list", values, &arena, &test_msg, GetParam())); + env_, "int32_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.int32_list(), Pointwise(Eq(), kValues)); } -// Test that repeated fields of type uint32_t are set correctly +// Test that repeated fields of type uint32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -447,13 +514,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint32_list", values, &arena, &test_msg, GetParam())); + env_, "uint32_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.uint32_list(), Pointwise(Eq(), kValues)); } -// Test that repeated fields of type int64_t are set correctly +// Test that repeated fields of type int64 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -463,13 +530,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_list", values, &arena, &test_msg, GetParam())); + env_, "int64_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.int64_list(), Pointwise(Eq(), kValues)); } -// Test that repeated fields of type uint64_t are set correctly +// Test that repeated fields of type uint64 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -479,13 +546,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_list", values, &arena, &test_msg, GetParam())); + env_, "uint64_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.uint64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type float are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -495,13 +562,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "float_list", values, &arena, &test_msg, GetParam())); + env_, "float_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.float_list(), Pointwise(Eq(), kValues)); } -// Test that repeated fields of type uint32_t are set correctly +// Test that repeated fields of type uint32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -511,13 +578,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "double_list", values, &arena, &test_msg, GetParam())); + env_, "double_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.double_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { - Arena arena; TestMessage test_msg; std::vector kValues = {"test1", "test2"}; @@ -527,13 +594,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_list", values, &arena, &test_msg, GetParam())); + env_, "string_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.string_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { - Arena arena; TestMessage test_msg; std::vector kValues = {"test1", "test2"}; @@ -543,14 +610,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bytes_list", values, &arena, &test_msg, GetParam())); + env_, "bytes_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.bytes_list(), Pointwise(Eq(), kValues)); } - // Test that repeated fields of type Message are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { - Arena arena; TestMessage test_msg; std::vector kValues(2); @@ -558,19 +624,18 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { kValues[1].set_string_value("test2"); std::vector values; for (const auto& value : kValues) { - values.push_back(CelProtoWrapper::CreateMessage(&value, &arena)); + values.push_back(CelProtoWrapper::CreateMessage(&value, &arena_)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "message_list", values, &arena, &test_msg, GetParam())); + env_, "message_list", values, &arena_, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.message_list()[0], EqualsProto(kValues[0])); ASSERT_THAT(test_msg.message_list()[1], EqualsProto(kValues[1])); } - // Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { - Arena arena; TestMessage test_msg; std::vector> entries; @@ -583,21 +648,20 @@ TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); auto cel_map = - CreateContainerBackedMap(absl::Span>( + *CreateContainerBackedMap(absl::Span>( entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - GetParam())); + env_, "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.string_int32_map().size(), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[0]), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[1]), 1); } -// Test that fields of type map are set correctly +// Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { - Arena arena; TestMessage test_msg; std::vector> entries; @@ -610,21 +674,20 @@ TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { {CelValue::CreateInt64(kKeys[1]), CelValue::CreateInt64(2)}); auto cel_map = - CreateContainerBackedMap(absl::Span>( + *CreateContainerBackedMap(absl::Span>( entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - GetParam())); + env_, "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int64_int32_map().size(), 2); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[0]), 1); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[1]), 2); } -// Test that fields of type map are set correctly +// Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { - Arena arena; TestMessage test_msg; std::vector> entries; @@ -637,93 +700,21 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { {CelValue::CreateUint64(kKeys[1]), CelValue::CreateInt64(2)}); auto cel_map = - CreateContainerBackedMap(absl::Span>( + *CreateContainerBackedMap(absl::Span>( entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - GetParam())); + env_, "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint64_int32_map().size(), 2); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[0]), 1); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[1]), 2); } -// Test that Empty Map is created successfully. -TEST_P(CreateCreateStructStepTest, TestCreateEmptyMap) { - Arena arena; - auto status = RunCreateMapExpression({}, &arena, GetParam()); - - ASSERT_OK(status); - - CelValue result_value = status.value(); - ASSERT_TRUE(result_value.IsMap()); - - const CelMap* cel_map = result_value.MapOrDie(); - ASSERT_EQ(cel_map->size(), 0); -} - -// Test message creation if unknown argument is passed -TEST(CreateCreateStructStepTest, TestMapCreateWithUnknown) { - Arena arena; - UnknownSet unknown_set; - std::vector> entries; - - std::vector kKeys = {"test2", "test1"}; - - entries.push_back( - {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); - entries.push_back({CelValue::CreateString(&kKeys[1]), - CelValue::CreateUnknownSet(&unknown_set)}); - - auto status = RunCreateMapExpression(entries, &arena, true); - - ASSERT_OK(status); - - CelValue result_value = status.value(); - ASSERT_TRUE(result_value.IsUnknownSet()); -} - -// Test that String Map is created successfully. -TEST_P(CreateCreateStructStepTest, TestCreateStringMap) { - Arena arena; - - std::vector> entries; - - std::vector kKeys = {"test2", "test1"}; - - entries.push_back( - {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); - entries.push_back( - {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); - - auto status = RunCreateMapExpression(entries, &arena, GetParam()); - - ASSERT_OK(status); - - CelValue result_value = status.value(); - ASSERT_TRUE(result_value.IsMap()); - - const CelMap* cel_map = result_value.MapOrDie(); - ASSERT_EQ(cel_map->size(), 2); - - auto lookup0 = (*cel_map)[CelValue::CreateString(&kKeys[0])]; - ASSERT_TRUE(lookup0.has_value()); - ASSERT_TRUE(lookup0.value().IsInt64()); - EXPECT_EQ(lookup0.value().Int64OrDie(), 2); - - auto lookup1 = (*cel_map)[CelValue::CreateString(&kKeys[1])]; - ASSERT_TRUE(lookup1.has_value()); - ASSERT_TRUE(lookup1.value().IsInt64()); - EXPECT_EQ(lookup1.value().Int64OrDie(), 1); -} - INSTANTIATE_TEST_SUITE_P(CombinedCreateStructTest, CreateCreateStructStepTest, - testing::Bool()); + testing::Combine(testing::Bool(), testing::Bool())); } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/direct_expression_step.cc b/eval/eval/direct_expression_step.cc new file mode 100644 index 000000000..2d7fc6fc0 --- /dev/null +++ b/eval/eval/direct_expression_step.cc @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/eval/direct_expression_step.h" + +#include + +#include "absl/status/status.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/evaluator_core.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { + +absl::Status WrappedDirectStep::Evaluate(ExecutionFrame* frame) const { + cel::Value result; + AttributeTrail attribute_trail; + CEL_RETURN_IF_ERROR(impl_->Evaluate(*frame, result, attribute_trail)); + frame->value_stack().Push(std::move(result), std::move(attribute_trail)); + return absl::OkStatus(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/direct_expression_step.h b/eval/eval/direct_expression_step.h new file mode 100644 index 000000000..f11479065 --- /dev/null +++ b/eval/eval/direct_expression_step.h @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +// Represents a directly evaluated CEL expression. +// +// Subexpressions assign to values on the C++ program stack and call their +// dependencies directly. +// +// This reduces the setup overhead for evaluation and minimizes value churn +// to / from a heap based value stack managed by the CEL runtime, but can't be +// used for arbitrarily nested expressions. +class DirectExpressionStep { + public: + explicit DirectExpressionStep(int64_t expr_id) : expr_id_(expr_id) {} + DirectExpressionStep() : expr_id_(-1) {} + + virtual ~DirectExpressionStep() = default; + + int64_t expr_id() const { return expr_id_; } + bool comes_from_ast() const { return expr_id_ >= 0; } + + virtual absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const = 0; + + // Return a type id for this node. + // + // Users must not make any assumptions about the type if the default value is + // returned. + virtual cel::NativeTypeId GetNativeTypeId() const { + return cel::NativeTypeId(); + } + + // Implementations optionally support inspecting the program tree. + virtual absl::optional> + GetDependencies() const { + return absl::nullopt; + } + + // Implementations optionally support extracting the program tree. + // + // Extract prevents the callee from functioning, and is only intended for use + // when replacing a given expression step. + virtual absl::optional>> + ExtractDependencies() { + return absl::nullopt; + }; + + protected: + int64_t expr_id_; +}; + +// Wrapper for direct steps to work with the stack machine impl. +class WrappedDirectStep : public ExpressionStep { + public: + WrappedDirectStep(std::unique_ptr impl, int64_t expr_id) + : ExpressionStep(expr_id, false), impl_(std::move(impl)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + cel::NativeTypeId GetNativeTypeId() const override { + return cel::NativeTypeId::For(); + } + + const DirectExpressionStep* wrapped() const { return impl_.get(); } + + private: + std::unique_ptr impl_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ diff --git a/eval/eval/equality_steps.cc b/eval/eval/equality_steps.cc new file mode 100644 index 000000000..d720302e4 --- /dev/null +++ b/eval/eval/equality_steps.cc @@ -0,0 +1,293 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/eval/equality_steps.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/builtins.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "runtime/internal/errors.h" +#include "runtime/standard/equality_functions.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::BoolValue; +using ::cel::IntValue; +using ::cel::MapValue; +using ::cel::UintValue; +using ::cel::Value; + +using ::cel::ValueKind; +using ::cel::internal::Number; +using ::cel::runtime_internal::ValueEqualImpl; + +absl::StatusOr EvaluateEquality( + ExecutionFrameBase& frame, const Value& lhs, const AttributeTrail& lhs_attr, + const Value& rhs, const AttributeTrail& rhs_attr, bool negation) { + if (lhs.IsError()) { + return lhs; + } + + if (rhs.IsError()) { + return rhs; + } + + if (frame.unknown_processing_enabled()) { + auto accu = frame.attribute_utility().CreateAccumulator(); + accu.MaybeAdd(lhs, lhs_attr); + accu.MaybeAdd(rhs, rhs_attr); + if (!accu.IsEmpty()) { + return std::move(accu).Build(); + } + } + + CEL_ASSIGN_OR_RETURN(auto is_equal, + ValueEqualImpl(lhs, rhs, frame.descriptor_pool(), + frame.message_factory(), frame.arena())); + if (!is_equal.has_value()) { + return cel::ErrorValue(cel::runtime_internal::CreateNoMatchingOverloadError( + negation ? cel::builtin::kInequal : cel::builtin::kEqual)); + } + return negation ? BoolValue(!*is_equal) : BoolValue(*is_equal); +} + +class DirectEqualityStep : public DirectExpressionStep { + public: + explicit DirectEqualityStep(std::unique_ptr lhs, + std::unique_ptr rhs, + bool negation, int64_t expr_id) + : DirectExpressionStep(expr_id), + lhs_(std::move(lhs)), + rhs_(std::move(rhs)), + negation_(negation) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override { + AttributeTrail lhs_attr; + CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, lhs_attr)); + + Value rhs_result; + AttributeTrail rhs_attr; + CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, rhs_attr)); + CEL_ASSIGN_OR_RETURN( + result, EvaluateEquality(frame, result, lhs_attr, rhs_result, rhs_attr, + negation_)); + return absl::OkStatus(); + } + + private: + std::unique_ptr lhs_; + std::unique_ptr rhs_; + bool negation_; +}; + +class IterativeEqualityStep : public ExpressionStepBase { + public: + explicit IterativeEqualityStep(bool negation, int64_t expr_id) + : ExpressionStepBase(expr_id), negation_(negation) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(2)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + auto args = frame->value_stack().GetSpan(2); + auto attrs = frame->value_stack().GetAttributeSpan(2); + + CEL_ASSIGN_OR_RETURN(Value result, + EvaluateEquality(*frame, args[0], attrs[0], args[1], + attrs[1], negation_)); + + frame->value_stack().PopAndPush(2, std::move(result)); + return absl::OkStatus(); + } + + private: + bool negation_; +}; + +absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, + const Value& item, + const MapValue& container) { + switch (item.kind()) { + case ValueKind::kBool: + case ValueKind::kString: + case ValueKind::kInt: + case ValueKind::kUint: + case ValueKind::kDouble: + break; + default: + return cel::ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError( + cel::builtin::kIn)); + } + Value result; + CEL_RETURN_IF_ERROR(container.Has(item, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), + &result)); + + if (result.IsTrue()) { + return result; + } + + if (item.IsDouble() || item.IsUint()) { + Number number = item.IsDouble() + ? Number::FromDouble(item.GetDouble().NativeValue()) + : Number::FromUint64(item.GetUint().NativeValue()); + if (number.LosslessConvertibleToInt()) { + CEL_RETURN_IF_ERROR( + container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { + return result; + } + } + } + + if (item.IsDouble() || item.IsInt()) { + Number number = item.IsDouble() + ? Number::FromDouble(item.GetDouble().NativeValue()) + : Number::FromInt64(item.GetInt().NativeValue()); + if (number.LosslessConvertibleToUint()) { + CEL_RETURN_IF_ERROR( + container.Has(UintValue(number.AsUint()), frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { + return result; + } + } + } + + return BoolValue(false); +} + +absl::StatusOr EvaluateIn(ExecutionFrameBase& frame, const Value& item, + const AttributeTrail& item_attr, + const Value& container, + const AttributeTrail& container_attr) { + if (item.IsError()) { + return item; + } + if (container.IsError()) { + return container; + } + + if (frame.unknown_processing_enabled()) { + auto accu = frame.attribute_utility().CreateAccumulator(); + accu.MaybeAdd(item, item_attr); + accu.MaybeAdd(container, container_attr); + if (!accu.IsEmpty()) { + return std::move(accu).Build(); + } + } + if (container.IsList()) { + return container.GetList().Contains(item, frame.descriptor_pool(), + frame.message_factory(), frame.arena()); + } + if (container.IsMap()) { + return EvaluateInMap(frame, item, container.GetMap()); + } + return cel::ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError(cel::builtin::kIn)); +} + +class DirectInStep : public DirectExpressionStep { + public: + explicit DirectInStep(std::unique_ptr item, + std::unique_ptr container, + int64_t expr_id) + : DirectExpressionStep(expr_id), + item_(std::move(item)), + container_(std::move(container)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override { + AttributeTrail item_attr; + CEL_RETURN_IF_ERROR(item_->Evaluate(frame, result, item_attr)); + + Value container_result; + AttributeTrail container_attr; + CEL_RETURN_IF_ERROR( + container_->Evaluate(frame, container_result, container_attr)); + CEL_ASSIGN_OR_RETURN(result, EvaluateIn(frame, result, item_attr, + container_result, container_attr)); + return absl::OkStatus(); + } + + private: + std::unique_ptr item_; + std::unique_ptr container_; +}; + +class IterativeInStep : public ExpressionStepBase { + public: + explicit IterativeInStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(2)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + + auto args = frame->value_stack().GetSpan(2); + auto attrs = frame->value_stack().GetAttributeSpan(2); + + CEL_ASSIGN_OR_RETURN( + Value result, EvaluateIn(*frame, args[0], attrs[0], args[1], attrs[1])); + frame->value_stack().PopAndPush(2, std::move(result)); + return absl::OkStatus(); + } +}; + +} // namespace + +// Factory method for recursive _==_ and _!=_ Execution step +std::unique_ptr CreateDirectEqualityStep( + std::unique_ptr lhs, + std::unique_ptr rhs, bool negation, int64_t expr_id) { + return std::make_unique(std::move(lhs), std::move(rhs), + negation, expr_id); +} + +// Factory method for iterative _==_ and _!=_ Execution step +std::unique_ptr CreateEqualityStep(bool negation, + int64_t expr_id) { + return std::make_unique(negation, expr_id); +} + +// Factory method for recursive @in Execution step +std::unique_ptr CreateDirectInStep( + std::unique_ptr item, + std::unique_ptr container, int64_t expr_id) { + return std::make_unique(std::move(item), std::move(container), + expr_id); +} + +// Factory method for iterative @in Execution step +std::unique_ptr CreateInStep(int64_t expr_id) { + return std::make_unique(expr_id); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/equality_steps.h b/eval/eval/equality_steps.h new file mode 100644 index 000000000..eb3bec4ca --- /dev/null +++ b/eval/eval/equality_steps.h @@ -0,0 +1,45 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EQUALITY_STEPS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EQUALITY_STEPS_H_ + +#include +#include + +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +// Factory method for recursive _==_/_!=_ Execution step +std::unique_ptr CreateDirectEqualityStep( + std::unique_ptr lhs, + std::unique_ptr rhs, bool negation, int64_t expr_id); + +// Factory method for iterative _==_/_!=_ Execution step +std::unique_ptr CreateEqualityStep(bool negation, + int64_t expr_id); + +// Factory method for recursive @in Execution step +std::unique_ptr CreateDirectInStep( + std::unique_ptr item, + std::unique_ptr container, int64_t expr_id); + +// Factory method for iterative @in Execution step +std::unique_ptr CreateInStep(int64_t expr_id); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EQUALITY_STEPS_H_ diff --git a/eval/eval/equality_steps_test.cc b/eval/eval/equality_steps_test.cc new file mode 100644 index 000000000..168ce7603 --- /dev/null +++ b/eval/eval/equality_steps_test.cc @@ -0,0 +1,569 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/equality_steps.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "base/attribute.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOk; +using ::cel::Attribute; +using ::cel::DoubleValue; +using ::cel::ErrorValue; +using ::cel::IntValue; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::ValueKind; +using ::cel::test::BoolValueIs; +using ::cel::test::ValueKindIs; + +class ValueStep : public ExpressionStep, public DirectExpressionStep { + public: + ValueStep(Value value, Attribute attr) + : ExpressionStep(-1), + DirectExpressionStep(-1), + value_(std::move(value)), + attr_(std::move(attr)) {} + explicit ValueStep(Value value) + : ExpressionStep(-1), + DirectExpressionStep(-1), + value_(std::move(value)), + attr_() {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + frame->value_stack().Push(value_, attr_); + return absl::OkStatus(); + } + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override { + result = value_; + attribute_trail = attr_; + return absl::OkStatus(); + } + + private: + Value value_; + AttributeTrail attr_; +}; + +TEST(RecursiveTest, PartialAttrUnknown) { + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + // A little contrived for simplicity, but this is for cases where e.g. + // `msg == Msg{}` but msg.foo is unknown. + auto plan = CreateDirectEqualityStep( + std::make_unique(IntValue(1), cel::Attribute("foo")), + std::make_unique(IntValue(2)), false, -1); + + activation.SetUnknownPatterns({cel::AttributePattern( + "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); + + ExecutionFrameBase frame(activation, opts, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + cel::Value result; + AttributeTrail attribute_trail; + ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); + + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); +} + +TEST(RecursiveTest, PartialAttrUnknownDisabled) { + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + auto plan = CreateDirectEqualityStep( + std::make_unique(IntValue(1), cel::Attribute("foo")), + std::make_unique(IntValue(2)), false, -1); + + activation.SetUnknownPatterns({cel::AttributePattern( + "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); + ExecutionFrameBase frame(activation, opts, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + cel::Value result; + AttributeTrail attribute_trail; + ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); + + EXPECT_THAT(result, BoolValueIs(false)); +} + +TEST(IterativeTest, PartialAttrUnknown) { + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + FlatExpressionEvaluatorState state( + /*value_stack_size=*/5, + /*comprehension_slot_count=*/0, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> steps; + steps.push_back( + std::make_unique(IntValue(1), cel::Attribute("foo"))); + steps.push_back(std::make_unique(IntValue(2))); + steps.push_back(CreateEqualityStep(false, -1)); + + activation.SetUnknownPatterns({cel::AttributePattern( + "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); + + ExecutionFrame frame(steps, activation, opts, state); + + ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); + + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); +} + +TEST(IterativeTest, PartialAttrUnknownDisabled) { + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + FlatExpressionEvaluatorState state( + /*value_stack_size=*/5, + /*comprehension_slot_count=*/0, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> steps; + steps.push_back( + std::make_unique(IntValue(1), cel::Attribute("foo"))); + steps.push_back(std::make_unique(IntValue(2))); + steps.push_back(CreateEqualityStep(false, -1)); + + activation.SetUnknownPatterns({cel::AttributePattern( + "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); + ExecutionFrame frame(steps, activation, opts, state); + + ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); + + EXPECT_THAT(result, BoolValueIs(false)); +} + +enum class InputType { kInt1, kInt2, kDouble1, kList, kMap, kError, kUnknown }; +enum class OutputType { kBoolTrue, kBoolFalse, kError, kUnknown }; + +struct EqualsTestCase { + InputType lhs; + InputType rhs; + bool negation; + OutputType expected_result; +}; + +class EqualsTest : public ::testing::TestWithParam {}; + +Value MakeValue(InputType type, google::protobuf::Arena* absl_nonnull arena) { + switch (type) { + case InputType::kInt1: + return IntValue(1); + case InputType::kInt2: + return IntValue(2); + case InputType::kDouble1: + return DoubleValue(1.0); + case InputType::kUnknown: + return UnknownValue(); + case InputType::kList: { + auto builder = cel::NewListValueBuilder(arena); + ABSL_CHECK_OK((builder)->Add(IntValue(1))); + return (std::move(*builder)).Build(); + } + case InputType::kMap: { + auto builder = cel::NewMapValueBuilder(arena); + ABSL_CHECK_OK((builder)->Put(IntValue(1), IntValue(2))); + return (std::move(*builder)).Build(); + } + case InputType::kError: + default: + return ErrorValue(absl::InternalError("error")); + } +} + +TEST_P(EqualsTest, Recursive) { + const EqualsTestCase& test_case = GetParam(); + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + auto plan = CreateDirectEqualityStep( + std::make_unique(MakeValue(test_case.lhs, &arena)), + std::make_unique(MakeValue(test_case.rhs, &arena)), + test_case.negation, -1); + + ExecutionFrameBase frame(activation, opts, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + cel::Value result; + AttributeTrail attribute_trail; + ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); + + switch (test_case.expected_result) { + case OutputType::kBoolTrue: + EXPECT_THAT(result, BoolValueIs(true)); + break; + case OutputType::kBoolFalse: + EXPECT_THAT(result, BoolValueIs(false)); + break; + case OutputType::kError: + EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); + break; + case OutputType::kUnknown: + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); + break; + } +} + +TEST_P(EqualsTest, Iterative) { + const EqualsTestCase& test_case = GetParam(); + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + FlatExpressionEvaluatorState state( + /*value_stack_size=*/5, + /*comprehension_slot_count=*/0, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> steps; + steps.push_back( + std::make_unique(MakeValue(test_case.lhs, &arena))); + steps.push_back( + std::make_unique(MakeValue(test_case.rhs, &arena))); + steps.push_back(CreateEqualityStep(test_case.negation, -1)); + + ExecutionFrame frame(steps, activation, opts, state); + + ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); + + switch (test_case.expected_result) { + case OutputType::kBoolTrue: + EXPECT_THAT(result, BoolValueIs(true)); + break; + case OutputType::kBoolFalse: + EXPECT_THAT(result, BoolValueIs(false)); + break; + case OutputType::kError: + EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); + break; + case OutputType::kUnknown: + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P(EqualsTest, EqualsTest, + testing::Values( + EqualsTestCase{ + InputType::kInt1, + InputType::kInt2, + false, + OutputType::kBoolFalse, + }, + EqualsTestCase{ + InputType::kInt1, + InputType::kInt1, + false, + OutputType::kBoolTrue, + }, + EqualsTestCase{ + InputType::kInt1, + InputType::kList, + false, + OutputType::kBoolFalse, + }, + EqualsTestCase{ + InputType::kInt1, + InputType::kDouble1, + false, + OutputType::kBoolTrue, + }, + EqualsTestCase{ + InputType::kInt2, + InputType::kDouble1, + false, + OutputType::kBoolFalse, + }, + EqualsTestCase{ + InputType::kInt1, + InputType::kError, + false, + OutputType::kError, + }, + EqualsTestCase{ + InputType::kError, + InputType::kInt1, + false, + OutputType::kError, + }, + EqualsTestCase{ + InputType::kInt1, + InputType::kUnknown, + false, + OutputType::kUnknown, + }, + EqualsTestCase{ + InputType::kUnknown, + InputType::kInt1, + false, + OutputType::kUnknown, + }, + EqualsTestCase{ + InputType::kError, + InputType::kUnknown, + false, + OutputType::kError, + }, + EqualsTestCase{ + InputType::kUnknown, + InputType::kError, + false, + OutputType::kError, + }, + // != + EqualsTestCase{ + InputType::kInt1, + InputType::kInt2, + true, + OutputType::kBoolTrue, + }, + EqualsTestCase{ + InputType::kError, + InputType::kInt1, + true, + OutputType::kError, + }, + EqualsTestCase{ + InputType::kUnknown, + InputType::kInt1, + true, + OutputType::kUnknown, + }, + EqualsTestCase{ + InputType::kInt1, + InputType::kDouble1, + true, + OutputType::kBoolFalse, + })); + +struct InTestCase { + InputType lhs; + InputType rhs; + OutputType expected_result; +}; + +class InTest : public ::testing::TestWithParam {}; + +TEST_P(InTest, Recursive) { + const InTestCase& test_case = GetParam(); + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + auto plan = CreateDirectInStep( + std::make_unique(MakeValue(test_case.lhs, &arena)), + std::make_unique(MakeValue(test_case.rhs, &arena)), -1); + + ExecutionFrameBase frame(activation, opts, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + cel::Value result; + AttributeTrail attribute_trail; + ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); + + switch (test_case.expected_result) { + case OutputType::kBoolTrue: + EXPECT_THAT(result, BoolValueIs(true)); + break; + case OutputType::kBoolFalse: + EXPECT_THAT(result, BoolValueIs(false)); + break; + case OutputType::kError: + EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); + break; + case OutputType::kUnknown: + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); + break; + } +} + +TEST_P(InTest, Iterative) { + const InTestCase& test_case = GetParam(); + cel::Activation activation; + google::protobuf::Arena arena; + cel::RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + + FlatExpressionEvaluatorState state( + /*value_stack_size=*/5, + /*comprehension_slot_count=*/0, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + + std::vector> steps; + steps.push_back( + std::make_unique(MakeValue(test_case.lhs, &arena))); + steps.push_back( + std::make_unique(MakeValue(test_case.rhs, &arena))); + steps.push_back(CreateInStep(-1)); + + ExecutionFrame frame(steps, activation, opts, state); + + ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); + + switch (test_case.expected_result) { + case OutputType::kBoolTrue: + EXPECT_THAT(result, BoolValueIs(true)); + break; + case OutputType::kBoolFalse: + EXPECT_THAT(result, BoolValueIs(false)); + break; + case OutputType::kError: + EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); + break; + case OutputType::kUnknown: + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P(InTest, InTest, + testing::Values( + InTestCase{ + InputType::kInt1, + InputType::kInt2, + OutputType::kError, + }, + InTestCase{ + InputType::kInt1, + InputType::kList, + OutputType::kBoolTrue, + }, + InTestCase{ + InputType::kInt1, + InputType::kMap, + OutputType::kBoolTrue, + }, + InTestCase{ + InputType::kDouble1, + InputType::kList, + OutputType::kBoolTrue, + }, + InTestCase{ + InputType::kInt2, + InputType::kList, + OutputType::kBoolFalse, + }, + InTestCase{ + InputType::kDouble1, + InputType::kMap, + OutputType::kBoolTrue, + }, + InTestCase{ + InputType::kInt2, + InputType::kMap, + OutputType::kBoolFalse, + }, + InTestCase{ + InputType::kList, + InputType::kMap, + OutputType::kError, + }, + InTestCase{ + InputType::kList, + InputType::kList, + OutputType::kBoolFalse, + }, + InTestCase{ + InputType::kError, + InputType::kList, + OutputType::kError, + }, + InTestCase{ + InputType::kInt1, + InputType::kError, + OutputType::kError, + }, + InTestCase{ + InputType::kUnknown, + InputType::kList, + OutputType::kUnknown, + }, + InTestCase{ + InputType::kInt1, + InputType::kUnknown, + OutputType::kUnknown, + }, + InTestCase{ + InputType::kUnknown, + InputType::kError, + OutputType::kError, + })); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 45ccfc9eb..05dbed854 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -1,215 +1,178 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/eval/evaluator_core.h" +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/types/optional.h" -#include "eval/eval/attribute_trail.h" -#include "eval/public/cel_value.h" -#include "base/status_macros.h" -#include "absl/status/statusor.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace { - -absl::Status CheckIterAccess(CelExpressionFlatEvaluationState* state, - const std::string& name) { - if (state->iter_stack().empty()) { - return absl::Status( - absl::StatusCode::kInternal, - absl::StrCat( - "Attempted to update iteration variable outside of comprehension.'", - name, "'")); - } - auto iter = state->iter_variable_names().find(name); - if (iter == state->iter_variable_names().end()) { - return absl::Status( - absl::StatusCode::kInternal, - absl::StrCat("Attempted to set unknown variable '", name, "'")); - } - - return absl::OkStatus(); -} - -} // namespace +#include "absl/strings/str_cat.h" +#include "common/value.h" +#include "runtime/activation_interface.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" -void ValueStack::Clear() { - for (auto& v : stack_) { - v = CelValue(); - } - for (auto& attr : attribute_stack_) { - attr = AttributeTrail(); - } - - current_size_ = 0; -} - -CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( - size_t value_stack_size, const std::set& iter_variable_names, - google::protobuf::Arena* arena) - : value_stack_(value_stack_size), - iter_variable_names_(iter_variable_names), - arena_(arena) {} +namespace google::api::expr::runtime { -void CelExpressionFlatEvaluationState::Reset() { - iter_stack_.clear(); +void FlatExpressionEvaluatorState::Reset() { value_stack_.Clear(); + iterator_stack_.Clear(); + comprehension_slots_.Reset(); } const ExpressionStep* ExecutionFrame::Next() { - size_t end_pos = execution_path_.size(); + while (true) { + const size_t end_pos = execution_path_.size(); - if (pc_ < end_pos) return execution_path_[pc_++].get(); - if (pc_ > end_pos) { - GOOGLE_LOG(ERROR) << "Attempting to step beyond the end of execution path."; + if (ABSL_PREDICT_TRUE(pc_ < end_pos)) { + const auto* step = execution_path_[pc_++].get(); + ABSL_ASSUME(step != nullptr); + return step; + } + if (ABSL_PREDICT_TRUE(pc_ == end_pos)) { + if (!call_stack_.empty()) { + SubFrame& subframe = call_stack_.back(); + pc_ = subframe.return_pc; + execution_path_ = subframe.return_expression; + ABSL_DCHECK_EQ(value_stack().size(), subframe.expected_stack_size); + comprehension_slots().Set(subframe.slot_index, value_stack().Peek(), + value_stack().PeekAttribute()); + call_stack_.pop_back(); + continue; + } + } else { + ABSL_LOG(ERROR) << "Attempting to step beyond the end of execution path."; + } + return nullptr; } - return nullptr; } -absl::Status ExecutionFrame::PushIterFrame() { - state_->iter_stack().push_back({}); - return absl::OkStatus(); -} +namespace { -absl::Status ExecutionFrame::PopIterFrame() { - if (state_->iter_stack().empty()) { - return absl::InternalError("Loop stack underflow."); +// This class abuses the fact that `absl::Status` is trivially destructible when +// `absl::Status::ok()` is `true`. If the implementation of `absl::Status` every +// changes, LSan and ASan should catch it. We cannot deal with the cost of extra +// move assignment and destructor calls. +// +// This is useful only in the evaluation loop and is a direct replacement for +// `RETURN_IF_ERROR`. It yields the most improvements on benchmarks with lots of +// steps which never return non-OK `absl::Status`. +class EvaluationStatus final { + public: + explicit EvaluationStatus(absl::Status&& status) { + ::new (static_cast(&status_[0])) absl::Status(std::move(status)); } - state_->iter_stack().pop_back(); - return absl::OkStatus(); -} - -absl::Status ExecutionFrame::SetIterVar(const std::string& name, - const CelValue& val, - AttributeTrail trail) { - RETURN_IF_ERROR(CheckIterAccess(state_, name)); - state_->IterStackTop()[name] = {val, trail}; - return absl::OkStatus(); -} + EvaluationStatus() = delete; + EvaluationStatus(const EvaluationStatus&) = delete; + EvaluationStatus(EvaluationStatus&&) = delete; + EvaluationStatus& operator=(const EvaluationStatus&) = delete; + EvaluationStatus& operator=(EvaluationStatus&&) = delete; -absl::Status ExecutionFrame::SetIterVar(const std::string& name, - const CelValue& val) { - return SetIterVar(name, val, AttributeTrail()); -} - -absl::Status ExecutionFrame::ClearIterVar(const std::string& name) { - RETURN_IF_ERROR(CheckIterAccess(state_, name)); - state_->IterStackTop().erase(name); - return absl::OkStatus(); -} - -bool ExecutionFrame::GetIterVar(const std::string& name, CelValue* val) const { - absl::Status status = CheckIterAccess(state_, name); - if (!status.ok()) { - return false; + absl::Status Consume() && { + return std::move(*reinterpret_cast(&status_[0])); } - for (auto iter = state_->iter_stack().rbegin(); - iter != state_->iter_stack().rend(); ++iter) { - auto& frame = *iter; - auto frame_iter = frame.find(name); - if (frame_iter != frame.end()) { - const auto& entry = frame_iter->second; - *val = entry.value; - return true; - } + bool ok() const { + return ABSL_PREDICT_TRUE( + reinterpret_cast(&status_[0])->ok()); } - return false; -} + private: + alignas(absl::Status) char status_[sizeof(absl::Status)]; +}; -bool ExecutionFrame::GetIterAttr(const std::string& name, - const AttributeTrail** val) const { - absl::Status status = CheckIterAccess(state_, name); - if (!status.ok()) { - return false; - } +} // namespace + +absl::StatusOr ExecutionFrame::Evaluate( + EvaluationListener& listener) { + const size_t initial_stack_size = value_stack().size(); - for (auto iter = state_->iter_stack().rbegin(); - iter != state_->iter_stack().rend(); ++iter) { - auto& frame = *iter; - auto frame_iter = frame.find(name); - if (frame_iter != frame.end()) { - const auto& entry = frame_iter->second; - *val = &entry.attr_trail; - return true; + if (!listener) { + for (const ExpressionStep* expr = Next(); + ABSL_PREDICT_TRUE(expr != nullptr); expr = Next()) { + if (EvaluationStatus status(expr->Evaluate(this)); !status.ok()) { + return std::move(status).Consume(); + } + } + } else { + for (const ExpressionStep* expr = Next(); + ABSL_PREDICT_TRUE(expr != nullptr); expr = Next()) { + if (EvaluationStatus status(expr->Evaluate(this)); !status.ok()) { + return std::move(status).Consume(); + } + + if (pc_ == 0 || !expr->comes_from_ast()) { + // Skip if we just started a Call or if the step doesn't map to an + // AST id. + continue; + } + + if (ABSL_PREDICT_FALSE(value_stack().empty())) { + ABSL_LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " + "Try to disable short-circuiting."; + continue; + } + if (EvaluationStatus status(listener(expr->id(), value_stack().Peek(), + descriptor_pool(), message_factory(), + arena())); + !status.ok()) { + return std::move(status).Consume(); + } } } - return false; -} + const size_t final_stack_size = value_stack().size(); + if (ABSL_PREDICT_FALSE(final_stack_size != initial_stack_size + 1 || + final_stack_size == 0)) { + return absl::InternalError(absl::StrCat( + "Stack error during evaluation: expected=", initial_stack_size + 1, + ", actual=", final_stack_size)); + } -std::unique_ptr CelExpressionFlatImpl::InitializeState( - google::protobuf::Arena* arena) const { - return absl::make_unique( - path_.size(), iter_variable_names_, arena); + cel::Value value = std::move(value_stack().Peek()); + value_stack().Pop(1); + return value; } -absl::StatusOr CelExpressionFlatImpl::Evaluate( - const BaseActivation& activation, CelEvaluationState* state) const { - return Trace(activation, state, CelEvaluationListener()); +FlatExpressionEvaluatorState FlatExpression::MakeEvaluatorState( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + return FlatExpressionEvaluatorState(path_.size(), comprehension_slots_size_, + type_provider_, descriptor_pool, + message_factory, arena); } -absl::StatusOr CelExpressionFlatImpl::Trace( - const BaseActivation& activation, CelEvaluationState* _state, - CelEvaluationListener callback) const { - auto state = down_cast(_state); - state->Reset(); - - // Using both unknown attribute patterns and unknown paths via FieldMask is - // not allowed. - if (activation.unknown_paths().paths_size() != 0 && - !activation.unknown_attribute_patterns().empty()) { - return absl::InvalidArgumentError( - "Attempting to evaluate expression with both unknown_paths and " - "unknown_attribute_patterns set in the Activation"); - } +absl::StatusOr FlatExpression::EvaluateWithCallback( + const cel::ActivationInterface& activation, + const cel::EmbedderContext* absl_nullable embedder_context, + EvaluationListener listener, FlatExpressionEvaluatorState& state) const { + state.Reset(); - ExecutionFrame frame(path_, activation, max_iterations_, state, - enable_unknowns_, enable_unknown_function_results_, - enable_missing_attribute_errors_); - - ValueStack* stack = &frame.value_stack(); - size_t initial_stack_size = stack->size(); - const ExpressionStep* expr; - while ((expr = frame.Next()) != nullptr) { - auto status = expr->Evaluate(&frame); - if (!status.ok()) { - return status; - } - if (!callback) { - continue; - } - if (!expr->ComesFromAst()) { - // This step was added during compilation (e.g. Int64ConstImpl). - continue; - } + ExecutionFrame frame(subexpressions_, activation, options_, state, + std::move(listener), embedder_context); - if (stack->empty()) { - GOOGLE_LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " - "Try to disable short-circuiting."; - continue; - } - auto status2 = callback(expr->id(), stack->Peek(), state->arena()); - if (!status2.ok()) { - return status2; - } - } - - size_t final_stack_size = stack->size(); - if (initial_stack_size + 1 != final_stack_size || final_stack_size == 0) { - return absl::Status(absl::StatusCode::kInternal, - "Stack error during evaluation"); - } - CelValue value = stack->Peek(); - stack->Pop(1); - return value; + return frame.Evaluate(frame.callback()); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index fa37733df..575abfa05 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -1,46 +1,72 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ -#include -#include - -#include +#include +#include #include -#include -#include #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/arena.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/types/optional.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "eval/eval/attribute_trail.h" +#include "base/type_provider.h" +#include "common/native_type.h" +#include "common/value.h" #include "eval/eval/attribute_utility.h" -#include "eval/public/activation.h" -#include "eval/public/cel_attribute.h" -#include "eval/public/cel_expression.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" -#include "absl/status/statusor.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/evaluator_stack.h" +#include "eval/eval/iterator_stack.h" +#include "runtime/activation_interface.h" +#include "runtime/internal/activation_attribute_matcher_access.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +class EmbedderContext; +} // namespace cel -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // Forward declaration of ExecutionFrame, to resolve circular dependency. class ExecutionFrame; +using EvaluationListener = cel::TraceableProgram::EvaluationListener; + // Class Expression represents single execution step. class ExpressionStep { public: - virtual ~ExpressionStep() {} + explicit ExpressionStep(int64_t id, bool comes_from_ast = true) + : id_(id), comes_from_ast_(comes_from_ast) {} + + ExpressionStep(const ExpressionStep&) = delete; + ExpressionStep& operator=(const ExpressionStep&) = delete; + + virtual ~ExpressionStep() = default; // Performs actual evaluation. - // Values are passed between Expression objects via ValueStack, which is + // Values are passed between Expression objects via EvaluatorStack, which is // supplied with context. // Also, Expression gets values supplied by caller though Activation // interface. @@ -54,335 +80,437 @@ class ExpressionStep { // expression associated (e.g. a jump step), or if there is no ID assigned to // the corresponding expression. Useful for error scenarios where information // from Expr object is needed to create CelError. - virtual int64_t id() const = 0; + int64_t id() const { return id_; } // Returns if the execution step comes from AST. - virtual bool ComesFromAst() const = 0; + bool comes_from_ast() const { return comes_from_ast_; } + + // Return the type of the underlying expression step for special handling in + // the planning phase. This should only be overridden by special cases, and + // callers must not make any assumptions about the default case. + virtual cel::NativeTypeId GetNativeTypeId() const { + return cel::NativeTypeId(); + } + + private: + const int64_t id_; + const bool comes_from_ast_; }; using ExecutionPath = std::vector>; +using ExecutionPathView = + absl::Span>; -// CelValue stack. -// Implementation is based on vector to allow passing parameters from -// stack as Span<>. -class ValueStack { +// Class that wraps the state that needs to be allocated for expression +// evaluation. This can be reused to save on allocations. +class FlatExpressionEvaluatorState { public: - ValueStack(size_t max_size) : current_size_(0) { - stack_.resize(max_size); - attribute_stack_.resize(max_size); - } + FlatExpressionEvaluatorState( + size_t value_stack_size, size_t comprehension_slot_count, + const cel::TypeProvider& type_provider, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) + : value_stack_(value_stack_size), + // We currently use comprehension_slot_count because it is less of an + // over estimate than value_stack_size. In future we should just + // calculate the correct capacity. + iterator_stack_(comprehension_slot_count), + comprehension_slots_(comprehension_slot_count), + type_provider_(type_provider), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), + arena_(arena) {} - // Return the current stack size. - size_t size() const { return current_size_; } - - // Return the maximum size of the stack. - size_t max_size() const { return stack_.size(); } + void Reset(); - // Returns true if stack is empty. - bool empty() const { return current_size_ == 0; } + EvaluatorStack& value_stack() { return value_stack_; } - // Attributes stack size. - size_t attribute_size() const { return current_size_; } + cel::runtime_internal::IteratorStack& iterator_stack() { + return iterator_stack_; + } - // Check that stack has enough elements. - bool HasEnough(size_t size) const { return current_size_ >= size; } + ComprehensionSlots& comprehension_slots() { return comprehension_slots_; } - // Dumps the entire stack state as is. - void Clear(); + const cel::TypeProvider& type_provider() { return type_provider_; } - // Gets the last size elements of the stack. - // Checking that stack has enough elements is caller's responsibility. - // Please note that calls to Push may invalidate returned Span object. - absl::Span GetSpan(size_t size) const { - if (!HasEnough(size)) { - GOOGLE_LOG(ERROR) << "Requested span size (" << size - << ") exceeds current stack size: " << current_size_; - } - return absl::Span(stack_.data() + current_size_ - size, - size); + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return descriptor_pool_; } - // Gets the last size attribute trails of the stack. - // Checking that stack has enough elements is caller's responsibility. - // Please note that calls to Push may invalidate returned Span object. - absl::Span GetAttributeSpan(size_t size) const { - return absl::Span( - attribute_stack_.data() + current_size_ - size, size); + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return message_factory_; } - // Peeks the last element of the stack. - // Checking that stack is not empty is caller's responsibility. - const CelValue& Peek() const { - if (empty()) { - GOOGLE_LOG(ERROR) << "Peeking on empty ValueStack"; - } - return stack_[current_size_ - 1]; - } + google::protobuf::Arena* absl_nonnull arena() { return arena_; } + + private: + EvaluatorStack value_stack_; + cel::runtime_internal::IteratorStack iterator_stack_; + ComprehensionSlots comprehension_slots_; + const cel::TypeProvider& type_provider_; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull message_factory_; + google::protobuf::Arena* absl_nonnull arena_; +}; - // Peeks the last element of the attribute stack. - // Checking that stack is not empty is caller's responsibility. - const AttributeTrail& PeekAttribute() const { - if (empty()) { - GOOGLE_LOG(ERROR) << "Peeking on empty ValueStack"; +// Context needed for evaluation. This is sufficient for supporting +// recursive evaluation, but stack machine programs require an +// ExecutionFrame instance for managing a heap-backed stack. +class ExecutionFrameBase { + public: + // Overload for test usages. + ExecutionFrameBase(const cel::ActivationInterface& activation, + const cel::RuntimeOptions& options, + const cel::TypeProvider& type_provider, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) + : activation_(&activation), + callback_(), + options_(&options), + type_provider_(type_provider), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), + arena_(arena), + embedder_context_(nullptr), + attribute_utility_(activation.GetUnknownAttributes(), + activation.GetMissingAttributes()), + slots_(&ComprehensionSlots::GetEmptyInstance()), + max_iterations_(options.comprehension_max_iterations), + iterations_(0) { + if (unknown_processing_enabled()) { + if (auto matcher = cel::runtime_internal:: + ActivationAttributeMatcherAccess::GetAttributeMatcher(activation); + matcher != nullptr) { + attribute_utility_.set_matcher(matcher); + } } - return attribute_stack_[current_size_ - 1]; } - // Clears the last size elements of the stack. - // Checking that stack has enough elements is caller's responsibility. - void Pop(size_t size) { - if (!HasEnough(size)) { - GOOGLE_LOG(ERROR) << "Trying to pop more elements (" << size - << ") than the current stack size: " << current_size_; + ExecutionFrameBase(const cel::ActivationInterface& activation, + EvaluationListener callback, + const cel::RuntimeOptions& options, + const cel::TypeProvider& type_provider, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + const cel::EmbedderContext* absl_nullable embedder_context, + ComprehensionSlots& slots) + : activation_(&activation), + callback_(std::move(callback)), + options_(&options), + type_provider_(type_provider), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), + arena_(arena), + embedder_context_(embedder_context), + attribute_utility_(activation.GetUnknownAttributes(), + activation.GetMissingAttributes()), + slots_(&slots), + max_iterations_(options.comprehension_max_iterations), + iterations_(0) { + if (unknown_processing_enabled()) { + if (auto matcher = cel::runtime_internal:: + ActivationAttributeMatcherAccess::GetAttributeMatcher(activation); + matcher != nullptr) { + attribute_utility_.set_matcher(matcher); + } } - current_size_ -= size; } - // Put element on the top of the stack. - void Push(const CelValue& value) { Push(value, AttributeTrail()); } + const cel::ActivationInterface& activation() const { return *activation_; } - void Push(const CelValue& value, AttributeTrail attribute) { - if (current_size_ >= stack_.size()) { - GOOGLE_LOG(ERROR) << "No room to push more elements on to ValueStack"; - } - stack_[current_size_] = value; - attribute_stack_[current_size_] = attribute; - current_size_++; - } + EvaluationListener& callback() { return callback_; } - // Replace element on the top of the stack. - // Checking that stack is not empty is caller's responsibility. - void PopAndPush(const CelValue& value) { - PopAndPush(value, AttributeTrail()); - } + const cel::RuntimeOptions& options() const { return *options_; } - // Replace element on the top of the stack. - // Checking that stack is not empty is caller's responsibility. - void PopAndPush(const CelValue& value, AttributeTrail attribute) { - if (empty()) { - GOOGLE_LOG(ERROR) << "Cannot PopAndPush on empty stack."; - } - stack_[current_size_ - 1] = value; - attribute_stack_[current_size_ - 1] = attribute; + const cel::TypeProvider& type_provider() { return type_provider_; } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { + return descriptor_pool_; } - // Preallocate stack. - void Reserve(size_t size) { - stack_.reserve(size); - attribute_stack_.reserve(size); + google::protobuf::MessageFactory* absl_nonnull message_factory() const { + return message_factory_; } - private: - std::vector stack_; - std::vector attribute_stack_; - size_t current_size_; -}; + google::protobuf::Arena* absl_nonnull arena() const { return arena_; } -class CelExpressionFlatEvaluationState : public CelEvaluationState { - public: - CelExpressionFlatEvaluationState( - size_t value_stack_size, const std::set& iter_variable_names, - google::protobuf::Arena* arena); - - struct IterVarEntry { - CelValue value; - AttributeTrail attr_trail; - }; + const cel::EmbedderContext* absl_nullable embedder_context() const { + return embedder_context_; + } - // Need pointer stability to avoid copying the attr trail lookups. - using IterVarFrame = absl::node_hash_map; + const AttributeUtility& attribute_utility() const { + return attribute_utility_; + } - void Reset(); + bool attribute_tracking_enabled() const { + return options_->unknown_processing != + cel::UnknownProcessingOptions::kDisabled || + options_->enable_missing_attribute_errors; + } - ValueStack& value_stack() { return value_stack_; } + bool missing_attribute_errors_enabled() const { + return options_->enable_missing_attribute_errors; + } - std::vector& iter_stack() { return iter_stack_; } + bool unknown_processing_enabled() const { + return options_->unknown_processing != + cel::UnknownProcessingOptions::kDisabled; + } - IterVarFrame& IterStackTop() { return iter_stack_[iter_stack().size() - 1]; } + bool unknown_function_results_enabled() const { + return options_->unknown_processing == + cel::UnknownProcessingOptions::kAttributeAndFunction; + } - std::set& iter_variable_names() { return iter_variable_names_; } + ComprehensionSlots& comprehension_slots() { return *slots_; } - google::protobuf::Arena* arena() { return arena_; } + // Increment iterations and return an error if the iteration budget is + // exceeded + absl::Status IncrementIterations() { + if (max_iterations_ == 0) { + return absl::OkStatus(); + } + iterations_++; + if (iterations_ >= max_iterations_) { + return absl::Status(absl::StatusCode::kInternal, + "Iteration budget exceeded"); + } + return absl::OkStatus(); + } - private: - ValueStack value_stack_; - std::set iter_variable_names_; - std::vector iter_stack_; - google::protobuf::Arena* arena_; + protected: + const cel::ActivationInterface* absl_nonnull activation_; + EvaluationListener callback_; + const cel::RuntimeOptions* absl_nonnull options_; + const cel::TypeProvider& type_provider_; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull message_factory_; + google::protobuf::Arena* absl_nonnull arena_; + const cel::EmbedderContext* absl_nullable embedder_context_; + AttributeUtility attribute_utility_; + ComprehensionSlots* absl_nonnull slots_; + const int max_iterations_; + int iterations_; }; -// ExecutionFrame provides context for expression evaluation. -// The lifecycle of the object is bound to CelExpression Evaluate(...) call. -class ExecutionFrame { +// ExecutionFrame manages the context needed for expression evaluation. +// The lifecycle of the object is bound to a FlateExpression::Evaluate*(...) +// call. +class ExecutionFrame : public ExecutionFrameBase { public: // flat is the flattened sequence of execution steps that will be evaluated. // activation provides bindings between parameter names and values. - // arena serves as allocation manager during the expression evaluation. - - ExecutionFrame(const ExecutionPath& flat, const BaseActivation& activation, - int max_iterations, CelExpressionFlatEvaluationState* state, - bool enable_unknowns, bool enable_unknown_function_results, - bool enable_missing_attribute_errors) - : pc_(0UL), + // state contains the value factory for evaluation and the allocated data + // structures needed for evaluation. + ExecutionFrame( + ExecutionPathView flat, const cel::ActivationInterface& activation, + const cel::RuntimeOptions& options, FlatExpressionEvaluatorState& state, + EvaluationListener callback = EvaluationListener(), + const cel::EmbedderContext* absl_nullable embedder_context = nullptr) + : ExecutionFrameBase(activation, std::move(callback), options, + state.type_provider(), state.descriptor_pool(), + state.message_factory(), state.arena(), + embedder_context, state.comprehension_slots()), + pc_(0UL), execution_path_(flat), - activation_(activation), - enable_unknowns_(enable_unknowns), - enable_unknown_function_results_(enable_unknown_function_results), - enable_missing_attribute_errors_(enable_missing_attribute_errors), - attribute_utility_(&activation.unknown_attribute_patterns(), - &activation.missing_attribute_patterns(), - state->arena()), - max_iterations_(max_iterations), - iterations_(0), - state_(state) {} + value_stack_(&state.value_stack()), + iterator_stack_(&state.iterator_stack()), + subexpressions_() {} + + ExecutionFrame( + absl::Span subexpressions, + const cel::ActivationInterface& activation, + const cel::RuntimeOptions& options, FlatExpressionEvaluatorState& state, + EvaluationListener callback = EvaluationListener(), + const cel::EmbedderContext* absl_nullable embedder_context = nullptr) + : ExecutionFrameBase(activation, std::move(callback), options, + state.type_provider(), state.descriptor_pool(), + state.message_factory(), state.arena(), + embedder_context, state.comprehension_slots()), + pc_(0UL), + execution_path_(subexpressions[0]), + value_stack_(&state.value_stack()), + iterator_stack_(&state.iterator_stack()), + subexpressions_(subexpressions) { + ABSL_DCHECK(!subexpressions.empty()); + } // Returns next expression to evaluate. const ExpressionStep* Next(); - // Intended for use only in conditionals. + // Evaluate the execution frame to completion. + absl::StatusOr Evaluate(EvaluationListener& listener); + // Evaluate the execution frame to completion. + absl::StatusOr Evaluate() { return Evaluate(callback()); } + + // Intended for use in builtin shortcutting operations. + // + // Offset applies after normal pc increment. For example, JumpTo(0) is a + // no-op, JumpTo(1) skips the expected next step. absl::Status JumpTo(int offset) { + ABSL_DCHECK_LE(offset, static_cast(execution_path_.size())); + ABSL_DCHECK_GE(offset, -static_cast(pc_)); + int new_pc = static_cast(pc_) + offset; if (new_pc < 0 || new_pc > static_cast(execution_path_.size())) { return absl::Status(absl::StatusCode::kInternal, absl::StrCat("Jump address out of range: position: ", - pc_, ",offset: ", offset, + pc_, ", offset: ", offset, ", range: ", execution_path_.size())); } pc_ = static_cast(new_pc); return absl::OkStatus(); } - ValueStack& value_stack() { return state_->value_stack(); } - bool enable_unknowns() const { return enable_unknowns_; } - bool enable_unknown_function_results() const { - return enable_unknown_function_results_; - } - bool enable_missing_attribute_errors() const { - return enable_missing_attribute_errors_; + // Move pc to a subexpression. + // + // Unlike a `Call` in a programming language, the subexpression is evaluated + // in the same context as the caller (e.g. no stack isolation or scope change) + // + // Only intended for use in built-in notion of lazily evaluated + // subexpressions. + void Call(size_t slot_index, size_t subexpression_index) { + ABSL_DCHECK_LT(subexpression_index, subexpressions_.size()); + ExecutionPathView subexpression = subexpressions_[subexpression_index]; + ABSL_DCHECK(subexpression != execution_path_); + size_t return_pc = pc_; + // return pc == size() is supported (a tail call). + ABSL_DCHECK_LE(return_pc, execution_path_.size()); + call_stack_.push_back(SubFrame{return_pc, slot_index, execution_path_, + value_stack().size() + 1}); + pc_ = 0UL; + execution_path_ = subexpression; } - google::protobuf::Arena* arena() { return state_->arena(); } - const AttributeUtility& attribute_utility() const { - return attribute_utility_; - } + EvaluatorStack& value_stack() { return *value_stack_; } - // Returns reference to Activation - const BaseActivation& activation() const { return activation_; } - - // Creates a new frame for iteration variables. - absl::Status PushIterFrame(); + cel::runtime_internal::IteratorStack& iterator_stack() { + return *iterator_stack_; + } - // Discards the top frame for iteration variables. - absl::Status PopIterFrame(); + bool enable_attribute_tracking() const { + return attribute_tracking_enabled(); + } - // Sets the value of an iteration variable - absl::Status SetIterVar(const std::string& name, const CelValue& val); + bool enable_unknowns() const { return unknown_processing_enabled(); } - // Sets the value of an iteration variable - absl::Status SetIterVar(const std::string& name, const CelValue& val, - AttributeTrail trail); + bool enable_unknown_function_results() const { + return unknown_function_results_enabled(); + } - // Clears the value of an iteration variable - absl::Status ClearIterVar(const std::string& name); + bool enable_missing_attribute_errors() const { + return missing_attribute_errors_enabled(); + } - // Gets the current value of an iteration variable. - // Returns false if the variable is not currently in use (SetIterVar has been - // called since init or last clear). - bool GetIterVar(const std::string& name, CelValue* val) const; + bool enable_heterogeneous_numeric_lookups() const { + return options().enable_heterogeneous_equality; + } - // Gets the current value of an iteration variable. - // Returns false if the variable is not currently in use (SetIterVar has not - // been called since init or last clear). - bool GetIterAttr(const std::string& name, const AttributeTrail** val) const; + bool enable_comprehension_list_append() const { + return options().enable_comprehension_list_append; + } - // Increment iterations and return an error if the iteration budget is - // exceeded - absl::Status IncrementIterations() { - if (max_iterations_ == 0) { - return absl::OkStatus(); - } - iterations_++; - if (iterations_ >= max_iterations_) { - return absl::Status(absl::StatusCode::kInternal, - "Iteration budget exceeded"); - } - return absl::OkStatus(); + // Returns reference to the modern API activation. + const cel::ActivationInterface& modern_activation() const { + return *activation_; } private: + struct SubFrame { + size_t return_pc; + size_t slot_index; + ExecutionPathView return_expression; + size_t expected_stack_size; + }; + size_t pc_; // pc_ - Program Counter. Current position on execution path. - const ExecutionPath& execution_path_; - const BaseActivation& activation_; - bool enable_unknowns_; - bool enable_unknown_function_results_; - bool enable_missing_attribute_errors_; - AttributeUtility attribute_utility_; - const int max_iterations_; - int iterations_; - CelExpressionFlatEvaluationState* state_; + ExecutionPathView execution_path_; + EvaluatorStack* absl_nonnull const value_stack_; + cel::runtime_internal::IteratorStack* absl_nonnull const iterator_stack_; + absl::Span subexpressions_; + std::vector call_stack_; }; -// Implementation of the CelExpression that utilizes flattening -// of the expression tree. -class CelExpressionFlatImpl : public CelExpression { +// A flattened representation of the input CEL AST. +class FlatExpression { public: - // Constructs CelExpressionFlatImpl instance. - // path is flat execution path that is based upon - // flattened AST tree. Max iterations dictates the maximum number of - // iterations in the comprehension expressions (use 0 to disable the upper - // bound). - CelExpressionFlatImpl(const google::api::expr::v1alpha1::Expr* root_expr, - ExecutionPath path, int max_iterations, - std::set iter_variable_names, - bool enable_unknowns = false, - bool enable_unknown_function_results = false, - bool enable_missing_attribute_errors = false) + // path is flat execution path that is based upon the flattened AST tree + // type_provider is the configured type system that should be used for + // value creation in evaluation + FlatExpression(ExecutionPath path, size_t comprehension_slots_size, + const cel::TypeProvider& type_provider, + const cel::RuntimeOptions& options, + absl_nullable std::shared_ptr arena = nullptr) + : path_(std::move(path)), + subexpressions_({path_}), + comprehension_slots_size_(comprehension_slots_size), + type_provider_(type_provider), + options_(options), + arena_(std::move(arena)) {} + + FlatExpression(ExecutionPath path, + std::vector subexpressions, + size_t comprehension_slots_size, + const cel::TypeProvider& type_provider, + const cel::RuntimeOptions& options, + absl_nullable std::shared_ptr arena = nullptr) : path_(std::move(path)), - max_iterations_(max_iterations), - iter_variable_names_(std::move(iter_variable_names)), - enable_unknowns_(enable_unknowns), - enable_unknown_function_results_(enable_unknown_function_results), - enable_missing_attribute_errors_(enable_missing_attribute_errors) {} + subexpressions_(std::move(subexpressions)), + comprehension_slots_size_(comprehension_slots_size), + type_provider_(type_provider), + options_(options), + arena_(std::move(arena)) {} // Move-only - CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; - CelExpressionFlatImpl& operator=(const CelExpressionFlatImpl&) = delete; - - std::unique_ptr InitializeState( - google::protobuf::Arena* arena) const override; - - // Implementation of CelExpression evaluate method. - absl::StatusOr Evaluate(const BaseActivation& activation, - google::protobuf::Arena* arena) const override { - return Evaluate(activation, InitializeState(arena).get()); + FlatExpression(FlatExpression&&) = default; + FlatExpression& operator=(FlatExpression&&) = delete; + + // Create new evaluator state instance with the configured options and type + // provider. + FlatExpressionEvaluatorState MakeEvaluatorState( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + // Evaluate the expression. + // + // A status may be returned if an unexpected error occurs. Recoverable errors + // will be represented as a cel::ErrorValue result. + // + // If the listener is not empty, it will be called after each evaluation step + // that correlates to an AST node. The value passed to the will be the top of + // the evaluation stack, corresponding to the result of the subexpression. + absl::StatusOr EvaluateWithCallback( + const cel::ActivationInterface& activation, + const cel::EmbedderContext* absl_nullable embedder_context, + EvaluationListener listener, FlatExpressionEvaluatorState& state) const; + + const ExecutionPath& path() const { return path_; } + + absl::Span subexpressions() const { + return subexpressions_; } - absl::StatusOr Evaluate(const BaseActivation& activation, - CelEvaluationState* state) const override; + const cel::RuntimeOptions& options() const { return options_; } - // Implementation of CelExpression trace method. - absl::StatusOr Trace( - const BaseActivation& activation, google::protobuf::Arena* arena, - CelEvaluationListener callback) const override { - return Trace(activation, InitializeState(arena).get(), callback); - } + size_t comprehension_slots_size() const { return comprehension_slots_size_; } - absl::StatusOr Trace(const BaseActivation& activation, - CelEvaluationState* state, - CelEvaluationListener callback) const override; + const cel::TypeProvider& type_provider() const { return type_provider_; } private: - const ExecutionPath path_; - const int max_iterations_; - const std::set iter_variable_names_; - bool enable_unknowns_; - bool enable_unknown_function_results_; - bool enable_missing_attribute_errors_; + ExecutionPath path_; + std::vector subexpressions_; + size_t comprehension_slots_size_; + const cel::TypeProvider& type_provider_; + cel::RuntimeOptions options_; + // Arena used during planning phase, may hold constant values so should be + // kept alive. + absl_nullable std::shared_ptr arena_; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 5d2d7d6cb..8d61c4659 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -1,119 +1,91 @@ #include "eval/eval/evaluator_core.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "eval/compiler/flat_expr_builder.h" -#include "eval/eval/attribute_trail.h" +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "base/type_provider.h" +#include "common/value.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/internal/interop.h" +#include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" -#include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -using google::api::expr::v1alpha1::Expr; +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { + +using ::cel::IntValue; +using ::cel::TypeProvider; +using ::cel::interop_internal::CreateIntValue; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::expr::Expr; using ::google::api::expr::runtime::RegisterBuiltinFunctions; -using testing::_; -using testing::Eq; -using testing::NotNull; +using ::testing::_; +using ::testing::Eq; // Fake expression implementation -// Pushes int64_t(0) on top of value stack. +// Pushes int64(0) on top of value stack. class FakeConstExpressionStep : public ExpressionStep { public: + FakeConstExpressionStep() : ExpressionStep(0, true) {} + absl::Status Evaluate(ExecutionFrame* frame) const override { - frame->value_stack().Push(CelValue::CreateInt64(0)); + frame->value_stack().Push(CreateIntValue(0)); return absl::OkStatus(); } - - int64_t id() const override { return 0; } - - bool ComesFromAst() const override { return true; } }; // Fake expression implementation // Increments argument on top of the stack. class FakeIncrementExpressionStep : public ExpressionStep { public: + FakeIncrementExpressionStep() : ExpressionStep(0, true) {} + absl::Status Evaluate(ExecutionFrame* frame) const override { - CelValue value = frame->value_stack().Peek(); + auto value = frame->value_stack().Peek(); frame->value_stack().Pop(1); - EXPECT_TRUE(value.IsInt64()); - int64_t val = value.Int64OrDie(); - frame->value_stack().Push(CelValue::CreateInt64(val + 1)); + EXPECT_TRUE(value->Is()); + int64_t val = value.GetInt().NativeValue(); + frame->value_stack().Push(CreateIntValue(val + 1)); return absl::OkStatus(); } - - int64_t id() const override { return 0; } - - bool ComesFromAst() const override { return true; } }; -// Test Value Stack Push/Pop operation -TEST(EvaluatorCoreTest, ValueStackPushPop) { - google::protobuf::Arena arena; - google::api::expr::v1alpha1::Expr expr; - expr.mutable_ident_expr()->set_name("name"); - CelAttribute attribute(expr, {}); - ValueStack stack(10); - stack.Push(CelValue::CreateInt64(1)); - stack.Push(CelValue::CreateInt64(2), AttributeTrail()); - stack.Push(CelValue::CreateInt64(3), AttributeTrail(expr, &arena)); - - ASSERT_EQ(stack.Peek().Int64OrDie(), 3); - ASSERT_THAT(stack.PeekAttribute().attribute(), NotNull()); - ASSERT_EQ(*stack.PeekAttribute().attribute(), attribute); - - stack.Pop(1); - - ASSERT_EQ(stack.Peek().Int64OrDie(), 2); - ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); - - stack.Pop(1); - - ASSERT_EQ(stack.Peek().Int64OrDie(), 1); - ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); -} - -// Test that inner stacks within value stack retain the equality of their sizes. -TEST(EvaluatorCoreTest, ValueStackBalanced) { - ValueStack stack(10); - ASSERT_EQ(stack.size(), stack.attribute_size()); - - stack.Push(CelValue::CreateInt64(1)); - ASSERT_EQ(stack.size(), stack.attribute_size()); - stack.Push(CelValue::CreateInt64(2), AttributeTrail()); - stack.Push(CelValue::CreateInt64(3), AttributeTrail()); - ASSERT_EQ(stack.size(), stack.attribute_size()); - - stack.PopAndPush(CelValue::CreateInt64(4), AttributeTrail()); - ASSERT_EQ(stack.size(), stack.attribute_size()); - stack.PopAndPush(CelValue::CreateInt64(5)); - ASSERT_EQ(stack.size(), stack.attribute_size()); - - stack.Pop(3); - ASSERT_EQ(stack.size(), stack.attribute_size()); -} - TEST(EvaluatorCoreTest, ExecutionFrameNext) { ExecutionPath path; - auto const_step = absl::make_unique(); - auto incr_step1 = absl::make_unique(); - auto incr_step2 = absl::make_unique(); + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + auto const_step = std::make_unique(); + auto incr_step1 = std::make_unique(); + auto incr_step2 = std::make_unique(); path.push_back(std::move(const_step)); path.push_back(std::move(incr_step1)); path.push_back(std::move(incr_step2)); - auto dummy_expr = absl::make_unique(); + auto dummy_expr = std::make_unique(); - Activation activation; - CelExpressionFlatEvaluationState state(path.size(), {}, nullptr); - ExecutionFrame frame(path, activation, 0, &state, false, false, false); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + cel::Activation activation; + FlatExpressionEvaluatorState state( + path.size(), + /*comprehension_slots_size=*/0, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + ExecutionFrame frame(path, activation, options, state); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); @@ -121,81 +93,21 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { EXPECT_THAT(frame.Next(), Eq(nullptr)); } -// Test the set, get, and clear functions for "IterVar" on ExecutionFrame -TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { - const std::string test_key = "test_key"; - const int64_t test_value = 0xF00F00; - - Activation activation; - google::protobuf::Arena arena; - ExecutionPath path; - CelExpressionFlatEvaluationState state(path.size(), {test_key}, nullptr); - ExecutionFrame frame(path, activation, 0, &state, false, false, false); - - CelValue original = CelValue::CreateInt64(test_value); - Expr ident; - ident.mutable_ident_expr()->set_name("var"); - - AttributeTrail original_trail = - AttributeTrail(ident, &arena) - .Step(CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - &arena); - CelValue result; - const AttributeTrail* trail; - - ASSERT_OK(frame.PushIterFrame()); - - // Nothing is there yet - ASSERT_FALSE(frame.GetIterVar(test_key, &result)); - ASSERT_OK(frame.SetIterVar(test_key, original, original_trail)); - - // Make sure its now there - ASSERT_TRUE(frame.GetIterVar(test_key, &result)); - ASSERT_TRUE(frame.GetIterAttr(test_key, &trail)); - - int64_t result_value; - ASSERT_TRUE(result.GetValue(&result_value)); - EXPECT_EQ(test_value, result_value); - ASSERT_TRUE(trail->attribute()->variable().has_ident_expr()); - ASSERT_EQ(trail->attribute()->variable().ident_expr().name(), "var"); - - // Test that it goes away properly - ASSERT_OK(frame.ClearIterVar(test_key)); - ASSERT_FALSE(frame.GetIterVar(test_key, &result)); - ASSERT_FALSE(frame.GetIterAttr(test_key, &trail)); - - // Test that bogus names return the right thing - ASSERT_FALSE(frame.SetIterVar("foo", original).ok()); - ASSERT_FALSE(frame.ClearIterVar("bar").ok()); - - // Test error conditions for accesses outside of comprehension. - ASSERT_OK(frame.SetIterVar(test_key, original)); - ASSERT_OK(frame.PopIterFrame()); - - // Access on empty stack ok, but no value. - ASSERT_FALSE(frame.GetIterVar(test_key, &result)); - - // Pop empty stack - ASSERT_FALSE(frame.PopIterFrame().ok()); - - // Updates on empty stack not ok. - ASSERT_FALSE(frame.SetIterVar(test_key, original).ok()); - ASSERT_FALSE(frame.ClearIterVar(test_key).ok()); -} - TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { ExecutionPath path; - auto const_step = absl::make_unique(); - auto incr_step1 = absl::make_unique(); - auto incr_step2 = absl::make_unique(); + auto const_step = std::make_unique(); + auto incr_step1 = std::make_unique(); + auto incr_step2 = std::make_unique(); path.push_back(std::move(const_step)); path.push_back(std::move(incr_step1)); path.push_back(std::move(incr_step2)); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, FlatExpression(std::move(path), 0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; google::protobuf::Arena arena; @@ -216,7 +128,7 @@ class MockTraceCallback { TEST(EvaluatorCoreTest, TraceTest) { Expr expr; - google::api::expr::v1alpha1::SourceInfo source_info; + cel::expr::SourceInfo source_info; // 1 && [1,2,3].all(x, x > 0) @@ -271,14 +183,12 @@ TEST(EvaluatorCoreTest, TraceTest) { result_expr->set_id(25); result_expr->mutable_const_expr()->set_bool_value(true); - FlatExprBuilder builder; - auto builtin_status = RegisterBuiltinFunctions(builder.GetRegistry()); - ASSERT_OK(builtin_status); - builder.set_shortcircuiting(false); - auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_OK(build_status); - - auto cel_expr = std::move(build_status.value()); + cel::RuntimeOptions options; + options.short_circuiting = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; @@ -311,7 +221,4 @@ TEST(EvaluatorCoreTest, TraceTest) { ASSERT_OK(eval_status); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_stack.cc b/eval/eval/evaluator_stack.cc new file mode 100644 index 000000000..47c625dac --- /dev/null +++ b/eval/eval/evaluator_stack.cc @@ -0,0 +1,92 @@ +#include "eval/eval/evaluator_stack.h" + +#include +#include +#include +#include +#include + +#include "absl/base/dynamic_annotations.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_log.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "internal/new.h" + +namespace google::api::expr::runtime { + +void EvaluatorStack::Grow() { + const size_t new_max_size = std::max(max_size() * 2, size_t{1}); + ABSL_LOG(ERROR) << "evaluation stack is unexpectedly full: growing from " + << max_size() << " to " << new_max_size + << " as a last resort to avoid crashing: this should not " + "have happened so there must be a bug somewhere in " + "the planner or evaluator"; + Reserve(new_max_size); +} + +void EvaluatorStack::Reserve(size_t size) { + static_assert(alignof(cel::Value) <= __STDCPP_DEFAULT_NEW_ALIGNMENT__); + static_assert(alignof(AttributeTrail) <= __STDCPP_DEFAULT_NEW_ALIGNMENT__); + + if (max_size_ >= size) { + return; + } + + void* absl_nullability_unknown data = cel::internal::New(SizeBytes(size)); + + cel::Value* absl_nullability_unknown values_begin = + reinterpret_cast(data); + cel::Value* absl_nullability_unknown values = values_begin; + + AttributeTrail* absl_nullability_unknown attributes_begin = + reinterpret_cast(reinterpret_cast(data) + + AttributesBytesOffset(size)); + AttributeTrail* absl_nullability_unknown attributes = attributes_begin; + + if (max_size_ > 0) { + const size_t n = this->size(); + const size_t m = std::min(n, size); + + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin, values_begin + size, + values_begin + size, values + m); + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin, + attributes_begin + size, + attributes_begin + size, attributes + m); + + for (size_t i = 0; i < m; ++i) { + ::new (static_cast(values++)) + cel::Value(std::move(values_begin_[i])); + ::new (static_cast(attributes++)) + AttributeTrail(std::move(attributes_begin_[i])); + } + std::destroy_n(values_begin_, n); + std::destroy_n(attributes_begin_, n); + + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin_, values_begin_ + max_size_, + values_, values_begin_ + max_size_); + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER( + attributes_begin_, attributes_begin_ + max_size_, attributes_, + attributes_begin_ + max_size_); + + cel::internal::SizedDelete(data_, SizeBytes(max_size_)); + } else { + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin, values_begin + size, + values_begin + size, values); + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin, + attributes_begin + size, + attributes_begin + size, attributes); + } + + values_ = values; + values_begin_ = values_begin; + values_end_ = values_begin + size; + + attributes_ = attributes; + attributes_begin_ = attributes_begin; + + data_ = data; + max_size_ = size; +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_stack.h b/eval/eval/evaluator_stack.h new file mode 100644 index 000000000..b6abd1f76 --- /dev/null +++ b/eval/eval/evaluator_stack.h @@ -0,0 +1,327 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/dynamic_annotations.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/meta/type_traits.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "internal/align.h" +#include "internal/new.h" + +namespace google::api::expr::runtime { + +// CelValue stack. +// Implementation is based on vector to allow passing parameters from +// stack as Span<>. +class EvaluatorStack { + public: + explicit EvaluatorStack(size_t max_size) { Reserve(max_size); } + + EvaluatorStack(const EvaluatorStack&) = delete; + EvaluatorStack(EvaluatorStack&&) = delete; + + ~EvaluatorStack() { + if (max_size() > 0) { + const size_t n = size(); + std::destroy_n(values_begin_, n); + std::destroy_n(attributes_begin_, n); + cel::internal::SizedDelete(data_, SizeBytes(max_size_)); + } + } + + EvaluatorStack& operator=(const EvaluatorStack&) = delete; + EvaluatorStack& operator=(EvaluatorStack&&) = delete; + + // Return the current stack size. + size_t size() const { + ABSL_DCHECK_GE(values_, values_begin_); + ABSL_DCHECK_LE(values_, values_begin_ + max_size_); + ABSL_DCHECK_GE(attributes_, attributes_begin_); + ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); + ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); + + return values_ - values_begin_; + } + + // Return the maximum size of the stack. + size_t max_size() const { + ABSL_DCHECK_GE(values_, values_begin_); + ABSL_DCHECK_LE(values_, values_begin_ + max_size_); + ABSL_DCHECK_GE(attributes_, attributes_begin_); + ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); + ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); + + return max_size_; + } + + // Returns true if stack is empty. + bool empty() const { + ABSL_DCHECK_GE(values_, values_begin_); + ABSL_DCHECK_LE(values_, values_begin_ + max_size_); + ABSL_DCHECK_GE(attributes_, attributes_begin_); + ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); + ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); + + return values_ == values_begin_; + } + + bool full() const { + ABSL_DCHECK_GE(values_, values_begin_); + ABSL_DCHECK_LE(values_, values_begin_ + max_size_); + ABSL_DCHECK_GE(attributes_, attributes_begin_); + ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); + ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); + + return values_ == values_end_; + } + + // Attributes stack size. + ABSL_DEPRECATED("Use size()") + size_t attribute_size() const { return size(); } + + // Check that stack has enough elements. + bool HasEnough(size_t size) const { return this->size() >= size; } + + // Dumps the entire stack state as is. + void Clear() { + if (max_size() > 0) { + const size_t n = size(); + std::destroy_n(values_begin_, n); + std::destroy_n(attributes_begin_, n); + + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER( + values_begin_, values_begin_ + max_size_, values_, values_begin_); + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin_, + attributes_begin_ + max_size_, + attributes_, attributes_begin_); + + values_ = values_begin_; + attributes_ = attributes_begin_; + } + } + + // Gets the last size elements of the stack. + // Checking that stack has enough elements is caller's responsibility. + // Please note that calls to Push may invalidate returned Span object. + absl::Span GetSpan(size_t size) const { + ABSL_DCHECK(HasEnough(size)); + + return absl::Span(values_ - size, size); + } + + // Gets the last size attribute trails of the stack. + // Checking that stack has enough elements is caller's responsibility. + // Please note that calls to Push may invalidate returned Span object. + absl::Span GetAttributeSpan(size_t size) const { + ABSL_DCHECK(HasEnough(size)); + + return absl::Span(attributes_ - size, size); + } + + // Peeks the last element of the stack. + // Checking that stack is not empty is caller's responsibility. + cel::Value& Peek() { + ABSL_DCHECK(HasEnough(1)); + + return *(values_ - 1); + } + + // Peeks the last element of the stack. + // Checking that stack is not empty is caller's responsibility. + const cel::Value& Peek() const { + ABSL_DCHECK(HasEnough(1)); + + return *(values_ - 1); + } + + // Peeks the last element of the attribute stack. + // Checking that stack is not empty is caller's responsibility. + const AttributeTrail& PeekAttribute() const { + ABSL_DCHECK(HasEnough(1)); + + return *(attributes_ - 1); + } + + // Peeks the last element of the attribute stack. + // Checking that stack is not empty is caller's responsibility. + AttributeTrail& PeekAttribute() { + ABSL_DCHECK(HasEnough(1)); + + return *(attributes_ - 1); + } + + void Pop() { + ABSL_DCHECK(!empty()); + + --values_; + values_->~Value(); + --attributes_; + attributes_->~AttributeTrail(); + + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin_, values_begin_ + max_size_, + values_ + 1, values_); + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin_, + attributes_begin_ + max_size_, + attributes_ + 1, attributes_); + } + + // Clears the last size elements of the stack. + // Checking that stack has enough elements is caller's responsibility. + void Pop(size_t size) { + ABSL_DCHECK(HasEnough(size)); + + for (; size > 0; --size) { + Pop(); + } + } + + template , + std::is_convertible>>> + void Push(V&& value, A&& attribute) { + ABSL_DCHECK(!full()); + + if (ABSL_PREDICT_FALSE(full())) { + Grow(); + } + + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin_, values_begin_ + max_size_, + values_, values_ + 1); + ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin_, + attributes_begin_ + max_size_, + attributes_, attributes_ + 1); + + ::new (static_cast(values_++)) cel::Value(std::forward(value)); + ::new (static_cast(attributes_++)) + AttributeTrail(std::forward(attribute)); + } + + template >> + void Push(V&& value) { + ABSL_DCHECK(!full()); + + Push(std::forward(value), absl::nullopt); + } + + // Equivalent to `PopAndPush(1, ...)`. + template , + std::is_convertible>>> + void PopAndPush(V&& value, A&& attribute) { + ABSL_DCHECK(!empty()); + + *(values_ - 1) = std::forward(value); + *(attributes_ - 1) = std::forward(attribute); + } + + // Equivalent to `PopAndPush(1, ...)`. + template >> + void PopAndPush(V&& value) { + ABSL_DCHECK(!empty()); + + PopAndPush(std::forward(value), absl::nullopt); + } + + // Equivalent to `Pop(n)` followed by `Push(...)`. Both `V` and `A` MUST NOT + // be located on the stack. If this is the case, use SwapAndPop instead. + template , + std::is_convertible>>> + void PopAndPush(size_t n, V&& value, A&& attribute) { + if (n > 0) { + if constexpr (std::is_same_v>) { + ABSL_DCHECK(&value < values_begin_ || + &value >= values_begin_ + max_size_) + << "Attmpting to push a value about to be popped, use PopAndSwap " + "instead."; + } + if constexpr (std::is_same_v>) { + ABSL_DCHECK(&attribute < attributes_begin_ || + &attribute >= attributes_begin_ + max_size_) + << "Attmpting to push an attribute about to be popped, use " + "PopAndSwap instead."; + } + + Pop(n - 1); + + ABSL_DCHECK(!empty()); + + *(values_ - 1) = std::forward(value); + *(attributes_ - 1) = std::forward(attribute); + } else { + Push(std::forward(value), std::forward(attribute)); + } + } + + // Equivalent to `Pop(n)` followed by `Push(...)`. `V` MUST NOT be located on + // the stack. If this is the case, use SwapAndPop instead. + template >> + void PopAndPush(size_t n, V&& value) { + PopAndPush(n, std::forward(value), absl::nullopt); + } + + // Swaps the `n - i` element (from the top of the stack) with the `n` element, + // and pops `n - 1` elements. This results in the `n - i` element being at the + // top of the stack. + void SwapAndPop(size_t n, size_t i) { + ABSL_DCHECK_GT(n, 0); + ABSL_DCHECK_LT(i, n); + ABSL_DCHECK(HasEnough(n - 1)); + + using std::swap; + + if (i > 0) { + swap(*(values_ - n), *(values_ - n + i)); + swap(*(attributes_ - n), *(attributes_ - n + i)); + } + Pop(n - 1); + } + + // Update the max size of the stack and update capacity if needed. + void SetMaxSize(size_t size) { Reserve(size); } + + private: + static size_t AttributesBytesOffset(size_t size) { + return cel::internal::AlignUp(sizeof(cel::Value) * size, + __STDCPP_DEFAULT_NEW_ALIGNMENT__); + } + + static size_t SizeBytes(size_t size) { + return AttributesBytesOffset(size) + (sizeof(AttributeTrail) * size); + } + + void Grow(); + + // Preallocate stack. + void Reserve(size_t size); + + cel::Value* absl_nullability_unknown values_ = nullptr; + cel::Value* absl_nullability_unknown values_begin_ = nullptr; + AttributeTrail* absl_nullability_unknown attributes_ = nullptr; + AttributeTrail* absl_nullability_unknown attributes_begin_ = nullptr; + cel::Value* absl_nullability_unknown values_end_ = nullptr; + void* absl_nullability_unknown data_ = nullptr; + size_t max_size_ = 0; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ diff --git a/eval/eval/evaluator_stack_test.cc b/eval/eval/evaluator_stack_test.cc new file mode 100644 index 000000000..9ce862d8a --- /dev/null +++ b/eval/eval/evaluator_stack_test.cc @@ -0,0 +1,70 @@ +#include "eval/eval/evaluator_stack.h" + +#include "base/attribute.h" +#include "common/value.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { + +namespace { + +// Test Value Stack Push/Pop operation +TEST(EvaluatorStackTest, StackPushPop) { + cel::Attribute attribute("name", {}); + EvaluatorStack stack(10); + stack.Push(cel::IntValue(1)); + stack.Push(cel::IntValue(2), AttributeTrail()); + stack.Push(cel::IntValue(3), AttributeTrail("name")); + + ASSERT_EQ(stack.Peek().GetInt().NativeValue(), 3); + ASSERT_FALSE(stack.PeekAttribute().empty()); + ASSERT_EQ(stack.PeekAttribute().attribute(), attribute); + + stack.Pop(1); + + ASSERT_EQ(stack.Peek().GetInt().NativeValue(), 2); + ASSERT_TRUE(stack.PeekAttribute().empty()); + + stack.Pop(1); + + ASSERT_EQ(stack.Peek().GetInt().NativeValue(), 1); + ASSERT_TRUE(stack.PeekAttribute().empty()); +} + +// Test that inner stacks within value stack retain the equality of their sizes. +TEST(EvaluatorStackTest, StackBalanced) { + EvaluatorStack stack(10); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.Push(cel::IntValue(1)); + ASSERT_EQ(stack.size(), stack.attribute_size()); + stack.Push(cel::IntValue(2), AttributeTrail()); + stack.Push(cel::IntValue(3), AttributeTrail()); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.PopAndPush(cel::IntValue(4), AttributeTrail()); + ASSERT_EQ(stack.size(), stack.attribute_size()); + stack.PopAndPush(cel::IntValue(5)); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.Pop(3); + ASSERT_EQ(stack.size(), stack.attribute_size()); +} + +TEST(EvaluatorStackTest, Clear) { + EvaluatorStack stack(10); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.Push(cel::IntValue(1)); + stack.Push(cel::IntValue(2), AttributeTrail()); + stack.Push(cel::IntValue(3), AttributeTrail()); + ASSERT_EQ(stack.size(), 3); + + stack.Clear(); + ASSERT_EQ(stack.size(), 0); + ASSERT_TRUE(stack.empty()); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/eval/expression_build_warning.cc b/eval/eval/expression_build_warning.cc deleted file mode 100644 index 59a54651a..000000000 --- a/eval/eval/expression_build_warning.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "eval/eval/expression_build_warning.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -absl::Status BuilderWarnings::AddWarning(const absl::Status& warning) { - if (fail_immediately_) { - return warning; - } - warnings_.push_back(warning); - return absl::OkStatus(); -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/eval/expression_build_warning.h b/eval/eval/expression_build_warning.h deleted file mode 100644 index 20575abe9..000000000 --- a/eval/eval/expression_build_warning.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ -#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ - -#include - -#include "absl/status/status.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// Container for recording warnings. -class BuilderWarnings { - public: - explicit BuilderWarnings(bool fail_immediately = false) - : fail_immediately_(fail_immediately) {} - - // Add a warning. Returns the util:Status immediately if fail on warning is - // set. - absl::Status AddWarning(const absl::Status& warning); - - // Return the list of recorded warnings. - const std::vector& warnings() const { return warnings_; } - - private: - std::vector warnings_; - bool fail_immediately_; -}; - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ diff --git a/eval/eval/expression_build_warning_test.cc b/eval/eval/expression_build_warning_test.cc deleted file mode 100644 index 212b2e5ae..000000000 --- a/eval/eval/expression_build_warning_test.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "eval/eval/expression_build_warning.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/status/status.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace { - - -TEST(BuilderWarnings, NoFailCollects) { - BuilderWarnings warnings(false); - - auto status = warnings.AddWarning(absl::InternalError("internal")); - EXPECT_TRUE(status.ok()); - auto status2 = warnings.AddWarning(absl::InternalError("internal error 2")); - EXPECT_TRUE(status2.ok()); - - EXPECT_THAT(warnings.warnings(), testing::SizeIs(2)); -} - -TEST(BuilderWarnings, FailReturnsStatus) { - BuilderWarnings warnings(true); - - EXPECT_EQ(warnings.AddWarning(absl::InternalError("internal")).code(), - absl::StatusCode::kInternal); -} - -} // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/eval/expression_step_base.h b/eval/eval/expression_step_base.h index 0fac4d147..5b2f72f8e 100644 --- a/eval/eval/expression_step_base.h +++ b/eval/eval/expression_step_base.h @@ -1,38 +1,12 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_STEP_BASE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_STEP_BASE_H_ -#include - #include "eval/eval/evaluator_core.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { - -class ExpressionStepBase : public ExpressionStep { - public: - explicit ExpressionStepBase(int64_t expr_id, bool comes_from_ast = true) - : id_(expr_id), comes_from_ast_(comes_from_ast) {} - - // Non-copyable - ExpressionStepBase(const ExpressionStepBase&) = delete; - ExpressionStepBase& operator=(const ExpressionStepBase&) = delete; - - // Returns corresponding expression object ID. - int64_t id() const override { return id_; } - - // Returns if the execution step comes from AST. - bool ComesFromAst() const override { return comes_from_ast_; } +namespace google::api::expr::runtime { - private: - int64_t id_; - bool comes_from_ast_; -}; +using ExpressionStepBase = ExpressionStep; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_STEP_BASE_H_ diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 35a2f47e2..12c5af8a7 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -8,74 +8,117 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/arena.h" +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "common/value.h" +#include "common/value_kind.h" #include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_build_warning.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/activation.h" -#include "eval/public/cel_builtins.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_function_provider.h" -#include "eval/public/cel_function_registry.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" -#include "eval/public/unknown_function_result_set.h" -#include "eval/public/unknown_set.h" -#include "base/status_macros.h" -#include "absl/status/statusor.h" +#include "eval/internal/errors.h" +#include "internal/status_macros.h" +#include "runtime/activation_interface.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_provider.h" +#include "runtime/function_registry.h" +#include "runtime/internal/errors.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -// Non-strict functions are allowed to consume errors and UnknownSets. Currently -// only the special function "@not_strictly_false" is allowed to do this. -bool IsNonStrict(const std::string& name) { - return (name == builtin::kNotStrictlyFalse || - name == builtin::kNotStrictlyFalseDeprecated); -} +using ::cel::ErrorValue; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::ValueKindToKind; +using ::cel::runtime_internal::CreateNoMatchingOverloadError; // Determine if the overload should be considered. Overloads that can consume // errors or unknown sets must be allowed as a non-strict function. -bool ShouldAcceptOverload(const CelFunction* function, - absl::Span arguments) { - if (function == nullptr) { +bool ShouldAcceptOverload(const cel::FunctionDescriptor& descriptor, + absl::Span arguments) { + for (size_t i = 0; i < arguments.size(); i++) { + if (arguments[i]->Is() || + arguments[i]->Is()) { + return !descriptor.is_strict(); + } + } + return true; +} + +bool ArgumentKindsMatch(const cel::FunctionDescriptor& descriptor, + absl::Span arguments) { + auto types_size = descriptor.types().size(); + + if (types_size != arguments.size()) { return false; } - for (size_t i = 0; i < arguments.size(); i++) { - if (arguments[i].IsUnknownSet() || arguments[i].IsError()) { - return IsNonStrict(function->descriptor().name()); + + for (size_t i = 0; i < types_size; i++) { + const auto& arg = arguments[i]; + cel::Kind param_kind = descriptor.types()[i]; + if (arg->kind() != param_kind && param_kind != cel::Kind::kAny) { + return false; } } + return true; } +// Adjust new type names to legacy equivalent. int -> int64. +// Temporary fix to migrate value types without breaking clients. +// TODO(uncreated-issue/46): Update client tests that depend on this value. +std::string ToLegacyKindName(absl::string_view type_name) { + if (type_name == "int" || type_name == "uint") { + return absl::StrCat(type_name, "64"); + } + + return std::string(type_name); +} + +std::string CallArgTypeString(absl::Span args) { + std::string call_sig_string = ""; + + for (size_t i = 0; i < args.size(); i++) { + const auto& arg = args[i]; + if (!call_sig_string.empty()) { + absl::StrAppend(&call_sig_string, ", "); + } + absl::StrAppend( + &call_sig_string, + ToLegacyKindName(cel::KindToString(ValueKindToKind(arg->kind())))); + } + return absl::StrCat("(", call_sig_string, ")"); +} + // Convert partially unknown arguments to unknowns before passing to the // function. // TODO(issues/52): See if this can be refactored to remove the eager // arguments copy. // Argument and attribute spans are expected to be equal length. -std::vector CheckForPartialUnknowns( - ExecutionFrame* frame, absl::Span args, +std::vector CheckForPartialUnknowns( + ExecutionFrame* frame, absl::Span args, absl::Span attrs) { - std::vector result; + std::vector result; result.reserve(args.size()); for (size_t i = 0; i < args.size(); i++) { - auto attr_set = frame->attribute_utility().CheckForUnknowns( - attrs.subspan(i, 1), /*use_partial=*/true); - if (!attr_set.attributes().empty()) { - auto unknown_set = google::protobuf::Arena::Create(frame->arena(), - std::move(attr_set)); - result.push_back(CelValue::CreateUnknownSet(unknown_set)); + const AttributeTrail& trail = attrs.subspan(i, 1)[0]; + + if (frame->attribute_utility().CheckForUnknown(trail, + /*use_partial=*/true)) { + result.push_back( + frame->attribute_utility().CreateUnknownSet(trail.attribute())); } else { result.push_back(args.at(i)); } @@ -84,6 +127,25 @@ std::vector CheckForPartialUnknowns( return result; } +bool IsUnknownFunctionResultError(const Value& result) { + if (!result->Is()) { + return false; + } + + const auto& status = result.GetError().NativeValue(); + + if (status.code() != absl::StatusCode::kUnavailable) { + return false; + } + auto payload = status.GetPayload( + cel::runtime_internal::kPayloadUrlUnknownFunctionResult); + return payload.has_value() && payload.value() == "true"; +} + +// Simple wrapper around a function resolution result. A function call should +// resolve to a single function implementation and a descriptor or none. +using ResolveResult = absl::optional; + // Implementation of ExpressionStep that finds suitable CelFunction overload and // invokes it. Abstract base class standardizes behavior between lazy and eager // function bindings. Derived classes provide ResolveFunction behavior. @@ -91,29 +153,100 @@ class AbstractFunctionStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. AbstractFunctionStep(const std::string& name, size_t num_arguments, - int64_t expr_id) + bool receiver_style, int64_t expr_id) : ExpressionStepBase(expr_id), name_(name), - num_arguments_(num_arguments) {} + num_arguments_(num_arguments), + receiver_style_(receiver_style) {} absl::Status Evaluate(ExecutionFrame* frame) const override; - absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + // Handles overload resolution and updating result appropriately. + // Shouldn't update frame state. + // + // A non-ok result is an unrecoverable error, either from an illegal + // evaluation state or forwarded from an extension function. Errors where + // evaluation can reasonably condition are returned in the result as a + // cel::ErrorValue. + absl::StatusOr DoEvaluate(ExecutionFrame* frame) const; - virtual absl::StatusOr ResolveFunction( - absl::Span args, const ExecutionFrame* frame) const = 0; + virtual absl::StatusOr ResolveFunction( + absl::Span args, const ExecutionFrame* frame) const = 0; protected: std::string name_; size_t num_arguments_; + bool receiver_style_; }; -absl::Status AbstractFunctionStep::DoEvaluate(ExecutionFrame* frame, - CelValue* result) const { +inline absl::StatusOr Invoke( + const cel::FunctionOverloadReference& overload, int64_t expr_id, + absl::Span args, ExecutionFrameBase& frame) { + cel::Function::InvokeContext context(frame.descriptor_pool(), + frame.message_factory(), frame.arena()); + if (overload.descriptor.is_contextual()) { + context.set_embedder_context(frame.embedder_context()); + } + + CEL_ASSIGN_OR_RETURN(Value result, + overload.implementation.Invoke(args, context)); + + if (frame.unknown_function_results_enabled() && + IsUnknownFunctionResultError(result)) { + return frame.attribute_utility().CreateUnknownSet(overload.descriptor, + expr_id, args); + } + return result; +} + +Value NoOverloadResult(absl::string_view name, + absl::Span args, bool receiver_style, + ExecutionFrameBase& frame) { + // No matching overloads. + // Such absence can be caused by presence of CelError in arguments. + // To enable behavior of functions that accept CelError( &&, || ), CelErrors + // should be propagated along execution path. + for (size_t i = 0; i < args.size(); i++) { + const auto& arg = args[i]; + if (cel::InstanceOf(arg)) { + return arg; + } + } + + if (frame.unknown_processing_enabled()) { + // Already converted partial unknowns to unknown sets so just merge. + absl::optional unknown_set = + frame.attribute_utility().MergeUnknowns(args); + if (unknown_set.has_value()) { + return *unknown_set; + } + } + + // If no errors or unknowns in input args, create new CelError for missing + // overload. + std::string signature; + if (receiver_style) { + if (args.empty()) { + // Should not be possible, but return a sensible error in case of logic + // error. + return ErrorValue( + CreateNoMatchingOverloadError(absl::StrCat("().", name, "()"))); + } + return ErrorValue(CreateNoMatchingOverloadError(absl::StrCat( + "(", + ToLegacyKindName(cel::KindToString(ValueKindToKind(args[0].kind()))), + ").", name, CallArgTypeString(args.subspan(1))))); + } + return cel::ErrorValue(CreateNoMatchingOverloadError( + absl::StrCat(name, CallArgTypeString(args)))); +} + +absl::StatusOr AbstractFunctionStep::DoEvaluate( + ExecutionFrame* frame) const { // Create Span object that contains input arguments to the function. auto input_args = frame->value_stack().GetSpan(num_arguments_); - std::vector unknowns_args; + std::vector unknowns_args; // Preprocess args. If an argument is partially unknown, convert it to an // unknown attribute set. if (frame->enable_unknowns()) { @@ -123,57 +256,16 @@ absl::Status AbstractFunctionStep::DoEvaluate(ExecutionFrame* frame, } // Derived class resolves to a single function overload or none. - auto status = ResolveFunction(input_args, frame); - if (!status.ok()) { - return status.status(); - } - const CelFunction* matched_function = status.value(); + CEL_ASSIGN_OR_RETURN(ResolveResult matched_function, + ResolveFunction(input_args, frame)); // Overload found and is allowed to consume the arguments. - if (ShouldAcceptOverload(matched_function, input_args)) { - absl::Status status = - matched_function->Evaluate(input_args, result, frame->arena()); - if (!status.ok()) { - return status; - } - if (frame->enable_unknown_function_results() && - IsUnknownFunctionResult(*result)) { - const auto* function_result = - google::protobuf::Arena::Create( - frame->arena(), matched_function->descriptor(), id(), - std::vector(input_args.begin(), input_args.end())); - const auto* unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownFunctionResultSet(function_result)); - *result = CelValue::CreateUnknownSet(unknown_set); - } - } else { - // No matching overloads. - // We should not treat absense of overloads as non-recoverable error. - // Such absence can be caused by presence of CelError in arguments. - // To enable behavior of functions that accept CelError( &&, || ), CelErrors - // should be propagated along execution path. - for (const CelValue& arg : input_args) { - if (arg.IsError()) { - *result = arg; - return absl::OkStatus(); - } - } - - if (frame->enable_unknowns()) { - // Already converted partial unknowns to unknown sets so just merge. - auto unknown_set = - frame->attribute_utility().MergeUnknowns(input_args, nullptr); - if (unknown_set != nullptr) { - *result = CelValue::CreateUnknownSet(unknown_set); - return absl::OkStatus(); - } - } - - // If no errors or unknowns in input args, create new CelError. - *result = CreateNoMatchingOverloadError(frame->arena()); + if (matched_function.has_value() && + ShouldAcceptOverload(matched_function->descriptor, input_args)) { + return Invoke(*matched_function, id(), input_args, *frame); } - return absl::OkStatus(); + return NoOverloadResult(name_, input_args, receiver_style_, *frame); } absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { @@ -181,139 +273,257 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } - CelValue result; - auto status = DoEvaluate(frame, &result); - if (!status.ok()) { - return status; - } + // DoEvaluate may return a status for non-recoverable errors (e.g. + // unexpected typing, illegal expression state). Application errors that can + // reasonably be handled as a cel error will appear in the result value. + CEL_ASSIGN_OR_RETURN(auto result, DoEvaluate(frame)); - frame->value_stack().Pop(num_arguments_); - frame->value_stack().Push(result); + frame->value_stack().PopAndPush(num_arguments_, std::move(result)); return absl::OkStatus(); } -class EagerFunctionStep : public AbstractFunctionStep { - public: - EagerFunctionStep(std::vector&& overloads, - const std::string& name, size_t num_args, int64_t expr_id) - : AbstractFunctionStep(name, num_args, expr_id), overloads_(overloads) {} +absl::StatusOr ResolveStatic( + absl::Span input_args, + absl::Span overloads) { + for (const auto& overload : overloads) { + if (ArgumentKindsMatch(overload.descriptor, input_args)) { + return overload; + } + } + return std::nullopt; +} - absl::StatusOr ResolveFunction( - absl::Span input_args, - const ExecutionFrame* frame) const override; +absl::StatusOr ResolveLazy( + absl::Span input_args, absl::string_view name, + bool receiver_style, + absl::Span providers, + const ExecutionFrameBase& frame) { + ResolveResult result = std::nullopt; - private: - std::vector overloads_; -}; + std::vector arg_types(input_args.size()); -absl::StatusOr EagerFunctionStep::ResolveFunction( - absl::Span input_args, const ExecutionFrame* frame) const { - const CelFunction* matched_function = nullptr; + std::transform( + input_args.begin(), input_args.end(), arg_types.begin(), + [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); - for (auto overload : overloads_) { - if (overload->MatchArguments(input_args)) { + cel::FunctionDescriptor matcher{name, receiver_style, std::move(arg_types)}; + + const cel::ActivationInterface& activation = frame.activation(); + for (auto provider : providers) { + // The LazyFunctionStep has so far only resolved by function shape, check + // that the runtime argument kinds agree with the specific descriptor for + // the provider candidates. + if (!ArgumentKindsMatch(provider.descriptor, input_args)) { + continue; + } + + CEL_ASSIGN_OR_RETURN(auto overload, + provider.provider.GetFunction(matcher, activation)); + if (overload.has_value()) { // More than one overload matches our arguments. - if (matched_function != nullptr) { + if (result.has_value()) { return absl::Status(absl::StatusCode::kInternal, "Cannot resolve overloads"); } - matched_function = overload; + result.emplace(overload.value()); } } - return matched_function; + + return result; } +class EagerFunctionStep : public AbstractFunctionStep { + public: + EagerFunctionStep(std::vector overloads, + const std::string& name, size_t num_args, + bool receiver_style, int64_t expr_id) + : AbstractFunctionStep(name, num_args, receiver_style, expr_id), + overloads_(std::move(overloads)) {} + + absl::StatusOr ResolveFunction( + absl::Span input_args, + const ExecutionFrame* frame) const override { + return ResolveStatic(input_args, overloads_); + } + + private: + std::vector overloads_; +}; + class LazyFunctionStep : public AbstractFunctionStep { public: // Constructs LazyFunctionStep that attempts to lookup function implementation // at runtime. LazyFunctionStep(const std::string& name, size_t num_args, bool receiver_style, - const std::vector& providers, + std::vector providers, int64_t expr_id) - : AbstractFunctionStep(name, num_args, expr_id), - receiver_style_(receiver_style), - providers_(providers) {} + : AbstractFunctionStep(name, num_args, receiver_style, expr_id), + providers_(std::move(providers)) {} - absl::StatusOr ResolveFunction( - absl::Span input_args, + absl::StatusOr ResolveFunction( + absl::Span input_args, const ExecutionFrame* frame) const override; private: - bool receiver_style_; - std::vector providers_; + std::vector providers_; }; -absl::StatusOr LazyFunctionStep::ResolveFunction( - absl::Span input_args, const ExecutionFrame* frame) const { - const CelFunction* matched_function = nullptr; +absl::StatusOr LazyFunctionStep::ResolveFunction( + absl::Span input_args, + const ExecutionFrame* frame) const { + return ResolveLazy(input_args, name_, receiver_style_, providers_, *frame); +} + +class StaticResolver { + public: + explicit StaticResolver(std::vector overloads) + : overloads_(std::move(overloads)) {} + + absl::StatusOr Resolve(ExecutionFrameBase& frame, + absl::Span input) const { + return ResolveStatic(input, overloads_); + } - std::vector arg_types(num_arguments_); + private: + std::vector overloads_; +}; - std::transform(input_args.begin(), input_args.end(), arg_types.begin(), - [](const CelValue& value) { return value.type(); }); +class LazyResolver { + public: + explicit LazyResolver( + std::vector providers, + std::string name, bool receiver_style) + : providers_(std::move(providers)), + name_(std::move(name)), + receiver_style_(receiver_style) {} + + absl::StatusOr Resolve(ExecutionFrameBase& frame, + absl::Span input) const { + return ResolveLazy(input, name_, receiver_style_, providers_, frame); + } - CelFunctionDescriptor matcher{name_, receiver_style_, arg_types}; + private: + std::vector providers_; + std::string name_; + bool receiver_style_; +}; - const BaseActivation& activation = frame->activation(); - for (auto provider : providers_) { - auto status = provider->GetFunction(matcher, activation); - if (!status.ok()) { - return status; +template +class DirectFunctionStepImpl : public DirectExpressionStep { + public: + DirectFunctionStepImpl( + int64_t expr_id, const std::string& name, + std::vector> arg_steps, + bool receiver_style, Resolver&& resolver) + : DirectExpressionStep(expr_id), + name_(name), + arg_steps_(std::move(arg_steps)), + receiver_style_(receiver_style), + resolver_(std::forward(resolver)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& trail) const override { + absl::InlinedVector args; + absl::InlinedVector arg_trails; + + args.resize(arg_steps_.size()); + arg_trails.resize(arg_steps_.size()); + + for (size_t i = 0; i < arg_steps_.size(); i++) { + CEL_RETURN_IF_ERROR( + arg_steps_[i]->Evaluate(frame, args[i], arg_trails[i])); } - auto overload = status.value(); - if (overload != nullptr && overload->MatchArguments(input_args)) { - // More than one overload matches our arguments. - if (matched_function != nullptr) { - return absl::Status(absl::StatusCode::kInternal, - "Cannot resolve overloads"); + + if (frame.unknown_processing_enabled()) { + for (size_t i = 0; i < arg_trails.size(); i++) { + if (frame.attribute_utility().CheckForUnknown(arg_trails[i], + /*use_partial=*/true)) { + args[i] = frame.attribute_utility().CreateUnknownSet( + arg_trails[i].attribute()); + } } + } + + CEL_ASSIGN_OR_RETURN(ResolveResult resolved_function, + resolver_.Resolve(frame, args)); + + if (resolved_function.has_value() && + ShouldAcceptOverload(resolved_function->descriptor, args)) { + CEL_ASSIGN_OR_RETURN(result, + Invoke(*resolved_function, expr_id_, args, frame)); - matched_function = overload; + return absl::OkStatus(); } + + result = NoOverloadResult(name_, args, receiver_style_, frame); + + return absl::OkStatus(); } - return matched_function; -} + absl::optional> GetDependencies() + const override { + std::vector dependencies; + dependencies.reserve(arg_steps_.size()); + for (const auto& arg_step : arg_steps_) { + dependencies.push_back(arg_step.get()); + } + return dependencies; + } + + absl::optional>> + ExtractDependencies() override { + return std::move(arg_steps_); + } + + private: + friend Resolver; + std::string name_; + std::vector> arg_steps_; + bool receiver_style_; + Resolver resolver_; +}; } // namespace -absl::StatusOr> CreateFunctionStep( - const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id, - const CelFunctionRegistry& function_registry, - BuilderWarnings* builder_warnings) { - bool receiver_style = call_expr->has_target(); - size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); - const std::string& name = call_expr->function(); - - std::vector args(num_args, CelValue::Type::kAny); - - std::vector lazy_overloads = - function_registry.FindLazyOverloads(name, receiver_style, args); - - if (!lazy_overloads.empty()) { - std::unique_ptr step = absl::make_unique( - name, num_args, receiver_style, lazy_overloads, expr_id); - return std::move(step); - } +std::unique_ptr CreateDirectFunctionStep( + int64_t expr_id, const cel::CallExpr& call, + std::vector> deps, + std::vector overloads) { + return std::make_unique>( + expr_id, call.function(), std::move(deps), call.has_target(), + StaticResolver(std::move(overloads))); +} - auto overloads = function_registry.FindOverloads(name, receiver_style, args); +std::unique_ptr CreateDirectLazyFunctionStep( + int64_t expr_id, const cel::CallExpr& call, + std::vector> deps, + std::vector providers) { + return std::make_unique>( + expr_id, call.function(), std::move(deps), call.has_target(), + LazyResolver(std::move(providers), call.function(), call.has_target())); +} - // No overloads found. - if (overloads.empty()) { - RETURN_IF_ERROR(builder_warnings->AddWarning( - absl::Status(absl::StatusCode::kInvalidArgument, - "No overloads provided for FunctionStep creation"))); - } +absl::StatusOr> CreateFunctionStep( + const cel::CallExpr& call_expr, int64_t expr_id, + std::vector lazy_overloads) { + bool receiver_style = call_expr.has_target(); + size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); + const std::string& name = call_expr.function(); + return std::make_unique(name, num_args, receiver_style, + std::move(lazy_overloads), expr_id); +} - std::unique_ptr step = absl::make_unique( - std::move(overloads), name, num_args, expr_id); - return std::move(step); +absl::StatusOr> CreateFunctionStep( + const cel::CallExpr& call_expr, int64_t expr_id, + std::vector overloads) { + bool receiver_style = call_expr.has_target(); + size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); + const std::string& name = call_expr.function(); + return std::make_unique(std::move(overloads), name, + num_args, receiver_style, expr_id); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index 5d02fc15c..9f664dc09 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -1,34 +1,48 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ -#include - +#include #include +#include -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_build_warning.h" -#include "eval/public/activation.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_function_registry.h" -#include "eval/public/cel_value.h" #include "absl/status/statusor.h" +#include "common/expr.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_registry.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { +// Factory method for Call-based execution step where the function has been +// statically resolved from a set of eagerly functions configured in the +// CelFunctionRegistry. +std::unique_ptr CreateDirectFunctionStep( + int64_t expr_id, const cel::CallExpr& call, + std::vector> deps, + std::vector overloads); -// Factory method for Call - based Execution step -// Looks up function registry using data provided through Call parameter. +// Factory method for Call-based execution step where the function has been +// statically resolved from a set of lazy functions configured in the +// CelFunctionRegistry. +std::unique_ptr CreateDirectLazyFunctionStep( + int64_t expr_id, const cel::CallExpr& call, + std::vector> deps, + std::vector providers); + +// Factory method for Call-based execution step where the function will be +// resolved at runtime (lazily) from an input Activation. +absl::StatusOr> CreateFunctionStep( + const cel::CallExpr& call, int64_t expr_id, + std::vector lazy_overloads); + +// Factory method for Call-based execution step where the function has been +// statically resolved from a set of eagerly functions configured in the +// CelFunctionRegistry. absl::StatusOr> CreateFunctionStep( - const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id, - const CelFunctionRegistry& function_registry, - BuilderWarnings* builder_warnings); - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + const cel::CallExpr& call, int64_t expr_id, + std::vector overloads); + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 296ecdd04..3d3bae34d 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -1,35 +1,61 @@ #include "eval/eval/function_step.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/memory/memory.h" +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/builtins.h" +#include "base/type_provider.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/kind.h" +#include "common/value.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_build_warning.h" #include "eval/eval/ident_step.h" +#include "eval/internal/interop.h" +#include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" +#include "eval/public/portable_cel_function_adapter.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "eval/public/unknown_function_result_set.h" +#include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" -#include "base/status_macros.h" +#include "internal/testing.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_registry.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using testing::ElementsAre; -using testing::Eq; -using testing::Not; -using testing::UnorderedElementsAre; - -using google::api::expr::v1alpha1::Expr; +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::CallExpr; +using ::cel::Expr; +using ::cel::IdentExpr; +using ::cel::TypeProvider; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::testing::Eq; +using ::testing::Not; +using ::testing::Truly; int GetExprId() { static int id = 0; @@ -44,13 +70,13 @@ class ConstFunction : public CelFunction { : CelFunction(CreateDescriptor(name)), value_(value) {} static CelFunctionDescriptor CreateDescriptor(absl::string_view name) { - return CelFunctionDescriptor{std::string(name), false, {}}; + return CelFunctionDescriptor{name, false, {}}; } - static Expr::Call MakeCall(absl::string_view name) { - Expr::Call call; - call.set_function(name.data()); - call.clear_target(); + static CallExpr MakeCall(absl::string_view name) { + CallExpr call; + call.set_function(std::string(name)); + call.set_target(nullptr); return call; } @@ -84,12 +110,12 @@ class AddFunction : public CelFunction { "_+_", false, {CelValue::Type::kInt64, CelValue::Type::kInt64}}; } - static Expr::Call MakeCall() { - Expr::Call call; + static CallExpr MakeCall() { + CallExpr call; call.set_function("_+_"); - call.add_args(); - call.add_args(); - call.clear_target(); + call.mutable_args().emplace_back(); + call.mutable_args().emplace_back(); + call.set_target(nullptr); return call; } @@ -118,17 +144,19 @@ class AddFunction : public CelFunction { class SinkFunction : public CelFunction { public: - SinkFunction(CelValue::Type type) : CelFunction(CreateDescriptor(type)) {} + explicit SinkFunction(CelValue::Type type, bool is_strict = true) + : CelFunction(CreateDescriptor(type, is_strict)) {} - static CelFunctionDescriptor CreateDescriptor(CelValue::Type type) { - return CelFunctionDescriptor{"Sink", false, {type}}; + static CelFunctionDescriptor CreateDescriptor(CelValue::Type type, + bool is_strict = true) { + return CelFunctionDescriptor{"Sink", false, {type}, is_strict}; } - static Expr::Call MakeCall() { - Expr::Call call; + static CallExpr MakeCall() { + CallExpr call; call.set_function("Sink"); - call.add_args(); - call.clear_target(); + call.mutable_args().emplace_back(); + call.set_target(nullptr); return call; } @@ -144,223 +172,252 @@ class SinkFunction : public CelFunction { void AddDefaults(CelFunctionRegistry& registry) { static UnknownSet* unknown_set = new UnknownSet(); EXPECT_TRUE(registry - .Register(absl::make_unique( + .Register(std::make_unique( CelValue::CreateInt64(3), "Const3")) .ok()); EXPECT_TRUE(registry - .Register(absl::make_unique( + .Register(std::make_unique( CelValue::CreateInt64(2), "Const2")) .ok()); EXPECT_TRUE(registry - .Register(absl::make_unique( + .Register(std::make_unique( CelValue::CreateUnknownSet(unknown_set), "ConstUnknown")) .ok()); - EXPECT_TRUE(registry.Register(absl::make_unique()).ok()); + EXPECT_TRUE(registry.Register(std::make_unique()).ok()); EXPECT_TRUE( - registry.Register(absl::make_unique(CelValue::Type::kList)) + registry.Register(std::make_unique(CelValue::Type::kList)) .ok()); EXPECT_TRUE( - registry.Register(absl::make_unique(CelValue::Type::kMap)) + registry.Register(std::make_unique(CelValue::Type::kMap)) .ok()); EXPECT_TRUE( registry - .Register(absl::make_unique(CelValue::Type::kMessage)) + .Register(std::make_unique(CelValue::Type::kMessage)) .ok()); } +std::vector ArgumentMatcher(int argument_count) { + std::vector argument_matcher(argument_count); + for (int i = 0; i < argument_count; i++) { + argument_matcher[i] = CelValue::Type::kAny; + } + return argument_matcher; +} + +std::vector ArgumentMatcher(const CallExpr& call) { + return ArgumentMatcher(call.has_target() ? call.args().size() + 1 + : call.args().size()); +} + +std::unique_ptr CreateExpressionImpl( + const cel::RuntimeOptions& options, + std::unique_ptr expr) { + ExecutionPath path; + path.push_back(std::make_unique(std::move(expr), -1)); + + auto env = NewTestingRuntimeEnv(); + return std::make_unique( + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); +} + +absl::StatusOr> MakeTestFunctionStep( + const CallExpr& call, const CelFunctionRegistry& registry) { + auto argument_matcher = ArgumentMatcher(call); + auto lazy_overloads = registry.ModernFindLazyOverloads( + call.function(), call.has_target(), argument_matcher); + if (!lazy_overloads.empty()) { + return CreateFunctionStep(call, GetExprId(), lazy_overloads); + } + auto overloads = registry.FindStaticOverloads( + call.function(), call.has_target(), argument_matcher); + return CreateFunctionStep(call, GetExprId(), overloads); +} + // Test common functions with varying levels of unknown support. class FunctionStepTest : public testing::TestWithParam { public: // underlying expression impl moves path std::unique_ptr GetExpression(ExecutionPath&& path) { - bool unknowns; - bool unknown_function_results; - switch (GetParam()) { - case UnknownProcessingOptions::kAttributeAndFunction: - unknowns = true; - unknown_function_results = true; - break; - case UnknownProcessingOptions::kAttributeOnly: - unknowns = true; - unknown_function_results = false; - break; - case UnknownProcessingOptions::kDisabled: - unknowns = false; - unknown_function_results = false; - break; - } - return absl::make_unique( - &dummy_expr_, std::move(path), 0, std::set(), unknowns, - unknown_function_results); + cel::RuntimeOptions options; + options.unknown_processing = GetParam(); + + auto env = NewTestingRuntimeEnv(); + return std::make_unique( + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); } - - private: - Expr dummy_expr_; }; TEST_P(FunctionStepTest, SimpleFunctionTest) { ExecutionPath path; - BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); - Expr::Call call1 = ConstFunction::MakeCall("Const3"); - Expr::Call call2 = ConstFunction::MakeCall("Const2"); - Expr::Call add_call = AddFunction::MakeCall(); + CallExpr call1 = ConstFunction::MakeCall("Const3"); + CallExpr call2 = ConstFunction::MakeCall("Const2"); + CallExpr add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsInt64()); EXPECT_THAT(value.Int64OrDie(), Eq(5)); } TEST_P(FunctionStepTest, TestStackUnderflow) { ExecutionPath path; - BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); AddFunction add_func; - Expr::Call call1 = ConstFunction::MakeCall("Const3"); - Expr::Call add_call = AddFunction::MakeCall(); + CallExpr call1 = ConstFunction::MakeCall("Const3"); + CallExpr add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); - ASSERT_OK(step0_status); - ASSERT_OK(step2_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step2_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - EXPECT_FALSE(status.ok()); + EXPECT_THAT(impl->Evaluate(activation, &arena), Not(IsOk())); } -// Test that creation fails if fail on warnings is set in the warnings -// collection. -TEST(FunctionStepTest, TestNoOverloadsOnCreation) { +// Test situation when no overloads match input arguments during evaluation. +TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { + ExecutionPath path; + CelFunctionRegistry registry; - BuilderWarnings warnings(true); + AddDefaults(registry); - Expr::Call call = ConstFunction::MakeCall("Const0"); + ASSERT_TRUE(registry + .Register(std::make_unique( + CelValue::CreateUint64(4), "Const4")) + .ok()); + + CallExpr call1 = ConstFunction::MakeCall("Const3"); + CallExpr call2 = ConstFunction::MakeCall("Const4"); + // Add expects {int64, int64} but it's {int64, uint64}. + CallExpr add_call = AddFunction::MakeCall(); - // function step with empty overloads - auto step0_status = - CreateFunctionStep(&call, GetExprId(), registry, &warnings); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + + std::unique_ptr impl = GetExpression(std::move(path)); - EXPECT_FALSE(step0_status.ok()); + Activation activation; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT(*value.ErrorOrDie(), + StatusIs(absl::StatusCode::kUnknown, + testing::HasSubstr("_+_(int64, uint64)"))); } -// Test that no overloads error is warned, actual error delayed to runtime by -// default. -TEST_P(FunctionStepTest, TestNoOverloadsOnCreationDelayedError) { - CelFunctionRegistry registry; +TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationReceiver) { ExecutionPath path; - Expr::Call call = ConstFunction::MakeCall("Const0"); - BuilderWarnings warnings; - // function step with empty overloads - auto step0_status = - CreateFunctionStep(&call, GetExprId(), registry, &warnings); + CelFunctionRegistry registry; + AddDefaults(registry); + + CallExpr call1 = ConstFunction::MakeCall("Const3"); + CallExpr call2 = ConstFunction::MakeCall("Const3"); + // Add expects {int64, int64} but it's {int64, uint64}. + CallExpr add_call; + add_call.add_args(); + add_call.set_target(Expr()); + add_call.set_function("_+_"); - EXPECT_TRUE(step0_status.ok()); - EXPECT_THAT(warnings.warnings(), testing::SizeIs(1)); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); - path.push_back(std::move(step0_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); + EXPECT_THAT(*value.ErrorOrDie(), + StatusIs(absl::StatusCode::kUnknown, + testing::HasSubstr("(int64)._+_(int64)"))); } // Test situation when no overloads match input arguments during evaluation. -TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { +TEST_P(FunctionStepTest, TestNoMatchingOverloadsUnexpectedArgCount) { ExecutionPath path; - BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); - ASSERT_TRUE(registry - .Register(absl::make_unique( - CelValue::CreateUint64(4), "Const4")) - .ok()); + CallExpr call1 = ConstFunction::MakeCall("Const3"); - Expr::Call call1 = ConstFunction::MakeCall("Const3"); - Expr::Call call2 = ConstFunction::MakeCall("Const4"); - // Add expects {int64_t, int64_t} but it's {int64_t, uint64_t}. - Expr::Call add_call = AddFunction::MakeCall(); + // expect overloads for {int64, int64} but get call for {int64, int64, int64}. + CallExpr add_call = AddFunction::MakeCall(); + add_call.mutable_args().emplace_back(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(call1, registry)); - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); + ASSERT_OK_AND_ASSIGN( + auto step3, + CreateFunctionStep(add_call, -1, + registry.FindStaticOverloads( + add_call.function(), false, + {cel::Kind::kInt64, cel::Kind::kInt64}))); - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + path.push_back(std::move(step3)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); + EXPECT_THAT(*value.ErrorOrDie(), + StatusIs(absl::StatusCode::kUnknown, + testing::HasSubstr("_+_(int64, int64, int64)"))); } // Test situation when no overloads match input arguments during evaluation @@ -368,109 +425,141 @@ TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; AddDefaults(registry); - CelError error0; - CelError error1; + CelError error0 = absl::CancelledError(); + CelError error1 = absl::CancelledError(); // Constants have ERROR type, while AddFunction expects INT. ASSERT_TRUE(registry - .Register(absl::make_unique( + .Register(std::make_unique( CelValue::CreateError(&error0), "ConstError1")) .ok()); ASSERT_TRUE(registry - .Register(absl::make_unique( + .Register(std::make_unique( CelValue::CreateError(&error1), "ConstError2")) .ok()); - Expr::Call call1 = ConstFunction::MakeCall("ConstError1"); - Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); - Expr::Call add_call = AddFunction::MakeCall(); + CallExpr call1 = ConstFunction::MakeCall("ConstError1"); + CallExpr call2 = ConstFunction::MakeCall("ConstError2"); + CallExpr add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); - EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); + EXPECT_THAT(*value.ErrorOrDie(), Eq(error0)); } TEST_P(FunctionStepTest, LazyFunctionTest) { ExecutionPath path; Activation activation; CelFunctionRegistry registry; - BuilderWarnings warnings; - - auto register0_status = - registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const3")); - ASSERT_OK(register0_status); - auto insert0_status = activation.InsertFunction( - absl::make_unique(CelValue::CreateInt64(3), "Const3")); - ASSERT_OK(insert0_status); - auto register1_status = - registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const2")); - ASSERT_OK(register1_status); - auto insert1_status = activation.InsertFunction( - absl::make_unique(CelValue::CreateInt64(2), "Const2")); - ASSERT_OK(insert1_status); - ASSERT_OK(registry.Register(absl::make_unique())); - - Expr::Call call1 = ConstFunction::MakeCall("Const3"); - Expr::Call call2 = ConstFunction::MakeCall("Const2"); - Expr::Call add_call = AddFunction::MakeCall(); - - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + ASSERT_OK( + registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const3"))); + ASSERT_OK(activation.InsertFunction( + std::make_unique(CelValue::CreateInt64(3), "Const3"))); + ASSERT_OK( + registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const2"))); + ASSERT_OK(activation.InsertFunction( + std::make_unique(CelValue::CreateInt64(2), "Const2"))); + ASSERT_OK(registry.Register(std::make_unique())); + + CallExpr call1 = ConstFunction::MakeCall("Const3"); + CallExpr call2 = ConstFunction::MakeCall("Const2"); + CallExpr add_call = AddFunction::MakeCall(); + + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsInt64()); EXPECT_THAT(value.Int64OrDie(), Eq(5)); } +TEST_P(FunctionStepTest, LazyFunctionOverloadingTest) { + ExecutionPath path; + Activation activation; + CelFunctionRegistry registry; + auto floor_int = PortableUnaryFunctionAdapter::Create( + "Floor", false, [](google::protobuf::Arena*, int64_t val) { return val; }); + auto floor_double = PortableUnaryFunctionAdapter::Create( + "Floor", false, + [](google::protobuf::Arena*, double val) { return std::floor(val); }); + + ASSERT_OK(registry.RegisterLazyFunction(floor_int->descriptor())); + ASSERT_OK(activation.InsertFunction(std::move(floor_int))); + ASSERT_OK(registry.RegisterLazyFunction(floor_double->descriptor())); + ASSERT_OK(activation.InsertFunction(std::move(floor_double))); + ASSERT_OK(registry.Register( + PortableBinaryFunctionAdapter::Create( + "_<_", false, [](google::protobuf::Arena*, int64_t lhs, int64_t rhs) -> bool { + return lhs < rhs; + }))); + + cel::Constant lhs; + lhs.set_int64_value(20); + cel::Constant rhs; + rhs.set_double_value(21.9); + + CallExpr call1; + call1.mutable_args().emplace_back(); + call1.set_function("Floor"); + CallExpr call2; + call2.mutable_args().emplace_back(); + call2.set_function("Floor"); + + CallExpr lt_call; + lt_call.mutable_args().emplace_back(); + lt_call.mutable_args().emplace_back(); + lt_call.set_function("_<_"); + + ASSERT_OK_AND_ASSIGN( + auto step0, + CreateConstValueStep(cel::interop_internal::CreateIntValue(20), -1)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN( + auto step2, + CreateConstValueStep(cel::interop_internal::CreateDoubleValue(21.9), -1)); + ASSERT_OK_AND_ASSIGN(auto step3, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step4, MakeTestFunctionStep(lt_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + path.push_back(std::move(step3)); + path.push_back(std::move(step4)); + + std::unique_ptr impl = GetExpression(std::move(path)); + + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsBool()); + EXPECT_TRUE(value.BoolOrDie()); +} + // Test situation when no overloads match input arguments during evaluation // and at least one of arguments is error. TEST_P(FunctionStepTest, @@ -479,57 +568,39 @@ TEST_P(FunctionStepTest, Activation activation; google::protobuf::Arena arena; CelFunctionRegistry registry; - BuilderWarnings warnings; AddDefaults(registry); - CelError error0; - CelError error1; + CelError error0 = absl::CancelledError(); + CelError error1 = absl::CancelledError(); // Constants have ERROR type, while AddFunction expects INT. - auto register0_status = registry.RegisterLazyFunction( - ConstFunction::CreateDescriptor("ConstError1")); - ASSERT_OK(register0_status); - auto insert0_status = - activation.InsertFunction(absl::make_unique( - CelValue::CreateError(&error0), "ConstError1")); - ASSERT_OK(insert0_status); - auto register1_status = registry.RegisterLazyFunction( - ConstFunction::CreateDescriptor("ConstError2")); - ASSERT_OK(register1_status); - auto insert1_status = - activation.InsertFunction(absl::make_unique( - CelValue::CreateError(&error1), "ConstError2")); - ASSERT_OK(insert1_status); - - Expr::Call call1 = ConstFunction::MakeCall("ConstError1"); - Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); - Expr::Call add_call = AddFunction::MakeCall(); - - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + ASSERT_OK(registry.RegisterLazyFunction( + ConstFunction::CreateDescriptor("ConstError1"))); + ASSERT_OK(activation.InsertFunction(std::make_unique( + CelValue::CreateError(&error0), "ConstError1"))); + ASSERT_OK(registry.RegisterLazyFunction( + ConstFunction::CreateDescriptor("ConstError2"))); + ASSERT_OK(activation.InsertFunction(std::make_unique( + CelValue::CreateError(&error1), "ConstError2"))); + + CallExpr call1 = ConstFunction::MakeCall("ConstError1"); + CallExpr call2 = ConstFunction::MakeCall("ConstError2"); + CallExpr add_call = AddFunction::MakeCall(); + + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); - EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); + EXPECT_THAT(*value.ErrorOrDie(), Eq(error0)); } std::string TestNameFn(testing::TestParamInfo opt) { @@ -555,85 +626,61 @@ class FunctionStepTestUnknowns : public testing::TestWithParam { public: std::unique_ptr GetExpression(ExecutionPath&& path) { - bool unknown_functions; - switch (GetParam()) { - case UnknownProcessingOptions::kAttributeAndFunction: - unknown_functions = true; - break; - default: - unknown_functions = false; - break; - } - return absl::make_unique(&expr, std::move(path), 0, - std::set(), - true, unknown_functions); + cel::RuntimeOptions options; + options.unknown_processing = GetParam(); + + auto env = NewTestingRuntimeEnv(); + return std::make_unique( + env, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); } - - private: - Expr expr; }; TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { ExecutionPath path; - BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); - Expr::Call call1 = ConstFunction::MakeCall("Const3"); - Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); - Expr::Call add_call = AddFunction::MakeCall(); - - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + CallExpr call1 = ConstFunction::MakeCall("Const3"); + CallExpr call2 = ConstFunction::MakeCall("ConstUnknown"); + CallExpr add_call = AddFunction::MakeCall(); - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); } TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { ExecutionPath path; - BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); // Build the expression path that corresponds to CEL expression // "sink(param)". - Expr::Ident ident1; + IdentExpr ident1; ident1.set_name("param"); - Expr::Call call1 = SinkFunction::MakeCall(); - - auto step0_status = CreateIdentStep(&ident1, GetExprId()); - auto step1_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + CallExpr call1 = SinkFunction::MakeCall(); - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); + ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep("param", GetExprId())); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call1, registry)); - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); std::unique_ptr impl = GetExpression(std::move(path)); @@ -643,68 +690,51 @@ TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { activation.InsertValue("param", CelProtoWrapper::CreateMessage(&msg, &arena)); CelAttributePattern pattern( "param", - {CelAttributeQualifierPattern::Create(CelValue::CreateBool(true))}); + {CreateCelAttributeQualifierPattern(CelValue::CreateBool(true))}); // Set attribute pattern that marks attribute "param[true]" as unknown. // It should result in "param" being handled as partially unknown, which is // is handled as fully unknown when used as function input argument. activation.set_unknown_attribute_patterns({pattern}); - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); } TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; AddDefaults(registry); - CelError error0; + CelError error0 = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error0); ASSERT_TRUE( registry - .Register(absl::make_unique(error_value, "ConstError")) + .Register(std::make_unique(error_value, "ConstError")) .ok()); - Expr::Call call1 = ConstFunction::MakeCall("ConstError"); - Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); - Expr::Call add_call = AddFunction::MakeCall(); + CallExpr call1 = ConstFunction::MakeCall("ConstError"); + CallExpr call2 = ConstFunction::MakeCall("ConstUnknown"); + CallExpr add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); // Making sure we propagate the error. - ASSERT_EQ(value.ErrorOrDie(), error_value.ErrorOrDie()); + ASSERT_EQ(*value.ErrorOrDie(), *error_value.ErrorOrDie()); } INSTANTIATE_TEST_SUITE_P( @@ -713,265 +743,479 @@ INSTANTIATE_TEST_SUITE_P( UnknownProcessingOptions::kAttributeAndFunction), &TestNameFn); -MATCHER_P2(IsAdd, a, b, "") { - const UnknownFunctionResult* result = arg; - return result->arguments().size() == 2 && - result->arguments().at(0).IsInt64() && - result->arguments().at(1).IsInt64() && - result->arguments().at(0).Int64OrDie() == a && - result->arguments().at(1).Int64OrDie() == b && - result->descriptor().name() == "_+_"; -} - TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( - absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + std::make_unique(CelValue::CreateInt64(2), "Const2"))); ASSERT_OK(registry.Register( - absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + std::make_unique(CelValue::CreateInt64(3), "Const3"))); ASSERT_OK(registry.Register( - absl::make_unique(ShouldReturnUnknown::kYes))); - - Expr::Call call1 = ConstFunction::MakeCall("Const2"); - Expr::Call call2 = ConstFunction::MakeCall("Const3"); - Expr::Call add_call = AddFunction::MakeCall(); - - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); - - Expr dummy_expr; - - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + std::make_unique(ShouldReturnUnknown::kYes))); + + CallExpr call1 = ConstFunction::MakeCall("Const2"); + CallExpr call2 = ConstFunction::MakeCall("Const3"); + CallExpr add_call = AddFunction::MakeCall(); + + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); - // Arguments captured. - EXPECT_THAT(value.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - ElementsAre(IsAdd(2, 3))); } TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( - absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + std::make_unique(CelValue::CreateInt64(2), "Const2"))); ASSERT_OK(registry.Register( - absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + std::make_unique(CelValue::CreateInt64(3), "Const3"))); ASSERT_OK(registry.Register( - absl::make_unique(ShouldReturnUnknown::kYes))); + std::make_unique(ShouldReturnUnknown::kYes))); // Add(Add(2, 3), Add(2, 3)) - Expr::Call call1 = ConstFunction::MakeCall("Const2"); - Expr::Call call2 = ConstFunction::MakeCall("Const3"); - Expr::Call add_call = AddFunction::MakeCall(); - - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step3_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step4_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step5_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step6_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - ASSERT_OK(step3_status); - ASSERT_OK(step4_status); - ASSERT_OK(step5_status); - ASSERT_OK(step6_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); - path.push_back(std::move(step3_status.value())); - path.push_back(std::move(step4_status.value())); - path.push_back(std::move(step5_status.value())); - path.push_back(std::move(step6_status.value())); - - Expr dummy_expr; - - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + CallExpr call1 = ConstFunction::MakeCall("Const2"); + CallExpr call2 = ConstFunction::MakeCall("Const3"); + CallExpr add_call = AddFunction::MakeCall(); + + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); + ASSERT_OK_AND_ASSIGN(auto step3, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step4, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step5, MakeTestFunctionStep(add_call, registry)); + ASSERT_OK_AND_ASSIGN(auto step6, MakeTestFunctionStep(add_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + path.push_back(std::move(step3)); + path.push_back(std::move(step4)); + path.push_back(std::move(step5)); + path.push_back(std::move(step6)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); - // Arguments captured. - EXPECT_THAT(value.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - ElementsAre(IsAdd(2, 3))); } TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( - absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + std::make_unique(CelValue::CreateInt64(2), "Const2"))); ASSERT_OK(registry.Register( - absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + std::make_unique(CelValue::CreateInt64(3), "Const3"))); ASSERT_OK(registry.Register( - absl::make_unique(ShouldReturnUnknown::kYes))); + std::make_unique(ShouldReturnUnknown::kYes))); // Add(Add(2, 3), Add(3, 2)) - Expr::Call call1 = ConstFunction::MakeCall("Const2"); - Expr::Call call2 = ConstFunction::MakeCall("Const3"); - Expr::Call add_call = AddFunction::MakeCall(); - - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step3_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step4_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step5_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step6_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); - ASSERT_OK(step3_status); - ASSERT_OK(step4_status); - ASSERT_OK(step5_status); - ASSERT_OK(step6_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); - path.push_back(std::move(step3_status.value())); - path.push_back(std::move(step4_status.value())); - path.push_back(std::move(step5_status.value())); - path.push_back(std::move(step6_status.value())); - - Expr dummy_expr; - - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + CallExpr call1 = ConstFunction::MakeCall("Const2"); + CallExpr call2 = ConstFunction::MakeCall("Const3"); + CallExpr add_call = AddFunction::MakeCall(); + + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); + ASSERT_OK_AND_ASSIGN(auto step3, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step4, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step5, MakeTestFunctionStep(add_call, registry)); + ASSERT_OK_AND_ASSIGN(auto step6, MakeTestFunctionStep(add_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + path.push_back(std::move(step3)); + path.push_back(std::move(step4)); + path.push_back(std::move(step5)); + path.push_back(std::move(step6)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - - ASSERT_TRUE(value.IsUnknownSet()) << value.ErrorOrDie()->ToString(); - // Arguments captured. - EXPECT_THAT(value.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - UnorderedElementsAre(IsAdd(2, 3), IsAdd(3, 2))); + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsUnknownSet()) << *(value.ErrorOrDie()); } TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; - CelError error0; + CelError error0 = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error0); UnknownSet unknown_set; CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); ASSERT_OK(registry.Register( - absl::make_unique(error_value, "ConstError"))); + std::make_unique(error_value, "ConstError"))); ASSERT_OK(registry.Register( - absl::make_unique(unknown_value, "ConstUnknown"))); + std::make_unique(unknown_value, "ConstUnknown"))); ASSERT_OK(registry.Register( - absl::make_unique(ShouldReturnUnknown::kYes))); + std::make_unique(ShouldReturnUnknown::kYes))); + + CallExpr call1 = ConstFunction::MakeCall("ConstError"); + CallExpr call2 = ConstFunction::MakeCall("ConstUnknown"); + CallExpr add_call = AddFunction::MakeCall(); + + ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); + ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); + ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); - Expr::Call call1 = ConstFunction::MakeCall("ConstError"); - Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); - Expr::Call add_call = AddFunction::MakeCall(); + Activation activation; + google::protobuf::Arena arena; - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsError()); + // Making sure we propagate the error. + ASSERT_EQ(*value.ErrorOrDie(), *error_value.ErrorOrDie()); +} + +class MessageFunction : public CelFunction { + public: + MessageFunction() + : CelFunction( + CelFunctionDescriptor("Fn", false, {CelValue::Type::kMessage})) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + if (args.size() != 1 || !args.at(0).IsMessage()) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Bad arguments number"); + } + + *result = CelValue::CreateStringView("message"); + return absl::OkStatus(); + } +}; + +class MessageIdentityFunction : public CelFunction { + public: + MessageIdentityFunction() + : CelFunction( + CelFunctionDescriptor("Fn", false, {CelValue::Type::kMessage})) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + if (args.size() != 1 || !args.at(0).IsMessage()) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Bad arguments number"); + } + + *result = args.at(0); + return absl::OkStatus(); + } +}; - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - ASSERT_OK(step2_status); +class NullFunction : public CelFunction { + public: + NullFunction() + : CelFunction( + CelFunctionDescriptor("Fn", false, {CelValue::Type::kNullType})) {} - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - path.push_back(std::move(step2_status.value())); + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + if (args.size() != 1 || args.at(0).type() != CelValue::Type::kNullType) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Bad arguments number"); + } - Expr dummy_expr; + *result = CelValue::CreateStringView("null"); + return absl::OkStatus(); + } +}; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); +TEST(FunctionStepStrictnessTest, + IfFunctionStrictAndGivenUnknownSkipsInvocation) { + UnknownSet unknown_set; + CelFunctionRegistry registry; + ASSERT_OK(registry.Register(std::make_unique( + CelValue::CreateUnknownSet(&unknown_set), "ConstUnknown"))); + ASSERT_OK(registry.Register(std::make_unique( + CelValue::Type::kUnknownSet, /*is_strict=*/true))); + ExecutionPath path; + CallExpr call0 = ConstFunction::MakeCall("ConstUnknown"); + CallExpr call1 = SinkFunction::MakeCall(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step0, + MakeTestFunctionStep(call0, registry)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step1, + MakeTestFunctionStep(call1, registry)); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsUnknownSet()); +} +TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { + UnknownSet unknown_set; + CelFunctionRegistry registry; + ASSERT_OK(registry.Register(std::make_unique( + CelValue::CreateUnknownSet(&unknown_set), "ConstUnknown"))); + ASSERT_OK(registry.Register(std::make_unique( + CelValue::Type::kUnknownSet, /*is_strict=*/false))); + ExecutionPath path; + CallExpr call0 = ConstFunction::MakeCall("ConstUnknown"); + CallExpr call1 = SinkFunction::MakeCall(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step0, + MakeTestFunctionStep(call0, registry)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step1, + MakeTestFunctionStep(call1, registry)); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + Expr placeholder_expr; + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); + ASSERT_THAT(value, test::IsCelInt64(Eq(0))); +} - auto status = impl.Evaluate(activation, &arena); - ASSERT_OK(status); +class DirectFunctionStepTest : public testing::Test { + public: + DirectFunctionStepTest() = default; - auto value = status.value(); + void SetUp() override { + ASSERT_OK(cel::RegisterStandardFunctions(registry_, options_)); + } - ASSERT_TRUE(value.IsError()); - // Making sure we propagate the error. - ASSERT_EQ(value.ErrorOrDie(), error_value.ErrorOrDie()); + std::vector GetOverloads( + absl::string_view name, int64_t arguments_size) { + std::vector matcher; + matcher.resize(arguments_size, cel::Kind::kAny); + return registry_.FindStaticOverloads(name, false, matcher); + } + + // Helper for shorthand constructing direct expr deps. + // + // Works around copies in init-list construction. + std::vector> MakeDeps( + std::unique_ptr dep, + std::unique_ptr dep2) { + std::vector> result; + result.reserve(2); + result.push_back(std::move(dep)); + result.push_back(std::move(dep2)); + return result; + }; + + protected: + cel::FunctionRegistry registry_; + cel::RuntimeOptions options_; + google::protobuf::Arena arena_; +}; + +TEST_F(DirectFunctionStepTest, SimpleCall) { + cel::IntValue(1); + + CallExpr call; + call.set_function(cel::builtin::kAdd); + call.mutable_args().emplace_back(); + call.mutable_args().emplace_back(); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(cel::IntValue(1))); + deps.push_back(CreateConstValueDirectStep(cel::IntValue(1))); + + auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), + GetOverloads(cel::builtin::kAdd, 2)); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, test::IsCelInt64(2)); } -} // namespace +TEST_F(DirectFunctionStepTest, RecursiveCall) { + cel::IntValue(1); + + CallExpr call; + call.set_function(cel::builtin::kAdd); + call.mutable_args().emplace_back(); + call.mutable_args().emplace_back(); + + auto overloads = GetOverloads(cel::builtin::kAdd, 2); + + auto MakeLeaf = [&]() { + return CreateDirectFunctionStep( + -1, call, + MakeDeps(CreateConstValueDirectStep(cel::IntValue(1)), + CreateConstValueDirectStep(cel::IntValue(1))), + overloads); + }; + + auto expr = CreateDirectFunctionStep( + -1, call, + MakeDeps(CreateDirectFunctionStep( + -1, call, MakeDeps(MakeLeaf(), MakeLeaf()), overloads), + CreateDirectFunctionStep( + -1, call, MakeDeps(MakeLeaf(), MakeLeaf()), overloads)), + overloads); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, test::IsCelInt64(8)); +} + +TEST_F(DirectFunctionStepTest, ErrorHandlingCall) { + cel::IntValue(1); + + CallExpr add_call; + add_call.set_function(cel::builtin::kAdd); + add_call.mutable_args().emplace_back(); + add_call.mutable_args().emplace_back(); + + CallExpr div_call; + div_call.set_function(cel::builtin::kDivide); + div_call.mutable_args().emplace_back(); + div_call.mutable_args().emplace_back(); + + auto add_overloads = GetOverloads(cel::builtin::kAdd, 2); + auto div_overloads = GetOverloads(cel::builtin::kDivide, 2); + + auto error_expr = CreateDirectFunctionStep( + -1, div_call, + MakeDeps(CreateConstValueDirectStep(cel::IntValue(1)), + CreateConstValueDirectStep(cel::IntValue(0))), + div_overloads); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + auto expr = CreateDirectFunctionStep( + -1, add_call, + MakeDeps(std::move(error_expr), + CreateConstValueDirectStep(cel::IntValue(1))), + add_overloads); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, + test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr("divide by zero")))); +} + +TEST_F(DirectFunctionStepTest, NoOverload) { + cel::IntValue(1); + + CallExpr call; + call.set_function(cel::builtin::kAdd); + call.mutable_args().emplace_back(); + call.mutable_args().emplace_back(); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(cel::IntValue(1))); + deps.push_back(CreateConstValueDirectStep(cel::StringValue("2"))); + + auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), + GetOverloads(cel::builtin::kAdd, 2)); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, Truly(CheckNoMatchingOverloadError)); +} + +TEST_F(DirectFunctionStepTest, NoOverload0Args) { + cel::IntValue(1); + + CallExpr call; + call.set_function(cel::builtin::kAdd); + + std::vector> deps; + auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), + GetOverloads(cel::builtin::kAdd, 2)); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, Truly(CheckNoMatchingOverloadError)); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index 9c99621bd..7ec1a3031 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -1,19 +1,32 @@ #include "eval/eval/ident_step.h" -#include "google/protobuf/arena.h" +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/substitute.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/value.h" #include "eval/eval/attribute_trail.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/unknown_attribute_set.h" +#include "eval/internal/errors.h" +#include "internal/status_macros.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { + +using ::cel::Value; +using ::cel::runtime_internal::CreateError; + class IdentStep : public ExpressionStepBase { public: IdentStep(absl::string_view name, int64_t expr_id) @@ -22,92 +35,140 @@ class IdentStep : public ExpressionStepBase { absl::Status Evaluate(ExecutionFrame* frame) const override; private: - void DoEvaluate(ExecutionFrame* frame, CelValue* result, - AttributeTrail* trail) const; - std::string name_; }; -void IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, - AttributeTrail* trail) const { - // Special case - iterator looked up in - if (frame->GetIterVar(name_, result)) { - const AttributeTrail* iter_trail; - if (frame->GetIterAttr(name_, &iter_trail)) { - *trail = *iter_trail; +absl::Status LookupIdent(absl::string_view name, ExecutionFrameBase& frame, + Value& result, AttributeTrail& attribute) { + if (frame.attribute_tracking_enabled()) { + attribute = AttributeTrail(std::string(name)); + if (frame.missing_attribute_errors_enabled() && + frame.attribute_utility().CheckForMissingAttribute(attribute)) { + CEL_ASSIGN_OR_RETURN( + result, frame.attribute_utility().CreateMissingAttributeError( + attribute.attribute())); + return absl::OkStatus(); + } + if (frame.unknown_processing_enabled() && + frame.attribute_utility().CheckForUnknownExact(attribute)) { + result = + frame.attribute_utility().CreateUnknownSet(attribute.attribute()); + return absl::OkStatus(); } - return; } - auto value = frame->activation().FindValue(name_, frame->arena()); + CEL_ASSIGN_OR_RETURN( + auto found, frame.activation().FindVariable(name, frame.descriptor_pool(), + frame.message_factory(), + frame.arena(), &result)); - // Populate trails if either MissingAttributeError or UnknownPattern - // is enabled. - if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { - google::api::expr::v1alpha1::Expr expr; - expr.mutable_ident_expr()->set_name(name_); - *trail = AttributeTrail(expr, frame->arena()); + if (found) { + return absl::OkStatus(); } - if (frame->enable_missing_attribute_errors() && !name_.empty() && - frame->attribute_utility().CheckForMissingAttribute(*trail)) { - *result = CreateMissingAttributeError(frame->arena(), name_); - return; - } + result = cel::ErrorValue(CreateError( + absl::StrCat("No value with name \"", name, "\" found in Activation"))); - { - // We handle masked unknown paths for the sake of uniformity, although it is - // better not to bind unknown values to activation in first place. - // TODO(issues/41) Deprecate this style of unknowns handling after - // Unknowns are properly supported. - bool unknown_value = frame->activation().IsPathUnknown(name_); + return absl::OkStatus(); +} - if (unknown_value) { - *result = CreateUnknownValueError(frame->arena(), name_); - return; - } +absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { + Value value; + AttributeTrail attribute; + + CEL_RETURN_IF_ERROR(LookupIdent(name_, *frame, value, attribute)); + + frame->value_stack().Push(std::move(value), std::move(attribute)); + + return absl::OkStatus(); +} + +absl::StatusOr LookupSlot( + absl::string_view name, size_t slot_index, ExecutionFrameBase& frame) { + ComprehensionSlots::Slot* slot = frame.comprehension_slots().Get(slot_index); + if (!slot->Has()) { + return absl::InternalError( + absl::StrCat("Comprehension variable accessed out of scope: ", name)); } + return slot; +} - if (frame->enable_unknowns()) { - if (frame->attribute_utility().CheckForUnknown(*trail, false)) { - auto unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownAttributeSet({trail->attribute()})); - *result = CelValue::CreateUnknownSet(unknown_set); - return; - } +class SlotStep : public ExpressionStepBase { + public: + SlotStep(absl::string_view name, size_t slot_index, int64_t expr_id) + : ExpressionStepBase(expr_id), name_(name), slot_index_(slot_index) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + CEL_ASSIGN_OR_RETURN(const ComprehensionSlots::Slot* slot, + LookupSlot(name_, slot_index_, *frame)); + frame->value_stack().Push(slot->value(), slot->attribute()); + return absl::OkStatus(); } - if (value.has_value()) { - *result = value.value(); - } else { - *result = CreateErrorValue( - frame->arena(), - absl::Substitute("No value with name \"$0\" found in Activation", - name_)); + private: + std::string name_; + + size_t slot_index_; +}; + +class DirectIdentStep : public DirectExpressionStep { + public: + DirectIdentStep(absl::string_view name, int64_t expr_id) + : DirectExpressionStep(expr_id), name_(name) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + return LookupIdent(name_, frame, result, attribute); } -} -absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { - CelValue result; - AttributeTrail trail; + private: + std::string name_; +}; - DoEvaluate(frame, &result, &trail); +class DirectSlotStep : public DirectExpressionStep { + public: + DirectSlotStep(absl::string_view name, size_t slot_index, int64_t expr_id) + : DirectExpressionStep(expr_id), + name_(std::move(name)), + slot_index_(slot_index) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + CEL_ASSIGN_OR_RETURN(const ComprehensionSlots::Slot* slot, + LookupSlot(name_, slot_index_, frame)); + + if (frame.attribute_tracking_enabled()) { + attribute = slot->attribute(); + } + result = slot->value(); + return absl::OkStatus(); + } - frame->value_stack().Push(result, trail); + private: + std::string name_; + size_t slot_index_; +}; - return absl::OkStatus(); +} // namespace + +std::unique_ptr CreateDirectIdentStep( + absl::string_view identifier, int64_t expr_id) { + return std::make_unique(identifier, expr_id); } -} // namespace +std::unique_ptr CreateDirectSlotIdentStep( + absl::string_view identifier, size_t slot_index, int64_t expr_id) { + return std::make_unique(identifier, slot_index, expr_id); +} absl::StatusOr> CreateIdentStep( - const google::api::expr::v1alpha1::Expr::Ident* ident_expr, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(ident_expr->name(), expr_id); - return std::move(step); + const absl::string_view name, int64_t expr_id) { + return std::make_unique(name, expr_id); +} + +absl::StatusOr> CreateIdentStepForSlot( + const absl::string_view name, size_t slot_index, int64_t expr_id) { + return std::make_unique(name, slot_index, expr_id); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/ident_step.h b/eval/eval/ident_step.h index 8e18fa637..d1bdde388 100644 --- a/eval/eval/ident_step.h +++ b/eval/eval/ident_step.h @@ -1,22 +1,31 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_IDENT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_IDENT_STEP_H_ +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" -#include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { + +std::unique_ptr CreateDirectIdentStep( + absl::string_view identifier, int64_t expr_id); + +std::unique_ptr CreateDirectSlotIdentStep( + absl::string_view identifier, size_t slot_index, int64_t expr_id); // Factory method for Ident - based Execution step absl::StatusOr> CreateIdentStep( - const google::api::expr::v1alpha1::Expr::Ident* ident, int64_t expr_id); + absl::string_view name, int64_t expr_id); + +// Factory method for identifier that has been assigned to a slot. +absl::StatusOr> CreateIdentStepForSlot( + absl::string_view name, size_t slot_index, int64_t expr_id); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_IDENT_STEP_H_ diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index df208f035..ce10d7d98 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -1,38 +1,61 @@ #include "eval/eval/ident_step.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "base/type_provider.h" +#include "common/casting.h" +#include "common/memory.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/evaluator_core.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::protobuf::FieldMask; -using testing::Eq; - -using google::protobuf::Arena; +using ::absl_testing::StatusIs; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::MemoryManagerRef; +using ::cel::RuntimeOptions; +using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::google::protobuf::Arena; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::SizeIs; TEST(IdentStepTest, TestIdentStep) { - Expr expr; - auto ident_expr = expr.mutable_ident_expr(); - ident_expr->set_name("name0"); - - auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_OK(step_status); + ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*id=*/-1)); ExecutionPath path; - path.push_back(std::move(step_status.value())); - - auto dummy_expr = absl::make_unique(); + path.push_back(std::move(step)); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; Arena arena; @@ -49,19 +72,16 @@ TEST(IdentStepTest, TestIdentStep) { } TEST(IdentStepTest, TestIdentStepNameNotFound) { - Expr expr; - auto ident_expr = expr.mutable_ident_expr(); - ident_expr->set_name("name0"); - - auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_OK(step_status); + ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*id=*/-1)); ExecutionPath path; - path.push_back(std::move(step_status.value())); - - auto dummy_expr = absl::make_unique(); + path.push_back(std::move(step)); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; Arena arena; @@ -75,20 +95,18 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { } TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { - Expr expr; - auto ident_expr = expr.mutable_ident_expr(); - ident_expr->set_name("name0"); - - auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_OK(step_status); + ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*id=*/-1)); ExecutionPath path; - path.push_back(std::move(step_status.value())); - - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, - /*enable_unknowns=*/false); + path.push_back(std::move(step)); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; @@ -109,24 +127,25 @@ TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); - EXPECT_THAT(status0.value().StringOrDie().value(), Eq("test")); + EXPECT_THAT(status0->StringOrDie().value(), Eq("test")); } TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { - Expr expr; - auto ident_expr = expr.mutable_ident_expr(); - ident_expr->set_name("name0"); - - auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_OK(step_status); + ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*expr_id=*/1)); ExecutionPath path; - path.push_back(std::move(step_status.value())); + path.push_back(std::move(step)); - auto dummy_expr = absl::make_unique(); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + options.enable_missing_attribute_errors = true; - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, false, - false, /*enable_missing_attribute_errors=*/true); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; @@ -147,32 +166,35 @@ TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); - EXPECT_EQ(status0.value().ErrorOrDie()->code(), - absl::StatusCode::kInvalidArgument); - EXPECT_EQ(status0.value().ErrorOrDie()->message(), - "MissingAttributeError: name0"); + EXPECT_EQ(status0->ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(status0->ErrorOrDie()->message(), "MissingAttributeError: name0"); } -TEST(IdentStepTest, TestIdentStepUnknownValueError) { - Expr expr; - auto ident_expr = expr.mutable_ident_expr(); - ident_expr->set_name("name0"); - - auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_OK(step_status); +TEST(IdentStepTest, TestIdentStepUnknownAttribute) { + ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*expr_id=*/1)); ExecutionPath path; - path.push_back(std::move(step_status.value())); - - auto dummy_expr = absl::make_unique(); + path.push_back(std::move(step)); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); + // Expression with unknowns enabled. + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; std::string value("test"); activation.InsertValue("name0", CelValue::CreateString(&value)); + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern("name_bad", {})); + + activation.set_unknown_attribute_patterns(unknown_patterns); auto status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); @@ -181,68 +203,114 @@ TEST(IdentStepTest, TestIdentStepUnknownValueError) { ASSERT_TRUE(result.IsString()); EXPECT_THAT(result.StringOrDie().value(), Eq("test")); - FieldMask unknown_mask; - unknown_mask.add_paths("name0"); + unknown_patterns.push_back(CelAttributePattern("name0", {})); - activation.set_unknown_paths(unknown_mask); + activation.set_unknown_attribute_patterns(unknown_patterns); status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); result = status0.value(); - ASSERT_TRUE(result.IsError()); - ASSERT_TRUE(IsUnknownValueError(result)); - EXPECT_THAT(GetUnknownPathsSetOrDie(result), - Eq(std::set({"name0"}))); + ASSERT_TRUE(result.IsUnknownSet()); } -TEST(IdentStepTest, TestIdentStepUnknownAttribute) { - Expr expr; - auto ident_expr = expr.mutable_ident_expr(); - ident_expr->set_name("name0"); +TEST(DirectIdentStepTest, Basic) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + cel::Activation activation; + RuntimeOptions options; - auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_OK(step_status); + activation.InsertOrAssignValue("var1", IntValue(42)); - ExecutionPath path; - path.push_back(std::move(step_status.value())); + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + Value result; + AttributeTrail trail; - auto dummy_expr = absl::make_unique(); + auto step = CreateDirectIdentStep("var1", -1); - // Expression with unknowns enabled. - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, true); + ASSERT_OK(step->Evaluate(frame, result, trail)); - Activation activation; - Arena arena; - std::string value("test"); + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), Eq(42)); +} - activation.InsertValue("name0", CelValue::CreateString(&value)); - std::vector unknown_patterns; - unknown_patterns.push_back(CelAttributePattern("name_bad", {})); +TEST(DirectIdentStepTest, UnknownAttribute) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; - activation.set_unknown_attribute_patterns(unknown_patterns); - auto status0 = impl.Evaluate(activation, &arena); - ASSERT_OK(status0); + activation.InsertOrAssignValue("var1", IntValue(42)); + activation.SetUnknownPatterns({CreateCelAttributePattern("var1", {})}); - CelValue result = status0.value(); + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + Value result; + AttributeTrail trail; - ASSERT_TRUE(result.IsString()); - EXPECT_THAT(result.StringOrDie().value(), Eq("test")); + auto step = CreateDirectIdentStep("var1", -1); - unknown_patterns.push_back(CelAttributePattern("name0", {})); + ASSERT_OK(step->Evaluate(frame, result, trail)); - activation.set_unknown_attribute_patterns(unknown_patterns); - status0 = impl.Evaluate(activation, &arena); - ASSERT_OK(status0); + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).attribute_set(), SizeIs(1)); +} - result = status0.value(); +TEST(DirectIdentStepTest, MissingAttribute) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + cel::Activation activation; + RuntimeOptions options; + options.enable_missing_attribute_errors = true; - ASSERT_TRUE(result.IsUnknownSet()); + activation.InsertOrAssignValue("var1", IntValue(42)); + activation.SetMissingPatterns({CreateCelAttributePattern("var1", {})}); + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + Value result; + AttributeTrail trail; + + auto step = CreateDirectIdentStep("var1", -1); + + ASSERT_OK(step->Evaluate(frame, result, trail)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("var1"))); +} + +TEST(DirectIdentStepTest, NotFound) { + google::protobuf::Arena arena; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + cel::Activation activation; + RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena); + Value result; + AttributeTrail trail; + + auto step = CreateDirectIdentStep("var1", -1); + + ASSERT_OK(step->Evaluate(frame, result, trail)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("\"var1\" found in Activation"))); } } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/iterator_stack.h b/eval/eval/iterator_stack.h new file mode 100644 index 000000000..9b5daa889 --- /dev/null +++ b/eval/eval/iterator_stack.h @@ -0,0 +1,77 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_ITERATOR_STACK_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_ITERATOR_STACK_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "common/value.h" + +namespace cel::runtime_internal { + +class IteratorStack final { + public: + explicit IteratorStack(size_t max_size) : max_size_(max_size) { + iterators_.reserve(max_size_); + } + + IteratorStack(const IteratorStack&) = delete; + IteratorStack(IteratorStack&&) = delete; + + IteratorStack& operator=(const IteratorStack&) = delete; + IteratorStack& operator=(IteratorStack&&) = delete; + + size_t size() const { return iterators_.size(); } + + bool empty() const { return iterators_.empty(); } + + bool full() const { return iterators_.size() == max_size_; } + + size_t max_size() const { return max_size_; } + + void Clear() { iterators_.clear(); } + + void Push(absl_nonnull ValueIteratorPtr iterator) { + ABSL_DCHECK(!full()); + ABSL_DCHECK(iterator != nullptr); + + iterators_.push_back(std::move(iterator)); + } + + ValueIterator* absl_nonnull Peek() { + ABSL_DCHECK(!empty()); + ABSL_DCHECK(iterators_.back() != nullptr); + + return iterators_.back().get(); + } + + void Pop() { + ABSL_DCHECK(!empty()); + + iterators_.pop_back(); + } + + private: + std::vector iterators_; + size_t max_size_; +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_ITERATOR_STACK_H_ diff --git a/eval/eval/jump_step.cc b/eval/eval/jump_step.cc index 5c85a645b..a65789841 100644 --- a/eval/eval/jump_step.cc +++ b/eval/eval/jump_step.cc @@ -1,15 +1,39 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/eval/jump_step.h" +#include +#include +#include + +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "eval/eval/expression_step_base.h" +#include "absl/types/optional.h" +#include "common/value.h" +#include "eval/internal/errors.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { +using ::cel::BoolValue; +using ::cel::ErrorValue; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::runtime_internal::CreateNoMatchingOverloadError; + class JumpStep : public JumpStepBase { public: // Constructs FunctionStep that uses overloads specified. @@ -36,13 +60,15 @@ class CondJumpStep : public JumpStepBase { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } - CelValue value = frame->value_stack().Peek(); + const auto& value = frame->value_stack().Peek(); + const auto should_jump = value.Is() && + jump_condition_ == value.GetBool().NativeValue(); if (!leave_on_stack_) { frame->value_stack().Pop(1); } - if (value.IsBool() && jump_condition_ == value.BoolOrDie()) { + if (should_jump) { return Jump(frame); } @@ -71,22 +97,22 @@ class BoolCheckJumpStep : public JumpStepBase { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } - CelValue value = frame->value_stack().Peek(); + const Value& value = frame->value_stack().Peek(); - if (value.IsError()) { - return Jump(frame); + if (value->Is()) { + return absl::OkStatus(); } - if (value.IsUnknownSet()) { + if (value->Is() || value->Is()) { return Jump(frame); } - if (!value.IsBool()) { - CelValue error_value = - CreateNoMatchingOverloadError(frame->arena(), ""); - frame->value_stack().PopAndPush(error_value); - return Jump(frame); - } + // Neither bool, error, nor unknown set. + Value error_value = + cel::ErrorValue(CreateNoMatchingOverloadError("")); + + frame->value_stack().PopAndPush(std::move(error_value)); + return Jump(frame); return absl::OkStatus(); } @@ -97,39 +123,25 @@ class BoolCheckJumpStep : public JumpStepBase { // Factory method for Conditional Jump step. // Conditional Jump requires a boolean value to sit on the stack. // It is compared to jump_condition, and if matched, jump is performed. -absl::StatusOr> CreateCondJumpStep( +std::unique_ptr CreateCondJumpStep( bool jump_condition, bool leave_on_stack, absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = absl::make_unique( - jump_condition, leave_on_stack, jump_offset, expr_id); - - return std::move(step); + return std::make_unique(jump_condition, leave_on_stack, + jump_offset, expr_id); } // Factory method for Jump step. -absl::StatusOr> CreateJumpStep( - absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(jump_offset, expr_id); - - return std::move(step); +std::unique_ptr CreateJumpStep(absl::optional jump_offset, + int64_t expr_id) { + return std::make_unique(jump_offset, expr_id); } // Factory method for Conditional Jump step. // Conditional Jump requires a value to sit on the stack. // If this value is an error or unknown, a jump is performed. -absl::StatusOr> CreateBoolCheckJumpStep( +std::unique_ptr CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(jump_offset, expr_id); - - return std::move(step); + return std::make_unique(jump_offset, expr_id); } -// TODO(issues/41) Make sure Unknowns are properly supported by ternary -// operation. - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/jump_step.h b/eval/eval/jump_step.h index 3dca34d27..55147da5f 100644 --- a/eval/eval/jump_step.h +++ b/eval/eval/jump_step.h @@ -1,16 +1,30 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_JUMP_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_JUMP_STEP_H_ +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/activation.h" -#include "eval/public/cel_value.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { class JumpStepBase : public ExpressionStepBase { public: @@ -31,27 +45,24 @@ class JumpStepBase : public ExpressionStepBase { }; // Factory method for Jump step. -absl::StatusOr> CreateJumpStep( - absl::optional jump_offset, int64_t expr_id); +std::unique_ptr CreateJumpStep(absl::optional jump_offset, + int64_t expr_id); // Factory method for Conditional Jump step. // Conditional Jump requires a boolean value to sit on the stack. // It is compared to jump_condition, and if matched, jump is performed. // leave on stack indicates whether value should be kept on top of the stack or // removed. -absl::StatusOr> CreateCondJumpStep( +std::unique_ptr CreateCondJumpStep( bool jump_condition, bool leave_on_stack, absl::optional jump_offset, int64_t expr_id); // Factory method for ErrorJump step. // This step performs a Jump when an Error is on the top of the stack. // Value is left on stack if it is a bool or an error. -absl::StatusOr> CreateBoolCheckJumpStep( +std::unique_ptr CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_JUMP_STEP_H_ diff --git a/eval/eval/lazy_init_step.cc b/eval/eval/lazy_init_step.cc new file mode 100644 index 000000000..eb9be7796 --- /dev/null +++ b/eval/eval/lazy_init_step.cc @@ -0,0 +1,236 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/lazy_init_step.h" + +#include +#include +#include +#include + +#include "cel/expr/value.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::Value; + +class LazyInitStep final : public ExpressionStepBase { + public: + LazyInitStep(size_t slot_index, size_t subexpression_index, int64_t expr_id) + : ExpressionStepBase(expr_id), + slot_index_(slot_index), + subexpression_index_(subexpression_index) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + ComprehensionSlot* slot = frame->comprehension_slots().Get(slot_index_); + if (slot->Has()) { + frame->value_stack().Push(slot->value(), slot->attribute()); + } else { + frame->Call(slot_index_, subexpression_index_); + } + return absl::OkStatus(); + } + + private: + const size_t slot_index_; + const size_t subexpression_index_; +}; + +class DirectLazyInitStep final : public DirectExpressionStep { + public: + DirectLazyInitStep(size_t slot_index, + const DirectExpressionStep* subexpression, int64_t expr_id) + : DirectExpressionStep(expr_id), + slot_index_(slot_index), + subexpression_(subexpression) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + ComprehensionSlot* slot = frame.comprehension_slots().Get(slot_index_); + if (slot->Has()) { + result = slot->value(); + attribute = slot->attribute(); + } else { + CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); + slot->Set(result, attribute); + } + return absl::OkStatus(); + } + + private: + const size_t slot_index_; + const DirectExpressionStep* absl_nonnull const subexpression_; +}; + +class BindStep : public DirectExpressionStep { + public: + BindStep(size_t slot_index, + std::unique_ptr subexpression, int64_t expr_id) + : DirectExpressionStep(expr_id), + slot_index_(slot_index), + subexpression_(std::move(subexpression)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); + + frame.comprehension_slots().ClearSlot(slot_index_); + + return absl::OkStatus(); + } + + private: + size_t slot_index_; + std::unique_ptr subexpression_; +}; + +class AssignSlotAndPopStepStep final : public ExpressionStepBase { + public: + explicit AssignSlotAndPopStepStep(size_t slot_index) + : ExpressionStepBase(/*expr_id=*/-1, /*comes_from_ast=*/false), + slot_index_(slot_index) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(1)) { + return absl::InternalError("Stack underflow assigning lazy value"); + } + + frame->comprehension_slots().Set(slot_index_, frame->value_stack().Peek(), + frame->value_stack().PeekAttribute()); + frame->value_stack().Pop(1); + + return absl::OkStatus(); + } + + private: + const size_t slot_index_; +}; + +class ClearSlotStep : public ExpressionStepBase { + public: + explicit ClearSlotStep(size_t slot_index, int64_t expr_id) + : ExpressionStepBase(expr_id), slot_index_(slot_index) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + frame->comprehension_slots().ClearSlot(slot_index_); + return absl::OkStatus(); + } + + private: + size_t slot_index_; +}; + +class ClearSlotsStep final : public ExpressionStepBase { + public: + explicit ClearSlotsStep(size_t slot_index, size_t slot_count, int64_t expr_id) + : ExpressionStepBase(expr_id), + slot_index_(slot_index), + slot_count_(slot_count) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + for (size_t i = 0; i < slot_count_; ++i) { + frame->comprehension_slots().ClearSlot(slot_index_ + i); + } + return absl::OkStatus(); + } + + private: + const size_t slot_index_; + const size_t slot_count_; +}; + +class BlockStep : public DirectExpressionStep { + public: + BlockStep(size_t slot_index, size_t slot_count, + std::unique_ptr subexpression, + int64_t expr_id) + : DirectExpressionStep(expr_id), + slot_index_(slot_index), + slot_count_(slot_count), + subexpression_(std::move(subexpression)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); + + for (size_t i = 0; i < slot_count_; ++i) { + frame.comprehension_slots().ClearSlot(slot_index_ + i); + } + + return absl::OkStatus(); + } + + private: + size_t slot_index_; + size_t slot_count_; + std::unique_ptr subexpression_; +}; + +} // namespace + +std::unique_ptr CreateDirectBindStep( + size_t slot_index, std::unique_ptr expression, + int64_t expr_id) { + return std::make_unique(slot_index, std::move(expression), expr_id); +} + +std::unique_ptr CreateDirectBlockStep( + size_t slot_index, size_t slot_count, + std::unique_ptr expression, int64_t expr_id) { + return std::make_unique(slot_index, slot_count, + std::move(expression), expr_id); +} + +std::unique_ptr CreateDirectLazyInitStep( + size_t slot_index, const DirectExpressionStep* absl_nonnull subexpression, + int64_t expr_id) { + return std::make_unique(slot_index, subexpression, + expr_id); +} + +std::unique_ptr CreateLazyInitStep(size_t slot_index, + size_t subexpression_index, + int64_t expr_id) { + return std::make_unique(slot_index, subexpression_index, + expr_id); +} + +std::unique_ptr CreateAssignSlotAndPopStep(size_t slot_index) { + return std::make_unique(slot_index); +} + +std::unique_ptr CreateClearSlotStep(size_t slot_index, + int64_t expr_id) { + return std::make_unique(slot_index, expr_id); +} + +std::unique_ptr CreateClearSlotsStep(size_t slot_index, + size_t slot_count, + int64_t expr_id) { + ABSL_DCHECK_GT(slot_count, 0); + return std::make_unique(slot_index, slot_count, expr_id); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/lazy_init_step.h b/eval/eval/lazy_init_step.h new file mode 100644 index 000000000..714308dfd --- /dev/null +++ b/eval/eval/lazy_init_step.h @@ -0,0 +1,87 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Program steps for lazily initialized aliases (e.g. cel.bind). +// +// When used, any reference to variable should be replaced with a conditional +// step that either runs the initialization routine or pushes the already +// initialized variable to the stack. +// +// All references to the variable should be replaced with: +// +// +-----------------+-------------------+--------------------+ +// | stack | pc | step | +// +-----------------+-------------------+--------------------+ +// | {} | 0 | check init slot(i) | +// +-----------------+-------------------+--------------------+ +// | {value} | 1 | assign slot(i) | +// +-----------------+-------------------+--------------------+ +// | {value} | 2 | | +// +-----------------+-------------------+--------------------+ +// | .... | +// +-----------------+-------------------+--------------------+ +// | {...} | n (end of scope) | clear slot(i) | +// +-----------------+-------------------+--------------------+ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_LAZY_INIT_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_LAZY_INIT_STEP_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +// Creates a step representing a Bind expression. +std::unique_ptr CreateDirectBindStep( + size_t slot_index, std::unique_ptr expression, + int64_t expr_id); + +// Creates a step representing a cel.@block expression. +std::unique_ptr CreateDirectBlockStep( + size_t slot_index, size_t slot_count, + std::unique_ptr expression, int64_t expr_id); + +// Creates a direct step representing accessing a lazily evaluated alias from +// a bind or block. +std::unique_ptr CreateDirectLazyInitStep( + size_t slot_index, const DirectExpressionStep* absl_nonnull subexpression, + int64_t expr_id); + +// Creates a step representing accessing a lazily evaluated alias from +// a bind or block. +std::unique_ptr CreateLazyInitStep(size_t slot_index, + size_t subexpression_index, + int64_t expr_id); + +// Helper step to assign a slot value from the top of stack on initialization. +std::unique_ptr CreateAssignSlotAndPopStep(size_t slot_index); + +// Helper step to clear a slot. +// Slots may be reused in different contexts so need to be cleared after a +// context is done. +std::unique_ptr CreateClearSlotStep(size_t slot_index, + int64_t expr_id); + +std::unique_ptr CreateClearSlotsStep(size_t slot_index, + size_t slot_count, + int64_t expr_id); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_LAZY_INIT_STEP_H_ diff --git a/eval/eval/lazy_init_step_test.cc b/eval/eval/lazy_init_step_test.cc new file mode 100644 index 000000000..b9bef90a1 --- /dev/null +++ b/eval/eval/lazy_init_step_test.cc @@ -0,0 +1,154 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/lazy_init_step.h" + +#include +#include + +#include "base/type_provider.h" +#include "common/value.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/evaluator_core.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::Activation; +using ::cel::IntValue; +using ::cel::RuntimeOptions; +using ::cel::TypeProvider; + +class LazyInitStepTest : public testing::Test { + private: + // arbitrary numbers enough for basic tests. + static constexpr size_t kValueStack = 5; + static constexpr size_t kComprehensionSlotCount = 3; + + public: + LazyInitStepTest() + : type_provider_(cel::internal::GetTestingDescriptorPool()), + evaluator_state_(kValueStack, kComprehensionSlotCount, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_) {} + + protected: + google::protobuf::Arena arena_; + cel::runtime_internal::RuntimeTypeProvider type_provider_; + FlatExpressionEvaluatorState evaluator_state_; + RuntimeOptions runtime_options_; + Activation activation_; +}; + +TEST_F(LazyInitStepTest, CreateCheckInitStepDoesInit) { + ExecutionPath path; + ExecutionPath subpath; + + path.push_back(CreateLazyInitStep(/*slot_index=*/0, + /*subexpression_index=*/1, -1)); + + ASSERT_OK_AND_ASSIGN(subpath.emplace_back(), + CreateConstValueStep(cel::IntValue(42), -1, false)); + + std::vector expression_table{path, subpath}; + + ExecutionFrame frame(expression_table, activation_, runtime_options_, + evaluator_state_); + ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate()); + + EXPECT_TRUE(value->Is() && value.GetInt().NativeValue() == 42); +} + +TEST_F(LazyInitStepTest, CreateCheckInitStepSkipInit) { + ExecutionPath path; + ExecutionPath subpath; + + // This is the expected usage, but in this test we are just depending on the + // fact that these don't change the stack and fit the program layout + // requirements. + path.push_back(CreateLazyInitStep(/*slot_index=*/0, -1, -1)); + + ASSERT_OK_AND_ASSIGN(subpath.emplace_back(), + CreateConstValueStep(cel::IntValue(42), -1, false)); + + std::vector expression_table{path, subpath}; + + ExecutionFrame frame(expression_table, activation_, runtime_options_, + evaluator_state_); + frame.comprehension_slots().Set(0, cel::IntValue(42)); + ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate()); + + EXPECT_TRUE(value->Is() && value.GetInt().NativeValue() == 42); +} + +TEST_F(LazyInitStepTest, CreateAssignSlotAndPopStepBasic) { + ExecutionPath path; + + path.push_back(CreateAssignSlotAndPopStep(0)); + + ExecutionFrame frame(path, activation_, runtime_options_, evaluator_state_); + frame.comprehension_slots().ClearSlot(0); + + frame.value_stack().Push(cel::IntValue(42)); + + // This will error because no return value, step will still evaluate. + frame.Evaluate().IgnoreError(); + + auto* slot = frame.comprehension_slots().Get(0); + ASSERT_TRUE(slot->Has()); + EXPECT_TRUE(slot->value()->Is() && + slot->value().GetInt().NativeValue() == 42); + EXPECT_TRUE(frame.value_stack().empty()); +} + +TEST_F(LazyInitStepTest, CreateClearSlotStepBasic) { + ExecutionPath path; + + path.push_back(CreateClearSlotStep(0, -1)); + + ExecutionFrame frame(path, activation_, runtime_options_, evaluator_state_); + frame.comprehension_slots().Set(0, cel::IntValue(42)); + + // This will error because no return value, step will still evaluate. + frame.Evaluate().IgnoreError(); + + auto* slot = frame.comprehension_slots().Get(0); + ASSERT_FALSE(slot->Has()); +} + +TEST_F(LazyInitStepTest, CreateClearSlotsStepBasic) { + ExecutionPath path; + + path.push_back(CreateClearSlotsStep(0, 2, -1)); + + ExecutionFrame frame(path, activation_, runtime_options_, evaluator_state_); + frame.comprehension_slots().Set(0, cel::IntValue(42)); + frame.comprehension_slots().Set(1, cel::IntValue(42)); + + // This will error because no return value, step will still evaluate. + frame.Evaluate().IgnoreError(); + + EXPECT_FALSE(frame.comprehension_slots().Get(0)->Has()); + EXPECT_FALSE(frame.comprehension_slots().Get(1)->Has()); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/logic_step.cc b/eval/eval/logic_step.cc index ed4da4700..f844d8c05 100644 --- a/eval/eval/logic_step.cc +++ b/eval/eval/logic_step.cc @@ -1,86 +1,253 @@ #include "eval/eval/logic_step.h" +#include +#include +#include +#include + +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/builtins.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/cel_builtins.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" +#include "eval/internal/errors.h" +#include "internal/status_macros.h" +#include "runtime/internal/errors.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -class LogicalOpStep : public ExpressionStepBase { +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::ValueKind; +using ::cel::runtime_internal::CreateNoMatchingOverloadError; + +enum class OpType { kAnd, kOr }; + +// Shared logic for the fall through case (we didn't see the shortcircuit +// value). +absl::Status ReturnLogicResult(ExecutionFrameBase& frame, OpType op_type, + Value& lhs_result, Value& rhs_result, + AttributeTrail& attribute_trail, + AttributeTrail& rhs_attr) { + ValueKind lhs_kind = lhs_result.kind(); + ValueKind rhs_kind = rhs_result.kind(); + + if (frame.unknown_processing_enabled()) { + if (lhs_kind == ValueKind::kUnknown && rhs_kind == ValueKind::kUnknown) { + lhs_result = frame.attribute_utility().MergeUnknownValues( + Cast(lhs_result), Cast(rhs_result)); + // Clear attribute trail so this doesn't get re-identified as a new + // unknown and reset the accumulated attributes. + attribute_trail = AttributeTrail(); + return absl::OkStatus(); + } else if (lhs_kind == ValueKind::kUnknown) { + return absl::OkStatus(); + } else if (rhs_kind == ValueKind::kUnknown) { + lhs_result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + } + + if (lhs_kind == ValueKind::kError) { + return absl::OkStatus(); + } else if (rhs_kind == ValueKind::kError) { + lhs_result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + + if (lhs_kind == ValueKind::kBool && rhs_kind == ValueKind::kBool) { + return absl::OkStatus(); + } + + // Otherwise, add a no overload error. + attribute_trail = AttributeTrail(); + lhs_result = cel::ErrorValue(CreateNoMatchingOverloadError( + op_type == OpType::kOr ? cel::builtin::kOr : cel::builtin::kAnd)); + return absl::OkStatus(); +} + +class ExhaustiveDirectLogicStep : public DirectExpressionStep { + public: + explicit ExhaustiveDirectLogicStep(std::unique_ptr lhs, + std::unique_ptr rhs, + OpType op_type, int64_t expr_id) + : DirectExpressionStep(expr_id), + lhs_(std::move(lhs)), + rhs_(std::move(rhs)), + op_type_(op_type) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::unique_ptr lhs_; + std::unique_ptr rhs_; + OpType op_type_; +}; + +absl::Status ExhaustiveDirectLogicStep::Evaluate( + ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute_trail) const { + CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, attribute_trail)); + ValueKind lhs_kind = result.kind(); + + Value rhs_result; + AttributeTrail rhs_attr; + CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, attribute_trail)); + + ValueKind rhs_kind = rhs_result.kind(); + if (lhs_kind == ValueKind::kBool) { + bool lhs_bool = Cast(result).NativeValue(); + if ((op_type_ == OpType::kOr && lhs_bool) || + (op_type_ == OpType::kAnd && !lhs_bool)) { + return absl::OkStatus(); + } + } + + if (rhs_kind == ValueKind::kBool) { + bool rhs_bool = Cast(rhs_result).NativeValue(); + if ((op_type_ == OpType::kOr && rhs_bool) || + (op_type_ == OpType::kAnd && !rhs_bool)) { + result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + } + + return ReturnLogicResult(frame, op_type_, result, rhs_result, attribute_trail, + rhs_attr); +} + +class DirectLogicStep : public DirectExpressionStep { public: - enum class OpType { AND, OR }; + explicit DirectLogicStep(std::unique_ptr lhs, + std::unique_ptr rhs, + OpType op_type, int64_t expr_id) + : DirectExpressionStep(expr_id), + lhs_(std::move(lhs)), + rhs_(std::move(rhs)), + op_type_(op_type) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::unique_ptr lhs_; + std::unique_ptr rhs_; + OpType op_type_; +}; + +absl::Status DirectLogicStep::Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, attribute_trail)); + ValueKind lhs_kind = result.kind(); + if (lhs_kind == ValueKind::kBool) { + bool lhs_bool = Cast(result).NativeValue(); + if ((op_type_ == OpType::kOr && lhs_bool) || + (op_type_ == OpType::kAnd && !lhs_bool)) { + return absl::OkStatus(); + } + } + + Value rhs_result; + AttributeTrail rhs_attr; + CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, attribute_trail)); + + ValueKind rhs_kind = rhs_result.kind(); + + if (rhs_kind == ValueKind::kBool) { + bool rhs_bool = Cast(rhs_result).NativeValue(); + if ((op_type_ == OpType::kOr && rhs_bool) || + (op_type_ == OpType::kAnd && !rhs_bool)) { + result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + } + + return ReturnLogicResult(frame, op_type_, result, rhs_result, attribute_trail, + rhs_attr); +} + +class LogicalOpStep : public ExpressionStepBase { + public: // Constructs FunctionStep that uses overloads specified. LogicalOpStep(OpType op_type, int64_t expr_id) : ExpressionStepBase(expr_id), op_type_(op_type) { - shortcircuit_ = (op_type_ == OpType::OR); + shortcircuit_ = (op_type_ == OpType::kOr); } absl::Status Evaluate(ExecutionFrame* frame) const override; private: - absl::Status Calculate(ExecutionFrame* frame, absl::Span args, - CelValue* result) const { + void Calculate(ExecutionFrame* frame, absl::Span args, + Value& result) const { bool bool_args[2]; bool has_bool_args[2]; for (size_t i = 0; i < args.size(); i++) { - has_bool_args[i] = args[i].GetValue(bool_args + i); - if (has_bool_args[i] && shortcircuit_ == bool_args[i]) { - *result = CelValue::CreateBool(bool_args[i]); - return absl::OkStatus(); + has_bool_args[i] = args[i]->Is(); + if (has_bool_args[i]) { + bool_args[i] = args[i].GetBool().NativeValue(); + if (bool_args[i] == shortcircuit_) { + result = BoolValue{bool_args[i]}; + return; + } } } if (has_bool_args[0] && has_bool_args[1]) { switch (op_type_) { - case OpType::AND: - *result = CelValue::CreateBool(bool_args[0] && bool_args[1]); - return absl::OkStatus(); - break; - case OpType::OR: - *result = CelValue::CreateBool(bool_args[0] || bool_args[1]); - return absl::OkStatus(); - break; + case OpType::kAnd: + result = BoolValue{bool_args[0] && bool_args[1]}; + return; + case OpType::kOr: + result = BoolValue{bool_args[0] || bool_args[1]}; + return; } } // As opposed to regular function, logical operation treat Unknowns with // higher precedence than error. This is due to the fact that after Unknown - // is resolved to actual value, it may shortcircuit and thus hide the error. + // is resolved to actual value, it may short-circuit and thus hide the + // error. if (frame->enable_unknowns()) { // Check if unknown? - const UnknownSet* unknown_set = - frame->attribute_utility().MergeUnknowns(args, - /*initial_set=*/nullptr); - - if (unknown_set) { - *result = CelValue::CreateUnknownSet(unknown_set); - return absl::OkStatus(); + absl::optional unknown_set = + frame->attribute_utility().MergeUnknowns(args); + if (unknown_set.has_value()) { + result = std::move(*unknown_set); + return; } } - if (args[0].IsError()) { - *result = args[0]; - return absl::OkStatus(); - } else if (args[1].IsError()) { - *result = args[1]; - return absl::OkStatus(); + if (args[0]->Is()) { + result = args[0]; + return; + } else if (args[1]->Is()) { + result = args[1]; + return; } // Fallback. - *result = CreateNoMatchingOverloadError( - frame->arena(), - (op_type_ == OpType::OR) ? builtin::kOr : builtin::kAnd); - return absl::OkStatus(); + result = cel::ErrorValue(CreateNoMatchingOverloadError( + (op_type_ == OpType::kOr) ? cel::builtin::kOr : cel::builtin::kAnd)); } const OpType op_type_; @@ -95,40 +262,226 @@ absl::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { // Create Span object that contains input arguments to the function. auto args = frame->value_stack().GetSpan(2); + Value result; + Calculate(frame, args, result); + frame->value_stack().PopAndPush(args.size(), std::move(result)); - CelValue value; + return absl::OkStatus(); +} - auto status = Calculate(frame, args, &value); - if (!status.ok()) { - return status; +std::unique_ptr CreateDirectLogicStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, OpType op_type, + bool shortcircuiting) { + if (shortcircuiting) { + return std::make_unique(std::move(lhs), std::move(rhs), + op_type, expr_id); + } else { + return std::make_unique( + std::move(lhs), std::move(rhs), op_type, expr_id); + } +} + +class DirectNotStep : public DirectExpressionStep { + public: + explicit DirectNotStep(std::unique_ptr operand, + int64_t expr_id) + : DirectExpressionStep(expr_id), operand_(std::move(operand)) {} + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::unique_ptr operand_; +}; + +absl::Status DirectNotStep::Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute_trail)); + + if (frame.unknown_processing_enabled()) { + if (frame.attribute_utility().CheckForUnknownPartial(attribute_trail)) { + result = frame.attribute_utility().CreateUnknownSet( + attribute_trail.attribute()); + return absl::OkStatus(); + } } - frame->value_stack().Pop(args.size()); - frame->value_stack().Push(value); + switch (result.kind()) { + case ValueKind::kBool: + result = BoolValue{!result.GetBool().NativeValue()}; + break; + case ValueKind::kUnknown: + case ValueKind::kError: + // just forward. + break; + default: + result = + cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot)); + break; + } - return status; + return absl::OkStatus(); +} + +class IterativeNotStep : public ExpressionStepBase { + public: + explicit IterativeNotStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; +}; + +absl::Status IterativeNotStep::Evaluate(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(1)) { + return absl::InternalError("Value stack underflow"); + } + const Value& operand = frame->value_stack().Peek(); + + if (frame->unknown_processing_enabled()) { + const AttributeTrail& attribute_trail = + frame->value_stack().PeekAttribute(); + if (frame->attribute_utility().CheckForUnknownPartial(attribute_trail)) { + frame->value_stack().PopAndPush( + frame->attribute_utility().CreateUnknownSet( + attribute_trail.attribute())); + return absl::OkStatus(); + } + } + + switch (operand.kind()) { + case ValueKind::kBool: + frame->value_stack().PopAndPush( + BoolValue{!operand.GetBool().NativeValue()}); + break; + case ValueKind::kUnknown: + case ValueKind::kError: + // just forward. + break; + default: + frame->value_stack().PopAndPush( + cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot))); + break; + } + + return absl::OkStatus(); +} + +class DirectNotStrictlyFalseStep : public DirectExpressionStep { + public: + explicit DirectNotStrictlyFalseStep( + std::unique_ptr operand, int64_t expr_id) + : DirectExpressionStep(expr_id), operand_(std::move(operand)) {} + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::unique_ptr operand_; +}; + +absl::Status DirectNotStrictlyFalseStep::Evaluate( + ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute_trail)); + + switch (result.kind()) { + case ValueKind::kBool: + // just forward. + break; + case ValueKind::kUnknown: + case ValueKind::kError: + result = BoolValue(true); + break; + default: + result = + cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot)); + break; + } + + return absl::OkStatus(); +} + +class IterativeNotStrictlyFalseStep : public ExpressionStepBase { + public: + explicit IterativeNotStrictlyFalseStep(int64_t expr_id) + : ExpressionStepBase(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; +}; + +absl::Status IterativeNotStrictlyFalseStep::Evaluate( + ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(1)) { + return absl::InternalError("Value stack underflow"); + } + const Value& operand = frame->value_stack().Peek(); + + switch (operand.kind()) { + case ValueKind::kBool: + // just forward. + break; + case ValueKind::kUnknown: + case ValueKind::kError: + frame->value_stack().PopAndPush(BoolValue(true)); + break; + default: + frame->value_stack().PopAndPush( + cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot))); + break; + } + + return absl::OkStatus(); } } // namespace +// Factory method for "And" Execution step +std::unique_ptr CreateDirectAndStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting) { + return CreateDirectLogicStep(std::move(lhs), std::move(rhs), expr_id, + OpType::kAnd, shortcircuiting); +} + +// Factory method for "Or" Execution step +std::unique_ptr CreateDirectOrStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting) { + return CreateDirectLogicStep(std::move(lhs), std::move(rhs), expr_id, + OpType::kOr, shortcircuiting); +} // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(LogicalOpStep::OpType::AND, expr_id); - - return std::move(step); + return std::make_unique(OpType::kAnd, expr_id); } // Factory method for "Or" Execution step absl::StatusOr> CreateOrStep(int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(LogicalOpStep::OpType::OR, expr_id); + return std::make_unique(OpType::kOr, expr_id); +} + +// Factory method for recursive logical not "!" Execution step +std::unique_ptr CreateDirectNotStep( + std::unique_ptr operand, int64_t expr_id) { + return std::make_unique(std::move(operand), expr_id); +} + +// Factory method for iterative logical not "!" Execution step +std::unique_ptr CreateNotStep(int64_t expr_id) { + return std::make_unique(expr_id); +} + +// Factory method for recursive logical "@not_strictly_false" Execution step. +std::unique_ptr CreateDirectNotStrictlyFalseStep( + std::unique_ptr operand, int64_t expr_id) { + return std::make_unique(std::move(operand), + expr_id); +} - return std::move(step); +// Factory method for iterative logical "@not_strictly_false" Execution step. +std::unique_ptr CreateNotStrictlyFalseStep(int64_t expr_id) { + return std::make_unique(expr_id); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/logic_step.h b/eval/eval/logic_step.h index 397e90061..d75ed3715 100644 --- a/eval/eval/logic_step.h +++ b/eval/eval/logic_step.h @@ -1,16 +1,26 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_LOGIC_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_LOGIC_STEP_H_ +#include +#include + +#include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_value.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { + +// Factory method for "And" Execution step +std::unique_ptr CreateDirectAndStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting); + +// Factory method for "Or" Execution step +std::unique_ptr CreateDirectOrStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting); // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id); @@ -18,9 +28,20 @@ absl::StatusOr> CreateAndStep(int64_t expr_id); // Factory method for "Or" Execution step absl::StatusOr> CreateOrStep(int64_t expr_id); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +// Factory method for recursive logical not "!" Execution step +std::unique_ptr CreateDirectNotStep( + std::unique_ptr operand, int64_t expr_id); + +// Factory method for iterative logical not "!" Execution step +std::unique_ptr CreateNotStep(int64_t expr_id); + +// Factory method for recursive logical "@not_strictly_false" Execution step. +std::unique_ptr CreateDirectNotStrictlyFalseStep( + std::unique_ptr operand, int64_t expr_id); + +// Factory method for iterative logical "@not_strictly_false" Execution step. +std::unique_ptr CreateNotStrictlyFalseStep(int64_t expr_id); + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_LOGIC_STEP_H_ diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index 9aa95fe3a..17ca8ba0d 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -1,70 +1,102 @@ #include "eval/eval/logic_step.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "base/type_provider.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/unknown.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; +using ::absl_testing::IsOk; +using ::cel::Attribute; +using ::cel::AttributeSet; +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::Expr; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::google::protobuf::Arena; +using ::testing::Eq; -using google::protobuf::Arena; -using testing::Eq; class LogicStepTest : public testing::TestWithParam { public: + LogicStepTest() : env_(NewTestingRuntimeEnv()) {} + absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, bool is_or, CelValue* result, bool enable_unknown) { - Expr expr0; - auto ident_expr0 = expr0.mutable_ident_expr(); - ident_expr0->set_name("name0"); - - Expr expr1; - auto ident_expr1 = expr1.mutable_ident_expr(); - ident_expr1->set_name("name1"); - ExecutionPath path; - - auto step_status = CreateIdentStep(ident_expr0, expr0.id()); - if (!step_status.ok()) return step_status.status(); - - path.push_back(std::move(step_status).value()); - - step_status = CreateIdentStep(ident_expr1, expr1.id()); - if (!step_status.ok()) return step_status.status(); - - path.push_back(std::move(step_status).value()); - - step_status = (is_or) ? CreateOrStep(2) : CreateAndStep(2); - if (!step_status.ok()) return step_status.status(); - - path.push_back(std::move(step_status).value()); - - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, - enable_unknown); + CEL_ASSIGN_OR_RETURN(auto step, CreateIdentStep("name0", /*expr_id=*/-1)); + path.push_back(std::move(step)); + + CEL_ASSIGN_OR_RETURN(step, CreateIdentStep("name1", /*expr_id=*/-1)); + path.push_back(std::move(step)); + + CEL_ASSIGN_OR_RETURN(step, (is_or) ? CreateOrStep(2) : CreateAndStep(2)); + path.push_back(std::move(step)); + + auto dummy_expr = std::make_unique(); + cel::RuntimeOptions options; + if (enable_unknown) { + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl impl( + env_, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; - std::string value("test"); - activation.InsertValue("name0", arg0); activation.InsertValue("name1", arg1); - auto status0 = impl.Evaluate(activation, &arena_); - if (!status0.ok()) return status0.status(); - - *result = status0.value(); + CEL_ASSIGN_OR_RETURN(CelValue value, impl.Evaluate(activation, &arena_)); + *result = value; return absl::OkStatus(); } private: + absl_nonnull std::shared_ptr env_; Arena arena_; }; @@ -73,28 +105,28 @@ TEST_P(LogicStepTest, TestAndLogic) { absl::Status status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(true), false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } @@ -104,81 +136,81 @@ TEST_P(LogicStepTest, TestOrLogic) { absl::Status status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(true), true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } TEST_P(LogicStepTest, TestAndLogicErrorHandling) { CelValue result; - CelError error; + CelError error = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error); absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(true), false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(true), error_value, false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(false), error_value, false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(error_value, CelValue::CreateBool(false), false, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } TEST_P(LogicStepTest, TestOrLogicErrorHandling) { CelValue result; - CelError error; + CelError error = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error); absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(false), true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(false), error_value, true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(true), error_value, true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(error_value, CelValue::CreateBool(true), true, &result, GetParam()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } @@ -186,134 +218,432 @@ TEST_P(LogicStepTest, TestOrLogicErrorHandling) { TEST_F(LogicStepTest, TestAndLogicUnknownHandling) { CelValue result; UnknownSet unknown_set; - CelError cel_error; + CelError cel_error = absl::CancelledError(); CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); CelValue error_value = CelValue::CreateError(&cel_error); absl::Status status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), false, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, false, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, false, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(unknown_value, CelValue::CreateBool(false), false, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(error_value, unknown_value, false, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(unknown_value, error_value, false, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); - Expr expr0; - auto ident_expr0 = expr0.mutable_ident_expr(); - ident_expr0->set_name("name0"); - - Expr expr1; - auto ident_expr1 = expr1.mutable_ident_expr(); - ident_expr1->set_name("name1"); - - CelAttribute attr0(expr0, {}), attr1(expr1, {}); - UnknownAttributeSet unknown_attr_set0({&attr0}); - UnknownAttributeSet unknown_attr_set1({&attr1}); + CelAttribute attr0("name0", {}), attr1("name1", {}); + UnknownAttributeSet unknown_attr_set0({attr0}); + UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); - EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); - EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), false, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); - ASSERT_THAT( - result.UnknownSetOrDie()->unknown_attributes().attributes().size(), - Eq(2)); + ASSERT_THAT(result.UnknownSetOrDie()->unknown_attributes().size(), Eq(2)); } TEST_F(LogicStepTest, TestOrLogicUnknownHandling) { CelValue result; UnknownSet unknown_set; - CelError cel_error; + CelError cel_error = absl::CancelledError(); CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); CelValue error_value = CelValue::CreateError(&cel_error); absl::Status status = EvaluateLogic( unknown_value, CelValue::CreateBool(false), true, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, true, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, true, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), true, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(unknown_value, error_value, true, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(error_value, unknown_value, true, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); - Expr expr0; - auto ident_expr0 = expr0.mutable_ident_expr(); - ident_expr0->set_name("name0"); - - Expr expr1; - auto ident_expr1 = expr1.mutable_ident_expr(); - ident_expr1->set_name("name1"); - - CelAttribute attr0(expr0, {}), attr1(expr1, {}); - UnknownAttributeSet unknown_attr_set0({&attr0}); - UnknownAttributeSet unknown_attr_set1({&attr1}); + CelAttribute attr0("name0", {}), attr1("name1", {}); + UnknownAttributeSet unknown_attr_set0({attr0}); + UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); - EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); - EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), true, &result, true); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); - ASSERT_THAT( - result.UnknownSetOrDie()->unknown_attributes().attributes().size(), - Eq(2)); + ASSERT_THAT(result.UnknownSetOrDie()->unknown_attributes().size(), Eq(2)); } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); + +enum class BinaryOp { kAnd, kOr }; +enum class UnaryOp { kNot, kNotStrictlyFalse }; + +enum class OpArg { + kTrue, + kFalse, + kUnknown, + kError, + // Arbitrary incorrect type + kInt +}; + +enum class OpResult { + kTrue, + kFalse, + kUnknown, + kError, +}; + +struct BinaryTestCase { + std::string name; + BinaryOp op; + OpArg arg0; + OpArg arg1; + OpResult result; +}; + +UnknownValue MakeUnknownValue(std::string attr) { + std::vector attrs; + attrs.push_back(Attribute(std::move(attr))); + return cel::UnknownValue(cel::Unknown(AttributeSet(attrs))); +} + +std::unique_ptr MakeArgStep(OpArg arg, + absl::string_view name) { + switch (arg) { + case OpArg::kTrue: + return CreateConstValueDirectStep(BoolValue(true)); + case OpArg::kFalse: + return CreateConstValueDirectStep(BoolValue(false)); + case OpArg::kUnknown: + return CreateConstValueDirectStep(MakeUnknownValue(std::string(name))); + case OpArg::kError: + return CreateConstValueDirectStep( + cel::ErrorValue(absl::InternalError(name))); + case OpArg::kInt: + return CreateConstValueDirectStep(IntValue(42)); + } +}; + +class DirectBinaryLogicStepTest + : public testing::TestWithParam> { + public: + DirectBinaryLogicStepTest() = default; + + bool ShortcircuitingEnabled() { return std::get<0>(GetParam()); } + const BinaryTestCase& GetTestCase() { return std::get<1>(GetParam()); } + + protected: + Arena arena_; +}; + +TEST_P(DirectBinaryLogicStepTest, TestCases) { + const BinaryTestCase& test_case = GetTestCase(); + + std::unique_ptr lhs = + MakeArgStep(test_case.arg0, "lhs"); + std::unique_ptr rhs = + MakeArgStep(test_case.arg1, "rhs"); + + std::unique_ptr op = + (test_case.op == BinaryOp::kAnd) + ? CreateDirectAndStep(std::move(lhs), std::move(rhs), -1, + ShortcircuitingEnabled()) + : CreateDirectOrStep(std::move(lhs), std::move(rhs), -1, + ShortcircuitingEnabled()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value value; + AttributeTrail attr; + ASSERT_THAT(op->Evaluate(frame, value, attr), IsOk()); + + switch (test_case.result) { + case OpResult::kTrue: + ASSERT_TRUE(value.IsBool()); + EXPECT_TRUE(value.GetBool().NativeValue()); + break; + case OpResult::kFalse: + ASSERT_TRUE(value.IsBool()); + EXPECT_FALSE(value.GetBool().NativeValue()); + break; + case OpResult::kUnknown: + EXPECT_TRUE(value.IsUnknown()); + break; + case OpResult::kError: + EXPECT_TRUE(value.IsError()); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + DirectBinaryLogicStepTest, DirectBinaryLogicStepTest, + testing::Combine(testing::Bool(), + testing::ValuesIn>({ + { + "AndFalseFalse", + BinaryOp::kAnd, + OpArg::kFalse, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndFalseTrue", + BinaryOp::kAnd, + OpArg::kFalse, + OpArg::kTrue, + OpResult::kFalse, + }, + { + "AndTrueFalse", + BinaryOp::kAnd, + OpArg::kTrue, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndTrueTrue", + BinaryOp::kAnd, + OpArg::kTrue, + OpArg::kTrue, + OpResult::kTrue, + }, + + { + "AndTrueError", + BinaryOp::kAnd, + OpArg::kTrue, + OpArg::kError, + OpResult::kError, + }, + { + "AndErrorTrue", + BinaryOp::kAnd, + OpArg::kError, + OpArg::kTrue, + OpResult::kError, + }, + { + "AndFalseError", + BinaryOp::kAnd, + OpArg::kFalse, + OpArg::kError, + OpResult::kFalse, + }, + { + "AndErrorFalse", + BinaryOp::kAnd, + OpArg::kError, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndErrorError", + BinaryOp::kAnd, + OpArg::kError, + OpArg::kError, + OpResult::kError, + }, + + { + "AndTrueUnknown", + BinaryOp::kAnd, + OpArg::kTrue, + OpArg::kUnknown, + OpResult::kUnknown, + }, + { + "AndUnknownTrue", + BinaryOp::kAnd, + OpArg::kUnknown, + OpArg::kTrue, + OpResult::kUnknown, + }, + { + "AndFalseUnknown", + BinaryOp::kAnd, + OpArg::kFalse, + OpArg::kUnknown, + OpResult::kFalse, + }, + { + "AndUnknownFalse", + BinaryOp::kAnd, + OpArg::kUnknown, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndUnknownUnknown", + BinaryOp::kAnd, + OpArg::kUnknown, + OpArg::kUnknown, + OpResult::kUnknown, + }, + { + "AndUnknownError", + BinaryOp::kAnd, + OpArg::kUnknown, + OpArg::kError, + OpResult::kUnknown, + }, + { + "AndErrorUnknown", + BinaryOp::kAnd, + OpArg::kError, + OpArg::kUnknown, + OpResult::kUnknown, + }, + // Or cases are simplified since the logic generalizes + // and is covered by and cases. + })), + [](const testing::TestParamInfo& info) + -> std::string { + bool shortcircuiting_enabled = std::get<0>(info.param); + absl::string_view name = std::get<1>(info.param).name; + return absl::StrCat( + name, (shortcircuiting_enabled ? "ShortcircuitingEnabled" : "")); + }); + +struct UnaryTestCase { + std::string name; + UnaryOp op; + OpArg arg; + OpResult result; +}; + +class DirectUnaryLogicStepTest : public testing::TestWithParam { + public: + DirectUnaryLogicStepTest() = default; + + const UnaryTestCase& GetTestCase() { return GetParam(); } + + protected: + Arena arena_; +}; + +TEST_P(DirectUnaryLogicStepTest, TestCases) { + const UnaryTestCase& test_case = GetTestCase(); + + std::unique_ptr arg = MakeArgStep(test_case.arg, "arg"); + + std::unique_ptr op = + (test_case.op == UnaryOp::kNot) + ? CreateDirectNotStep(std::move(arg), -1) + : CreateDirectNotStrictlyFalseStep(std::move(arg), -1); + + cel::Activation activation; + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + cel::runtime_internal::RuntimeTypeProvider type_provider( + cel::internal::GetTestingDescriptorPool()); + ExecutionFrameBase frame(activation, options, type_provider, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value value; + AttributeTrail attr; + ASSERT_THAT(op->Evaluate(frame, value, attr), IsOk()); + + switch (test_case.result) { + case OpResult::kTrue: + ASSERT_TRUE(value.IsBool()); + EXPECT_TRUE(value.GetBool().NativeValue()); + break; + case OpResult::kFalse: + ASSERT_TRUE(value.IsBool()); + EXPECT_FALSE(value.GetBool().NativeValue()); + break; + case OpResult::kUnknown: + EXPECT_TRUE(value.IsUnknown()); + break; + case OpResult::kError: + EXPECT_TRUE(value.IsError()); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + DirectUnaryLogicStepTest, DirectUnaryLogicStepTest, + testing::ValuesIn>( + {UnaryTestCase{"NotTrue", UnaryOp::kNot, OpArg::kTrue, + OpResult::kFalse}, + UnaryTestCase{"NotError", UnaryOp::kNot, OpArg::kError, + OpResult::kError}, + UnaryTestCase{"NotUnknown", UnaryOp::kNot, OpArg::kUnknown, + OpResult::kUnknown}, + UnaryTestCase{"NotInt", UnaryOp::kNot, OpArg::kInt, OpResult::kError}, + UnaryTestCase{"NotFalse", UnaryOp::kNot, OpArg::kFalse, + OpResult::kTrue}, + UnaryTestCase{"NotStrictlyFalseTrue", UnaryOp::kNotStrictlyFalse, + OpArg::kTrue, OpResult::kTrue}, + UnaryTestCase{"NotStrictlyFalseError", UnaryOp::kNotStrictlyFalse, + OpArg::kError, OpResult::kTrue}, + UnaryTestCase{"NotStrictlyFalseUnknown", UnaryOp::kNotStrictlyFalse, + OpArg::kUnknown, OpResult::kTrue}, + UnaryTestCase{"NotStrictlyFalseInt", UnaryOp::kNotStrictlyFalse, + OpArg::kInt, OpResult::kError}, + UnaryTestCase{"NotStrictlyFalseFalse", UnaryOp::kNotStrictlyFalse, + OpArg::kFalse, OpResult::kFalse}}), + [](const testing::TestParamInfo& info) + -> std::string { return info.param.name; }); + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/optional_or_step.cc b/eval/eval/optional_or_step.cc new file mode 100644 index 000000000..1c52d91b6 --- /dev/null +++ b/eval/eval/optional_or_step.cc @@ -0,0 +1,305 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/optional_or_step.h" + +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/casting.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "eval/eval/jump_step.h" +#include "internal/status_macros.h" +#include "runtime/internal/errors.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::As; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::OptionalValue; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::runtime_internal::CreateNoMatchingOverloadError; + +enum class OptionalOrKind { kOrOptional, kOrValue }; + +ErrorValue MakeNoOverloadError(OptionalOrKind kind) { + switch (kind) { + case OptionalOrKind::kOrOptional: + return ErrorValue(CreateNoMatchingOverloadError("or")); + case OptionalOrKind::kOrValue: + return ErrorValue(CreateNoMatchingOverloadError("orValue")); + } + + ABSL_UNREACHABLE(); +} + +// Implements short-circuiting for optional.or. +// Expected layout if short-circuiting enabled: +// +// +--------+-----------------------+-------------------------------+ +// | idx | Step | Stack After | +// +--------+-----------------------+-------------------------------+ +// | 1 | | OptionalValue | +// +--------+-----------------------+-------------------------------+ +// | 2 | Jump to 5 if present | OptionalValue | +// +--------+-----------------------+-------------------------------+ +// | 3 | | OptionalValue, OptionalValue | +// +--------+-----------------------+-------------------------------+ +// | 4 | optional.or | OptionalValue | +// +--------+-----------------------+-------------------------------+ +// | 5 | | ... | +// +--------------------------------+-------------------------------+ +// +// If implementing the orValue variant, the jump step handles unwrapping ( +// getting the result of optional.value()) +class OptionalHasValueJumpStep final : public JumpStepBase { + public: + OptionalHasValueJumpStep(int64_t expr_id, OptionalOrKind kind) + : JumpStepBase({}, expr_id), kind_(kind) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(1)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + const auto& value = frame->value_stack().Peek(); + auto optional_value = As(value); + // We jump if the receiver is `optional_type` which has a value or the + // receiver is an error/unknown. Unlike `_||_` we are not commutative. If + // we run into an error/unknown, we skip the `else` branch. + const bool should_jump = + (optional_value.has_value() && optional_value->HasValue()) || + (!optional_value.has_value() && (cel::InstanceOf(value) || + cel::InstanceOf(value))); + if (should_jump) { + if (kind_ == OptionalOrKind::kOrValue && optional_value.has_value()) { + frame->value_stack().PopAndPush(optional_value->Value()); + } + return Jump(frame); + } + return absl::OkStatus(); + } + + private: + const OptionalOrKind kind_; +}; + +class OptionalOrStep : public ExpressionStepBase { + public: + explicit OptionalOrStep(int64_t expr_id, OptionalOrKind kind) + : ExpressionStepBase(expr_id), kind_(kind) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + const OptionalOrKind kind_; +}; + +// Shared implementation for optional or. +// +// If return value is Ok, the result is assigned to the result reference +// argument. +absl::Status EvalOptionalOr(OptionalOrKind kind, const Value& lhs, + const Value& rhs, const AttributeTrail& lhs_attr, + const AttributeTrail& rhs_attr, Value& result, + AttributeTrail& result_attr) { + if (InstanceOf(lhs) || InstanceOf(lhs)) { + result = lhs; + result_attr = lhs_attr; + return absl::OkStatus(); + } + + auto lhs_optional_value = As(lhs); + if (!lhs_optional_value.has_value()) { + result = MakeNoOverloadError(kind); + result_attr = AttributeTrail(); + return absl::OkStatus(); + } + + if (lhs_optional_value->HasValue()) { + if (kind == OptionalOrKind::kOrValue) { + result = lhs_optional_value->Value(); + } else { + result = lhs; + } + result_attr = lhs_attr; + return absl::OkStatus(); + } + + if (kind == OptionalOrKind::kOrOptional && !InstanceOf(rhs) && + !InstanceOf(rhs) && !InstanceOf(rhs)) { + result = MakeNoOverloadError(kind); + result_attr = AttributeTrail(); + return absl::OkStatus(); + } + + result = rhs; + result_attr = rhs_attr; + return absl::OkStatus(); +} + +absl::Status OptionalOrStep::Evaluate(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(2)) { + return absl::InternalError("Value stack underflow"); + } + + absl::Span args = frame->value_stack().GetSpan(2); + absl::Span args_attr = + frame->value_stack().GetAttributeSpan(2); + + Value result; + AttributeTrail result_attr; + CEL_RETURN_IF_ERROR(EvalOptionalOr(kind_, args[0], args[1], args_attr[0], + args_attr[1], result, result_attr)); + + frame->value_stack().PopAndPush(2, std::move(result), std::move(result_attr)); + return absl::OkStatus(); +} + +class ExhaustiveDirectOptionalOrStep : public DirectExpressionStep { + public: + ExhaustiveDirectOptionalOrStep( + int64_t expr_id, std::unique_ptr optional, + std::unique_ptr alternative, OptionalOrKind kind) + + : DirectExpressionStep(expr_id), + kind_(kind), + optional_(std::move(optional)), + alternative_(std::move(alternative)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; + + private: + OptionalOrKind kind_; + std::unique_ptr optional_; + std::unique_ptr alternative_; +}; + +absl::Status ExhaustiveDirectOptionalOrStep::Evaluate( + ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { + CEL_RETURN_IF_ERROR(optional_->Evaluate(frame, result, attribute)); + Value rhs; + AttributeTrail rhs_attr; + CEL_RETURN_IF_ERROR(alternative_->Evaluate(frame, rhs, rhs_attr)); + CEL_RETURN_IF_ERROR(EvalOptionalOr(kind_, result, rhs, attribute, rhs_attr, + result, attribute)); + return absl::OkStatus(); +} + +class DirectOptionalOrStep : public DirectExpressionStep { + public: + DirectOptionalOrStep(int64_t expr_id, + std::unique_ptr optional, + std::unique_ptr alternative, + OptionalOrKind kind) + + : DirectExpressionStep(expr_id), + kind_(kind), + optional_(std::move(optional)), + alternative_(std::move(alternative)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; + + private: + OptionalOrKind kind_; + std::unique_ptr optional_; + std::unique_ptr alternative_; +}; + +absl::Status DirectOptionalOrStep::Evaluate(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& attribute) const { + CEL_RETURN_IF_ERROR(optional_->Evaluate(frame, result, attribute)); + + if (InstanceOf(result) || InstanceOf(result)) { + // Forward the lhs error instead of attempting to evaluate the alternative + // (unlike CEL's commutative logic operators). + return absl::OkStatus(); + } + + auto optional_value = As(static_cast(result)); + if (!optional_value.has_value()) { + result = MakeNoOverloadError(kind_); + return absl::OkStatus(); + } + + if (optional_value->HasValue()) { + if (kind_ == OptionalOrKind::kOrValue) { + result = optional_value->Value(); + } + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(alternative_->Evaluate(frame, result, attribute)); + + // If optional.or check that rhs is an optional. + // + // Otherwise, we don't know what type to expect so can't check anything. + if (kind_ == OptionalOrKind::kOrOptional) { + if (!InstanceOf(result) && !InstanceOf(result) && + !InstanceOf(result)) { + result = MakeNoOverloadError(kind_); + } + } + + return absl::OkStatus(); +} + +} // namespace + +std::unique_ptr CreateOptionalHasValueJumpStep(bool or_value, + int64_t expr_id) { + return std::make_unique( + expr_id, + or_value ? OptionalOrKind::kOrValue : OptionalOrKind::kOrOptional); +} + +std::unique_ptr CreateOptionalOrStep(bool is_or_value, + int64_t expr_id) { + return std::make_unique( + expr_id, + is_or_value ? OptionalOrKind::kOrValue : OptionalOrKind::kOrOptional); +} + +std::unique_ptr CreateDirectOptionalOrStep( + int64_t expr_id, std::unique_ptr optional, + std::unique_ptr alternative, bool is_or_value, + bool short_circuiting) { + auto kind = + is_or_value ? OptionalOrKind::kOrValue : OptionalOrKind::kOrOptional; + if (short_circuiting) { + return std::make_unique(expr_id, std::move(optional), + std::move(alternative), kind); + } else { + return std::make_unique( + expr_id, std::move(optional), std::move(alternative), kind); + } +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/optional_or_step.h b/eval/eval/optional_or_step.h new file mode 100644 index 000000000..59977c857 --- /dev/null +++ b/eval/eval/optional_or_step.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_OPTIONAL_OR_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_OPTIONAL_OR_STEP_H_ + +#include +#include + +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/jump_step.h" + +namespace google::api::expr::runtime { + +// Factory method for OptionalHasValueJump step, used to implement +// short-circuiting optional.or and optional.orValue. +// +// Requires that the top of the stack is an optional. If `optional.hasValue` is +// true, performs a jump. If `or_value` is true and we are jumping, +// `optional.value` is called and the result replaces the optional at the top of +// the stack. +std::unique_ptr CreateOptionalHasValueJumpStep(bool or_value, + int64_t expr_id); + +// Factory method for OptionalOr step, used to implement optional.or and +// optional.orValue. +std::unique_ptr CreateOptionalOrStep(bool is_or_value, + int64_t expr_id); + +// Creates a step implementing the short-circuiting optional.or or +// optional.orValue step. +std::unique_ptr CreateDirectOptionalOrStep( + int64_t expr_id, std::unique_ptr optional, + std::unique_ptr alternative, bool is_or_value, + bool short_circuiting); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_OPTIONAL_OR_STEP_H_ diff --git a/eval/eval/optional_or_step_test.cc b/eval/eval/optional_or_step_test.cc new file mode 100644 index 000000000..14f1c3bd9 --- /dev/null +++ b/eval/eval/optional_or_step_test.cc @@ -0,0 +1,382 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/optional_or_step.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/errors.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::Activation; +using ::cel::As; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::OptionalValue; +using ::cel::RuntimeOptions; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::ValueKind; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::ValueKindIs; +using ::testing::HasSubstr; +using ::testing::NiceMock; + +class MockDirectStep : public DirectExpressionStep { + public: + MOCK_METHOD(absl::Status, Evaluate, + (ExecutionFrameBase & frame, Value& result, + AttributeTrail& scratch), + (const, override)); +}; + +std::unique_ptr MockNeverCalledDirectStep() { + auto* mock = new NiceMock(); + EXPECT_CALL(*mock, Evaluate).Times(0); + return absl::WrapUnique(mock); +} + +std::unique_ptr MockExpectCallDirectStep() { + auto* mock = new NiceMock(); + EXPECT_CALL(*mock, Evaluate) + .Times(1) + .WillRepeatedly( + [](ExecutionFrameBase& frame, Value& result, AttributeTrail& attr) { + result = ErrorValue(absl::InternalError("expected to be unused")); + return absl::OkStatus(); + }); + return absl::WrapUnique(mock); +} + +class OptionalOrTest : public testing::Test { + public: + OptionalOrTest() + : type_provider_(cel::internal::GetTestingDescriptorPool()) {} + + protected: + google::protobuf::Arena arena_; + cel::runtime_internal::RuntimeTypeProvider type_provider_; + Activation empty_activation_; +}; + +TEST_F(OptionalOrTest, OptionalOrLeftPresentShortcutRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, + CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), + MockNeverCalledDirectStep(), + /*is_or_value=*/false, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, OptionalValueIs(IntValueIs(42))); +} + +TEST_F(OptionalOrTest, OptionalOrLeftErrorShortcutsRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, + CreateConstValueDirectStep(ErrorValue(absl::InternalError("error"))), + MockNeverCalledDirectStep(), + /*is_or_value=*/false, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); +} + +TEST_F(OptionalOrTest, OptionalOrLeftErrorExhaustiveRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, + CreateConstValueDirectStep(ErrorValue(absl::InternalError("error"))), + MockExpectCallDirectStep(), + /*is_or_value=*/false, + /*short_circuiting=*/false); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); +} + +TEST_F(OptionalOrTest, OptionalOrLeftUnknownShortcutsRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(UnknownValue()), + MockNeverCalledDirectStep(), + /*is_or_value=*/false, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); +} + +TEST_F(OptionalOrTest, OptionalOrLeftUnknownExhaustiveRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(UnknownValue()), + MockExpectCallDirectStep(), + /*is_or_value=*/false, + /*short_circuiting=*/false); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); +} + +TEST_F(OptionalOrTest, OptionalOrLeftAbsentReturnRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::None()), + CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), + /*is_or_value=*/false, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, OptionalValueIs(IntValueIs(42))); +} + +TEST_F(OptionalOrTest, OptionalOrLeftWrongType) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(IntValue(42)), + MockNeverCalledDirectStep(), + /*is_or_value=*/false, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, + ErrorValueIs(StatusIs( + absl::StatusCode::kUnknown, + HasSubstr(cel::runtime_internal::kErrNoMatchingOverload)))); +} + +TEST_F(OptionalOrTest, OptionalOrRightWrongType) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::None()), + CreateConstValueDirectStep(IntValue(42)), + /*is_or_value=*/false, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, + ErrorValueIs(StatusIs( + absl::StatusCode::kUnknown, + HasSubstr(cel::runtime_internal::kErrNoMatchingOverload)))); +} + +TEST_F(OptionalOrTest, OptionalOrValueLeftPresentShortcutRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, + CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), + MockNeverCalledDirectStep(), + /*is_or_value=*/true, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, IntValueIs(42)); +} + +TEST_F(OptionalOrTest, OptionalOrValueLeftPresentExhaustiveRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, + CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), + MockExpectCallDirectStep(), + /*is_or_value=*/true, + /*short_circuiting=*/false); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, IntValueIs(42)); +} + +TEST_F(OptionalOrTest, OptionalOrValueLeftErrorShortcutsRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, + CreateConstValueDirectStep(ErrorValue(absl::InternalError("error"))), + MockNeverCalledDirectStep(), + /*is_or_value=*/true, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); +} + +TEST_F(OptionalOrTest, OptionalOrValueLeftUnknownShortcutsRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(UnknownValue()), + MockNeverCalledDirectStep(), true, true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); +} + +TEST_F(OptionalOrTest, OptionalOrValueLeftAbsentReturnRight) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::None()), + CreateConstValueDirectStep(IntValue(42)), + /*is_or_value=*/true, + /*short_circuiting=*/true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, IntValueIs(42)); +} + +TEST_F(OptionalOrTest, OptionalOrValueLeftWrongType) { + RuntimeOptions options; + ExecutionFrameBase frame(empty_activation_, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectOptionalOrStep( + /*expr_id=*/-1, CreateConstValueDirectStep(IntValue(42)), + MockNeverCalledDirectStep(), true, true); + + Value result; + AttributeTrail scratch; + + ASSERT_OK(step->Evaluate(frame, result, scratch)); + + EXPECT_THAT(result, + ErrorValueIs(StatusIs( + absl::StatusCode::kUnknown, + HasSubstr(cel::runtime_internal::kErrNoMatchingOverload)))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/regex_match_step.cc b/eval/eval/regex_match_step.cc new file mode 100644 index 000000000..2a06de1b8 --- /dev/null +++ b/eval/eval/regex_match_step.cc @@ -0,0 +1,135 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/regex_match_step.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" +#include "re2/re2.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::BoolValue; +using ::cel::StringValue; +using ::cel::Value; + +inline constexpr int kNumRegexMatchArguments = 1; +inline constexpr size_t kRegexMatchStepSubject = 0; + +struct MatchesVisitor final { + const RE2& re; + + bool operator()(const absl::Cord& value) const { + if (auto flat = value.TryFlat(); flat.has_value()) { + return RE2::PartialMatch(*flat, re); + } + return RE2::PartialMatch(static_cast(value), re); + } + + bool operator()(absl::string_view value) const { + return RE2::PartialMatch(value, re); + } +}; + +class RegexMatchStep final : public ExpressionStepBase { + public: + RegexMatchStep(int64_t expr_id, std::shared_ptr re2) + : ExpressionStepBase(expr_id, /*comes_from_ast=*/true), + re2_(std::move(re2)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(kNumRegexMatchArguments)) { + return absl::Status(absl::StatusCode::kInternal, + "Insufficient arguments supplied for regular " + "expression match"); + } + auto input_args = frame->value_stack().GetSpan(kNumRegexMatchArguments); + const auto& subject = input_args[kRegexMatchStepSubject]; + if (!subject->Is()) { + return absl::Status(absl::StatusCode::kInternal, + "First argument for regular " + "expression match must be a string"); + } + bool match = subject.GetString().NativeValue(MatchesVisitor{*re2_}); + frame->value_stack().Pop(kNumRegexMatchArguments); + frame->value_stack().Push(cel::BoolValue(match)); + return absl::OkStatus(); + } + + private: + const std::shared_ptr re2_; +}; + +class RegexMatchDirectStep final : public DirectExpressionStep { + public: + RegexMatchDirectStep(int64_t expr_id, + std::unique_ptr subject, + std::shared_ptr re2) + : DirectExpressionStep(expr_id), + subject_(std::move(subject)), + re2_(std::move(re2)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + AttributeTrail subject_attr; + CEL_RETURN_IF_ERROR(subject_->Evaluate(frame, result, subject_attr)); + if (result.IsError() || result.IsUnknown()) { + return absl::OkStatus(); + } + + if (!result.IsString()) { + return absl::Status(absl::StatusCode::kInternal, + "First argument for regular " + "expression match must be a string"); + } + bool match = result.GetString().NativeValue(MatchesVisitor{*re2_}); + result = BoolValue(match); + return absl::OkStatus(); + } + + private: + std::unique_ptr subject_; + const std::shared_ptr re2_; +}; + +} // namespace + +std::unique_ptr CreateDirectRegexMatchStep( + int64_t expr_id, std::unique_ptr subject, + std::shared_ptr re2) { + return std::make_unique(expr_id, std::move(subject), + std::move(re2)); +} + +absl::StatusOr> CreateRegexMatchStep( + std::shared_ptr re2, int64_t expr_id) { + return std::make_unique(expr_id, std::move(re2)); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/regex_match_step.h b/eval/eval/regex_match_step.h new file mode 100644 index 000000000..1d8a09118 --- /dev/null +++ b/eval/eval/regex_match_step.h @@ -0,0 +1,37 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "re2/re2.h" + +namespace google::api::expr::runtime { + +std::unique_ptr CreateDirectRegexMatchStep( + int64_t expr_id, std::unique_ptr subject, + std::shared_ptr re2); + +absl::StatusOr> CreateRegexMatchStep( + std::shared_ptr re2, int64_t expr_id); + +} + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ diff --git a/eval/eval/regex_match_step_test.cc b/eval/eval/regex_match_step_test.cc new file mode 100644 index 000000000..53b955b25 --- /dev/null +++ b/eval/eval/regex_match_step_test.cc @@ -0,0 +1,101 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/regex_match_step.h" + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_options.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::StatusIs; +using cel::expr::CheckedExpr; +using cel::expr::Reference; +using ::testing::Eq; +using ::testing::HasSubstr; + +Reference MakeMatchesStringOverload() { + Reference reference; + reference.add_overload_id("matches_string"); + return reference; +} + +TEST(RegexMatchStep, Precompiled) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("foo.matches('hello')")); + CheckedExpr checked_expr; + *checked_expr.mutable_expr() = parsed_expr.expr(); + *checked_expr.mutable_source_info() = parsed_expr.source_info(); + checked_expr.mutable_reference_map()->insert( + {checked_expr.expr().id(), MakeMatchesStringOverload()}); + InterpreterOptions options; + options.enable_regex_precompilation = true; + auto expr_builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto expr, + expr_builder->CreateExpression(&checked_expr)); + activation.InsertValue("foo", CelValue::CreateStringView("hello world!")); + ASSERT_OK_AND_ASSIGN(auto result, expr->Evaluate(activation, &arena)); + EXPECT_TRUE(result.IsBool()); + EXPECT_TRUE(result.BoolOrDie()); +} + +TEST(RegexMatchStep, PrecompiledInvalidRegex) { + Activation activation; + ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("foo.matches('(')")); + CheckedExpr checked_expr; + *checked_expr.mutable_expr() = parsed_expr.expr(); + *checked_expr.mutable_source_info() = parsed_expr.source_info(); + checked_expr.mutable_reference_map()->insert( + {checked_expr.expr().id(), MakeMatchesStringOverload()}); + InterpreterOptions options; + options.enable_regex_precompilation = true; + auto expr_builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); + EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid regular expression"))); +} + +TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) { + Activation activation; + ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("foo.matches('hello')")); + CheckedExpr checked_expr; + *checked_expr.mutable_expr() = parsed_expr.expr(); + *checked_expr.mutable_source_info() = parsed_expr.source_info(); + checked_expr.mutable_reference_map()->insert( + {checked_expr.expr().id(), MakeMatchesStringOverload()}); + InterpreterOptions options; + options.regex_max_program_size = 1; + options.enable_regex_precompilation = true; + auto expr_builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); + EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), + StatusIs(absl::StatusCode::kInvalidArgument, + Eq("regular expression exceeds max allowed size"))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 23d3e6432..b815f5d87 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -1,220 +1,511 @@ #include "eval/eval/select_step.h" +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/expr.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/cel_value.h" -#include "eval/public/containers/field_access.h" -#include "eval/public/containers/field_backed_list_impl.h" -#include "eval/public/containers/field_backed_map_impl.h" +#include "internal/status_macros.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::protobuf::Descriptor; -using google::protobuf::FieldDescriptor; -using google::protobuf::Reflection; +using ::cel::BoolValue; +using ::cel::ErrorValue; +using ::cel::MapValue; +using ::cel::NullValue; +using ::cel::OptionalValue; +using ::cel::ProtoWrapperTypeOptions; +using ::cel::StringValue; +using ::cel::StructValue; +using ::cel::Value; +using ::cel::ValueKind; + +// Common error for cases where evaluation attempts to perform select operations +// on an unsupported type. +// +// This should not happen under normal usage of the evaluator, but useful for +// troubleshooting broken invariants. +absl::Status InvalidSelectTargetError() { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Applying SELECT to non-message type"); +} + +absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, + ExecutionFrameBase& frame) { + if (frame.unknown_processing_enabled() && + frame.attribute_utility().CheckForUnknownExact(trail)) { + return frame.attribute_utility().CreateUnknownSet(trail.attribute()); + } + + if (frame.missing_attribute_errors_enabled() && + frame.attribute_utility().CheckForMissingAttribute(trail)) { + auto result = frame.attribute_utility().CreateMissingAttributeError( + trail.attribute()); + + if (result.ok()) { + return std::move(result).value(); + } + // Invariant broken (an invalid CEL Attribute shouldn't match anything). + // Log and return a CelError. + ABSL_LOG(ERROR) << "Invalid attribute pattern matched select path: " + << result.status().ToString(); // NOLINT: OSS compatibility + return cel::ErrorValue(std::move(result).status()); + } + + return std::nullopt; +} + +void TestOnlySelect(const StructValue& msg, const std::string& field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + absl::StatusOr has_field = msg.HasFieldByName(field); + + if (!has_field.ok()) { + *result = ErrorValue(std::move(has_field).status()); + return; + } + *result = BoolValue{*has_field}; +} + +void TestOnlySelect(const MapValue& map, const StringValue& field_name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) { + // Field presence only supports string keys containing valid identifier + // characters. + absl::Status presence = + map.Has(field_name, descriptor_pool, message_factory, arena, result); + + if (!presence.ok()) { + *result = ErrorValue(std::move(presence)); + return; + } + ABSL_DCHECK(!result->IsUnknown()); +} // SelectStep performs message field access specified by Expr::Select // message. class SelectStep : public ExpressionStepBase { public: - SelectStep(absl::string_view field, bool test_field_presence, int64_t expr_id, - absl::string_view select_path) + SelectStep(StringValue value, bool test_field_presence, int64_t expr_id, + bool enable_wrapper_type_null_unboxing, bool enable_optional_types) : ExpressionStepBase(expr_id), - field_(field), + field_value_(std::move(value)), + field_(field_value_.ToString()), test_field_presence_(test_field_presence), - select_path_(select_path) {} + unboxing_option_(enable_wrapper_type_null_unboxing + ? ProtoWrapperTypeOptions::kUnsetNull + : ProtoWrapperTypeOptions::kUnsetProtoDefault), + enable_optional_types_(enable_optional_types) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: - absl::Status CreateValueFromField(const google::protobuf::Message* msg, - google::protobuf::Arena* arena, - CelValue* result) const; + absl::Status PerformTestOnlySelect(ExecutionFrame* frame, + const Value& arg) const; + absl::StatusOr PerformSelect(ExecutionFrame* frame, const Value& arg, + Value& result) const; + cel::StringValue field_value_; std::string field_; bool test_field_presence_; - std::string select_path_; + ProtoWrapperTypeOptions unboxing_option_; + bool enable_optional_types_; }; -absl::Status SelectStep::CreateValueFromField(const google::protobuf::Message* msg, - google::protobuf::Arena* arena, - CelValue* result) const { - const Reflection* reflection = msg->GetReflection(); - const Descriptor* desc = msg->GetDescriptor(); - const FieldDescriptor* field_desc = desc->FindFieldByName(field_); - - if (field_desc == nullptr) { - *result = CreateNoSuchFieldError(arena); - return absl::OkStatus(); - } - - if (field_desc->is_map()) { - *result = CelValue::CreateMap(google::protobuf::Arena::Create( - arena, msg, field_desc, arena)); - return absl::OkStatus(); - } - if (field_desc->is_repeated()) { - *result = CelValue::CreateList(google::protobuf::Arena::Create( - arena, msg, field_desc, arena)); - return absl::OkStatus(); - } - if (test_field_presence_) { - *result = CelValue::CreateBool(reflection->HasField(*msg, field_desc)); - return absl::OkStatus(); - } - return CreateValueFromSingleField(msg, field_desc, arena, result); -} - absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { return absl::Status(absl::StatusCode::kInternal, "No arguments supplied for Select-type expression"); } - const CelValue& arg = frame->value_stack().Peek(); + const Value& arg = frame->value_stack().Peek(); const AttributeTrail& trail = frame->value_stack().PeekAttribute(); - CelValue result; + if (arg.IsUnknown() || arg.IsError()) { + // Bubble up unknowns and errors. + return absl::OkStatus(); + } + AttributeTrail result_trail; - // Non-empty select path - check if value mapped to unknown or error. - bool unknown_value = false; - // TODO(issues/41) deprecate this path after proper support of unknown is - // implemented - if (!select_path_.empty()) { - unknown_value = frame->activation().IsPathUnknown(select_path_); + // Handle unknown resolution. + if (frame->enable_unknowns() || frame->enable_missing_attribute_errors()) { + result_trail = trail.Step(&field_); } - // Select steps can be applied to either maps or messages - switch (arg.type()) { - case CelValue::Type::kMessage: { - const google::protobuf::Message* msg = arg.MessageOrDie(); + absl::optional optional_arg; - if (frame->enable_unknowns() || - frame->enable_missing_attribute_errors()) { - result_trail = trail.Step(&field_, frame->arena()); - } + if (enable_optional_types_ && arg.IsOptional()) { + optional_arg = arg.GetOptional(); + } - if (frame->enable_missing_attribute_errors() && - frame->attribute_utility().CheckForMissingAttribute(result_trail)) { - CelValue error_value = - CreateMissingAttributeError(frame->arena(), select_path_); - frame->value_stack().PopAndPush(error_value, result_trail); - return absl::OkStatus(); - } + if (!(optional_arg || arg->Is() || arg->Is())) { + frame->value_stack().PopAndPush(cel::ErrorValue(InvalidSelectTargetError()), + std::move(result_trail)); + return absl::OkStatus(); + } - if (frame->enable_unknowns() && - frame->attribute_utility().CheckForUnknown(result_trail, - /*use_partial=*/false)) { - auto unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownAttributeSet({result_trail.attribute()})); - result = CelValue::CreateUnknownSet(unknown_set); - frame->value_stack().PopAndPush(result, result_trail); - return absl::OkStatus(); - } + absl::optional marked_attribute_check = + CheckForMarkedAttributes(result_trail, *frame); + if (marked_attribute_check.has_value()) { + frame->value_stack().PopAndPush(std::move(marked_attribute_check).value(), + std::move(result_trail)); + return absl::OkStatus(); + } - if (msg == nullptr) { - CelValue error_value = - CreateErrorValue(frame->arena(), "Message is NULL"); - frame->value_stack().PopAndPush(error_value, result_trail); + // Handle test only Select. + if (test_field_presence_) { + if (optional_arg) { + if (!optional_arg->HasValue()) { + frame->value_stack().PopAndPush(cel::BoolValue{false}); return absl::OkStatus(); } + Value value; + optional_arg->Value(&value); + return PerformTestOnlySelect(frame, value); + } + return PerformTestOnlySelect(frame, arg); + } - if (unknown_value) { - CelValue error_value = - CreateUnknownValueError(frame->arena(), select_path_); - frame->value_stack().PopAndPush(error_value, result_trail); - return absl::OkStatus(); + // Normal select path. + // Select steps can be applied to either maps or messages + if (optional_arg) { + if (!optional_arg->HasValue()) { + // Leave optional_arg at the top of the stack. Its empty. + return absl::OkStatus(); + } + Value value; + Value result; + bool ok; + optional_arg->Value(&value); + CEL_ASSIGN_OR_RETURN(ok, PerformSelect(frame, value, result)); + if (!ok) { + frame->value_stack().PopAndPush(cel::OptionalValue::None(), + std::move(result_trail)); + return absl::OkStatus(); + } + frame->value_stack().PopAndPush( + cel::OptionalValue::Of(std::move(result), frame->arena()), + std::move(result_trail)); + return absl::OkStatus(); + } + + // Normal select path. + // Select steps can be applied to either maps or messages + switch (arg.kind()) { + case ValueKind::kStruct: { + Value result; + auto status = arg.GetStruct().GetFieldByName( + field_, unboxing_option_, frame->descriptor_pool(), + frame->message_factory(), frame->arena(), &result); + if (!status.ok()) { + result = ErrorValue(std::move(status)); } + frame->value_stack().PopAndPush(std::move(result), + std::move(result_trail)); + return absl::OkStatus(); + } + case ValueKind::kMap: { + Value result; + auto status = + arg.GetMap().Get(field_value_, frame->descriptor_pool(), + frame->message_factory(), frame->arena(), &result); + if (!status.ok()) { + result = ErrorValue(std::move(status)); + } + frame->value_stack().PopAndPush(std::move(result), + std::move(result_trail)); + return absl::OkStatus(); + } + default: + // Control flow should have returned earlier. + return InvalidSelectTargetError(); + } +} - absl::Status status = CreateValueFromField(msg, frame->arena(), &result); +absl::Status SelectStep::PerformTestOnlySelect(ExecutionFrame* frame, + const Value& arg) const { + switch (arg.kind()) { + case ValueKind::kMap: { + Value result; + TestOnlySelect(arg.GetMap(), field_value_, frame->descriptor_pool(), + frame->message_factory(), frame->arena(), &result); + frame->value_stack().PopAndPush(std::move(result)); + return absl::OkStatus(); + } + case ValueKind::kMessage: { + Value result; + TestOnlySelect(arg.GetStruct(), field_, frame->descriptor_pool(), + frame->message_factory(), frame->arena(), &result); + frame->value_stack().PopAndPush(std::move(result)); + return absl::OkStatus(); + } + default: + // Control flow should have returned earlier. + return InvalidSelectTargetError(); + } +} - if (status.ok()) { - frame->value_stack().PopAndPush(result, result_trail); +absl::StatusOr SelectStep::PerformSelect(ExecutionFrame* frame, + const Value& arg, + Value& result) const { + switch (arg->kind()) { + case ValueKind::kStruct: { + const auto& struct_value = arg.GetStruct(); + CEL_ASSIGN_OR_RETURN(auto ok, struct_value.HasFieldByName(field_)); + if (!ok) { + result = NullValue{}; + return false; } + CEL_RETURN_IF_ERROR(struct_value.GetFieldByName( + field_, unboxing_option_, frame->descriptor_pool(), + frame->message_factory(), frame->arena(), &result)); + ABSL_DCHECK(!result.IsUnknown()); + return true; + } + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN( + auto found, + arg.GetMap().Find(field_value_, frame->descriptor_pool(), + frame->message_factory(), frame->arena(), &result)); + ABSL_DCHECK(!found || !result.IsUnknown()); + return found; + } + default: + // Control flow should have returned earlier. + return InvalidSelectTargetError(); + } +} - return status; +class DirectSelectStep : public DirectExpressionStep { + public: + DirectSelectStep(int64_t expr_id, + std::unique_ptr operand, + StringValue field, bool test_only, + bool enable_wrapper_type_null_unboxing, + bool enable_optional_types) + : DirectExpressionStep(expr_id), + operand_(std::move(operand)), + field_value_(std::move(field)), + field_(field_value_.ToString()), + test_only_(test_only), + unboxing_option_(enable_wrapper_type_null_unboxing + ? ProtoWrapperTypeOptions::kUnsetNull + : ProtoWrapperTypeOptions::kUnsetProtoDefault), + enable_optional_types_(enable_optional_types) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute)); + + if (result.IsError() || result.IsUnknown()) { + // Just forward. + return absl::OkStatus(); } - case CelValue::Type::kMap: { - const CelMap* cel_map = arg.MapOrDie(); - if (cel_map == nullptr) { - CelValue error_value = CreateErrorValue(frame->arena(), "Map is NULL"); - frame->value_stack().PopAndPush(error_value); + if (frame.attribute_tracking_enabled()) { + attribute = attribute.Step(&field_); + absl::optional value = CheckForMarkedAttributes(attribute, frame); + if (value.has_value()) { + result = std::move(value).value(); return absl::OkStatus(); } + } - if (unknown_value) { - CelValue error_value = CreateErrorValue( - frame->arena(), absl::StrCat("Unknown value ", select_path_)); - frame->value_stack().PopAndPush(error_value); - return absl::OkStatus(); - } + absl::optional optional_arg; - auto lookup_result = (*cel_map)[CelValue::CreateString(&field_)]; + if (enable_optional_types_ && result.IsOptional()) { + optional_arg = result.GetOptional(); + } - // Test only Select expression. - if (test_field_presence_) { - result = CelValue::CreateBool(lookup_result.has_value()); - frame->value_stack().PopAndPush(result); + switch (result.kind()) { + case ValueKind::kStruct: + case ValueKind::kMap: + break; + default: + if (optional_arg) { + break; + } + result = cel::ErrorValue(InvalidSelectTargetError()); return absl::OkStatus(); - } + } - if (frame->enable_unknowns()) { - result_trail = trail.Step(&field_, frame->arena()); - if (frame->attribute_utility().CheckForUnknown(result_trail, false)) { - auto unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownAttributeSet({result_trail.attribute()})); - result = CelValue::CreateUnknownSet(unknown_set); - frame->value_stack().PopAndPush(result, result_trail); + if (test_only_) { + if (optional_arg) { + if (!optional_arg->HasValue()) { + result = cel::BoolValue{false}; return absl::OkStatus(); } + Value value; + optional_arg->Value(&value); + PerformTestOnlySelect(frame, value, result); + return absl::OkStatus(); } + PerformTestOnlySelect(frame, result, result); + return absl::OkStatus(); + } - // If object is not found, we return Error, per CEL specification. - if (lookup_result) { - result = lookup_result.value(); - } else { - result = CreateNoSuchKeyError(frame->arena(), field_); + if (optional_arg) { + if (!optional_arg->HasValue()) { + // result is still buffer for the container. just return. + return absl::OkStatus(); } - frame->value_stack().PopAndPush(result, result_trail); + Value value; + optional_arg->Value(&value); + return PerformOptionalSelect(frame, value, result); + } - return absl::OkStatus(); + auto status = PerformSelect(frame, result, result); + if (!status.ok()) { + result = ErrorValue(std::move(status)); } - case CelValue::Type::kUnknownSet: { - // Parent is unknown already, bubble it up. + return absl::OkStatus(); + } + + private: + std::unique_ptr operand_; + + void PerformTestOnlySelect(ExecutionFrameBase& frame, const Value& value, + Value& result) const; + absl::Status PerformOptionalSelect(ExecutionFrameBase& frame, + const Value& value, Value& result) const; + absl::Status PerformSelect(ExecutionFrameBase& frame, const Value& value, + Value& result) const; + + // Field name in formats supported by each of the map and struct field access + // APIs. + // + // ToString or ValueManager::CreateString may force a copy so we do this at + // plan time. + StringValue field_value_; + std::string field_; + + // whether this is a has() expression. + bool test_only_; + ProtoWrapperTypeOptions unboxing_option_; + bool enable_optional_types_; +}; + +void DirectSelectStep::PerformTestOnlySelect(ExecutionFrameBase& frame, + const cel::Value& value, + Value& result) const { + switch (value.kind()) { + case ValueKind::kMap: + TestOnlySelect(value.GetMap(), field_value_, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result); + return; + case ValueKind::kMessage: + TestOnlySelect(value.GetStruct(), field_, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result); + return; + default: + // Control flow should have returned earlier. + result = cel::ErrorValue(InvalidSelectTargetError()); + return; + } +} + +absl::Status DirectSelectStep::PerformOptionalSelect(ExecutionFrameBase& frame, + const Value& value, + Value& result) const { + switch (value.kind()) { + case ValueKind::kStruct: { + auto struct_value = value.GetStruct(); + CEL_ASSIGN_OR_RETURN(auto ok, struct_value.HasFieldByName(field_)); + if (!ok) { + result = OptionalValue::None(); + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(struct_value.GetFieldByName( + field_, unboxing_option_, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + ABSL_DCHECK(!result.IsUnknown()); + result = OptionalValue::Of(std::move(result), frame.arena()); return absl::OkStatus(); } - case CelValue::Type::kError: { - // If argument is CelError, we propagate it forward. - // It is already on the top of the stack. + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN( + auto found, + value.GetMap().Find(field_value_, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + if (!found) { + result = OptionalValue::None(); + return absl::OkStatus(); + } + ABSL_DCHECK(!result.IsUnknown()); + result = OptionalValue::Of(std::move(result), frame.arena()); return absl::OkStatus(); } default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Applying SELECT to non-message type"); + // Control flow should have returned earlier. + return InvalidSelectTargetError(); + } +} + +absl::Status DirectSelectStep::PerformSelect(ExecutionFrameBase& frame, + const cel::Value& value, + Value& result) const { + switch (value.kind()) { + case ValueKind::kStruct: + CEL_RETURN_IF_ERROR(value.GetStruct().GetFieldByName( + field_, unboxing_option_, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + ABSL_DCHECK(!result.IsUnknown()); + return absl::OkStatus(); + case ValueKind::kMap: + CEL_RETURN_IF_ERROR( + value.GetMap().Get(field_value_, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + ABSL_DCHECK(!result.IsUnknown()); + return absl::OkStatus(); + default: + // Control flow should have returned earlier. + return InvalidSelectTargetError(); } } } // namespace +std::unique_ptr CreateDirectSelectStep( + std::unique_ptr operand, StringValue field, + bool test_only, int64_t expr_id, bool enable_wrapper_type_null_unboxing, + bool enable_optional_types) { + return std::make_unique( + expr_id, std::move(operand), std::move(field), test_only, + enable_wrapper_type_null_unboxing, enable_optional_types); +} + // Factory method for Select - based Execution step absl::StatusOr> CreateSelectStep( - const google::api::expr::v1alpha1::Expr::Select* select_expr, int64_t expr_id, - absl::string_view select_path) { - std::unique_ptr step = absl::make_unique( - select_expr->field(), select_expr->test_only(), expr_id, select_path); - return std::move(step); + const cel::SelectExpr& select_expr, int64_t expr_id, + bool enable_wrapper_type_null_unboxing, bool enable_optional_types) { + return std::make_unique( + cel::StringValue(select_expr.field()), select_expr.test_only(), expr_id, + enable_wrapper_type_null_unboxing, enable_optional_types); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/select_step.h b/eval/eval/select_step.h index d288c1654..6eaaf9487 100644 --- a/eval/eval/select_step.h +++ b/eval/eval/select_step.h @@ -1,27 +1,28 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_SELECT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_SELECT_STEP_H_ +#include +#include + +#include "absl/status/statusor.h" +#include "common/expr.h" +#include "common/value.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" -#include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { -// Factory method for Select - based Execution step -absl::StatusOr> CreateSelectStep( - const google::api::expr::v1alpha1::Expr::Select* select_expr, int64_t expr_id, - absl::string_view select_path); +// Factory method for recursively evaluated select step. +std::unique_ptr CreateDirectSelectStep( + std::unique_ptr operand, cel::StringValue field, + bool test_only, int64_t expr_id, bool enable_wrapper_type_null_unboxing, + bool enable_optional_types = false); // Factory method for Select - based Execution step absl::StatusOr> CreateSelectStep( - const google::api::expr::v1alpha1::Expr::Select* select_expr, int64_t expr_id); + const cel::SelectExpr& select_expr, int64_t expr_id, + bool enable_wrapper_type_null_unboxing, bool enable_optional_types = false); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_SELECT_STEP_H_ diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index ba6f0954b..ce532eabd 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -1,183 +1,360 @@ #include "eval/eval/select_step.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "base/type_provider.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/legacy_value.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" +#include "eval/public/activation.h" #include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "eval/public/unknown_attribute_set.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_extensions.pb.h" #include "eval/testutil/test_message.pb.h" -#include "testutil/util.h" -#include "base/status_macros.h" +#include "extensions/protobuf/value.h" +#include "internal/proto_matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { namespace { -using testing::Eq; - -using testutil::EqualsProto; - -using google::api::expr::v1alpha1::Expr; - -// Helper method. Creates simple pipeline containing Select step and runs it. -absl::StatusOr RunExpression(const CelValue target, - absl::string_view field, bool test, - google::protobuf::Arena* arena, - absl::string_view unknown_path, - bool enable_unknowns) { - ExecutionPath path; - - Expr dummy_expr; +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::Attribute; +using ::cel::AttributeQualifier; +using ::cel::AttributeSet; +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::Expr; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::OptionalValue; +using ::cel::RuntimeOptions; +using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::extensions::ProtoMessageToValue; +using ::cel::internal::test::EqualsProto; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::cel::test::IntValueIs; +using ::testing::_; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Return; +using ::testing::UnorderedElementsAre; + +struct RunExpressionOptions { + bool enable_unknowns = false; + bool enable_wrapper_type_null_unboxing = false; +}; + +// Simple implementation LegacyTypeAccessApis / LegacyTypeInfoApis that allows +// mocking for getters/setters. +class MockAccessor : public LegacyTypeAccessApis, public LegacyTypeInfoApis { + public: + MOCK_METHOD(absl::StatusOr, HasField, + (absl::string_view field_name, + const CelValue::MessageWrapper& value), + (const, override)); + MOCK_METHOD(absl::StatusOr, GetField, + (absl::string_view field_name, + const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager), + (const, override)); + MOCK_METHOD(absl::string_view, GetTypename, + (const CelValue::MessageWrapper& instance), (const, override)); + MOCK_METHOD(std::string, DebugString, + (const CelValue::MessageWrapper& instance), (const, override)); + MOCK_METHOD(std::vector, ListFields, + (const CelValue::MessageWrapper& value), (const, override)); + const LegacyTypeAccessApis* GetAccessApis( + const CelValue::MessageWrapper& instance) const override { + return this; + } +}; + +class SelectStepTest : public testing::Test { + public: + SelectStepTest() : env_(NewTestingRuntimeEnv()) {} + // Helper method. Creates simple pipeline containing Select step and runs it. + absl::StatusOr RunExpression(const CelValue target, + absl::string_view field, bool test, + absl::string_view unknown_path, + RunExpressionOptions options) { + ExecutionPath path; + + Expr expr; + auto& select = expr.mutable_select_expr(); + select.set_field(std::string(field)); + select.set_test_only(test); + Expr& expr0 = select.mutable_operand(); + + auto& ident = expr0.mutable_ident_expr(); + ident.set_name("target"); + CEL_ASSIGN_OR_RETURN(auto step0, CreateIdentStep(ident.name(), expr0.id())); + CEL_ASSIGN_OR_RETURN( + auto step1, + CreateSelectStep(select, expr.id(), + options.enable_wrapper_type_null_unboxing)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + + cel::RuntimeOptions runtime_options; + if (options.enable_unknowns) { + runtime_options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl cel_expr( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), + runtime_options)); + Activation activation; + activation.InsertValue("target", target); - auto select = dummy_expr.mutable_select_expr(); - select->set_field(field.data()); - select->set_test_only(test); - Expr* expr0 = select->mutable_operand(); + return cel_expr.Evaluate(activation, &arena_); + } - auto ident = expr0->mutable_ident_expr(); - ident->set_name("target"); - auto step0_status = CreateIdentStep(ident, expr0->id()); - auto step1_status = CreateSelectStep(select, dummy_expr.id(), unknown_path); + absl::StatusOr RunExpression(const TestExtensions* message, + absl::string_view field, bool test, + RunExpressionOptions options) { + return RunExpression(CelProtoWrapper::CreateMessage(message, &arena_), + field, test, "", options); + } - if (!step0_status.ok()) { - return step0_status.status(); + absl::StatusOr RunExpression(const TestMessage* message, + absl::string_view field, bool test, + absl::string_view unknown_path, + RunExpressionOptions options) { + return RunExpression(CelProtoWrapper::CreateMessage(message, &arena_), + field, test, unknown_path, options); } - if (!step1_status.ok()) { - return step1_status.status(); + absl::StatusOr RunExpression(const TestMessage* message, + absl::string_view field, bool test, + RunExpressionOptions options) { + return RunExpression(message, field, test, "", options); } - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); + absl::StatusOr RunExpression(const CelMap* map_value, + absl::string_view field, bool test, + absl::string_view unknown_path, + RunExpressionOptions options) { + return RunExpression(CelValue::CreateMap(map_value), field, test, + unknown_path, options); + } - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - enable_unknowns); - Activation activation; - activation.InsertValue("target", target); + absl::StatusOr RunExpression(const CelMap* map_value, + absl::string_view field, bool test, + RunExpressionOptions options) { + return RunExpression(map_value, field, test, "", options); + } - return cel_expr.Evaluate(activation, arena); -} + protected: + absl_nonnull std::shared_ptr env_; + google::protobuf::Arena arena_; +}; -absl::StatusOr RunExpression(const TestMessage* message, - absl::string_view field, bool test, - google::protobuf::Arena* arena, - absl::string_view unknown_path, - bool enable_unknowns) { - return RunExpression(CelProtoWrapper::CreateMessage(message, arena), field, - test, arena, unknown_path, enable_unknowns); -} +class SelectStepConformanceTest : public SelectStepTest, + public testing::WithParamInterface {}; -absl::StatusOr RunExpression(const TestMessage* message, - absl::string_view field, bool test, - google::protobuf::Arena* arena, - bool enable_unknowns) { - return RunExpression(message, field, test, arena, "", enable_unknowns); -} +TEST_P(SelectStepConformanceTest, SelectMessageIsNull) { + RunExpressionOptions options; + options.enable_unknowns = GetParam(); -absl::StatusOr RunExpression(const CelMap* map_value, - absl::string_view field, bool test, - google::protobuf::Arena* arena, - absl::string_view unknown_path, - bool enable_unknowns) { - return RunExpression(CelValue::CreateMap(map_value), field, test, arena, - unknown_path, enable_unknowns); -} + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(static_cast(nullptr), + "bool_value", true, options)); -absl::StatusOr RunExpression(const CelMap* map_value, - absl::string_view field, bool test, - google::protobuf::Arena* arena, - bool enable_unknowns) { - return RunExpression(map_value, field, test, arena, "", enable_unknowns); + ASSERT_TRUE(result.IsError()); } -class SelectStepTest : public testing::TestWithParam {}; - -TEST_P(SelectStepTest, SelectMessageIsNull) { - google::protobuf::Arena arena; +TEST_P(SelectStepConformanceTest, SelectTargetNotStructOrMap) { + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - auto run_status = RunExpression(static_cast(nullptr), - "bool_value", true, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(CelValue::CreateStringView("some_value"), "some_field", + /*test=*/false, + /*unknown_path=*/"", options)); ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Applying SELECT to non-message type"))); } -TEST_P(SelectStepTest, PresenseIsFalseTest) { +TEST_P(SelectStepConformanceTest, PresenseIsFalseTest) { TestMessage message; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "bool_value", true, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "bool_value", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } -TEST_P(SelectStepTest, PresenseIsTrueTest) { +TEST_P(SelectStepConformanceTest, PresenseIsTrueTest) { + RunExpressionOptions options; + options.enable_unknowns = GetParam(); TestMessage message; message.set_bool_value(true); - google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "bool_value", true, options)); + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.BoolOrDie(), true); +} - auto run_status = - RunExpression(&message, "bool_value", true, &arena, GetParam()); - ASSERT_OK(run_status); +TEST_P(SelectStepConformanceTest, ExtensionsPresenceIsTrueTest) { + TestExtensions exts; + TestExtensions* nested = exts.MutableExtension(nested_ext); + nested->set_name("nested"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&exts, "google.api.expr.runtime.nested_ext", true, + options)); ASSERT_TRUE(result.IsBool()); - EXPECT_EQ(result.BoolOrDie(), true); + EXPECT_TRUE(result.BoolOrDie()); +} + +TEST_P(SelectStepConformanceTest, ExtensionsPresenceIsFalseTest) { + TestExtensions exts; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&exts, "google.api.expr.runtime.nested_ext", true, + options)); + + ASSERT_TRUE(result.IsBool()); + EXPECT_FALSE(result.BoolOrDie()); } -TEST_P(SelectStepTest, MapPresenseIsFalseTest) { +TEST_P(SelectStepConformanceTest, MapPresenseIsFalseTest) { + RunExpressionOptions options; + options.enable_unknowns = GetParam(); std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; auto map_value = CreateContainerBackedMap( - absl::Span>(key_values)); - - google::protobuf::Arena arena; - - auto run_status = - RunExpression(map_value.get(), "key2", true, &arena, GetParam()); - CelValue result = run_status.value(); + absl::Span>(key_values)) + .value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(map_value.get(), "key2", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } -TEST_P(SelectStepTest, MapPresenseIsTrueTest) { +TEST_P(SelectStepConformanceTest, MapPresenseIsTrueTest) { + RunExpressionOptions options; + options.enable_unknowns = GetParam(); std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; auto map_value = CreateContainerBackedMap( - absl::Span>(key_values)); - - google::protobuf::Arena arena; + absl::Span>(key_values)) + .value(); - auto run_status = - RunExpression(map_value.get(), "key1", true, &arena, GetParam()); - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(map_value.get(), "key1", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, MapPresenseIsTrueWithUnknownTest) { +TEST_F(SelectStepTest, MapPresenseIsErrorTest) { + TestMessage message; + + Expr select_expr; + auto& select = select_expr.mutable_select_expr(); + select.set_field("1"); + select.set_test_only(true); + Expr& expr1 = select.mutable_operand(); + auto& select_map = expr1.mutable_select_expr(); + select_map.set_field("int32_int32_map"); + Expr& expr0 = select_map.mutable_operand(); + auto& ident = expr0.mutable_ident_expr(); + ident.set_name("target"); + + ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); + ASSERT_OK_AND_ASSIGN( + auto step1, + CreateSelectStep(select_map, expr1.id(), + /*enable_wrapper_type_null_unboxing=*/false)); + ASSERT_OK_AND_ASSIGN( + auto step2, + CreateSelectStep(select, select_expr.id(), + /*enable_wrapper_type_null_unboxing=*/false)); + + ExecutionPath path; + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + path.push_back(std::move(step2)); + CelExpressionFlatImpl cel_expr( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); + Activation activation; + activation.InsertValue("target", + CelProtoWrapper::CreateMessage(&message, &arena_)); + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); + EXPECT_TRUE(result.IsError()); + EXPECT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); +} + +TEST_F(SelectStepTest, MapPresenseIsTrueWithUnknownTest) { UnknownSet unknown_set; std::string key1 = "key1"; std::vector> key_values{ @@ -185,495 +362,635 @@ TEST(SelectStepTest, MapPresenseIsTrueWithUnknownTest) { CelValue::CreateUnknownSet(&unknown_set)}}; auto map_value = CreateContainerBackedMap( - absl::Span>(key_values)); - - google::protobuf::Arena arena; + absl::Span>(key_values)) + .value(); - auto run_status = RunExpression(map_value.get(), "key1", true, &arena, true); - CelValue result = run_status.value(); + RunExpressionOptions options; + options.enable_unknowns = true; + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(map_value.get(), "key1", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } -TEST_P(SelectStepTest, FieldIsNotPresentInProtoTest) { +TEST_P(SelectStepConformanceTest, FieldIsNotPresentInProtoTest) { TestMessage message; - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "fake_field", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "fake_field", false, options)); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(result.ErrorOrDie()->code(), Eq(absl::StatusCode::kNotFound)); } -TEST_P(SelectStepTest, FieldIsNotSetTest) { +TEST_P(SelectStepConformanceTest, FieldIsNotSetTest) { TestMessage message; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "bool_value", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "bool_value", false, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } -TEST_P(SelectStepTest, SimpleBoolTest) { +TEST_P(SelectStepConformanceTest, SimpleBoolTest) { TestMessage message; message.set_bool_value(true); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "bool_value", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "bool_value", false, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } -TEST_P(SelectStepTest, SimpleInt32Test) { +TEST_P(SelectStepConformanceTest, SimpleInt32Test) { TestMessage message; message.set_int32_value(1); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "int32_value", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "int32_value", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 1); } -TEST_P(SelectStepTest, SimpleInt64Test) { +TEST_P(SelectStepConformanceTest, SimpleInt64Test) { TestMessage message; message.set_int64_value(1); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "int64_value", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "int64_value", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 1); } -TEST_P(SelectStepTest, SimpleUInt32Test) { +TEST_P(SelectStepConformanceTest, SimpleUInt32Test) { TestMessage message; message.set_uint32_value(1); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "uint32_value", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "uint32_value", false, options)); ASSERT_TRUE(result.IsUint64()); EXPECT_EQ(result.Uint64OrDie(), 1); } -TEST_P(SelectStepTest, SimpleUint64Test) { +TEST_P(SelectStepConformanceTest, SimpleUint64Test) { TestMessage message; message.set_uint64_value(1); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "uint64_value", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "uint64_value", false, options)); ASSERT_TRUE(result.IsUint64()); EXPECT_EQ(result.Uint64OrDie(), 1); } -TEST_P(SelectStepTest, SimpleStringTest) { +TEST_P(SelectStepConformanceTest, SimpleStringTest) { TestMessage message; std::string value = "test"; message.set_string_value(value); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "string_value", false, options)); - auto run_status = - RunExpression(&message, "string_value", false, &arena, GetParam()); - ASSERT_OK(run_status); + ASSERT_TRUE(result.IsString()); + EXPECT_EQ(result.StringOrDie().value(), "test"); +} - CelValue result = run_status.value(); +TEST_P(SelectStepConformanceTest, WrapperTypeNullUnboxingEnabledTest) { + TestMessage message; + message.mutable_string_wrapper_value()->set_value("test"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + options.enable_wrapper_type_null_unboxing = true; + + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&message, "string_wrapper_value", false, options)); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "test"); + ASSERT_OK_AND_ASSIGN( + result, RunExpression(&message, "int32_wrapper_value", false, options)); + EXPECT_TRUE(result.IsNull()); } +TEST_P(SelectStepConformanceTest, WrapperTypeNullUnboxingDisabledTest) { + TestMessage message; + message.mutable_string_wrapper_value()->set_value("test"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + options.enable_wrapper_type_null_unboxing = false; + + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&message, "string_wrapper_value", false, options)); + + ASSERT_TRUE(result.IsString()); + EXPECT_EQ(result.StringOrDie().value(), "test"); + ASSERT_OK_AND_ASSIGN( + result, RunExpression(&message, "int32_wrapper_value", false, options)); + EXPECT_TRUE(result.IsInt64()); +} -TEST_P(SelectStepTest, SimpleBytesTest) { +TEST_P(SelectStepConformanceTest, SimpleBytesTest) { TestMessage message; std::string value = "test"; message.set_bytes_value(value); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "bytes_value", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "bytes_value", false, options)); ASSERT_TRUE(result.IsBytes()); EXPECT_EQ(result.BytesOrDie().value(), "test"); } -TEST_P(SelectStepTest, SimpleMessageTest) { +TEST_P(SelectStepConformanceTest, SimpleMessageTest) { TestMessage message; - TestMessage* message2 = message.mutable_message_value(); message2->set_int32_value(1); message2->set_string_value("test"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + + ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "message_value", + false, options)); + + ASSERT_TRUE(result.IsMessage()); + EXPECT_THAT(*message2, EqualsProto(*result.MessageOrDie())); +} + +TEST_P(SelectStepConformanceTest, GlobalExtensionsIntTest) { + TestExtensions exts; + exts.SetExtension(int32_ext, 42); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&exts, "google.api.expr.runtime.int32_ext", + false, options)); - google::protobuf::Arena arena; + ASSERT_TRUE(result.IsInt64()); + EXPECT_EQ(result.Int64OrDie(), 42L); +} - auto run_status = - RunExpression(&message, "message_value", false, &arena, GetParam()); - ASSERT_OK(run_status); +TEST_P(SelectStepConformanceTest, GlobalExtensionsMessageTest) { + TestExtensions exts; + TestExtensions* nested = exts.MutableExtension(nested_ext); + nested->set_name("nested"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&exts, "google.api.expr.runtime.nested_ext", false, + options)); ASSERT_TRUE(result.IsMessage()); - EXPECT_THAT(*message2, EqualsProto(*result.MessageOrDie())); + EXPECT_THAT(result.MessageOrDie(), Eq(nested)); } -TEST_P(SelectStepTest, SimpleEnumTest) { - TestMessage message; +TEST_P(SelectStepConformanceTest, GlobalExtensionsMessageUnsetTest) { + TestExtensions exts; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - message.set_enum_value(TestMessage::TEST_ENUM_1); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&exts, "google.api.expr.runtime.nested_ext", false, + options)); - google::protobuf::Arena arena; + ASSERT_TRUE(result.IsMessage()); + EXPECT_THAT(result.MessageOrDie(), Eq(&TestExtensions::default_instance())); +} - auto run_status = - RunExpression(&message, "enum_value", false, &arena, GetParam()); - ASSERT_OK(run_status); +TEST_P(SelectStepConformanceTest, GlobalExtensionsWrapperTest) { + TestExtensions exts; + google::protobuf::Int32Value* wrapper = + exts.MutableExtension(int32_wrapper_ext); + wrapper->set_value(42); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&exts, "google.api.expr.runtime.int32_wrapper_ext", false, + options)); ASSERT_TRUE(result.IsInt64()); - EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); + EXPECT_THAT(result.Int64OrDie(), Eq(42L)); } -TEST_P(SelectStepTest, SimpleListTest) { - TestMessage message; +TEST_P(SelectStepConformanceTest, GlobalExtensionsWrapperUnsetTest) { + TestExtensions exts; + RunExpressionOptions options; + options.enable_wrapper_type_null_unboxing = true; + options.enable_unknowns = GetParam(); - message.add_int32_list(1); - message.add_int32_list(2); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&exts, "google.api.expr.runtime.int32_wrapper_ext", false, + options)); - google::protobuf::Arena arena; + ASSERT_TRUE(result.IsNull()); +} - auto run_status = - RunExpression(&message, "int32_list", false, &arena, GetParam()); - ASSERT_OK(run_status); +TEST_P(SelectStepConformanceTest, MessageExtensionsEnumTest) { + TestExtensions exts; + exts.SetExtension(TestMessageExtensions::enum_ext, TestExtEnum::TEST_EXT_1); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&exts, + "google.api.expr.runtime.TestMessageExtensions.enum_ext", + false, options)); + + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(TestExtEnum::TEST_EXT_1)); +} + +TEST_P(SelectStepConformanceTest, MessageExtensionsRepeatedStringTest) { + TestExtensions exts; + exts.AddExtension(TestMessageExtensions::repeated_string_exts, "test1"); + exts.AddExtension(TestMessageExtensions::repeated_string_exts, "test2"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression( + &exts, + "google.api.expr.runtime.TestMessageExtensions.repeated_string_exts", + false, options)); ASSERT_TRUE(result.IsList()); + const CelList* cel_list = result.ListOrDie(); + EXPECT_THAT(cel_list->size(), Eq(2)); +} + +TEST_P(SelectStepConformanceTest, MessageExtensionsRepeatedStringUnsetTest) { + TestExtensions exts; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression( + &exts, + "google.api.expr.runtime.TestMessageExtensions.repeated_string_exts", + false, options)); + + ASSERT_TRUE(result.IsList()); const CelList* cel_list = result.ListOrDie(); + EXPECT_THAT(cel_list->size(), Eq(0)); +} + +TEST_P(SelectStepConformanceTest, NullMessageAccessor) { + TestMessage message; + TestMessage* message2 = message.mutable_message_value(); + message2->set_int32_value(1); + message2->set_string_value("test"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + CelValue value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, TrivialTypeInfo::GetInstance())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(value, "message_value", + /*test=*/false, + /*unknown_path=*/"", options)); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kNotFound)); + + // same for has + ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", + /*test=*/true, + /*unknown_path=*/"", options)); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_P(SelectStepConformanceTest, CustomAccessor) { + TestMessage message; + TestMessage* message2 = message.mutable_message_value(); + message2->set_int32_value(1); + message2->set_string_value("test"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + testing::NiceMock accessor; + CelValue value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, &accessor)); + + ON_CALL(accessor, GetField(_, _, _, _)) + .WillByDefault(Return(CelValue::CreateInt64(2))); + ON_CALL(accessor, HasField(_, _)).WillByDefault(Return(false)); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(value, "message_value", + /*test=*/false, + /*unknown_path=*/"", options)); + EXPECT_THAT(result, test::IsCelInt64(2)); + + // testonly select (has) + ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", + /*test=*/true, + /*unknown_path=*/"", options)); + + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST_P(SelectStepConformanceTest, CustomAccessorErrorHandling) { + TestMessage message; + TestMessage* message2 = message.mutable_message_value(); + message2->set_int32_value(1); + message2->set_string_value("test"); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + testing::NiceMock accessor; + CelValue value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, &accessor)); + + ON_CALL(accessor, GetField(_, _, _, _)) + .WillByDefault(Return(absl::InternalError("bad data"))); + ON_CALL(accessor, HasField(_, _)) + .WillByDefault(Return(absl::NotFoundError("not found"))); + + // For get field, implementation may return an error-type cel value or a + // status (e.g. broken assumption using a core type). + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(value, "message_value", + /*test=*/false, + /*unknown_path=*/"", options)); + EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kInternal))); + + // testonly select (has) errors are coerced to CelError. + ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", + /*test=*/true, + /*unknown_path=*/"", options)); + + EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); +} + +TEST_P(SelectStepConformanceTest, SimpleEnumTest) { + TestMessage message; + message.set_enum_value(TestMessage::TEST_ENUM_1); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "enum_value", false, options)); + + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); +} + +TEST_P(SelectStepConformanceTest, SimpleListTest) { + TestMessage message; + message.add_int32_list(1); + message.add_int32_list(2); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(&message, "int32_list", false, options)); + + ASSERT_TRUE(result.IsList()); + const CelList* cel_list = result.ListOrDie(); EXPECT_THAT(cel_list->size(), Eq(2)); } -TEST_P(SelectStepTest, SimpleMapTest) { +TEST_P(SelectStepConformanceTest, SimpleMapTest) { TestMessage message; auto map_field = message.mutable_string_int32_map(); (*map_field)["test0"] = 1; (*map_field)["test1"] = 2; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(&message, "string_int32_map", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); - + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpression(&message, "string_int32_map", false, options)); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); - EXPECT_THAT(cel_map->size(), Eq(2)); } -TEST_P(SelectStepTest, MapSimpleInt32Test) { +TEST_P(SelectStepConformanceTest, MapSimpleInt32Test) { std::string key1 = "key1"; std::string key2 = "key2"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}, {CelValue::CreateString(&key2), CelValue::CreateInt64(2)}}; - auto map_value = CreateContainerBackedMap( - absl::Span>(key_values)); + absl::Span>(key_values)) + .value(); + RunExpressionOptions options; + options.enable_unknowns = GetParam(); - google::protobuf::Arena arena; - - auto run_status = - RunExpression(map_value.get(), "key1", false, &arena, GetParam()); - ASSERT_OK(run_status); - - CelValue result = run_status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(map_value.get(), "key1", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 1); } // Test Select behavior, when expression to select from is an Error. -TEST_P(SelectStepTest, CelErrorAsArgument) { +TEST_P(SelectStepConformanceTest, CelErrorAsArgument) { ExecutionPath path; Expr dummy_expr; - auto select = dummy_expr.mutable_select_expr(); - select->set_field("position"); - select->set_test_only(false); - Expr* expr0 = select->mutable_operand(); - - auto ident = expr0->mutable_ident_expr(); - ident->set_name("message"); - auto step0_status = CreateIdentStep(ident, expr0->id()); - auto step1_status = CreateSelectStep(select, dummy_expr.id(), ""); + auto& select = dummy_expr.mutable_select_expr(); + select.set_field("position"); + select.set_test_only(false); + Expr& expr0 = select.mutable_operand(); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); + auto& ident = expr0.mutable_ident_expr(); + ident.set_name("message"); + ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); + ASSERT_OK_AND_ASSIGN( + auto step1, + CreateSelectStep(select, dummy_expr.id(), + /*enable_wrapper_type_null_unboxing=*/false)); - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); - CelError error; + CelError error = absl::CancelledError(); - google::protobuf::Arena arena; - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - GetParam()); + cel::RuntimeOptions options; + if (GetParam()) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl cel_expr( + env_, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); - auto status = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(status); - - auto result = status.value(); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(result.ErrorOrDie(), Eq(&error)); + EXPECT_THAT(*result.ErrorOrDie(), Eq(error)); } -TEST(SelectStepTest, DisableMissingAttributeOK) { +TEST_F(SelectStepTest, DisableMissingAttributeOK) { TestMessage message; message.set_bool_value(true); - google::protobuf::Arena arena; ExecutionPath path; Expr dummy_expr; - auto select = dummy_expr.mutable_select_expr(); - select->set_field("bool_value"); - select->set_test_only(false); - Expr* expr0 = select->mutable_operand(); - - auto ident = expr0->mutable_ident_expr(); - ident->set_name("message"); - auto step0_status = CreateIdentStep(ident, expr0->id()); - auto step1_status = - CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - /*enable_unknowns=*/false); + auto& select = dummy_expr.mutable_select_expr(); + select.set_field("bool_value"); + select.set_test_only(false); + Expr& expr0 = select.mutable_operand(); + + auto& ident = expr0.mutable_ident_expr(); + ident.set_name("message"); + ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); + ASSERT_OK_AND_ASSIGN( + auto step1, + CreateSelectStep(select, dummy_expr.id(), + /*enable_wrapper_type_null_unboxing=*/false)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + + CelExpressionFlatImpl cel_expr( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); - - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); + CelProtoWrapper::CreateMessage(&message, &arena_)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); - EXPECT_EQ(true, result.BoolOrDie()); + EXPECT_EQ(result.BoolOrDie(), true); CelAttributePattern pattern("message", {}); activation.set_missing_attribute_patterns({pattern}); - auto eval_status1 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status1); - EXPECT_EQ(true, eval_status1.value().BoolOrDie()); + ASSERT_OK_AND_ASSIGN(result, cel_expr.Evaluate(activation, &arena_)); + EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, UnrecoverableUnknownValueProducesError) { +TEST_F(SelectStepTest, UnrecoverableUnknownValueProducesError) { TestMessage message; message.set_bool_value(true); - google::protobuf::Arena arena; ExecutionPath path; Expr dummy_expr; - auto select = dummy_expr.mutable_select_expr(); - select->set_field("bool_value"); - select->set_test_only(false); - Expr* expr0 = select->mutable_operand(); - - auto ident = expr0->mutable_ident_expr(); - ident->set_name("message"); - auto step0_status = CreateIdentStep(ident, expr0->id()); - auto step1_status = - CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, false, - false, - /*enable_missing_attribute_errors=*/true); + auto& select = dummy_expr.mutable_select_expr(); + select.set_field("bool_value"); + select.set_test_only(false); + Expr& expr0 = select.mutable_operand(); + + auto& ident = expr0.mutable_ident_expr(); + ident.set_name("message"); + ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); + ASSERT_OK_AND_ASSIGN( + auto step1, + CreateSelectStep(select, dummy_expr.id(), + /*enable_wrapper_type_null_unboxing=*/false)); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + + cel::RuntimeOptions options; + options.enable_missing_attribute_errors = true; + CelExpressionFlatImpl cel_expr( + env_, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); - - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); + CelProtoWrapper::CreateMessage(&message, &arena_)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); CelAttributePattern pattern("message", - {CelAttributeQualifierPattern::Create( + {CreateCelAttributeQualifierPattern( CelValue::CreateStringView("bool_value"))}); activation.set_missing_attribute_patterns({pattern}); - auto eval_status1 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status1); - - EXPECT_EQ(eval_status1.value().ErrorOrDie()->code(), - absl::StatusCode::kInvalidArgument); - EXPECT_EQ(eval_status1.value().ErrorOrDie()->message(), - "MissingAttributeError: message.bool_value"); + ASSERT_OK_AND_ASSIGN(result, cel_expr.Evaluate(activation, &arena_)); + EXPECT_THAT(*result.ErrorOrDie(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("MissingAttributeError: message.bool_value"))); } -TEST_P(SelectStepTest, UnknownValueProducesError) { +TEST_F(SelectStepTest, UnknownPatternResolvesToUnknown) { TestMessage message; message.set_bool_value(true); - google::protobuf::Arena arena; ExecutionPath path; Expr dummy_expr; - auto select = dummy_expr.mutable_select_expr(); - select->set_field("bool_value"); - select->set_test_only(false); - Expr* expr0 = select->mutable_operand(); + auto& select = dummy_expr.mutable_select_expr(); + select.set_field("bool_value"); + select.set_test_only(false); + Expr& expr0 = select.mutable_operand(); - auto ident = expr0->mutable_ident_expr(); - ident->set_name("message"); - auto step0_status = CreateIdentStep(ident, expr0->id()); + auto& ident = expr0.mutable_ident_expr(); + ident.set_name("message"); + auto step0_status = CreateIdentStep(ident.name(), expr0.id()); auto step1_status = - CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - GetParam()); - Activation activation; - activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); - - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); + CreateSelectStep(select, dummy_expr.id(), + /*enable_wrapper_type_null_unboxing=*/false); - ASSERT_TRUE(result.IsBool()); - EXPECT_EQ(result.BoolOrDie(), true); + ASSERT_THAT(step0_status, IsOk()); + ASSERT_THAT(step1_status, IsOk()); - google::protobuf::FieldMask mask; - mask.add_paths("message.bool_value"); + path.push_back(*std::move(step0_status)); + path.push_back(*std::move(step1_status)); - activation.set_unknown_paths(mask); - - auto eval_status1 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status1); - - result = eval_status1.value(); - - ASSERT_TRUE(result.IsError()); - ASSERT_TRUE(IsUnknownValueError(result)); - EXPECT_THAT(GetUnknownPathsSetOrDie(result), - Eq(std::set({"message.bool_value"}))); -} - -TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { - TestMessage message; - message.set_bool_value(true); - google::protobuf::Arena arena; - ExecutionPath path; - - Expr dummy_expr; - - auto select = dummy_expr.mutable_select_expr(); - select->set_field("bool_value"); - select->set_test_only(false); - Expr* expr0 = select->mutable_operand(); - - auto ident = expr0->mutable_ident_expr(); - ident->set_name("message"); - auto step0_status = CreateIdentStep(ident, expr0->id()); - auto step1_status = - CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); - - ASSERT_OK(step0_status); - ASSERT_OK(step1_status); - - path.push_back(std::move(step0_status.value())); - path.push_back(std::move(step1_status.value())); - - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, true); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + CelExpressionFlatImpl cel_expr( + env_, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), options)); { std::vector unknown_patterns; Activation activation; activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); + CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } @@ -686,32 +1003,26 @@ TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { unknown_patterns.push_back(CelAttributePattern("message", {})); Activation activation; activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); + CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsUnknownSet()); } { std::vector unknown_patterns; unknown_patterns.push_back(CelAttributePattern( - "message", {CelAttributeQualifierPattern::Create( + "message", {CreateCelAttributeQualifierPattern( CelValue::CreateString(&kSegmentCorrect1))})); Activation activation; activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); + CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsUnknownSet()); } @@ -721,40 +1032,552 @@ TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { "message", {CelAttributeQualifierPattern::CreateWildcard()})); Activation activation; activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); + CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsUnknownSet()); } { std::vector unknown_patterns; unknown_patterns.push_back(CelAttributePattern( - "message", {CelAttributeQualifierPattern::Create( + "message", {CreateCelAttributeQualifierPattern( CelValue::CreateString(&kSegmentIncorrect))})); Activation activation; activation.InsertValue("message", - CelProtoWrapper::CreateMessage(&message, &arena)); + CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); - auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_OK(eval_status0); - - CelValue result = eval_status0.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } } -INSTANTIATE_TEST_SUITE_P(SelectStepTest, SelectStepTest, testing::Bool()); +INSTANTIATE_TEST_SUITE_P(UnknownsEnabled, SelectStepConformanceTest, + testing::Bool()); + +class DirectSelectStepTest : public testing::Test { + public: + DirectSelectStepTest() + : type_provider_(cel::internal::GetTestingDescriptorPool()) {} + + cel::Value TestWrapMessage(const google::protobuf::Message* message) { + CelValue value = CelProtoWrapper::CreateMessage(message, &arena_); + auto result = cel::interop_internal::FromLegacyValue(&arena_, value); + ABSL_DCHECK_OK(result.status()); + return std::move(result).value(); + } + + std::vector AttributeStrings(const UnknownValue& v) { + std::vector result; + for (const Attribute& attr : v.attribute_set()) { + auto attr_str = attr.AsString(); + ABSL_DCHECK_OK(attr_str.status()); + result.push_back(std::move(attr_str).value()); + } + return result; + } + + protected: + google::protobuf::Arena arena_; + cel::runtime_internal::RuntimeTypeProvider type_provider_; +}; + +TEST_F(DirectSelectStepTest, SelectFromMap) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), cel::StringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + auto map_builder = cel::NewMapValueBuilder(&arena_); + ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); + ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); + activation.InsertOrAssignValue("map_val", std::move(*map_builder).Build()); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_EQ(Cast(result).NativeValue(), 1); +} + +TEST_F(DirectSelectStepTest, HasMap) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), cel::StringValue("two"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + auto map_builder = cel::NewMapValueBuilder(&arena_); + ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); + ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); + activation.InsertOrAssignValue("map_val", std::move(*map_builder).Build()); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_TRUE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromOptionalMap) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), + cel::StringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + auto map_builder = cel::NewMapValueBuilder(&arena_); + ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); + ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); + activation.InsertOrAssignValue( + "map_val", OptionalValue::Of(std::move(*map_builder).Build(), &arena_)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(static_cast(result)).Value(), + IntValueIs(1)); +} + +TEST_F(DirectSelectStepTest, SelectFromOptionalMapAbsent) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), + cel::StringValue("three"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + auto map_builder = cel::NewMapValueBuilder(&arena_); + ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); + ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); + activation.InsertOrAssignValue( + "map_val", OptionalValue::Of(std::move(*map_builder).Build(), &arena_)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_FALSE( + Cast(static_cast(result)).HasValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromOptionalStruct) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep(CreateDirectIdentStep("struct_val", -1), + cel::StringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + TestAllTypes message; + message.set_single_int64(1); + + ASSERT_OK_AND_ASSIGN( + Value struct_val, + ProtoMessageToValue(std::move(message), + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_)); + + activation.InsertOrAssignValue("struct_val", + OptionalValue::Of(struct_val, &arena_)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(static_cast(result)).Value(), + IntValueIs(1)); +} + +TEST_F(DirectSelectStepTest, SelectFromOptionalStructFieldNotSet) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep(CreateDirectIdentStep("struct_val", -1), + cel::StringValue("single_string"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + TestAllTypes message; + message.set_single_int64(1); + + ASSERT_OK_AND_ASSIGN( + Value struct_val, + ProtoMessageToValue(std::move(message), + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_)); + + activation.InsertOrAssignValue("struct_val", + OptionalValue::Of(struct_val, &arena_)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_FALSE( + Cast(static_cast(result)).HasValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromEmptyOptional) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), + cel::StringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + activation.InsertOrAssignValue("map_val", OptionalValue::None()); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_FALSE( + cel::Cast(static_cast(result)).HasValue()); +} + +TEST_F(DirectSelectStepTest, HasOptional) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), + cel::StringValue("two"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + auto map_builder = cel::NewMapValueBuilder(&arena_); + ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); + ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); + activation.InsertOrAssignValue( + "map_val", OptionalValue::Of(std::move(*map_builder).Build(), &arena_)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_TRUE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, HasEmptyOptional) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), + cel::StringValue("two"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + activation.InsertOrAssignValue("map_val", OptionalValue::None()); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_FALSE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromStruct) { + cel::Activation activation; + RuntimeOptions options; + + auto step = + CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), + cel::StringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_EQ(Cast(result).NativeValue(), 1); +} + +TEST_F(DirectSelectStepTest, HasStruct) { + cel::Activation activation; + RuntimeOptions options; + + auto step = + CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), + cel::StringValue("single_string"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + + // has(test_all_types.single_string) + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_FALSE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromUnsupportedType) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("bool_val", -1), cel::StringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + activation.InsertOrAssignValue("bool_val", BoolValue(false)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Applying SELECT to non-message type"))); +} + +TEST_F(DirectSelectStepTest, AttributeUpdatedIfRequested) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + auto step = + CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), + cel::StringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_EQ(Cast(result).NativeValue(), 1); + + ASSERT_OK_AND_ASSIGN(std::string attr_str, attr.attribute().AsString()); + EXPECT_EQ(attr_str, "test_all_types.single_int64"); +} + +TEST_F(DirectSelectStepTest, MissingAttributesToErrors) { + cel::Activation activation; + RuntimeOptions options; + options.enable_missing_attribute_errors = true; + + auto step = + CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), + cel::StringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + activation.SetMissingPatterns({cel::AttributePattern( + "test_all_types", + {cel::AttributeQualifierPattern::OfString("single_int64")})}); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("test_all_types.single_int64"))); +} + +TEST_F(DirectSelectStepTest, IdentifiesUnknowns) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + auto step = + CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), + cel::StringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + activation.SetUnknownPatterns({cel::AttributePattern( + "test_all_types", + {cel::AttributeQualifierPattern::OfString("single_int64")})}); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_THAT(AttributeStrings(Cast(result)), + UnorderedElementsAre("test_all_types.single_int64")); +} + +TEST_F(DirectSelectStepTest, ForwardErrorValue) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + auto step = CreateDirectSelectStep( + CreateConstValueDirectStep(cel::ErrorValue(absl::InternalError("test1")), + -1), + cel::StringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, HasSubstr("test1"))); +} + +TEST_F(DirectSelectStepTest, ForwardUnknownOperand) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + AttributeSet attr_set({Attribute("attr", {AttributeQualifier::OfInt(0)})}); + auto step = CreateDirectSelectStep( + CreateConstValueDirectStep( + cel::UnknownValue(cel::Unknown(std::move(attr_set))), -1), + cel::StringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + Value result; + AttributeTrail attr; + ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(AttributeStrings(Cast(result)), + UnorderedElementsAre("attr[0]")); +} + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +} // namespace google::api::expr::runtime diff --git a/eval/eval/shadowable_value_step.cc b/eval/eval/shadowable_value_step.cc new file mode 100644 index 000000000..1ebab2f1e --- /dev/null +++ b/eval/eval/shadowable_value_step.cc @@ -0,0 +1,98 @@ +#include "eval/eval/shadowable_value_step.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::Value; + +class ShadowableValueStep : public ExpressionStepBase { + public: + ShadowableValueStep(std::string identifier, cel::Value value, int64_t expr_id) + : ExpressionStepBase(expr_id), + identifier_(std::move(identifier)), + value_(std::move(value)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + std::string identifier_; + Value value_; +}; + +absl::Status ShadowableValueStep::Evaluate(ExecutionFrame* frame) const { + cel::Value result; + CEL_ASSIGN_OR_RETURN(auto found, + frame->modern_activation().FindVariable( + identifier_, frame->descriptor_pool(), + frame->message_factory(), frame->arena(), &result)); + if (found) { + frame->value_stack().Push(std::move(result)); + } else { + frame->value_stack().Push(value_); + } + return absl::OkStatus(); +} + +class DirectShadowableValueStep : public DirectExpressionStep { + public: + DirectShadowableValueStep(std::string identifier, cel::Value value, + int64_t expr_id) + : DirectExpressionStep(expr_id), + identifier_(std::move(identifier)), + value_(std::move(value)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; + + private: + std::string identifier_; + Value value_; +}; + +// TODO(uncreated-issue/67): Attribute tracking is skipped for the shadowed case. May +// cause problems for users with unknown tracking and variables named like +// 'list' etc, but follows the current behavior of the stack machine version. +absl::Status DirectShadowableValueStep::Evaluate( + ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { + CEL_ASSIGN_OR_RETURN(auto found, + frame.activation().FindVariable( + identifier_, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + if (!found) { + result = value_; + } + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> CreateShadowableValueStep( + absl::string_view name, cel::Value value, int64_t expr_id) { + return absl::make_unique(std::string(name), + std::move(value), expr_id); +} + +std::unique_ptr CreateDirectShadowableValueStep( + absl::string_view name, cel::Value value, int64_t expr_id) { + return std::make_unique(std::string(name), + std::move(value), expr_id); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/shadowable_value_step.h b/eval/eval/shadowable_value_step.h new file mode 100644 index 000000000..9c386f02d --- /dev/null +++ b/eval/eval/shadowable_value_step.h @@ -0,0 +1,26 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/value.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +// Create an identifier resolution step with a default value that may be +// shadowed by an identifier of the same name within the runtime-provided +// Activation. +absl::StatusOr> CreateShadowableValueStep( + absl::string_view name, cel::Value value, int64_t expr_id); + +std::unique_ptr CreateDirectShadowableValueStep( + absl::string_view name, cel::Value value, int64_t expr_id); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc new file mode 100644 index 000000000..4a7cabea1 --- /dev/null +++ b/eval/eval/shadowable_value_step_test.cc @@ -0,0 +1,88 @@ +#include "eval/eval/shadowable_value_step.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "base/type_provider.h" +#include "common/value.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/evaluator_core.h" +#include "eval/internal/interop.h" +#include "eval/public/activation.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_options.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::TypeProvider; +using ::cel::interop_internal::CreateTypeValueFromView; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::google::protobuf::Arena; +using ::testing::Eq; + +absl::StatusOr RunShadowableExpression( + const absl_nonnull std::shared_ptr& env, + std::string identifier, cel::Value value, const Activation& activation, + Arena* arena) { + CEL_ASSIGN_OR_RETURN( + auto step, + CreateShadowableValueStep(std::move(identifier), std::move(value), 1)); + ExecutionPath path; + path.push_back(std::move(step)); + + CelExpressionFlatImpl impl( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); + return impl.Evaluate(activation, arena); +} + +TEST(ShadowableValueStepTest, TestEvaluateNoShadowing) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + std::string type_name = "google.api.expr.runtime.TestMessage"; + + Activation activation; + Arena arena; + + auto type_value = CreateTypeValueFromView(&arena, type_name); + auto status = + RunShadowableExpression(env, type_name, type_value, activation, &arena); + ASSERT_OK(status); + + auto value = status.value(); + ASSERT_TRUE(value.IsCelType()); + EXPECT_THAT(value.CelTypeOrDie().value(), Eq(type_name)); +} + +TEST(ShadowableValueStepTest, TestEvaluateShadowedIdentifier) { + absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); + std::string type_name = "int"; + auto shadow_value = CelValue::CreateInt64(1024L); + + Activation activation; + activation.InsertValue(type_name, shadow_value); + Arena arena; + + auto type_value = CreateTypeValueFromView(&arena, type_name); + auto status = + RunShadowableExpression(env, type_name, type_value, activation, &arena); + ASSERT_OK(status); + + auto value = status.value(); + ASSERT_TRUE(value.IsInt64()); + EXPECT_THAT(value.Int64OrDie(), Eq(1024L)); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/eval/ternary_step.cc b/eval/eval/ternary_step.cc index a52430ad3..a12d6863e 100644 --- a/eval/eval/ternary_step.cc +++ b/eval/eval/ternary_step.cc @@ -1,23 +1,130 @@ #include "eval/eval/ternary_step.h" +#include +#include +#include +#include + +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" +#include "base/builtins.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/cel_builtins.h" -#include "eval/public/cel_value.h" -#include "eval/public/unknown_attribute_set.h" +#include "eval/internal/errors.h" +#include "internal/status_macros.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { +using ::cel::builtin::kTernary; +using ::cel::runtime_internal::CreateNoMatchingOverloadError; + +inline constexpr size_t kTernaryStepCondition = 0; +inline constexpr size_t kTernaryStepTrue = 1; +inline constexpr size_t kTernaryStepFalse = 2; + +class ExhaustiveDirectTernaryStep : public DirectExpressionStep { + public: + ExhaustiveDirectTernaryStep(std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, + int64_t expr_id) + : DirectExpressionStep(expr_id), + condition_(std::move(condition)), + left_(std::move(left)), + right_(std::move(right)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const override { + cel::Value condition; + cel::Value lhs; + cel::Value rhs; + + AttributeTrail condition_attr; + AttributeTrail lhs_attr; + AttributeTrail rhs_attr; + + CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); + CEL_RETURN_IF_ERROR(left_->Evaluate(frame, lhs, lhs_attr)); + CEL_RETURN_IF_ERROR(right_->Evaluate(frame, rhs, rhs_attr)); + + if (condition.IsError() || condition.IsUnknown()) { + result = std::move(condition); + attribute = std::move(condition_attr); + return absl::OkStatus(); + } + + if (!condition.IsBool()) { + result = cel::ErrorValue(CreateNoMatchingOverloadError(kTernary)); + return absl::OkStatus(); + } + + if (condition.GetBool().NativeValue()) { + result = std::move(lhs); + attribute = std::move(lhs_attr); + } else { + result = std::move(rhs); + attribute = std::move(rhs_attr); + } + return absl::OkStatus(); + } + + private: + std::unique_ptr condition_; + std::unique_ptr left_; + std::unique_ptr right_; +}; + +class ShortcircuitingDirectTernaryStep : public DirectExpressionStep { + public: + ShortcircuitingDirectTernaryStep( + std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, int64_t expr_id) + : DirectExpressionStep(expr_id), + condition_(std::move(condition)), + left_(std::move(left)), + right_(std::move(right)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const override { + cel::Value condition; + + AttributeTrail condition_attr; + + CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); + + if (condition.IsError() || condition.IsUnknown()) { + result = std::move(condition); + attribute = std::move(condition_attr); + return absl::OkStatus(); + } + + if (!condition.IsBool()) { + result = cel::ErrorValue(CreateNoMatchingOverloadError(kTernary)); + return absl::OkStatus(); + } + + if (condition.GetBool().NativeValue()) { + return left_->Evaluate(frame, result, attribute); + } + return right_->Evaluate(frame, result, attribute); + } + + private: + std::unique_ptr condition_; + std::unique_ptr left_; + std::unique_ptr right_; +}; + class TernaryStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. - TernaryStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + explicit TernaryStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; }; @@ -31,15 +138,13 @@ absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const { // Create Span object that contains input arguments to the function. auto args = frame->value_stack().GetSpan(3); - CelValue value; - - const CelValue& condition = args.at(0); + const auto& condition = args[kTernaryStepCondition]; // As opposed to regular functions, ternary treats unknowns or errors on the // condition (arg0) as blocking. If we get an error or unknown then we // ignore the other arguments and forward the condition as the result. if (frame->enable_unknowns()) { // Check if unknown? - if (condition.IsUnknownSet()) { + if (condition.IsUnknown()) { frame->value_stack().Pop(2); return absl::OkStatus(); } @@ -50,32 +155,40 @@ absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } - CelValue result; + cel::Value result; if (!condition.IsBool()) { - result = CreateNoMatchingOverloadError(frame->arena(), builtin::kTernary); - } else if (condition.BoolOrDie()) { - result = args.at(1); + result = cel::ErrorValue(CreateNoMatchingOverloadError(kTernary)); + } else if (condition.GetBool().NativeValue()) { + result = args[kTernaryStepTrue]; } else { - result = args.at(2); + result = args[kTernaryStepFalse]; } - frame->value_stack().Pop(args.size()); - frame->value_stack().Push(result); + frame->value_stack().PopAndPush(args.size(), std::move(result)); return absl::OkStatus(); } } // namespace +// Factory method for ternary (_?_:_) recursive execution step +std::unique_ptr CreateDirectTernaryStep( + std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, int64_t expr_id, + bool shortcircuiting) { + if (shortcircuiting) { + return std::make_unique( + std::move(condition), std::move(left), std::move(right), expr_id); + } + + return std::make_unique( + std::move(condition), std::move(left), std::move(right), expr_id); +} + absl::StatusOr> CreateTernaryStep( int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(expr_id); - - return step; + return std::make_unique(expr_id); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/ternary_step.h b/eval/eval/ternary_step.h index f86747644..2b51e95ea 100644 --- a/eval/eval/ternary_step.h +++ b/eval/eval/ternary_step.h @@ -1,24 +1,26 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_ +#include +#include + +#include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_value.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { + +// Factory method for ternary (_?_:_) recursive execution step +std::unique_ptr CreateDirectTernaryStep( + std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, int64_t expr_id, + bool shortcircuiting = true); // Factory method for ternary (_?_:_) execution step absl::StatusOr> CreateTernaryStep( int64_t expr_id); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_ diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index 5372e992a..ff66c0308 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -1,76 +1,90 @@ #include "eval/eval/ternary_step.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "base/type_provider.h" +#include "common/casting.h" +#include "common/expr.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" +#include "eval/public/activation.h" +#include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/internal/runtime_type_provider.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; +using ::absl_testing::StatusIs; +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::Expr; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::RuntimeOptions; +using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::google::protobuf::Arena; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Truly; -using google::protobuf::Arena; -using testing::Eq; class LogicStepTest : public testing::TestWithParam { public: + LogicStepTest() : env_(NewTestingRuntimeEnv()) {} + absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, CelValue arg2, CelValue* result, bool enable_unknown) { - Expr expr0; - expr0.set_id(1); - auto ident_expr0 = expr0.mutable_ident_expr(); - ident_expr0->set_name("name0"); - - Expr expr1; - expr1.set_id(2); - auto ident_expr1 = expr1.mutable_ident_expr(); - ident_expr1->set_name("name1"); - - Expr expr2; - expr2.set_id(3); - auto ident_expr2 = expr2.mutable_ident_expr(); - ident_expr2->set_name("name2"); - ExecutionPath path; - auto step_status = CreateIdentStep(ident_expr0, expr0.id()); - if (!step_status.ok()) { - return step_status.status(); - } - - path.push_back(std::move(step_status).value()); - - step_status = CreateIdentStep(ident_expr1, expr1.id()); - if (!step_status.ok()) { - return step_status.status(); - } + CEL_ASSIGN_OR_RETURN(auto step, CreateIdentStep("name0", /*expr_id=*/-1)); + path.push_back(std::move(step)); - path.push_back(std::move(step_status).value()); + CEL_ASSIGN_OR_RETURN(step, CreateIdentStep("name1", /*expr_id=*/-1)); + path.push_back(std::move(step)); - step_status = CreateIdentStep(ident_expr2, expr2.id()); - if (!step_status.ok()) { - return step_status.status(); - } + CEL_ASSIGN_OR_RETURN(step, CreateIdentStep("name2", /*expr_id=*/-1)); + path.push_back(std::move(step)); - path.push_back(std::move(step_status).value()); + CEL_ASSIGN_OR_RETURN(step, CreateTernaryStep(4)); + path.push_back(std::move(step)); - step_status = CreateTernaryStep(4); - if (!step_status.ok()) { - return step_status.status(); + cel::RuntimeOptions options; + if (enable_unknown) { + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeOnly; } - - path.push_back(std::move(step_status).value()); - - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, - enable_unknown); + CelExpressionFlatImpl impl( + env_, + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; std::string value("test"); @@ -86,6 +100,7 @@ class LogicStepTest : public testing::TestWithParam { } private: + absl_nonnull std::shared_ptr env_; Arena arena_; }; @@ -108,22 +123,18 @@ TEST_P(LogicStepTest, TestBoolCond) { TEST_P(LogicStepTest, TestErrorHandling) { CelValue result; - CelError error; + CelError error = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error); - absl::Status status = - EvaluateLogic(error_value, CelValue::CreateBool(true), - CelValue::CreateBool(false), &result, GetParam()); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(error_value, CelValue::CreateBool(true), + CelValue::CreateBool(false), &result, GetParam())); ASSERT_TRUE(result.IsError()); - status = EvaluateLogic(CelValue::CreateBool(true), error_value, - CelValue::CreateBool(false), &result, GetParam()); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(CelValue::CreateBool(true), error_value, + CelValue::CreateBool(false), &result, GetParam())); ASSERT_TRUE(result.IsError()); - status = EvaluateLogic(CelValue::CreateBool(false), error_value, - CelValue::CreateBool(false), &result, GetParam()); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(CelValue::CreateBool(false), error_value, + CelValue::CreateBool(false), &result, GetParam())); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } @@ -131,68 +142,234 @@ TEST_P(LogicStepTest, TestErrorHandling) { TEST_F(LogicStepTest, TestUnknownHandling) { CelValue result; UnknownSet unknown_set; - CelError cel_error; + CelError cel_error = absl::CancelledError(); CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); CelValue error_value = CelValue::CreateError(&cel_error); - absl::Status status = - EvaluateLogic(unknown_value, CelValue::CreateBool(true), - CelValue::CreateBool(false), &result, true); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(unknown_value, CelValue::CreateBool(true), + CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); - status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, - CelValue::CreateBool(false), &result, true); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(CelValue::CreateBool(true), unknown_value, + CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); - status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, - CelValue::CreateBool(false), &result, true); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(CelValue::CreateBool(false), unknown_value, + CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); - status = EvaluateLogic(error_value, unknown_value, - CelValue::CreateBool(false), &result, true); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(error_value, unknown_value, + CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsError()); - status = EvaluateLogic(unknown_value, error_value, - CelValue::CreateBool(false), &result, true); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(unknown_value, error_value, + CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); Expr expr0; - auto ident_expr0 = expr0.mutable_ident_expr(); - ident_expr0->set_name("name0"); + auto& ident_expr0 = expr0.mutable_ident_expr(); + ident_expr0.set_name("name0"); Expr expr1; - auto ident_expr1 = expr1.mutable_ident_expr(); - ident_expr1->set_name("name1"); + auto& ident_expr1 = expr1.mutable_ident_expr(); + ident_expr1.set_name("name1"); - CelAttribute attr0(expr0, {}), attr1(expr1, {}); - UnknownAttributeSet unknown_attr_set0({&attr0}); - UnknownAttributeSet unknown_attr_set1({&attr1}); + CelAttribute attr0(expr0.ident_expr().name(), {}), + attr1(expr1.ident_expr().name(), {}); + UnknownAttributeSet unknown_attr_set0({attr0}); + UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); - EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); - EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); - status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), - CelValue::CreateUnknownSet(&unknown_set1), - CelValue::CreateBool(false), &result, true); - ASSERT_OK(status); + ASSERT_OK(EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), + CelValue::CreateUnknownSet(&unknown_set1), + CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, testing::SizeIs(1)); - EXPECT_THAT(attrs[0]->variable().ident_expr().name(), Eq("name0")); + EXPECT_THAT(attrs.begin()->variable_name(), Eq("name0")); } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); + +class TernaryStepDirectTest : public testing::TestWithParam { + public: + TernaryStepDirectTest() + : type_provider_(cel::internal::GetTestingDescriptorPool()) {} + + bool Shortcircuiting() { return GetParam(); } + + protected: + Arena arena_; + cel::runtime_internal::RuntimeTypeProvider type_provider_; +}; + +TEST_P(TernaryStepDirectTest, ReturnLhs) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(BoolValue(true), -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_EQ(Cast(result).NativeValue(), 1); +} + +TEST_P(TernaryStepDirectTest, ReturnRhs) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(BoolValue(false), -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_EQ(Cast(result).NativeValue(), 2); +} + +TEST_P(TernaryStepDirectTest, ForwardError) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + cel::Value error_value = cel::ErrorValue(absl::InternalError("test error")); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(error_value, -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, "test error")); +} + +TEST_P(TernaryStepDirectTest, ForwardUnknown) { + cel::Activation activation; + RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + ExecutionFrameBase frame(activation, opts, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::vector attrs{{cel::Attribute("var")}}; + + cel::UnknownValue unknown_value = + cel::UnknownValue(cel::Unknown(cel::AttributeSet(attrs))); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(unknown_value, -1), + CreateConstValueDirectStep(IntValue(2), -1), + CreateConstValueDirectStep(IntValue(3), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue().unknown_attributes(), + ElementsAre(Truly([](const cel::Attribute& attr) { + return attr.variable_name() == "var"; + }))); +} + +TEST_P(TernaryStepDirectTest, UnexpectedCondtionKind) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(IntValue(-1), -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("No matching overloads found"))); +} + +TEST_P(TernaryStepDirectTest, Shortcircuiting) { + class RecordCallStep : public DirectExpressionStep { + public: + explicit RecordCallStep(bool& was_called) + : DirectExpressionStep(-1), was_called_(&was_called) {} + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& trail) const override { + *was_called_ = true; + result = IntValue(1); + return absl::OkStatus(); + } + + private: + bool* absl_nonnull was_called_; + }; + + bool lhs_was_called = false; + bool rhs_was_called = false; + + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, type_provider_, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(BoolValue(false), -1), + std::make_unique(lhs_was_called), + std::make_unique(rhs_was_called), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), Eq(1)); + bool expect_eager_eval = !Shortcircuiting(); + EXPECT_EQ(lhs_was_called, expect_eager_eval); + EXPECT_TRUE(rhs_was_called); +} + +INSTANTIATE_TEST_SUITE_P(TernaryStepDirectTest, TernaryStepDirectTest, + testing::Bool()); + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/eval/trace_step.h b/eval/eval/trace_step.h new file mode 100644 index 000000000..cf4240248 --- /dev/null +++ b/eval/eval/trace_step.h @@ -0,0 +1,73 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "internal/status_macros.h" +namespace google::api::expr::runtime { + +// A decorator that implements tracing for recursively evaluated CEL +// expressions. +// +// Allows inspection for extensions to extract the wrapped expression. +class TraceStep : public DirectExpressionStep { + public: + explicit TraceStep(std::unique_ptr expression) + : DirectExpressionStep(-1), expression_(std::move(expression)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& trail) const override { + CEL_RETURN_IF_ERROR(expression_->Evaluate(frame, result, trail)); + if (!frame.callback()) { + return absl::OkStatus(); + } + return frame.callback()(expression_->expr_id(), result, + frame.descriptor_pool(), frame.message_factory(), + frame.arena()); + } + + cel::NativeTypeId GetNativeTypeId() const override { + return cel::NativeTypeId::For(); + } + + absl::optional> GetDependencies() + const override { + return {{expression_.get()}}; + } + + absl::optional>> + ExtractDependencies() override { + std::vector> dependencies; + dependencies.push_back(std::move(expression_)); + return dependencies; + }; + + private: + std::unique_ptr expression_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ diff --git a/eval/internal/BUILD b/eval/internal/BUILD new file mode 100644 index 000000000..d6f31493e --- /dev/null +++ b/eval/internal/BUILD @@ -0,0 +1,104 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "interop", + hdrs = ["interop.h"], + deps = ["//common:legacy_value"], +) + +cc_library( + name = "cel_value_equal", + srcs = ["cel_value_equal.cc"], + hdrs = ["cel_value_equal.h"], + deps = [ + "//common:kind", + "//eval/public:cel_number", + "//eval/public:cel_value", + "//eval/public:message_wrapper", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_info_apis", + "//internal:number", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "cel_value_equal_test", + srcs = ["cel_value_equal_test.cc"], + deps = [ + ":cel_value_equal", + "//eval/public:cel_value", + "//eval/public:message_wrapper", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:trivial_legacy_type_info", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "errors", + srcs = ["errors.cc"], + hdrs = ["errors.h"], + deps = [ + "//runtime/internal:errors", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "adapter_activation_impl", + srcs = ["adapter_activation_impl.cc"], + hdrs = ["adapter_activation_impl.h"], + deps = [ + ":interop", + "//base:attributes", + "//common:value", + "//eval/public:base_activation", + "//eval/public:cel_value", + "//internal:status_macros", + "//runtime:activation_interface", + "//runtime:function_overload_reference", + "//runtime/internal:activation_attribute_matcher_access", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/internal/adapter_activation_impl.cc b/eval/internal/adapter_activation_impl.cc new file mode 100644 index 000000000..c88fe8145 --- /dev/null +++ b/eval/internal/adapter_activation_impl.cc @@ -0,0 +1,87 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/internal/adapter_activation_impl.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/value.h" +#include "eval/internal/interop.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "runtime/function_overload_reference.h" +#include "runtime/internal/activation_attribute_matcher_access.h" +#include "runtime/internal/attribute_matcher.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::interop_internal { + +using ::google::api::expr::runtime::CelFunction; + +absl::StatusOr AdapterActivationImpl::FindVariable( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + // This implementation should only be used during interop, when we can + // always assume the memory manager is backed by a protobuf arena. + + absl::optional legacy_value = + legacy_activation_.FindValue(name, arena); + if (!legacy_value.has_value()) { + return false; + } + CEL_RETURN_IF_ERROR(ModernValue(arena, *legacy_value, *result)); + return true; +} + +std::vector +AdapterActivationImpl::FindFunctionOverloads(absl::string_view name) const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + std::vector legacy_candidates = + legacy_activation_.FindFunctionOverloads(name); + std::vector result; + result.reserve(legacy_candidates.size()); + for (const auto* candidate : legacy_candidates) { + if (candidate == nullptr) { + continue; + } + result.push_back({candidate->descriptor(), *candidate}); + } + return result; +} + +absl::Span AdapterActivationImpl::GetUnknownAttributes() + const { + return legacy_activation_.unknown_attribute_patterns(); +} + +absl::Span AdapterActivationImpl::GetMissingAttributes() + const { + return legacy_activation_.missing_attribute_patterns(); +} + +const runtime_internal::AttributeMatcher* absl_nullable +AdapterActivationImpl::GetAttributeMatcher() const { + return runtime_internal::ActivationAttributeMatcherAccess:: + GetAttributeMatcher(legacy_activation_); +} + +} // namespace cel::interop_internal diff --git a/eval/internal/adapter_activation_impl.h b/eval/internal/adapter_activation_impl.h new file mode 100644 index 000000000..ebf3156aa --- /dev/null +++ b/eval/internal/adapter_activation_impl.h @@ -0,0 +1,68 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ADAPTER_ACTIVATION_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ADAPTER_ACTIVATION_IMPL_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/value.h" +#include "eval/public/base_activation.h" +#include "runtime/activation_interface.h" +#include "runtime/function_overload_reference.h" +#include "runtime/internal/attribute_matcher.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::interop_internal { + +// An Activation implementation that adapts the legacy version (based on +// expr::CelValue) to the new cel::Handle based version. This implementation +// must be scoped to an evaluation. +class AdapterActivationImpl : public ActivationInterface { + public: + explicit AdapterActivationImpl( + const google::api::expr::runtime::BaseActivation& legacy_activation) + : legacy_activation_(legacy_activation) {} + + absl::StatusOr FindVariable( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override; + + std::vector FindFunctionOverloads( + absl::string_view name) const override; + + absl::Span GetUnknownAttributes() const override; + + absl::Span GetMissingAttributes() const override; + + private: + const runtime_internal::AttributeMatcher* absl_nullable GetAttributeMatcher() + const override; + + const google::api::expr::runtime::BaseActivation& legacy_activation_; +}; + +} // namespace cel::interop_internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ADAPTER_ACTIVATION_IMPL_H_ diff --git a/eval/internal/cel_value_equal.cc b/eval/internal/cel_value_equal.cc new file mode 100644 index 000000000..f61f93ca4 --- /dev/null +++ b/eval/internal/cel_value_equal.cc @@ -0,0 +1,242 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/internal/cel_value_equal.h" + +#include + +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "common/kind.h" +#include "eval/public/cel_number.h" +#include "eval/public/cel_value.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "internal/number.h" +#include "google/protobuf/arena.h" + +namespace cel::interop_internal { + +namespace { + +using ::cel::internal::Number; +using ::google::api::expr::runtime::CelList; +using ::google::api::expr::runtime::CelMap; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::GetNumberFromCelValue; +using ::google::api::expr::runtime::LegacyTypeAccessApis; +using ::google::api::expr::runtime::LegacyTypeInfoApis; + +// Forward declaration of the functors for generic equality operator. +// Equal defined between compatible types. +struct HeterogeneousEqualProvider { + absl::optional operator()(const CelValue& lhs, + const CelValue& rhs) const; +}; + +// Comparison template functions +template +absl::optional Inequal(Type lhs, Type rhs) { + return lhs != rhs; +} + +template +absl::optional Equal(Type lhs, Type rhs) { + return lhs == rhs; +} + +// Equality for lists. Template parameter provides either heterogeneous or +// homogenous equality for comparing members. +template +absl::optional ListEqual(const CelList* t1, const CelList* t2) { + if (t1 == t2) { + return true; + } + int index_size = t1->size(); + if (t2->size() != index_size) { + return false; + } + + google::protobuf::Arena arena; + for (int i = 0; i < index_size; i++) { + CelValue e1 = (*t1).Get(&arena, i); + CelValue e2 = (*t2).Get(&arena, i); + absl::optional eq = EqualsProvider()(e1, e2); + if (eq.has_value()) { + if (!(*eq)) { + return false; + } + } else { + // Propagate that the equality is undefined. + return eq; + } + } + + return true; +} + +// Equality for maps. Template parameter provides either heterogeneous or +// homogenous equality for comparing values. +template +absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { + if (t1 == t2) { + return true; + } + if (t1->size() != t2->size()) { + return false; + } + + google::protobuf::Arena arena; + auto list_keys = t1->ListKeys(&arena); + if (!list_keys.ok()) { + return absl::nullopt; + } + const CelList* keys = *list_keys; + for (int i = 0; i < keys->size(); i++) { + CelValue key = (*keys).Get(&arena, i); + CelValue v1 = (*t1).Get(&arena, key).value(); + absl::optional v2 = (*t2).Get(&arena, key); + if (!v2.has_value()) { + auto number = GetNumberFromCelValue(key); + if (!number.has_value()) { + return false; + } + if (!key.IsInt64() && number->LosslessConvertibleToInt()) { + CelValue int_key = CelValue::CreateInt64(number->AsInt()); + absl::optional eq = EqualsProvider()(key, int_key); + if (eq.has_value() && *eq) { + v2 = (*t2).Get(&arena, int_key); + } + } + if (!key.IsUint64() && !v2.has_value() && + number->LosslessConvertibleToUint()) { + CelValue uint_key = CelValue::CreateUint64(number->AsUint()); + absl::optional eq = EqualsProvider()(key, uint_key); + if (eq.has_value() && *eq) { + v2 = (*t2).Get(&arena, uint_key); + } + } + } + if (!v2.has_value()) { + return false; + } + absl::optional eq = EqualsProvider()(v1, *v2); + if (!eq.has_value() || !*eq) { + // Shortcircuit on value comparison errors and 'false' results. + return eq; + } + } + + return true; +} + +bool MessageEqual(const CelValue::MessageWrapper& m1, + const CelValue::MessageWrapper& m2) { + const LegacyTypeInfoApis* lhs_type_info = m1.legacy_type_info(); + const LegacyTypeInfoApis* rhs_type_info = m2.legacy_type_info(); + + if (lhs_type_info->GetTypename(m1) != rhs_type_info->GetTypename(m2)) { + return false; + } + + const LegacyTypeAccessApis* accessor = lhs_type_info->GetAccessApis(m1); + + if (accessor == nullptr) { + return false; + } + + return accessor->IsEqualTo(m1, m2); +} + +// Generic equality for CEL values of the same type. +// EqualityProvider is used for equality among members of container types. +template +absl::optional HomogenousCelValueEqual(const CelValue& t1, + const CelValue& t2) { + if (t1.type() != t2.type()) { + return absl::nullopt; + } + switch (t1.type()) { + case Kind::kNullType: + return Equal(CelValue::NullType(), + CelValue::NullType()); + case Kind::kBool: + return Equal(t1.BoolOrDie(), t2.BoolOrDie()); + case Kind::kInt64: + return Equal(t1.Int64OrDie(), t2.Int64OrDie()); + case Kind::kUint64: + return Equal(t1.Uint64OrDie(), t2.Uint64OrDie()); + case Kind::kDouble: + return Equal(t1.DoubleOrDie(), t2.DoubleOrDie()); + case Kind::kString: + return Equal(t1.StringOrDie(), t2.StringOrDie()); + case Kind::kBytes: + return Equal(t1.BytesOrDie(), t2.BytesOrDie()); + case Kind::kDuration: + return Equal(t1.DurationOrDie(), t2.DurationOrDie()); + case Kind::kTimestamp: + return Equal(t1.TimestampOrDie(), t2.TimestampOrDie()); + case Kind::kList: + return ListEqual(t1.ListOrDie(), t2.ListOrDie()); + case Kind::kMap: + return MapEqual(t1.MapOrDie(), t2.MapOrDie()); + case Kind::kCelType: + return Equal(t1.CelTypeOrDie(), + t2.CelTypeOrDie()); + default: + break; + } + return absl::nullopt; +} + +absl::optional HeterogeneousEqualProvider::operator()( + const CelValue& lhs, const CelValue& rhs) const { + return CelValueEqualImpl(lhs, rhs); +} + +} // namespace + +// Equal operator is defined for all types at plan time. Runtime delegates to +// the correct implementation for types or returns nullopt if the comparison +// isn't defined. +absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { + if (v1.type() == v2.type()) { + // Message equality is only defined if heterogeneous comparisons are enabled + // to preserve the legacy behavior for equality. + if (CelValue::MessageWrapper lhs, rhs; + v1.GetValue(&lhs) && v2.GetValue(&rhs)) { + return MessageEqual(lhs, rhs); + } + return HomogenousCelValueEqual(v1, v2); + } + + absl::optional lhs = GetNumberFromCelValue(v1); + absl::optional rhs = GetNumberFromCelValue(v2); + + if (rhs.has_value() && lhs.has_value()) { + return *lhs == *rhs; + } + + // TODO(uncreated-issue/6): It's currently possible for the interpreter to create a + // map containing an Error. Return no matching overload to propagate an error + // instead of a false result. + if (v1.IsError() || v1.IsUnknownSet() || v2.IsError() || v2.IsUnknownSet()) { + return absl::nullopt; + } + + return false; +} + +} // namespace cel::interop_internal diff --git a/eval/internal/cel_value_equal.h b/eval/internal/cel_value_equal.h new file mode 100644 index 000000000..7eb38beb1 --- /dev/null +++ b/eval/internal/cel_value_equal.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_CEL_VALUE_EQUAL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_CEL_VALUE_EQUAL_H_ + +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" + +namespace cel::interop_internal { + +// Implementation for general equality between CELValues. Exposed for +// consistent behavior in set membership functions. +// +// Returns nullopt if the comparison is undefined between differently typed +// values. +absl::optional CelValueEqualImpl( + const google::api::expr::runtime::CelValue& v1, + const google::api::expr::runtime::CelValue& v2); + +} // namespace cel::interop_internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_CEL_VALUE_EQUAL_H_ diff --git a/eval/internal/cel_value_equal_test.cc b/eval/internal/cel_value_equal_test.cc new file mode 100644 index 000000000..109a63795 --- /dev/null +++ b/eval/internal/cel_value_equal_test.cc @@ -0,0 +1,537 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/internal/cel_value_equal.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::interop_internal { +namespace { + +using ::google::api::expr::runtime::CelList; +using ::google::api::expr::runtime::CelMap; +using ::google::api::expr::runtime::CelProtoWrapper; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::ContainerBackedListImpl; +using ::google::api::expr::runtime::CreateContainerBackedMap; +using ::google::api::expr::runtime::MessageWrapper; +using ::google::api::expr::runtime::TestMessage; +using ::google::api::expr::runtime::TrivialTypeInfo; +using ::testing::_; +using ::testing::Combine; +using ::testing::Optional; +using ::testing::Values; +using ::testing::ValuesIn; + +struct EqualityTestCase { + enum class ErrorKind { kMissingOverload, kMissingIdentifier }; + absl::string_view expr; + std::variant result; + CelValue lhs = CelValue::CreateNull(); + CelValue rhs = CelValue::CreateNull(); +}; + +bool IsNumeric(CelValue::Type type) { + return type == CelValue::Type::kDouble || type == CelValue::Type::kInt64 || + type == CelValue::Type::kUint64; +} + +const CelList& CelListExample1() { + static ContainerBackedListImpl* example = + new ContainerBackedListImpl({CelValue::CreateInt64(1)}); + return *example; +} + +const CelList& CelListExample2() { + static ContainerBackedListImpl* example = + new ContainerBackedListImpl({CelValue::CreateInt64(2)}); + return *example; +} + +const CelMap& CelMapExample1() { + static CelMap* example = []() { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; + // Implementation copies values into a hash map. + auto map = CreateContainerBackedMap(absl::MakeSpan(values)); + return map->release(); + }(); + return *example; +} + +const CelMap& CelMapExample2() { + static CelMap* example = []() { + std::vector> values{ + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; + auto map = CreateContainerBackedMap(absl::MakeSpan(values)); + return map->release(); + }(); + return *example; +} + +const std::vector& ValueExamples1() { + static std::vector* examples = []() { + google::protobuf::Arena arena; + auto result = std::make_unique>(); + + result->push_back(CelValue::CreateNull()); + result->push_back(CelValue::CreateBool(false)); + result->push_back(CelValue::CreateInt64(1)); + result->push_back(CelValue::CreateUint64(1)); + result->push_back(CelValue::CreateDouble(1.0)); + result->push_back(CelValue::CreateStringView("string")); + result->push_back(CelValue::CreateBytesView("bytes")); + // No arena allocs expected in this example. + result->push_back(CelProtoWrapper::CreateMessage( + std::make_unique().release(), &arena)); + result->push_back(CelValue::CreateDuration(absl::Seconds(1))); + result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(1))); + result->push_back(CelValue::CreateList(&CelListExample1())); + result->push_back(CelValue::CreateMap(&CelMapExample1())); + result->push_back(CelValue::CreateCelTypeView("type")); + + return result.release(); + }(); + return *examples; +} + +const std::vector& ValueExamples2() { + static std::vector* examples = []() { + google::protobuf::Arena arena; + auto result = std::make_unique>(); + auto message2 = std::make_unique(); + message2->set_int64_value(2); + + result->push_back(CelValue::CreateNull()); + result->push_back(CelValue::CreateBool(true)); + result->push_back(CelValue::CreateInt64(2)); + result->push_back(CelValue::CreateUint64(2)); + result->push_back(CelValue::CreateDouble(2.0)); + result->push_back(CelValue::CreateStringView("string2")); + result->push_back(CelValue::CreateBytesView("bytes2")); + // No arena allocs expected in this example. + result->push_back( + CelProtoWrapper::CreateMessage(message2.release(), &arena)); + result->push_back(CelValue::CreateDuration(absl::Seconds(2))); + result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(2))); + result->push_back(CelValue::CreateList(&CelListExample2())); + result->push_back(CelValue::CreateMap(&CelMapExample2())); + result->push_back(CelValue::CreateCelTypeView("type2")); + + return result.release(); + }(); + return *examples; +} + +class CelValueEqualImplTypesTest + : public testing::TestWithParam> { + public: + CelValueEqualImplTypesTest() = default; + + const CelValue& lhs() { return std::get<0>(GetParam()); } + + const CelValue& rhs() { return std::get<1>(GetParam()); } + + bool should_be_equal() { return std::get<2>(GetParam()); } +}; + +std::string CelValueEqualTestName( + const testing::TestParamInfo>& + test_case) { + return absl::StrCat(CelValue::TypeName(std::get<0>(test_case.param).type()), + CelValue::TypeName(std::get<1>(test_case.param).type()), + (std::get<2>(test_case.param)) ? "Equal" : "Inequal"); +} + +TEST_P(CelValueEqualImplTypesTest, Basic) { + absl::optional result = CelValueEqualImpl(lhs(), rhs()); + + if (lhs().IsNull() || rhs().IsNull()) { + if (lhs().IsNull() && rhs().IsNull()) { + EXPECT_THAT(result, Optional(true)); + } else { + EXPECT_THAT(result, Optional(false)); + } + } else if (lhs().type() == rhs().type() || + (IsNumeric(lhs().type()) && IsNumeric(rhs().type()))) { + EXPECT_THAT(result, Optional(should_be_equal())); + } else { + EXPECT_THAT(result, Optional(false)); + } +} + +INSTANTIATE_TEST_SUITE_P(EqualityBetweenTypes, CelValueEqualImplTypesTest, + Combine(ValuesIn(ValueExamples1()), + ValuesIn(ValueExamples1()), Values(true)), + &CelValueEqualTestName); + +INSTANTIATE_TEST_SUITE_P(InequalityBetweenTypes, CelValueEqualImplTypesTest, + Combine(ValuesIn(ValueExamples1()), + ValuesIn(ValueExamples2()), Values(false)), + &CelValueEqualTestName); + +struct NumericInequalityTestCase { + std::string name; + CelValue a; + CelValue b; +}; + +const std::vector& NumericValuesNotEqualExample() { + static std::vector* examples = []() { + auto result = std::make_unique>(); + result->push_back({"NegativeIntAndUint", CelValue::CreateInt64(-1), + CelValue::CreateUint64(2)}); + result->push_back( + {"IntAndLargeUint", CelValue::CreateInt64(1), + CelValue::CreateUint64( + static_cast(std::numeric_limits::max()) + 1)}); + result->push_back( + {"IntAndLargeDouble", CelValue::CreateInt64(2), + CelValue::CreateDouble( + static_cast(std::numeric_limits::max()) + 1025)}); + result->push_back( + {"IntAndSmallDouble", CelValue::CreateInt64(2), + CelValue::CreateDouble( + static_cast(std::numeric_limits::lowest()) - + 1025)}); + result->push_back( + {"UintAndLargeDouble", CelValue::CreateUint64(2), + CelValue::CreateDouble( + static_cast(std::numeric_limits::max()) + + 2049)}); + result->push_back({"NegativeDoubleAndUint", CelValue::CreateDouble(-2.0), + CelValue::CreateUint64(123)}); + + // NaN tests. + result->push_back({"NanAndDouble", CelValue::CreateDouble(NAN), + CelValue::CreateDouble(1.0)}); + result->push_back({"NanAndNan", CelValue::CreateDouble(NAN), + CelValue::CreateDouble(NAN)}); + result->push_back({"DoubleAndNan", CelValue::CreateDouble(1.0), + CelValue::CreateDouble(NAN)}); + result->push_back( + {"IntAndNan", CelValue::CreateInt64(1), CelValue::CreateDouble(NAN)}); + result->push_back( + {"NanAndInt", CelValue::CreateDouble(NAN), CelValue::CreateInt64(1)}); + result->push_back( + {"UintAndNan", CelValue::CreateUint64(1), CelValue::CreateDouble(NAN)}); + result->push_back( + {"NanAndUint", CelValue::CreateDouble(NAN), CelValue::CreateUint64(1)}); + + return result.release(); + }(); + return *examples; +} + +using NumericInequalityTest = testing::TestWithParam; +TEST_P(NumericInequalityTest, NumericValues) { + NumericInequalityTestCase test_case = GetParam(); + absl::optional result = CelValueEqualImpl(test_case.a, test_case.b); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, false); +} + +INSTANTIATE_TEST_SUITE_P( + InequalityBetweenNumericTypesTest, NumericInequalityTest, + ValuesIn(NumericValuesNotEqualExample()), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(CelValueEqualImplTest, LossyNumericEquality) { + absl::optional result = CelValueEqualImpl( + CelValue::CreateDouble( + static_cast(std::numeric_limits::max()) - 1), + CelValue::CreateInt64(std::numeric_limits::max())); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE(*result); +} + +TEST(CelValueEqualImplTest, ListMixedTypesInequal) { + ContainerBackedListImpl lhs({CelValue::CreateInt64(1)}); + ContainerBackedListImpl rhs({CelValue::CreateStringView("abc")}); + + EXPECT_THAT( + CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), + Optional(false)); +} + +TEST(CelValueEqualImplTest, NestedList) { + ContainerBackedListImpl inner_lhs({CelValue::CreateInt64(1)}); + ContainerBackedListImpl lhs({CelValue::CreateList(&inner_lhs)}); + ContainerBackedListImpl inner_rhs({CelValue::CreateNull()}); + ContainerBackedListImpl rhs({CelValue::CreateList(&inner_rhs)}); + + EXPECT_THAT( + CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), + Optional(false)); +} + +TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { + std::vector> lhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(false)); +} + +TEST(CelValueEqualImplTest, MapMixedKeyTypesEqual) { + std::vector> lhs_data{ + {CelValue::CreateUint64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(true)); +} + +TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { + std::vector> lhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(2), CelValue::CreateInt64(2)}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(false)); +} + +TEST(CelValueEqualImplTest, NestedMaps) { + std::vector> inner_lhs_data{ + {CelValue::CreateInt64(2), CelValue::CreateStringView("abc")}}; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr inner_lhs, + CreateContainerBackedMap(absl::MakeSpan(inner_lhs_data))); + std::vector> lhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateMap(inner_lhs.get())}}; + + std::vector> inner_rhs_data{ + {CelValue::CreateInt64(2), CelValue::CreateNull()}}; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr inner_rhs, + CreateContainerBackedMap(absl::MakeSpan(inner_rhs_data))); + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateMap(inner_rhs.get())}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(false)); +} + +TEST(CelValueEqualImplTest, ProtoEqualityDifferingTypenameInequal) { + // If message wrappers report a different typename, treat as inequal without + // calling into the provided equal implementation. + google::protobuf::Arena arena; + TestMessage example; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &example)); + + CelValue lhs = CelProtoWrapper::CreateMessage(&example, &arena); + CelValue rhs = CelValue::CreateMessageWrapper( + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + + EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); +} + +TEST(CelValueEqualImplTest, ProtoEqualityNoAccessorInequal) { + // If message wrappers report no access apis, then treat as inequal. + TestMessage example; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &example)); + + CelValue lhs = CelValue::CreateMessageWrapper( + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + CelValue rhs = CelValue::CreateMessageWrapper( + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + + EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); +} + +TEST(CelValueEqualImplTest, ProtoEqualityAny) { + google::protobuf::Arena arena; + TestMessage packed_value; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &packed_value)); + + TestMessage lhs; + lhs.mutable_any_value()->PackFrom(packed_value); + + TestMessage rhs; + rhs.mutable_any_value()->PackFrom(packed_value); + + EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), + CelProtoWrapper::CreateMessage(&rhs, &arena)), + Optional(true)); + + // Equality falls back to bytewise comparison if type is missing. + lhs.mutable_any_value()->clear_type_url(); + rhs.mutable_any_value()->clear_type_url(); + EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), + CelProtoWrapper::CreateMessage(&rhs, &arena)), + Optional(true)); +} + +// Add transitive dependencies in appropriate order for the dynamic descriptor +// pool. +// Return false if the dependencies could not be added to the pool. +bool AddDepsToPool(const google::protobuf::FileDescriptor* descriptor, + google::protobuf::DescriptorPool& pool) { + for (int i = 0; i < descriptor->dependency_count(); i++) { + if (!AddDepsToPool(descriptor->dependency(i), pool)) { + return false; + } + } + google::protobuf::FileDescriptorProto descriptor_proto; + descriptor->CopyTo(&descriptor_proto); + return pool.BuildFile(descriptor_proto) != nullptr; +} + +// Equivalent descriptors managed by separate descriptor pools are not equal, so +// the underlying messages are not considered equal. +TEST(CelValueEqualImplTest, DynamicDescriptorAndGeneratedInequal) { + // Simulate a dynamically loaded descriptor that happens to match the + // compiled version. + google::protobuf::DescriptorPool pool; + google::protobuf::DynamicMessageFactory factory; + google::protobuf::Arena arena; + factory.SetDelegateToGeneratedFactory(false); + + ASSERT_TRUE(AddDepsToPool(TestMessage::descriptor()->file(), pool)); + + TestMessage example_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + int64_value: 12345 + bool_list: false + bool_list: true + message_value { float_value: 1.0 } + )pb", + &example_message)); + + // Messages from a loaded descriptor and generated versions can't be compared + // via MessageDifferencer, so return false. + std::unique_ptr example_dynamic_message( + factory + .GetPrototype(pool.FindMessageTypeByName( + TestMessage::descriptor()->full_name())) + ->New()); + + ASSERT_TRUE(example_dynamic_message->ParseFromString( + example_message.SerializeAsString())); + + EXPECT_THAT(CelValueEqualImpl( + CelProtoWrapper::CreateMessage(&example_message, &arena), + CelProtoWrapper::CreateMessage(example_dynamic_message.get(), + &arena)), + Optional(false)); +} + +TEST(CelValueEqualImplTest, DynamicMessageAndMessageEqual) { + google::protobuf::DynamicMessageFactory factory; + google::protobuf::Arena arena; + factory.SetDelegateToGeneratedFactory(false); + + TestMessage example_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + int64_value: 12345 + bool_list: false + bool_list: true + message_value { float_value: 1.0 } + )pb", + &example_message)); + + // Dynamic message and generated Message subclass with the same generated + // descriptor are comparable. + std::unique_ptr example_dynamic_message( + factory.GetPrototype(TestMessage::descriptor())->New()); + + ASSERT_TRUE(example_dynamic_message->ParseFromString( + example_message.SerializeAsString())); + + EXPECT_THAT(CelValueEqualImpl( + CelProtoWrapper::CreateMessage(&example_message, &arena), + CelProtoWrapper::CreateMessage(example_dynamic_message.get(), + &arena)), + Optional(true)); +} + +} // namespace +} // namespace cel::interop_internal diff --git a/eval/internal/errors.cc b/eval/internal/errors.cc new file mode 100644 index 000000000..99e962588 --- /dev/null +++ b/eval/internal/errors.cc @@ -0,0 +1,64 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/internal/errors.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "runtime/internal/errors.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace interop_internal { + +using ::google::protobuf::Arena; + +const absl::Status* CreateNoMatchingOverloadError(google::protobuf::Arena* arena, + absl::string_view fn) { + return Arena::Create( + arena, runtime_internal::CreateNoMatchingOverloadError(fn)); +} + +const absl::Status* CreateNoSuchFieldError(google::protobuf::Arena* arena, + absl::string_view field) { + return Arena::Create( + arena, runtime_internal::CreateNoSuchFieldError(field)); +} + +const absl::Status* CreateNoSuchKeyError(google::protobuf::Arena* arena, + absl::string_view key) { + return Arena::Create( + arena, runtime_internal::CreateNoSuchKeyError(key)); +} + +const absl::Status* CreateMissingAttributeError( + google::protobuf::Arena* arena, absl::string_view missing_attribute_path) { + return Arena::Create( + arena, + runtime_internal::CreateMissingAttributeError(missing_attribute_path)); +} + +const absl::Status* CreateUnknownFunctionResultError( + google::protobuf::Arena* arena, absl::string_view help_message) { + return Arena::Create( + arena, runtime_internal::CreateUnknownFunctionResultError(help_message)); +} + +const absl::Status* CreateError(google::protobuf::Arena* arena, absl::string_view message, + absl::StatusCode code) { + return Arena::Create(arena, code, message); +} + +} // namespace interop_internal +} // namespace cel diff --git a/eval/internal/errors.h b/eval/internal/errors.h new file mode 100644 index 000000000..6487e7c40 --- /dev/null +++ b/eval/internal/errors.h @@ -0,0 +1,54 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Factories and constants for well-known CEL errors. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ERRORS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ERRORS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "runtime/internal/errors.h" // IWYU pragma: export +#include "google/protobuf/arena.h" + +namespace cel { +namespace interop_internal { +// Factories for interop error values. +// const pointer Results are arena allocated to support interop with cel::Handle +// and expr::runtime::CelValue. +const absl::Status* CreateNoMatchingOverloadError(google::protobuf::Arena* arena, + absl::string_view fn); + +const absl::Status* CreateNoSuchFieldError(google::protobuf::Arena* arena, + absl::string_view field); + +const absl::Status* CreateNoSuchKeyError(google::protobuf::Arena* arena, + absl::string_view key); + +const absl::Status* CreateUnknownValueError(google::protobuf::Arena* arena, + absl::string_view unknown_path); + +const absl::Status* CreateMissingAttributeError( + google::protobuf::Arena* arena, absl::string_view missing_attribute_path); + +const absl::Status* CreateUnknownFunctionResultError( + google::protobuf::Arena* arena, absl::string_view help_message); + +const absl::Status* CreateError( + google::protobuf::Arena* arena, absl::string_view message, + absl::StatusCode code = absl::StatusCode::kUnknown); + +} // namespace interop_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ERRORS_H_ diff --git a/eval/internal/interop.h b/eval/internal/interop.h new file mode 100644 index 000000000..906a0fb61 --- /dev/null +++ b/eval/internal/interop.h @@ -0,0 +1,20 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_INTEROP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_INTEROP_H_ + +#include "common/legacy_value.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_INTEROP_H_ diff --git a/eval/public/BUILD b/eval/public/BUILD index 747aa9fe3..31ad2d480 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -1,16 +1,74 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +package_group( + name = "ast_visibility", + packages = [ + "//eval/compiler", + "//extensions", + ], +) + +licenses(["notice"]) exports_files(["LICENSE"]) +cc_library( + name = "message_wrapper", + hdrs = [ + "message_wrapper.h", + ], + deps = [ + "//base/internal:message_wrapper", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/numeric:bits", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "message_wrapper_test", + srcs = [ + "message_wrapper_test.cc", + ], + deps = [ + ":message_wrapper", + "//eval/public/structs:trivial_legacy_type_info", + "//eval/testutil:test_message_cc_proto", + "//internal:casts", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "cel_value_internal", hdrs = [ "cel_value_internal.h", ], deps = [ + ":message_wrapper", + "//internal:casts", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -24,12 +82,27 @@ cc_library( ], deps = [ ":cel_value_internal", + ":message_wrapper", + ":unknown_set", + "//common:kind", + "//common:memory", + "//common:native_type", + "//eval/internal:errors", + "//eval/public/structs:legacy_type_info_apis", + "//extensions/protobuf:memory_manager", + "//internal:casts", + "//internal:status_macros", + "//internal:utf8", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) @@ -44,11 +117,12 @@ cc_library( ], deps = [ ":cel_value", - ":cel_value_internal", + "//base:attributes", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -57,23 +131,15 @@ cc_library( hdrs = [ "cel_value_producer.h", ], - deps = [ - ":cel_value", - "@com_google_absl//absl/strings", - ], + deps = [":cel_value"], ) cc_library( name = "unknown_attribute_set", - srcs = [ - ], hdrs = [ "unknown_attribute_set.h", ], - deps = [ - ":cel_attribute", - "@com_google_absl//absl/container:flat_hash_set", - ], + deps = ["//base:attributes"], ) cc_library( @@ -85,13 +151,17 @@ cc_library( "activation.h", ], deps = [ + ":base_activation", ":cel_attribute", ":cel_function", ":cel_value", ":cel_value_producer", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) @@ -109,7 +179,7 @@ cc_library( "//eval/public/containers:field_access", "//eval/public/containers:field_backed_list_impl", "//eval/public/containers:field_backed_map_impl", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", ], ) @@ -123,22 +193,28 @@ cc_library( ], deps = [ ":cel_value", + "//common:function_descriptor", + "//common:value", + "//eval/internal:interop", + "//internal:status_macros", + "//runtime:function", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "cel_function_adapter", - srcs = [ - "cel_function_adapter.cc", - ], + name = "cel_function_adapter_impl", hdrs = [ - "cel_function_adapter.h", + "cel_function_adapter_impl.h", ], deps = [ ":cel_function", ":cel_function_registry", - "//eval/public/structs:cel_proto_wrapper", + ":cel_value", + "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -146,25 +222,35 @@ cc_library( ) cc_library( - name = "cel_function_provider", - srcs = [ - "cel_function_provider.cc", - ], + name = "cel_function_adapter", hdrs = [ - "cel_function_provider.h", + "cel_function_adapter.h", ], deps = [ - ":activation", - ":cel_function", - "@com_google_absl//absl/status:statusor", + ":cel_function_adapter_impl", + ":cel_value", + "//eval/public/structs:cel_proto_wrapper", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", ], ) +cc_library( + name = "portable_cel_function_adapter", + hdrs = [ + "portable_cel_function_adapter.h", + ], + deps = [":cel_function_adapter"], +) + cc_library( name = "cel_builtins", hdrs = [ "cel_builtins.h", ], + deps = [ + "//base:builtins", + ], ) cc_library( @@ -176,16 +262,204 @@ cc_library( "builtin_func_registrar.h", ], deps = [ + ":cel_function_registry", + ":cel_options", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime/standard:arithmetic_functions", + "//runtime/standard:comparison_functions", + "//runtime/standard:container_functions", + "//runtime/standard:container_membership_functions", + "//runtime/standard:equality_functions", + "//runtime/standard:logical_functions", + "//runtime/standard:regex_functions", + "//runtime/standard:string_functions", + "//runtime/standard:time_functions", + "//runtime/standard:type_conversion_functions", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "comparison_functions", + srcs = [ + "comparison_functions.cc", + ], + hdrs = [ + "comparison_functions.h", + ], + deps = [ + ":cel_function_registry", + ":cel_options", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime/standard:comparison_functions", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "comparison_functions_test", + size = "small", + srcs = [ + "comparison_functions_test.cc", + ], + deps = [ + ":activation", + ":cel_expr_builder_factory", + ":cel_expression", + ":cel_function_registry", + ":cel_options", + ":cel_value", + ":comparison_functions", + "//eval/public/testing:matchers", + "//internal:status_macros", + "//internal:testing", + "//parser", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "equality_function_registrar", + srcs = [ + "equality_function_registrar.cc", + ], + hdrs = [ + "equality_function_registrar.h", + ], + deps = [ + ":cel_function_registry", + ":cel_options", + "//eval/internal:cel_value_equal", + "//runtime:runtime_options", + "//runtime/standard:equality_functions", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "equality_function_registrar_test", + size = "small", + srcs = [ + "equality_function_registrar_test.cc", + ], + deps = [ + ":activation", ":cel_builtins", - ":cel_function", - ":cel_function_adapter", + ":cel_expr_builder_factory", + ":cel_expression", ":cel_function_registry", ":cel_options", + ":cel_value", + ":equality_function_registrar", + ":message_wrapper", "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:trivial_legacy_type_info", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//internal:benchmark", + "//internal:status_macros", + "//internal:testing", + "//parser", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "container_function_registrar", + srcs = [ + "container_function_registrar.cc", + ], + hdrs = [ + "container_function_registrar.h", + ], + deps = [ + ":cel_function_registry", + ":cel_options", + "//runtime:runtime_options", + "//runtime/standard:container_functions", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "container_function_registrar_test", + size = "small", + srcs = [ + "container_function_registrar_test.cc", + ], + deps = [ + ":activation", + ":cel_expr_builder_factory", + ":cel_expression", + ":cel_value", + ":container_function_registrar", + ":equality_function_registrar", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/testing:matchers", + "//internal:testing", + "//parser", + ], +) + +cc_library( + name = "logical_function_registrar", + srcs = [ + "logical_function_registrar.cc", + ], + hdrs = [ + "logical_function_registrar.h", + ], + deps = [ + ":cel_function_registry", + ":cel_options", + "//runtime/standard:logical_functions", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "logical_function_registrar_test", + size = "small", + srcs = [ + "logical_function_registrar_test.cc", + ], + deps = [ + ":activation", + ":cel_expr_builder_factory", + ":cel_expression", + ":cel_options", + ":cel_value", + ":logical_function_registrar", + ":portable_cel_function_adapter", + "//eval/public/testing:matchers", + "//internal:testing", + "//parser", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", - "@com_googlesource_code_re2//:re2", ], ) @@ -202,7 +476,10 @@ cc_library( ":cel_function_adapter", ":cel_function_registry", ":cel_value", + "//eval/public/structs:cel_proto_wrapper", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_googleapis//google/type:timeofday_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -213,12 +490,15 @@ cc_library( "cel_expression.h", ], deps = [ - ":activation", - ":cel_function", + ":base_activation", ":cel_function_registry", + ":cel_type_registry", ":cel_value", + "//common:legacy_value", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -227,8 +507,7 @@ cc_library( srcs = ["source_position.cc"], hdrs = ["source_position.h"], deps = [ - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf_lite", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -239,8 +518,7 @@ cc_library( ], deps = [ ":source_position", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -251,7 +529,7 @@ cc_library( ], deps = [ ":ast_visitor", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -266,17 +544,23 @@ cc_library( deps = [ ":ast_visitor", ":source_position", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "cel_options", + srcs = [ + "cel_options.cc", + ], hdrs = [ "cel_options.h", ], deps = [ + "//runtime:runtime_options", + "@com_google_absl//absl/base:core_headers", "@com_google_protobuf//:protobuf", ], ) @@ -291,8 +575,25 @@ cc_library( ], deps = [ ":cel_expression", + ":cel_function", ":cel_options", + "//common:kind", + "//common:memory", + "//eval/compiler:cel_expression_builder_flat_impl", + "//eval/compiler:comprehension_vulnerability_check", + "//eval/compiler:constant_folding", "//eval/compiler:flat_expr_builder", + "//eval/compiler:qualified_reference_resolver", + "//eval/compiler:regex_precompilation_optimization", + "//extensions:select_optimization", + "//internal:noop_delete", + "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", ], ) @@ -306,10 +607,13 @@ cc_library( ], deps = [ ":cel_value", - "//internal:proto_util", + "//internal:proto_time_encoding", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_protobuf//:json_util", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:time_util", ], ) @@ -319,11 +623,24 @@ cc_library( hdrs = ["cel_function_registry.h"], deps = [ ":cel_function", - ":cel_function_provider", ":cel_options", ":cel_value", + "//common:function_descriptor", + "//common:kind", + "//common:value", + "//eval/internal:interop", + "//internal:status_macros", + "//runtime:function", + "//runtime:function_overload_reference", + "//runtime:function_registry", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", ], ) @@ -335,13 +652,21 @@ cc_test( ], deps = [ ":cel_value", - ":unknown_attribute_set", ":unknown_set", - "//base:status_macros", + "//common:memory", + "//eval/internal:errors", + "//eval/public/structs:trivial_legacy_type_info", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:memory_manager", + "//internal:testing", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", ], ) @@ -355,8 +680,12 @@ cc_test( ":cel_attribute", ":cel_value", "//eval/public/structs:cel_proto_wrapper", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) @@ -371,11 +700,12 @@ cc_test( ":activation", ":cel_attribute", ":cel_function", - "//base:status_macros", "//eval/eval:attribute_trail", "//eval/eval:ident_step", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", + "//internal:testing", "//parser", - "@com_google_googletest//:gtest_main", ], ) @@ -386,37 +716,59 @@ cc_test( ], deps = [ ":ast_traverse", - "@com_google_googletest//:gtest_main", + ":ast_visitor", + "//internal:testing", + ], +) + +cc_library( + name = "ast_rewrite", + srcs = [ + "ast_rewrite.cc", + ], + hdrs = [ + "ast_rewrite.h", + ], + deps = [ + ":ast_visitor", + ":source_position", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( - name = "activation_bind_helper_test", - size = "small", + name = "ast_rewrite_test", srcs = [ - "activation_bind_helper_test.cc", + "ast_rewrite_test.cc", ], deps = [ - ":activation", - ":activation_bind_helper", - "//base:status_macros", - "//eval/testutil:test_message_cc_proto", + ":ast_rewrite", + ":ast_visitor", + ":source_position", + "//internal:testing", + "//parser", "//testutil:util", - "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( - name = "cel_function_provider_test", + name = "activation_bind_helper_test", + size = "small", srcs = [ - "cel_function_provider_test.cc", + "activation_bind_helper_test.cc", ], deps = [ - ":cel_function", - ":cel_function_provider", - "//base:status_macros", - "@com_google_googletest//:gtest_main", + ":activation", + ":activation_bind_helper", + "//eval/testutil:test_message_cc_proto", + "//internal:status_macros", + "//internal:testing", + "//testutil:util", + "@com_google_absl//absl/status", ], ) @@ -426,12 +778,15 @@ cc_test( "cel_function_registry_test.cc", ], deps = [ + ":activation", ":cel_function", - ":cel_function_provider", ":cel_function_registry", - "//base:status_macros", + "//common:kind", + "//eval/internal:adapter_activation_impl", + "//internal:testing", + "//runtime:function_overload_reference", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", ], ) @@ -442,11 +797,59 @@ cc_test( "cel_function_adapter_test.cc", ], deps = [ - ":cel_function", ":cel_function_adapter", - ":cel_value", - "//base:status_macros", - "@com_google_googletest//:gtest_main", + "//internal:status_macros", + "//internal:testing", + ], +) + +cc_library( + name = "cel_type_registry", + srcs = ["cel_type_registry.cc"], + hdrs = ["cel_type_registry.h"], + deps = [ + "//base:data", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_provider", + "//eval/public/structs:protobuf_descriptor_type_provider", + "//runtime:type_registry", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "cel_type_registry_test", + srcs = ["cel_type_registry_test.cc"], + deps = [ + ":cel_type_registry", + "//base:data", + "//common:memory", + "//common:type", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_provider", + "//internal:testing", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "cel_type_registry_protobuf_reflection_test", + srcs = ["cel_type_registry_protobuf_reflection_test.cc"], + deps = [ + ":cel_type_registry", + "//common:memory", + "//common:type", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", ], ) @@ -462,12 +865,17 @@ cc_test( ":cel_builtins", ":cel_expr_builder_factory", ":cel_function_registry", - "//base:status_macros", + ":cel_options", + ":cel_value", "//eval/public/structs:cel_proto_wrapper", + "//internal:status_macros", + "//internal:testing", + "//internal:time", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -479,13 +887,20 @@ cc_test( ], deps = [ ":builtin_func_registrar", + ":cel_expr_builder_factory", + ":cel_expression", ":cel_function_registry", ":cel_value", ":extension_func_registrar", - "//base:status_macros", + "//eval/public/structs:cel_proto_wrapper", + "//internal:status_macros", + "//internal:testing", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_googleapis//google/type:timeofday_cc_proto", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:time_util", ], ) @@ -497,8 +912,8 @@ cc_test( ], deps = [ ":source_position", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -512,7 +927,7 @@ cc_test( ":cel_attribute", ":cel_value", ":unknown_attribute_set", - "@com_google_googletest//:gtest_main", + "//internal:testing", ], ) @@ -524,14 +939,14 @@ cc_test( ], deps = [ ":value_export_util", - "//base:status_macros", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", + "//internal:status_macros", + "//internal:testing", "//testutil:util", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", ], ) @@ -540,12 +955,8 @@ cc_library( srcs = ["unknown_function_result_set.cc"], hdrs = ["unknown_function_result_set.h"], deps = [ - ":cel_function", - ":cel_options", - ":cel_value", - ":set_util", - "@com_google_absl//absl/container:btree", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "//base:function_result", + "//base:function_result_set", ], ) @@ -562,10 +973,14 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", + "//internal:testing", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", ], ) @@ -575,6 +990,7 @@ cc_library( deps = [ ":unknown_attribute_set", ":unknown_function_result_set", + "//base/internal:unknown_set", ], ) @@ -583,11 +999,12 @@ cc_test( srcs = ["unknown_set_test.cc"], deps = [ ":cel_attribute", + ":cel_function", ":unknown_attribute_set", ":unknown_function_result_set", ":unknown_set", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -602,15 +1019,19 @@ cc_library( ], deps = [ ":cel_value", - "//base:status_macros", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", - "//internal:proto_util", + "//internal:proto_time_encoding", + "//internal:status_macros", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", ], ) @@ -618,7 +1039,21 @@ cc_library( name = "set_util", srcs = ["set_util.cc"], hdrs = ["set_util.h"], - deps = ["//eval/public:cel_value"], + deps = [":cel_value"], +) + +cc_library( + name = "base_activation", + hdrs = ["base_activation.h"], + deps = [ + ":cel_attribute", + ":cel_function", + ":cel_value", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:field_mask_cc_proto", + ], ) cc_test( @@ -634,9 +1069,85 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_protobuf//:empty_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_test( + name = "builtin_func_registrar_test", + srcs = ["builtin_func_registrar_test.cc"], + deps = [ + ":activation", + ":builtin_func_registrar", + ":cel_expr_builder_factory", + ":cel_expression", + ":cel_options", + ":cel_value", + "//eval/public/testing:matchers", + "//internal:testing", + "//internal:time", + "//parser", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest_main", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "cel_number", + srcs = ["cel_number.cc"], + hdrs = ["cel_number.h"], + deps = [ + ":cel_value", + "//internal:number", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "cel_number_test", + srcs = ["cel_number_test.cc"], + deps = [ + ":cel_number", + ":cel_value", + "//internal:testing", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "string_extension_func_registrar", + srcs = ["string_extension_func_registrar.cc"], + hdrs = ["string_extension_func_registrar.h"], + deps = [ + ":cel_function_registry", + ":cel_options", + "//extensions:strings", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "string_extension_func_registrar_test", + srcs = ["string_extension_func_registrar_test.cc"], + deps = [ + ":builtin_func_registrar", + ":cel_function_registry", + ":cel_value", + ":string_extension_func_registrar", + "//eval/public/containers:container_backed_list_impl", + "//internal:testing", + "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/activation.cc b/eval/public/activation.cc index ecd95ee13..95a1c2a4c 100644 --- a/eval/public/activation.cc +++ b/eval/public/activation.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include "absl/status/status.h" diff --git a/eval/public/activation.h b/eval/public/activation.h index a6346699e..6f2bb59c1 100644 --- a/eval/public/activation.h +++ b/eval/public/activation.h @@ -1,69 +1,28 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_H_ -#include #include +#include +#include #include -#include "google/protobuf/field_mask.pb.h" -#include "google/protobuf/util/field_mask_util.h" +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "eval/public/base_activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "eval/public/cel_value_producer.h" +#include "runtime/internal/attribute_matcher.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace cel::runtime_internal { +class ActivationAttributeMatcherAccess; +} -// Base class for an activation. -class BaseActivation { - public: - BaseActivation() = default; - - // Non-copyable/non-assignable - BaseActivation(const BaseActivation&) = delete; - BaseActivation& operator=(const BaseActivation&) = delete; - - // Return a list of function overloads for the given name. - virtual std::vector FindFunctionOverloads( - absl::string_view) const = 0; - - // Provide the value that is bound to the name, if found. - // arena parameter is provided to support the case when we want to pass the - // ownership of returned object ( Message/List/Map ) to Evaluator. - virtual absl::optional FindValue(absl::string_view, - google::protobuf::Arena*) const = 0; - - // Check whether a select path is unknown. - virtual bool IsPathUnknown(absl::string_view) const { return false; } - - // Return FieldMask defining the list of unknown paths. - virtual const google::protobuf::FieldMask& unknown_paths() const { - return google::protobuf::FieldMask::default_instance(); - } - - // Return the collection of attribute patterns that determine missing - // attributes. - virtual const std::vector& missing_attribute_patterns() - const { - static const std::vector empty; - return empty; - } - - // Return the collection of attribute patterns that determine "unknown" - // values. - virtual const std::vector& unknown_attribute_patterns() - const { - static const std::vector empty; - return empty; - } - - virtual ~BaseActivation() {} -}; +namespace google::api::expr::runtime { // Instance of Activation class is used by evaluator. // It provides binding between references used in expressions @@ -76,6 +35,10 @@ class Activation : public BaseActivation { Activation(const Activation&) = delete; Activation& operator=(const Activation&) = delete; + // Move-constructible/move-assignable + Activation(Activation&& other) = default; + Activation& operator=(Activation&& other) = default; + // BaseActivation std::vector FindFunctionOverloads( absl::string_view name) const override; @@ -83,11 +46,6 @@ class Activation : public BaseActivation { absl::optional FindValue(absl::string_view name, google::protobuf::Arena* arena) const override; - bool IsPathUnknown(absl::string_view path) const override { - return google::protobuf::util::FieldMaskUtil::IsPathInFieldMask(std::string(path), - unknown_paths_); - } - // Insert a function into the activation (ie a lazily bound function). Returns // a status if the name and shape of the function matches another one that has // already been bound. @@ -115,23 +73,15 @@ class Activation : public BaseActivation { // cleared. int ClearCachedValues(); - // Set unknown value paths through FieldMask - void set_unknown_paths(google::protobuf::FieldMask mask) { - unknown_paths_ = std::move(mask); - } - - // Set error paths through FieldMask + // Set missing attribute patterns for evaluation. + // + // If a field access is found to match any of the provided patterns, the + // result is treated as a missing attribute error. void set_missing_attribute_patterns( std::vector missing_attribute_patterns) { missing_attribute_patterns_ = std::move(missing_attribute_patterns); } - // Return FieldMask defining the list of unknown paths. - const google::protobuf::FieldMask& unknown_paths() const override { - return unknown_paths_; - } - - // Return FieldMask defining the list of unknown paths. const std::vector& missing_attribute_patterns() const override { return missing_attribute_patterns_; @@ -185,19 +135,36 @@ class Activation : public BaseActivation { std::unique_ptr producer_; }; + friend class cel::runtime_internal::ActivationAttributeMatcherAccess; + + void SetAttributeMatcher( + const cel::runtime_internal::AttributeMatcher* matcher) { + attribute_matcher_ = matcher; + } + + void SetAttributeMatcher( + std::unique_ptr matcher) { + owned_attribute_matcher_ = std::move(matcher); + attribute_matcher_ = owned_attribute_matcher_.get(); + } + + const cel::runtime_internal::AttributeMatcher* absl_nullable + GetAttributeMatcher() const override { + return attribute_matcher_; + } + absl::flat_hash_map value_map_; absl::flat_hash_map>> function_map_; - // TODO(issues/41) deprecate when unknowns support is done. - google::protobuf::FieldMask unknown_paths_; std::vector missing_attribute_patterns_; std::vector unknown_attribute_patterns_; + + const cel::runtime_internal::AttributeMatcher* attribute_matcher_ = nullptr; + std::unique_ptr + owned_attribute_matcher_; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_H_ diff --git a/eval/public/activation_bind_helper.cc b/eval/public/activation_bind_helper.cc index a5bede00d..1e8004003 100644 --- a/eval/public/activation_bind_helper.cc +++ b/eval/public/activation_bind_helper.cc @@ -1,5 +1,6 @@ #include "eval/public/activation_bind_helper.h" +#include "absl/status/status.h" #include "eval/public/containers/field_access.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" @@ -37,6 +38,13 @@ absl::Status CreateValueFromField(const google::protobuf::Message* msg, absl::Status BindProtoToActivation(const Message* message, Arena* arena, Activation* activation, ProtoUnsetFieldOptions options) { + // If we need to bind any types that are backed by an arena allocation, we + // will cause a memory leak. + if (arena == nullptr) { + return absl::InvalidArgumentError( + "arena must not be null for BindProtoToActivation."); + } + // TODO(issues/24): Improve the utilities to bind dynamic values as well. const Descriptor* desc = message->GetDescriptor(); const google::protobuf::Reflection* reflection = message->GetReflection(); diff --git a/eval/public/activation_bind_helper.h b/eval/public/activation_bind_helper.h index 2154b91ec..fe5828f12 100644 --- a/eval/public/activation_bind_helper.h +++ b/eval/public/activation_bind_helper.h @@ -17,7 +17,8 @@ enum class ProtoUnsetFieldOptions { }; // Utility method, that takes a protobuf Message and interprets it as a -// namespace, binding its fields to Activation. +// namespace, binding its fields to Activation. |arena| must be non-null. +// // Field names and values become respective names and values of parameters // bound to the Activation object. // Example: @@ -33,7 +34,7 @@ enum class ProtoUnsetFieldOptions { // person.set_name("John Doe"); // person.age(42); // -// RETURN_IF_ERROR(BindProtoToActivation(&person, &arena, &activation)); +// CEL_RETURN_IF_ERROR(BindProtoToActivation(&person, &arena, &activation)); // // After this snippet, activation will have two parameters bound: // "name", with string value of "John Doe" diff --git a/eval/public/activation_bind_helper_test.cc b/eval/public/activation_bind_helper_test.cc index d28653ac2..f644f1669 100644 --- a/eval/public/activation_bind_helper_test.cc +++ b/eval/public/activation_bind_helper_test.cc @@ -1,11 +1,11 @@ #include "eval/public/activation_bind_helper.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/status/status.h" #include "eval/public/activation.h" #include "eval/testutil/test_message.pb.h" +#include "internal/status_macros.h" +#include "internal/testing.h" #include "testutil/util.h" -#include "base/status_macros.h" namespace google { namespace api { @@ -122,9 +122,20 @@ TEST(ActivationBindHelperTest, TestBindDefaultFields) { result = activation.FindValue("message_value", &arena); ASSERT_TRUE(result.has_value()); - EXPECT_NE(nullptr, result.value().MessageOrDie()); + EXPECT_NE(nullptr, result->MessageOrDie()); EXPECT_THAT(TestMessage::default_instance(), - EqualsProto(*result.value().MessageOrDie())); + EqualsProto(*result->MessageOrDie())); +} + +TEST(ActivationBindHelperTest, RejectsNullArena) { + TestMessage message; + message.set_bool_value(true); + + Activation activation; + + ASSERT_EQ(BindProtoToActivation(&message, /*arena=*/nullptr, &activation), + absl::InvalidArgumentError( + "arena must not be null for BindProtoToActivation.")); } } // namespace diff --git a/eval/public/activation_test.cc b/eval/public/activation_test.cc index 97b21dd88..238caf45e 100644 --- a/eval/public/activation_test.cc +++ b/eval/public/activation_test.cc @@ -1,13 +1,17 @@ #include "eval/public/activation.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include + #include "eval/eval/attribute_trail.h" #include "eval/eval/ident_step.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" +#include "internal/testing.h" #include "parser/parser.h" -#include "base/status_macros.h" namespace google { namespace api { @@ -16,21 +20,23 @@ namespace runtime { namespace { -using google::api::expr::v1alpha1::Expr; +using ::absl_testing::StatusIs; +using ::cel::extensions::ProtoMemoryManager; +using ::cel::expr::Expr; using ::google::protobuf::Arena; -using testing::ElementsAre; -using testing::Eq; -using testing::HasSubstr; -using testing::IsEmpty; -using testing::Property; -using testing::Return; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Property; +using ::testing::Return; class MockValueProducer : public CelValueProducer { public: MOCK_METHOD(CelValue, Produce, (Arena*), (override)); }; -// Simple function that takes no args and returns an int64_t. +// Simple function that takes no args and returns an int64. class ConstCelFunction : public CelFunction { public: explicit ConstCelFunction(absl::string_view name) @@ -76,7 +82,7 @@ TEST(ActivationTest, CheckValueInsertFindAndRemove) { TEST(ActivationTest, CheckValueProducerInsertFindAndRemove) { const std::string kValue = "42"; - auto producer = absl::make_unique(); + auto producer = std::make_unique(); google::protobuf::Arena arena; @@ -113,9 +119,8 @@ TEST(ActivationTest, CheckValueProducerInsertFindAndRemove) { TEST(ActivationTest, CheckInsertFunction) { Activation activation; - auto insert_status = activation.InsertFunction( - std::make_unique("ConstFunc")); - EXPECT_OK(insert_status); + ASSERT_OK(activation.InsertFunction( + std::make_unique("ConstFunc"))); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT(overloads, @@ -123,26 +128,20 @@ TEST(ActivationTest, CheckInsertFunction) { &CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("ConstFunc"))))); - absl::Status status = activation.InsertFunction( - std::make_unique("ConstFunc")); - - EXPECT_THAT(std::string(status.message()), - HasSubstr("Function with same shape")); + EXPECT_THAT(activation.InsertFunction( + std::make_unique("ConstFunc")), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Function with same shape"))); - overloads = activation.FindFunctionOverloads("ConstFunc0"); - - EXPECT_THAT(overloads, IsEmpty()); + EXPECT_THAT(activation.FindFunctionOverloads("ConstFunc0"), IsEmpty()); } TEST(ActivationTest, CheckRemoveFunction) { Activation activation; - auto insert_status = - activation.InsertFunction(std::make_unique( - CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kInt64}})); - EXPECT_OK(insert_status); - insert_status = activation.InsertFunction(std::make_unique( - CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kUint64}})); - EXPECT_OK(insert_status); + ASSERT_OK(activation.InsertFunction(std::make_unique( + CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kInt64}}))); + EXPECT_OK(activation.InsertFunction(std::make_unique( + CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kUint64}}))); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT( @@ -156,16 +155,15 @@ TEST(ActivationTest, CheckRemoveFunction) { EXPECT_TRUE(activation.RemoveFunctionEntries( {"ConstFunc", false, {CelValue::Type::kAny}})); - overloads = activation.FindFunctionOverloads("ConstFunc"); - EXPECT_THAT(overloads, IsEmpty()); + EXPECT_THAT(activation.FindFunctionOverloads("ConstFunc"), IsEmpty()); } TEST(ActivationTest, CheckValueProducerClear) { const std::string kValue1 = "42"; const std::string kValue2 = "43"; - auto producer1 = absl::make_unique(); - auto producer2 = absl::make_unique(); + auto producer1 = std::make_unique(); + auto producer2 = std::make_unique(); google::protobuf::Arena arena; @@ -183,7 +181,7 @@ TEST(ActivationTest, CheckValueProducerClear) { // Produce first value auto opt_value = activation.FindValue("value42", &arena); EXPECT_TRUE(opt_value.has_value()); - EXPECT_THAT(opt_value.value().StringOrDie().value(), Eq(kValue1)); + EXPECT_THAT(opt_value->StringOrDie().value(), Eq(kValue1)); // Test clearing bound value EXPECT_TRUE(activation.ClearValueEntry("value42")); @@ -192,7 +190,7 @@ TEST(ActivationTest, CheckValueProducerClear) { // Produce second value auto opt_value2 = activation.FindValue("value43", &arena); EXPECT_TRUE(opt_value2.has_value()); - EXPECT_THAT(opt_value2.value().StringOrDie().value(), Eq(kValue2)); + EXPECT_THAT(opt_value2->StringOrDie().value(), Eq(kValue2)); // Clear all values EXPECT_EQ(1, activation.ClearCachedValues()); @@ -202,13 +200,12 @@ TEST(ActivationTest, CheckValueProducerClear) { // Produce first value again auto opt_value3 = activation.FindValue("value42", &arena); EXPECT_TRUE(opt_value3.has_value()); - EXPECT_THAT(opt_value3.value().StringOrDie().value(), Eq(kValue1)); + EXPECT_THAT(opt_value3->StringOrDie().value(), Eq(kValue1)); EXPECT_EQ(1, activation.ClearCachedValues()); } TEST(ActivationTest, ErrorPathTest) { Activation activation; - Arena arena; Expr expr; auto* select_expr = expr.mutable_select_expr(); @@ -219,19 +216,19 @@ TEST(ActivationTest, ErrorPathTest) { const CelAttributePattern destination_ip_pattern( "destination", - {CelAttributeQualifierPattern::Create(CelValue::CreateStringView("ip"))}); + {CreateCelAttributeQualifierPattern(CelValue::CreateStringView("ip"))}); - AttributeTrail trail(*ident_expr, &arena); - trail = trail.Step( - CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), &arena); + AttributeTrail trail("destination"); + trail = + trail.Step(CreateCelAttributeQualifier(CelValue::CreateStringView("ip"))); - ASSERT_EQ(destination_ip_pattern.IsMatch(*trail.attribute()), + ASSERT_EQ(destination_ip_pattern.IsMatch(trail.attribute()), CelAttributePattern::MatchType::FULL); EXPECT_TRUE(activation.missing_attribute_patterns().empty()); activation.set_missing_attribute_patterns({destination_ip_pattern}); EXPECT_EQ( - activation.missing_attribute_patterns()[0].IsMatch(*trail.attribute()), + activation.missing_attribute_patterns()[0].IsMatch(trail.attribute()), CelAttributePattern::MatchType::FULL); } diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc new file mode 100644 index 000000000..87c667eb5 --- /dev/null +++ b/eval/public/ast_rewrite.cc @@ -0,0 +1,391 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/ast_rewrite.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/log/absl_log.h" +#include "absl/types/variant.h" +#include "eval/public/ast_visitor.h" +#include "eval/public/source_position.h" + +namespace google::api::expr::runtime { + +using cel::expr::Expr; +using cel::expr::SourceInfo; +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; + +namespace { + +struct ArgRecord { + // Not null. + Expr* expr; + // Not null. + const SourceInfo* source_info; + + // For records that are direct arguments to call, we need to call + // the CallArg visitor immediately after the argument is evaluated. + const Expr* calling_expr; + int call_arg; +}; + +struct ComprehensionRecord { + // Not null. + Expr* expr; + // Not null. + const SourceInfo* source_info; + + const Comprehension* comprehension; + const Expr* comprehension_expr; + ComprehensionArg comprehension_arg; + bool use_comprehension_callbacks; +}; + +struct ExprRecord { + // Not null. + Expr* expr; + // Not null. + const SourceInfo* source_info; +}; + +using StackRecordKind = + std::variant; + +struct StackRecord { + public: + ABSL_ATTRIBUTE_UNUSED static constexpr int kNotCallArg = -1; + static constexpr int kTarget = -2; + + StackRecord(Expr* e, const SourceInfo* info) { + ExprRecord record; + record.expr = e; + record.source_info = info; + record_variant = record; + } + + StackRecord(Expr* e, const SourceInfo* info, Comprehension* comprehension, + Expr* comprehension_expr, ComprehensionArg comprehension_arg, + bool use_comprehension_callbacks) { + if (use_comprehension_callbacks) { + ComprehensionRecord record; + record.expr = e; + record.source_info = info; + record.comprehension = comprehension; + record.comprehension_expr = comprehension_expr; + record.comprehension_arg = comprehension_arg; + record.use_comprehension_callbacks = use_comprehension_callbacks; + record_variant = record; + return; + } + ArgRecord record; + record.expr = e; + record.source_info = info; + record.calling_expr = comprehension_expr; + record.call_arg = comprehension_arg; + record_variant = record; + } + + StackRecord(Expr* e, const SourceInfo* info, const Expr* call, int argnum) { + ArgRecord record; + record.expr = e; + record.source_info = info; + record.calling_expr = call; + record.call_arg = argnum; + record_variant = record; + } + + Expr* expr() const { return absl::get(record_variant).expr; } + + const SourceInfo* source_info() const { + return absl::get(record_variant).source_info; + } + + bool IsExprRecord() const { + return absl::holds_alternative(record_variant); + } + + StackRecordKind record_variant; + bool visited = false; +}; + +struct PreVisitor { + void operator()(const ExprRecord& record) { + Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PreVisitExpr(expr, &position); + switch (expr->expr_kind_case()) { + case Expr::kSelectExpr: + visitor->PreVisitSelect(&expr->select_expr(), expr, &position); + break; + case Expr::kCallExpr: + visitor->PreVisitCall(&expr->call_expr(), expr, &position); + break; + case Expr::kComprehensionExpr: + visitor->PreVisitComprehension(&expr->comprehension_expr(), expr, + &position); + break; + default: + // No pre-visit action. + break; + } + } + + // Do nothing for Arg variant. + void operator()(const ArgRecord&) {} + + void operator()(const ComprehensionRecord& record) { + Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PreVisitComprehensionSubexpression( + expr, record.comprehension, record.comprehension_arg, &position); + } + + AstVisitor* visitor; +}; + +void PreVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PreVisitor{visitor}, record.record_variant); +} + +struct PostVisitor { + void operator()(const ExprRecord& record) { + Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + switch (expr->expr_kind_case()) { + case Expr::kConstExpr: + visitor->PostVisitConst(&expr->const_expr(), expr, &position); + break; + case Expr::kIdentExpr: + visitor->PostVisitIdent(&expr->ident_expr(), expr, &position); + break; + case Expr::kSelectExpr: + visitor->PostVisitSelect(&expr->select_expr(), expr, &position); + break; + case Expr::kCallExpr: + visitor->PostVisitCall(&expr->call_expr(), expr, &position); + break; + case Expr::kListExpr: + visitor->PostVisitCreateList(&expr->list_expr(), expr, &position); + break; + case Expr::kStructExpr: + visitor->PostVisitCreateStruct(&expr->struct_expr(), expr, &position); + break; + case Expr::kComprehensionExpr: + visitor->PostVisitComprehension(&expr->comprehension_expr(), expr, + &position); + break; + case Expr::EXPR_KIND_NOT_SET: + break; + default: + ABSL_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); + } + + visitor->PostVisitExpr(expr, &position); + } + + void operator()(const ArgRecord& record) { + Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + if (record.call_arg == StackRecord::kTarget) { + visitor->PostVisitTarget(record.calling_expr, &position); + } else { + visitor->PostVisitArg(record.call_arg, record.calling_expr, &position); + } + } + + void operator()(const ComprehensionRecord& record) { + Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PostVisitComprehensionSubexpression( + expr, record.comprehension, record.comprehension_arg, &position); + } + + AstVisitor* visitor; +}; + +void PostVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PostVisitor{visitor}, record.record_variant); +} + +void PushSelectDeps(Select* select_expr, const SourceInfo* source_info, + std::stack* stack) { + if (select_expr->has_operand()) { + stack->push(StackRecord(select_expr->mutable_operand(), source_info)); + } +} + +void PushCallDeps(Call* call_expr, Expr* expr, const SourceInfo* source_info, + std::stack* stack) { + const int arg_size = call_expr->args_size(); + // Our contract is that we visit arguments in order. To do that, we need + // to push them onto the stack in reverse order. + for (int i = arg_size - 1; i >= 0; --i) { + stack->push(StackRecord(call_expr->mutable_args(i), source_info, expr, i)); + } + // Are we receiver-style? + if (call_expr->has_target()) { + stack->push(StackRecord(call_expr->mutable_target(), source_info, expr, + StackRecord::kTarget)); + } +} + +void PushListDeps(CreateList* list_expr, const SourceInfo* source_info, + std::stack* stack) { + auto& elements = *list_expr->mutable_elements(); + for (auto it = elements.rbegin(); it != elements.rend(); ++it) { + auto& element = *it; + stack->push(StackRecord(&element, source_info)); + } +} + +void PushStructDeps(CreateStruct* struct_expr, const SourceInfo* source_info, + std::stack* stack) { + auto& entries = *struct_expr->mutable_entries(); + for (auto it = entries.rbegin(); it != entries.rend(); ++it) { + auto& entry = *it; + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_value()) { + stack->push(StackRecord(entry.mutable_value(), source_info)); + } + + if (entry.has_map_key()) { + stack->push(StackRecord(entry.mutable_map_key(), source_info)); + } + } +} + +void PushComprehensionDeps(Comprehension* c, Expr* expr, + const SourceInfo* source_info, + std::stack* stack, + bool use_comprehension_callbacks) { + StackRecord iter_range(c->mutable_iter_range(), source_info, c, expr, + ITER_RANGE, use_comprehension_callbacks); + StackRecord accu_init(c->mutable_accu_init(), source_info, c, expr, ACCU_INIT, + use_comprehension_callbacks); + StackRecord loop_condition(c->mutable_loop_condition(), source_info, c, expr, + LOOP_CONDITION, use_comprehension_callbacks); + StackRecord loop_step(c->mutable_loop_step(), source_info, c, expr, LOOP_STEP, + use_comprehension_callbacks); + StackRecord result(c->mutable_result(), source_info, c, expr, RESULT, + use_comprehension_callbacks); + // Push them in reverse order. + stack->push(result); + stack->push(loop_step); + stack->push(loop_condition); + stack->push(accu_init); + stack->push(iter_range); +} + +struct PushDepsVisitor { + void operator()(const ExprRecord& record) { + Expr* expr = record.expr; + switch (expr->expr_kind_case()) { + case Expr::kSelectExpr: + PushSelectDeps(expr->mutable_select_expr(), record.source_info, &stack); + break; + case Expr::kCallExpr: + PushCallDeps(expr->mutable_call_expr(), expr, record.source_info, + &stack); + break; + case Expr::kListExpr: + PushListDeps(expr->mutable_list_expr(), record.source_info, &stack); + break; + case Expr::kStructExpr: + PushStructDeps(expr->mutable_struct_expr(), record.source_info, &stack); + break; + case Expr::kComprehensionExpr: + PushComprehensionDeps(expr->mutable_comprehension_expr(), expr, + record.source_info, &stack, + options.use_comprehension_callbacks); + break; + default: + break; + } + } + + void operator()(const ArgRecord& record) { + stack.push(StackRecord(record.expr, record.source_info)); + } + + void operator()(const ComprehensionRecord& record) { + stack.push(StackRecord(record.expr, record.source_info)); + } + + std::stack& stack; + const RewriteTraversalOptions& options; +}; + +void PushDependencies(const StackRecord& record, std::stack& stack, + const RewriteTraversalOptions& options) { + absl::visit(PushDepsVisitor{stack, options}, record.record_variant); +} + +} // namespace + +bool AstRewrite(Expr* expr, const SourceInfo* source_info, + AstRewriter* visitor) { + return AstRewrite(expr, source_info, visitor, RewriteTraversalOptions{}); +} + +bool AstRewrite(Expr* expr, const SourceInfo* source_info, AstRewriter* visitor, + RewriteTraversalOptions options) { + std::stack stack; + std::vector traversal_path; + + stack.push(StackRecord(expr, source_info)); + bool rewritten = false; + + while (!stack.empty()) { + StackRecord& record = stack.top(); + if (!record.visited) { + if (record.IsExprRecord()) { + traversal_path.push_back(record.expr()); + visitor->TraversalStackUpdate(absl::MakeSpan(traversal_path)); + + SourcePosition pos(record.expr()->id(), record.source_info()); + if (visitor->PreVisitRewrite(record.expr(), &pos)) { + rewritten = true; + } + } + PreVisit(record, visitor); + PushDependencies(record, stack, options); + record.visited = true; + } else { + PostVisit(record, visitor); + if (record.IsExprRecord()) { + SourcePosition pos(record.expr()->id(), record.source_info()); + if (visitor->PostVisitRewrite(record.expr(), &pos)) { + rewritten = true; + } + + traversal_path.pop_back(); + visitor->TraversalStackUpdate(absl::MakeSpan(traversal_path)); + } + stack.pop(); + } + } + + return rewritten; +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/ast_rewrite.h b/eval/public/ast_rewrite.h new file mode 100644 index 000000000..791778c4f --- /dev/null +++ b/eval/public/ast_rewrite.h @@ -0,0 +1,175 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ + +#include "cel/expr/syntax.pb.h" +#include "absl/types/span.h" +#include "eval/public/ast_visitor.h" + +namespace google::api::expr::runtime { + +// Traversal options for AstRewrite. +struct RewriteTraversalOptions { + // If enabled, use comprehension specific callbacks instead of the general + // arguments callbacks. + bool use_comprehension_callbacks; + + RewriteTraversalOptions() : use_comprehension_callbacks(false) {} +}; + +// Interface for AST rewriters. +// Extends AstVisitor interface with update methods. +// see AstRewrite for more details on usage. +class AstRewriter : public AstVisitor { + public: + ~AstRewriter() override {} + + // Rewrite a sub expression before visiting. + // Occurs before visiting Expr. If expr is modified, the new value will be + // visited. + virtual bool PreVisitRewrite(cel::expr::Expr* expr, + const SourcePosition* position) = 0; + + // Rewrite a sub expression after visiting. + // Occurs after visiting expr and it's children. If expr is modified, the old + // sub expression is visited. + virtual bool PostVisitRewrite(cel::expr::Expr* expr, + const SourcePosition* position) = 0; + + // Notify the visitor of updates to the traversal stack. + virtual void TraversalStackUpdate( + absl::Span path) = 0; +}; + +// Trivial implementation for AST rewriters. +// Virtual methods are overridden with no-op callbacks. +class AstRewriterBase : public AstRewriter { + public: + ~AstRewriterBase() override {} + + void PreVisitExpr(const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitExpr(const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PreVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PreVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitArg(int, const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitTarget(const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + void PostVisitCreateStruct(const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, + const SourcePosition*) override {} + + bool PreVisitRewrite(cel::expr::Expr* expr, + const SourcePosition* position) override { + return false; + } + + bool PostVisitRewrite(cel::expr::Expr* expr, + const SourcePosition* position) override { + return false; + } + + void TraversalStackUpdate( + absl::Span path) override {} +}; + +// Traverses the AST representation in an expr proto. Returns true if any +// rewrites occur. +// +// Rewrites may happen before and/or after visiting an expr subtree. If a +// change happens during the pre-visit rewrite, the updated subtree will be +// visited. If a change happens during the post-visit rewrite, the old subtree +// will be visited. +// +// expr: root node of the tree. +// source_info: optional additional parse information about the expression +// visitor: the callback object that receives the visitation notifications +// options: options for traversal. see RewriteTraversalOptions. Defaults are +// used if not sepecified. +// +// Traversal order follows the pattern: +// PreVisitRewrite +// PreVisitExpr +// ..PreVisit{ExprKind} +// ....PreVisit{ArgumentIndex} +// .......PreVisitExpr (subtree) +// .......PostVisitExpr (subtree) +// ....PostVisit{ArgumentIndex} +// ..PostVisit{ExprKind} +// PostVisitExpr +// PostVisitRewrite +// +// Example callback order for fn(1, var): +// PreVisitExpr +// ..PreVisitCall(fn) +// ......PreVisitExpr +// ........PostVisitConst(1) +// ......PostVisitExpr +// ....PostVisitArg(fn, 0) +// ......PreVisitExpr +// ........PostVisitIdent(var) +// ......PostVisitExpr +// ....PostVisitArg(fn, 1) +// ..PostVisitCall(fn) +// PostVisitExpr + +bool AstRewrite(cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, + AstRewriter* visitor); + +bool AstRewrite(cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, + AstRewriter* visitor, RewriteTraversalOptions options); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ diff --git a/eval/public/ast_rewrite_test.cc b/eval/public/ast_rewrite_test.cc new file mode 100644 index 000000000..b2ee8d13c --- /dev/null +++ b/eval/public/ast_rewrite_test.cc @@ -0,0 +1,600 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/ast_rewrite.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "eval/public/ast_visitor.h" +#include "eval/public/source_position.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "testutil/util.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::expr::Constant; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::InSequence; + +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; + +class MockAstRewriter : public AstRewriter { + public: + // Expr handler. + MOCK_METHOD(void, PreVisitExpr, + (const Expr* expr, const SourcePosition* position), (override)); + + // Expr handler. + MOCK_METHOD(void, PostVisitExpr, + (const Expr* expr, const SourcePosition* position), (override)); + + MOCK_METHOD(void, PostVisitConst, + (const Constant* const_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Ident node handler. + MOCK_METHOD(void, PostVisitIdent, + (const Ident* ident_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Select node handler group + MOCK_METHOD(void, PreVisitSelect, + (const Select* select_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + MOCK_METHOD(void, PostVisitSelect, + (const Select* select_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Call node handler group + MOCK_METHOD(void, PreVisitCall, + (const Call* call_expr, const Expr* expr, + const SourcePosition* position), + (override)); + MOCK_METHOD(void, PostVisitCall, + (const Call* call_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehension, + (const Comprehension* comprehension_expr, const Expr* expr, + const SourcePosition* position), + (override)); + MOCK_METHOD(void, PostVisitComprehension, + (const Comprehension* comprehension_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehensionSubexpression, + (const Expr* expr, const Comprehension* comprehension_expr, + ComprehensionArg comprehension_arg, + const SourcePosition* position), + (override)); + MOCK_METHOD(void, PostVisitComprehensionSubexpression, + (const Expr* expr, const Comprehension* comprehension_expr, + ComprehensionArg comprehension_arg, + const SourcePosition* position), + (override)); + + // We provide finer granularity for Call and Comprehension node callbacks + // to allow special handling for short-circuiting. + MOCK_METHOD(void, PostVisitTarget, + (const Expr* expr, const SourcePosition* position), (override)); + MOCK_METHOD(void, PostVisitArg, + (int arg_num, const Expr* expr, const SourcePosition* position), + (override)); + + // CreateList node handler group + MOCK_METHOD(void, PostVisitCreateList, + (const CreateList* list_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // CreateStruct node handler group + MOCK_METHOD(void, PostVisitCreateStruct, + (const CreateStruct* struct_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + MOCK_METHOD(bool, PreVisitRewrite, + (Expr * expr, const SourcePosition* position), (override)); + + MOCK_METHOD(bool, PostVisitRewrite, + (Expr * expr, const SourcePosition* position), (override)); + + MOCK_METHOD(void, TraversalStackUpdate, (absl::Span path), + (override)); +}; + +TEST(AstCrawlerTest, CheckCrawlConstant) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto const_expr = expr.mutable_const_expr(); + + EXPECT_CALL(handler, PostVisitConst(const_expr, &expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +TEST(AstCrawlerTest, CheckCrawlIdent) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto ident_expr = expr.mutable_ident_expr(); + + EXPECT_CALL(handler, PostVisitIdent(ident_expr, &expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test handling of Select node when operand is not set. +TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto select_expr = expr.mutable_select_expr(); + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitSelect(select_expr, &expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test handling of Select node +TEST(AstCrawlerTest, CheckCrawlSelect) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto select_expr = expr.mutable_select_expr(); + auto operand = select_expr->mutable_operand(); + auto ident_expr = operand->mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitIdent(ident_expr, operand, _)).Times(1); + EXPECT_CALL(handler, PostVisitSelect(select_expr, &expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test handling of Call node without receiver +TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { + SourceInfo source_info; + MockAstRewriter handler; + + // (, ) + Expr expr; + auto* call_expr = expr.mutable_call_expr(); + Expr* arg0 = call_expr->add_args(); + auto* const_expr = arg0->mutable_const_expr(); + Expr* arg1 = call_expr->add_args(); + auto* ident_expr = arg1->mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitTarget(_, _)).Times(0); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test handling of Call node with receiver +TEST(AstCrawlerTest, CheckCrawlCallReceiver) { + SourceInfo source_info; + MockAstRewriter handler; + + // .(, ) + Expr expr; + auto* call_expr = expr.mutable_call_expr(); + Expr* target = call_expr->mutable_target(); + auto* target_ident = target->mutable_ident_expr(); + Expr* arg0 = call_expr->add_args(); + auto* const_expr = arg0->mutable_const_expr(); + Expr* arg1 = call_expr->add_args(); + auto* ident_expr = arg1->mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); + + // Target + EXPECT_CALL(handler, PostVisitIdent(target_ident, target, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(target, _)).Times(1); + EXPECT_CALL(handler, PostVisitTarget(&expr, _)).Times(1); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehension) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto c = expr.mutable_comprehension_expr(); + auto iter_range = c->mutable_iter_range(); + auto iter_range_expr = iter_range->mutable_const_expr(); + auto accu_init = c->mutable_accu_init(); + auto accu_init_expr = accu_init->mutable_ident_expr(); + auto loop_condition = c->mutable_loop_condition(); + auto loop_condition_expr = loop_condition->mutable_const_expr(); + auto loop_step = c->mutable_loop_step(); + auto loop_step_expr = loop_step->mutable_ident_expr(); + auto result = c->mutable_result(); + auto result_expr = result->mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); + + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) + .Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) + .Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(loop_condition, c, + LOOP_CONDITION, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(loop_condition, c, + LOOP_CONDITION, _)) + .Times(1); + + // LOOP STEP + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) + .Times(1); + + // RESULT + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(result, c, RESULT, _)) + .Times(1); + + EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); + + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(result, c, RESULT, _)) + .Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(c, &expr, _)).Times(1); + + RewriteTraversalOptions opts; + opts.use_comprehension_callbacks = true; + AstRewrite(&expr, &source_info, &handler, opts); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto c = expr.mutable_comprehension_expr(); + auto iter_range = c->mutable_iter_range(); + auto iter_range_expr = iter_range->mutable_const_expr(); + auto accu_init = c->mutable_accu_init(); + auto accu_init_expr = accu_init->mutable_ident_expr(); + auto loop_condition = c->mutable_loop_condition(); + auto loop_condition_expr = loop_condition->mutable_const_expr(); + auto loop_step = c->mutable_loop_step(); + auto loop_step_expr = loop_step->mutable_ident_expr(); + auto result = c->mutable_result(); + auto result_expr = result->mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); + + EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(ITER_RANGE, &expr, _)).Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(ACCU_INIT, &expr, _)).Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(LOOP_CONDITION, &expr, _)).Times(1); + + // LOOP STEP + EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(LOOP_STEP, &expr, _)).Times(1); + + // RESULT + EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(RESULT, &expr, _)).Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(c, &expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test handling of CreateList node. +TEST(AstCrawlerTest, CheckCreateList) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto list_expr = expr.mutable_list_expr(); + auto arg0 = list_expr->add_elements(); + auto const_expr = arg0->mutable_const_expr(); + auto arg1 = list_expr->add_elements(); + auto ident_expr = arg1->mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitCreateList(list_expr, &expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test handling of CreateStruct node. +TEST(AstCrawlerTest, CheckCreateStruct) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto struct_expr = expr.mutable_struct_expr(); + auto entry0 = struct_expr->add_entries(); + + auto key = entry0->mutable_map_key()->mutable_const_expr(); + auto value = entry0->mutable_value()->mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(key, &entry0->map_key(), _)).Times(1); + EXPECT_CALL(handler, PostVisitIdent(value, &entry0->value(), _)).Times(1); + EXPECT_CALL(handler, PostVisitCreateStruct(struct_expr, &expr, _)).Times(1); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test generic Expr handlers. +TEST(AstCrawlerTest, CheckExprHandlers) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr expr; + auto struct_expr = expr.mutable_struct_expr(); + auto entry0 = struct_expr->add_entries(); + + entry0->mutable_map_key()->mutable_const_expr(); + entry0->mutable_value()->mutable_ident_expr(); + + EXPECT_CALL(handler, PreVisitExpr(_, _)).Times(3); + EXPECT_CALL(handler, PostVisitExpr(_, _)).Times(3); + + AstRewrite(&expr, &source_info, &handler); +} + +// Test generic Expr handlers. +TEST(AstCrawlerTest, CheckExprRewriteHandlers) { + SourceInfo source_info; + MockAstRewriter handler; + + Expr select_expr; + select_expr.mutable_select_expr()->set_field("var"); + auto* inner_select_expr = + select_expr.mutable_select_expr()->mutable_operand(); + inner_select_expr->mutable_select_expr()->set_field("mid"); + auto* ident = inner_select_expr->mutable_select_expr()->mutable_operand(); + ident->mutable_ident_expr()->set_name("top"); + + { + InSequence sequence; + EXPECT_CALL(handler, + TraversalStackUpdate(testing::ElementsAre(&select_expr))); + EXPECT_CALL(handler, PreVisitRewrite(&select_expr, _)); + + EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( + &select_expr, inner_select_expr))); + EXPECT_CALL(handler, PreVisitRewrite(inner_select_expr, _)); + + EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( + &select_expr, inner_select_expr, ident))); + EXPECT_CALL(handler, PreVisitRewrite(ident, _)); + + EXPECT_CALL(handler, PostVisitRewrite(ident, _)); + EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( + &select_expr, inner_select_expr))); + + EXPECT_CALL(handler, PostVisitRewrite(inner_select_expr, _)); + EXPECT_CALL(handler, + TraversalStackUpdate(testing::ElementsAre(&select_expr))); + + EXPECT_CALL(handler, PostVisitRewrite(&select_expr, _)); + EXPECT_CALL(handler, TraversalStackUpdate(testing::IsEmpty())); + } + + EXPECT_FALSE(AstRewrite(&select_expr, &source_info, &handler)); +} + +// Simple rewrite that replaces a select path with a dot-qualified identifier. +class RewriterExample : public AstRewriterBase { + public: + RewriterExample() {} + bool PostVisitRewrite(Expr* expr, const SourcePosition* info) override { + if (target_.has_value() && expr->id() == *target_) { + expr->mutable_ident_expr()->set_name("com.google.Identifier"); + return true; + } + return false; + } + + void PostVisitIdent(const Ident* ident, const Expr* expr, + const SourcePosition* pos) override { + if (path_.size() >= 3) { + if (ident->name() == "com") { + const Expr* p1 = path_.at(path_.size() - 2); + const Expr* p2 = path_.at(path_.size() - 3); + + if (p1->has_select_expr() && p1->select_expr().field() == "google" && + p2->has_select_expr() && + p2->select_expr().field() == "Identifier") { + target_ = p2->id(); + } + } + } + } + + void TraversalStackUpdate(absl::Span path) override { + path_ = path; + } + + private: + absl::Span path_; + absl::optional target_; +}; + +TEST(AstRewrite, SelectRewriteExample) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed, + parser::Parse("com.google.Identifier")); + RewriterExample example; + ASSERT_TRUE( + AstRewrite(parsed.mutable_expr(), &parsed.source_info(), &example)); + + EXPECT_THAT(parsed.expr(), testutil::EqualsProto(R"pb( + id: 3 + ident_expr { name: "com.google.Identifier" } + )pb")); +} + +// Rewrites x -> y -> z to demonstrate traversal when a node is rewritten on +// both passes. +class PreRewriterExample : public AstRewriterBase { + public: + PreRewriterExample() {} + bool PreVisitRewrite(Expr* expr, const SourcePosition* info) override { + if (expr->ident_expr().name() == "x") { + expr->mutable_ident_expr()->set_name("y"); + return true; + } + return false; + } + + bool PostVisitRewrite(Expr* expr, const SourcePosition* info) override { + if (expr->ident_expr().name() == "y") { + expr->mutable_ident_expr()->set_name("z"); + return true; + } + return false; + } + + void PostVisitIdent(const Ident* ident, const Expr* expr, + const SourcePosition* pos) override { + visited_idents_.push_back(ident->name()); + } + + const std::vector& visited_idents() const { + return visited_idents_; + } + + private: + std::vector visited_idents_; +}; + +TEST(AstRewrite, PreAndPostVisitExpample) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed, parser::Parse("x")); + PreRewriterExample visitor; + ASSERT_TRUE( + AstRewrite(parsed.mutable_expr(), &parsed.source_info(), &visitor)); + + EXPECT_THAT(parsed.expr(), testutil::EqualsProto(R"pb( + id: 1 + ident_expr { name: "z" } + )pb")); + EXPECT_THAT(visitor.visited_idents(), ElementsAre("y")); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index c87e5fc88..c18b806b9 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -12,172 +12,285 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "eval/public/ast_traverse.h" + #include -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "eval/public/ast_traverse.h" +#include "cel/expr/syntax.pb.h" +#include "absl/log/absl_log.h" +#include "absl/types/variant.h" +#include "eval/public/ast_visitor.h" #include "eval/public/source_position.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; -using Ident = google::api::expr::v1alpha1::Expr::Ident; -using Select = google::api::expr::v1alpha1::Expr::Select; -using Call = google::api::expr::v1alpha1::Expr::Call; -using CreateList = google::api::expr::v1alpha1::Expr::CreateList; -using CreateStruct = google::api::expr::v1alpha1::Expr::CreateStruct; -using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; +using cel::expr::Expr; +using cel::expr::SourceInfo; +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; namespace { -struct StackRecord { - public: - static constexpr int kNotCallArg = -1; - - StackRecord(const Expr *e, const SourceInfo *info) - : expr(e), - source_info(info), - visited(false), - calling_expr(nullptr), - call_arg(kNotCallArg) {} - - StackRecord(const Expr *e, const SourceInfo *info, const Expr *call, - int argnum) - : expr(e), - source_info(info), - visited(false), - calling_expr(call), - call_arg(argnum) {} - - const Expr *expr; - const SourceInfo *source_info; - bool visited; +struct ArgRecord { + // Not null. + const Expr* expr; + // Not null. + const SourceInfo* source_info; // For records that are direct arguments to call, we need to call // the CallArg visitor immediately after the argument is evaluated. - const Expr *calling_expr; + const Expr* calling_expr; int call_arg; }; -void PreVisit(const StackRecord &record, AstVisitor *visitor) { - const Expr *expr = record.expr; - const SourcePosition position(expr->id(), record.source_info); - visitor->PreVisitExpr(expr, &position); - switch (expr->expr_kind_case()) { - case Expr::kSelectExpr: - visitor->PreVisitSelect(&expr->select_expr(), expr, &position); - break; - case Expr::kCallExpr: - visitor->PreVisitCall(&expr->call_expr(), expr, &position); - break; - case Expr::kComprehensionExpr: - visitor->PreVisitComprehension(&expr->comprehension_expr(), expr, - &position); - break; - default: - // No pre-visit action. - break; +struct ComprehensionRecord { + // Not null. + const Expr* expr; + // Not null. + const SourceInfo* source_info; + + const Comprehension* comprehension; + const Expr* comprehension_expr; + ComprehensionArg comprehension_arg; + bool use_comprehension_callbacks; +}; + +struct ExprRecord { + // Not null. + const Expr* expr; + // Not null. + const SourceInfo* source_info; +}; + +using StackRecordKind = + std::variant; + +struct StackRecord { + public: + ABSL_ATTRIBUTE_UNUSED static constexpr int kNotCallArg = -1; + static constexpr int kTarget = -2; + + StackRecord(const Expr* e, const SourceInfo* info) { + ExprRecord record; + record.expr = e; + record.source_info = info; + record_variant = record; + } + + StackRecord(const Expr* e, const SourceInfo* info, + const Comprehension* comprehension, + const Expr* comprehension_expr, + ComprehensionArg comprehension_arg, + bool use_comprehension_callbacks) { + if (use_comprehension_callbacks) { + ComprehensionRecord record; + record.expr = e; + record.source_info = info; + record.comprehension = comprehension; + record.comprehension_expr = comprehension_expr; + record.comprehension_arg = comprehension_arg; + record.use_comprehension_callbacks = use_comprehension_callbacks; + record_variant = record; + return; + } + ArgRecord record; + record.expr = e; + record.source_info = info; + record.calling_expr = comprehension_expr; + record.call_arg = comprehension_arg; + record_variant = record; + } + + StackRecord(const Expr* e, const SourceInfo* info, const Expr* call, + int argnum) { + ArgRecord record; + record.expr = e; + record.source_info = info; + record.calling_expr = call; + record.call_arg = argnum; + record_variant = record; } + StackRecordKind record_variant; + bool visited = false; +}; + +struct PreVisitor { + void operator()(const ExprRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PreVisitExpr(expr, &position); + switch (expr->expr_kind_case()) { + case Expr::kConstExpr: + visitor->PreVisitConst(&expr->const_expr(), expr, &position); + break; + case Expr::kIdentExpr: + visitor->PreVisitIdent(&expr->ident_expr(), expr, &position); + break; + case Expr::kSelectExpr: + visitor->PreVisitSelect(&expr->select_expr(), expr, &position); + break; + case Expr::kCallExpr: + visitor->PreVisitCall(&expr->call_expr(), expr, &position); + break; + case Expr::kListExpr: + visitor->PreVisitCreateList(&expr->list_expr(), expr, &position); + break; + case Expr::kStructExpr: + visitor->PreVisitCreateStruct(&expr->struct_expr(), expr, &position); + break; + case Expr::kComprehensionExpr: + visitor->PreVisitComprehension(&expr->comprehension_expr(), expr, + &position); + break; + default: + // No pre-visit action. + break; + } + } + + // Do nothing for Arg variant. + void operator()(const ArgRecord&) {} + + void operator()(const ComprehensionRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PreVisitComprehensionSubexpression( + expr, record.comprehension, record.comprehension_arg, &position); + } + + AstVisitor* visitor; +}; + +void PreVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PreVisitor{visitor}, record.record_variant); } -void PostVisit(const StackRecord &record, AstVisitor *visitor) { - const Expr *expr = record.expr; - const SourcePosition position(expr->id(), record.source_info); - switch (expr->expr_kind_case()) { - case Expr::kConstExpr: - visitor->PostVisitConst(&expr->const_expr(), expr, &position); - break; - case Expr::kIdentExpr: - visitor->PostVisitIdent(&expr->ident_expr(), expr, &position); - break; - case Expr::kSelectExpr: - visitor->PostVisitSelect(&expr->select_expr(), expr, &position); - break; - case Expr::kCallExpr: - visitor->PostVisitCall(&expr->call_expr(), expr, &position); - break; - case Expr::kListExpr: - visitor->PostVisitCreateList(&expr->list_expr(), expr, &position); - break; - case Expr::kStructExpr: - visitor->PostVisitCreateStruct(&expr->struct_expr(), expr, &position); - break; - case Expr::kComprehensionExpr: - visitor->PostVisitComprehension(&expr->comprehension_expr(), expr, - &position); - break; - default: - GOOGLE_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); - } - - if (record.call_arg != StackRecord::kNotCallArg && - record.calling_expr != nullptr) { - visitor->PostVisitArg(record.call_arg, record.calling_expr, &position); - } - visitor->PostVisitExpr(expr, &position); +struct PostVisitor { + void operator()(const ExprRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + switch (expr->expr_kind_case()) { + case Expr::kConstExpr: + visitor->PostVisitConst(&expr->const_expr(), expr, &position); + break; + case Expr::kIdentExpr: + visitor->PostVisitIdent(&expr->ident_expr(), expr, &position); + break; + case Expr::kSelectExpr: + visitor->PostVisitSelect(&expr->select_expr(), expr, &position); + break; + case Expr::kCallExpr: + visitor->PostVisitCall(&expr->call_expr(), expr, &position); + break; + case Expr::kListExpr: + visitor->PostVisitCreateList(&expr->list_expr(), expr, &position); + break; + case Expr::kStructExpr: + visitor->PostVisitCreateStruct(&expr->struct_expr(), expr, &position); + break; + case Expr::kComprehensionExpr: + visitor->PostVisitComprehension(&expr->comprehension_expr(), expr, + &position); + break; + default: + ABSL_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); + } + + visitor->PostVisitExpr(expr, &position); + } + + void operator()(const ArgRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + if (record.call_arg == StackRecord::kTarget) { + visitor->PostVisitTarget(record.calling_expr, &position); + } else { + visitor->PostVisitArg(record.call_arg, record.calling_expr, &position); + } + } + + void operator()(const ComprehensionRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PostVisitComprehensionSubexpression( + expr, record.comprehension, record.comprehension_arg, &position); + } + + AstVisitor* visitor; +}; + +void PostVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PostVisitor{visitor}, record.record_variant); } -void PushSelectDeps(const Select *select_expr, const StackRecord &record, - std::stack *stack) { +void PushSelectDeps(const Select* select_expr, const SourceInfo* source_info, + std::stack* stack) { if (select_expr->has_operand()) { - stack->push(StackRecord(&select_expr->operand(), record.source_info)); + stack->push(StackRecord(&select_expr->operand(), source_info)); } } -void PushCallDeps(const Call *call_expr, const Expr *expr, - const StackRecord &record, std::stack *stack) { +void PushCallDeps(const Call* call_expr, const Expr* expr, + const SourceInfo* source_info, + std::stack* stack) { const int arg_size = call_expr->args_size(); // Our contract is that we visit arguments in order. To do that, we need // to push them onto the stack in reverse order. for (int i = arg_size - 1; i >= 0; --i) { - stack->push(StackRecord(&call_expr->args(i), record.source_info, expr, i)); + stack->push(StackRecord(&call_expr->args(i), source_info, expr, i)); } // Are we receiver-style? if (call_expr->has_target()) { - stack->push(StackRecord(&call_expr->target(), record.source_info)); + stack->push(StackRecord(&call_expr->target(), source_info, expr, + StackRecord::kTarget)); } } -void PushListDeps(const CreateList *list_expr, const StackRecord &record, - std::stack *stack) { - const auto &elements = list_expr->elements(); +void PushListDeps(const CreateList* list_expr, const SourceInfo* source_info, + std::stack* stack) { + const auto& elements = list_expr->elements(); for (auto it = elements.rbegin(); it != elements.rend(); ++it) { - const auto &element = *it; - stack->push(StackRecord(&element, record.source_info)); + const auto& element = *it; + stack->push(StackRecord(&element, source_info)); } } -void PushStructDeps(const CreateStruct *struct_expr, const StackRecord &record, - std::stack *stack) { - const auto &entries = struct_expr->entries(); +void PushStructDeps(const CreateStruct* struct_expr, + const SourceInfo* source_info, + std::stack* stack) { + const auto& entries = struct_expr->entries(); for (auto it = entries.rbegin(); it != entries.rend(); ++it) { - const auto &entry = *it; + const auto& entry = *it; // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_value()) { - stack->push(StackRecord(&entry.value(), record.source_info)); + stack->push(StackRecord(&entry.value(), source_info)); } if (entry.has_map_key()) { - stack->push(StackRecord(&entry.map_key(), record.source_info)); + stack->push(StackRecord(&entry.map_key(), source_info)); } } } -void PushComprehensionDeps(const Comprehension *c, const Expr *expr, - const StackRecord &record, - std::stack *stack) { - const SourceInfo *source_info = record.source_info; - StackRecord iter_range(&c->iter_range(), source_info, expr, ITER_RANGE); - StackRecord accu_init(&c->accu_init(), source_info, expr, ACCU_INIT); - StackRecord loop_condition(&c->loop_condition(), source_info, expr, - LOOP_CONDITION); - StackRecord loop_step(&c->loop_step(), source_info, expr, LOOP_STEP); - StackRecord result(&c->result(), source_info, expr, RESULT); +void PushComprehensionDeps(const Comprehension* c, const Expr* expr, + const SourceInfo* source_info, + std::stack* stack, + bool use_comprehension_callbacks) { + StackRecord iter_range(&c->iter_range(), source_info, c, expr, ITER_RANGE, + use_comprehension_callbacks); + StackRecord accu_init(&c->accu_init(), source_info, c, expr, ACCU_INIT, + use_comprehension_callbacks); + StackRecord loop_condition(&c->loop_condition(), source_info, c, expr, + LOOP_CONDITION, use_comprehension_callbacks); + StackRecord loop_step(&c->loop_step(), source_info, c, expr, LOOP_STEP, + use_comprehension_callbacks); + StackRecord result(&c->result(), source_info, c, expr, RESULT, + use_comprehension_callbacks); // Push them in reverse order. stack->push(result); stack->push(loop_step); @@ -186,42 +299,61 @@ void PushComprehensionDeps(const Comprehension *c, const Expr *expr, stack->push(iter_range); } -void PushDependencies(const StackRecord &record, - std::stack *stack) { - const Expr *expr = record.expr; - switch (expr->expr_kind_case()) { - case Expr::kSelectExpr: - PushSelectDeps(&expr->select_expr(), record, stack); - break; - case Expr::kCallExpr: - PushCallDeps(&expr->call_expr(), expr, record, stack); - break; - case Expr::kListExpr: - PushListDeps(&expr->list_expr(), record, stack); - break; - case Expr::kStructExpr: - PushStructDeps(&expr->struct_expr(), record, stack); - break; - case Expr::kComprehensionExpr: - PushComprehensionDeps(&expr->comprehension_expr(), expr, record, stack); - break; - default: - break; +struct PushDepsVisitor { + void operator()(const ExprRecord& record) { + const Expr* expr = record.expr; + switch (expr->expr_kind_case()) { + case Expr::kSelectExpr: + PushSelectDeps(&expr->select_expr(), record.source_info, &stack); + break; + case Expr::kCallExpr: + PushCallDeps(&expr->call_expr(), expr, record.source_info, &stack); + break; + case Expr::kListExpr: + PushListDeps(&expr->list_expr(), record.source_info, &stack); + break; + case Expr::kStructExpr: + PushStructDeps(&expr->struct_expr(), record.source_info, &stack); + break; + case Expr::kComprehensionExpr: + PushComprehensionDeps(&expr->comprehension_expr(), expr, + record.source_info, &stack, + options.use_comprehension_callbacks); + break; + default: + break; + } + } + + void operator()(const ArgRecord& record) { + stack.push(StackRecord(record.expr, record.source_info)); + } + + void operator()(const ComprehensionRecord& record) { + stack.push(StackRecord(record.expr, record.source_info)); } + + std::stack& stack; + const TraversalOptions& options; +}; + +void PushDependencies(const StackRecord& record, std::stack& stack, + const TraversalOptions& options) { + absl::visit(PushDepsVisitor{stack, options}, record.record_variant); } } // namespace -void AstTraverse(const Expr *expr, const SourceInfo *source_info, - AstVisitor *visitor) { +void AstTraverse(const Expr* expr, const SourceInfo* source_info, + AstVisitor* visitor, TraversalOptions options) { std::stack stack; stack.push(StackRecord(expr, source_info)); while (!stack.empty()) { - StackRecord &record = stack.top(); + StackRecord& record = stack.top(); if (!record.visited) { PreVisit(record, visitor); - PushDependencies(record, &stack); + PushDependencies(record, stack, options); record.visited = true; } else { PostVisit(record, visitor); @@ -230,7 +362,4 @@ void AstTraverse(const Expr *expr, const SourceInfo *source_info, } } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/ast_traverse.h b/eval/public/ast_traverse.h index fe77541bf..f81c6f47a 100644 --- a/eval/public/ast_traverse.h +++ b/eval/public/ast_traverse.h @@ -17,24 +17,51 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ +#include "cel/expr/syntax.pb.h" #include "eval/public/ast_visitor.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// This method performs traversal of AST. -// expr is root node of the tree. -// handler is callback object. -void AstTraverse(const google::api::expr::v1alpha1::Expr *expr, - const google::api::expr::v1alpha1::SourceInfo *source_info, - AstVisitor *visitor); - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +namespace google::api::expr::runtime { + +struct TraversalOptions { + bool use_comprehension_callbacks; + + TraversalOptions() : use_comprehension_callbacks(false) {} +}; + +// Traverses the AST representation in an expr proto. +// +// expr: root node of the tree. +// source_info: optional additional parse information about the expression +// visitor: the callback object that receives the visitation notifications +// +// Traversal order follows the pattern: +// PreVisitExpr +// ..PreVisit{ExprKind} +// ....PreVisit{ArgumentIndex} +// .......PreVisitExpr (subtree) +// .......PostVisitExpr (subtree) +// ....PostVisit{ArgumentIndex} +// ..PostVisit{ExprKind} +// PostVisitExpr +// +// Example callback order for fn(1, var): +// PreVisitExpr +// ..PreVisitCall(fn) +// ......PreVisitExpr +// ........PostVisitConst(1) +// ......PostVisitExpr +// ....PostVisitArg(fn, 0) +// ......PreVisitExpr +// ........PostVisitIdent(var) +// ......PostVisitExpr +// ....PostVisitArg(fn, 1) +// ..PostVisitCall(fn) +// PostVisitExpr +void AstTraverse(const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, + AstVisitor* visitor, + TraversalOptions options = TraversalOptions()); + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ diff --git a/eval/public/ast_traverse_test.cc b/eval/public/ast_traverse_test.cc index da71ea80f..ca6d81b72 100644 --- a/eval/public/ast_traverse_test.cc +++ b/eval/public/ast_traverse_test.cc @@ -14,26 +14,23 @@ #include "eval/public/ast_traverse.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "eval/public/ast_visitor.h" +#include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::Constant; +using cel::expr::Expr; +using cel::expr::SourceInfo; using testing::_; -using Ident = google::api::expr::v1alpha1::Expr::Ident; -using Select = google::api::expr::v1alpha1::Expr::Select; -using Call = google::api::expr::v1alpha1::Expr::Call; -using CreateList = google::api::expr::v1alpha1::Expr::CreateList; -using CreateStruct = google::api::expr::v1alpha1::Expr::CreateStruct; -using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; class MockAstVisitor : public AstVisitor { public: @@ -45,11 +42,24 @@ class MockAstVisitor : public AstVisitor { MOCK_METHOD(void, PostVisitExpr, (const Expr* expr, const SourcePosition* position), (override)); + // Constant node handler. + MOCK_METHOD(void, PreVisitConst, + (const Constant* const_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Constant node handler. MOCK_METHOD(void, PostVisitConst, (const Constant* const_expr, const Expr* expr, const SourcePosition* position), (override)); + // Ident node handler. + MOCK_METHOD(void, PreVisitIdent, + (const Ident* ident_expr, const Expr* expr, + const SourcePosition* position), + (override)); + // Ident node handler. MOCK_METHOD(void, PostVisitIdent, (const Ident* ident_expr, const Expr* expr, @@ -87,18 +97,44 @@ class MockAstVisitor : public AstVisitor { const SourcePosition* position), (override)); + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehensionSubexpression, + (const Expr* expr, const Comprehension* comprehension_expr, + ComprehensionArg comprehension_arg, + const SourcePosition* position), + (override)); + MOCK_METHOD(void, PostVisitComprehensionSubexpression, + (const Expr* expr, const Comprehension* comprehension_expr, + ComprehensionArg comprehension_arg, + const SourcePosition* position), + (override)); + // We provide finer granularity for Call and Comprehension node callbacks // to allow special handling for short-circuiting. + MOCK_METHOD(void, PostVisitTarget, + (const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitArg, (int arg_num, const Expr* expr, const SourcePosition* position), (override)); + // CreateList node handler group + MOCK_METHOD(void, PreVisitCreateList, + (const CreateList* list_expr, const Expr* expr, + const SourcePosition* position), + (override)); + // CreateList node handler group MOCK_METHOD(void, PostVisitCreateList, (const CreateList* list_expr, const Expr* expr, const SourcePosition* position), (override)); + // CreateStruct node handler group + MOCK_METHOD(void, PreVisitCreateStruct, + (const CreateStruct* struct_expr, const Expr* expr, + const SourcePosition* position), + (override)); + // CreateStruct node handler group MOCK_METHOD(void, PostVisitCreateStruct, (const CreateStruct* struct_expr, const Expr* expr, @@ -113,6 +149,7 @@ TEST(AstCrawlerTest, CheckCrawlConstant) { Expr expr; auto const_expr = expr.mutable_const_expr(); + EXPECT_CALL(handler, PreVisitConst(const_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(const_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); @@ -125,6 +162,7 @@ TEST(AstCrawlerTest, CheckCrawlIdent) { Expr expr; auto ident_expr = expr.mutable_ident_expr(); + EXPECT_CALL(handler, PreVisitIdent(ident_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(ident_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); @@ -163,27 +201,80 @@ TEST(AstCrawlerTest, CheckCrawlSelect) { AstTraverse(&expr, &source_info, &handler); } -// Test handling of Call node -TEST(AstCrawlerTest, CheckCrawlCall) { +// Test handling of Call node without receiver +TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { SourceInfo source_info; MockAstVisitor handler; + // (, ) Expr expr; - auto call_expr = expr.mutable_call_expr(); - auto arg0 = call_expr->add_args(); - auto const_expr = arg0->mutable_const_expr(); - auto arg1 = call_expr->add_args(); - auto ident_expr = arg1->mutable_ident_expr(); + auto* call_expr = expr.mutable_call_expr(); + Expr* arg0 = call_expr->add_args(); + auto* const_expr = arg0->mutable_const_expr(); + Expr* arg1 = call_expr->add_args(); + auto* ident_expr = arg1->mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitTarget(_, _)).Times(0); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of Call node with receiver +TEST(AstCrawlerTest, CheckCrawlCallReceiver) { + SourceInfo source_info; + MockAstVisitor handler; + + // .(, ) + Expr expr; + auto* call_expr = expr.mutable_call_expr(); + Expr* target = call_expr->mutable_target(); + auto* target_ident = target->mutable_ident_expr(); + Expr* arg0 = call_expr->add_args(); + auto* const_expr = arg0->mutable_const_expr(); + Expr* arg1 = call_expr->add_args(); + auto* ident_expr = arg1->mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); + + // Target + EXPECT_CALL(handler, PostVisitIdent(target_ident, target, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(target, _)).Times(1); + EXPECT_CALL(handler, PostVisitTarget(&expr, _)).Times(1); + + // Arg0 EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); + + // Arg1 EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); + + // Back to call EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } @@ -211,15 +302,99 @@ TEST(AstCrawlerTest, CheckCrawlComprehension) { // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) + .Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) + .Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(loop_condition, c, + LOOP_CONDITION, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(loop_condition, c, + LOOP_CONDITION, _)) + .Times(1); + + // LOOP STEP + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) + .Times(1); + + // RESULT + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(result, c, RESULT, _)) + .Times(1); + + EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); + + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(result, c, RESULT, _)) + .Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(c, &expr, _)).Times(1); + + TraversalOptions opts; + opts.use_comprehension_callbacks = true; + AstTraverse(&expr, &source_info, &handler, opts); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto c = expr.mutable_comprehension_expr(); + auto iter_range = c->mutable_iter_range(); + auto iter_range_expr = iter_range->mutable_const_expr(); + auto accu_init = c->mutable_accu_init(); + auto accu_init_expr = accu_init->mutable_ident_expr(); + auto loop_condition = c->mutable_loop_condition(); + auto loop_condition_expr = loop_condition->mutable_const_expr(); + auto loop_step = c->mutable_loop_step(); + auto loop_step_expr = loop_step->mutable_ident_expr(); + auto result = c->mutable_result(); + auto result_expr = result->mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(ITER_RANGE, &expr, _)).Times(1); + + // ACCU_INIT EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(ACCU_INIT, &expr, _)).Times(1); + + // LOOP CONDITION EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) .Times(1); EXPECT_CALL(handler, PostVisitArg(LOOP_CONDITION, &expr, _)).Times(1); + + // LOOP STEP EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(LOOP_STEP, &expr, _)).Times(1); + + // RESULT EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(RESULT, &expr, _)).Times(1); @@ -242,6 +417,7 @@ TEST(AstCrawlerTest, CheckCreateList) { testing::InSequence seq; + EXPECT_CALL(handler, PreVisitCreateList(list_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitCreateList(list_expr, &expr, _)).Times(1); @@ -263,6 +439,7 @@ TEST(AstCrawlerTest, CheckCreateStruct) { testing::InSequence seq; + EXPECT_CALL(handler, PreVisitCreateStruct(struct_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(key, &entry0->map_key(), _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(value, &entry0->value(), _)).Times(1); EXPECT_CALL(handler, PostVisitCreateStruct(struct_expr, &expr, _)).Times(1); @@ -290,7 +467,4 @@ TEST(AstCrawlerTest, CheckExprHandlers) { } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/ast_visitor.h b/eval/public/ast_visitor.h index af143c666..f8185a576 100644 --- a/eval/public/ast_visitor.h +++ b/eval/public/ast_visitor.h @@ -17,8 +17,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ +#include "cel/expr/syntax.pb.h" #include "eval/public/source_position.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" namespace google { namespace api { @@ -49,82 +49,132 @@ class AstVisitor { // Is invoked before child Expr nodes being processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PreVisitExpr(const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitExpr(const cel::expr::Expr*, const SourcePosition*) {} // Expr node handler method. Called for all Expr nodes. // Is invoked after child Expr nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PostVisitExpr(const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitExpr(const cel::expr::Expr*, + const SourcePosition*) {} + + // Const node handler. + // Invoked before child nodes are processed. + // TODO(issues/22): this method is not pure virtual to avoid dependencies + // breakage. Change it in subsequent CLs. + virtual void PreVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, const SourcePosition*) {} // Const node handler. // Invoked after child nodes are processed. - virtual void PostVisitConst(const google::api::expr::v1alpha1::Constant*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, const SourcePosition*) = 0; + // Ident node handler. + // Invoked before child nodes are processed. + // TODO(issues/22): this method is not pure virtual to avoid dependencies + // breakage. Change it in subsequent CLs. + virtual void PreVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, + const SourcePosition*) {} + // Ident node handler. // Invoked after child nodes are processed. - virtual void PostVisitIdent(const google::api::expr::v1alpha1::Expr::Ident*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Select node handler // Invoked before child nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PreVisitSelect(const google::api::expr::v1alpha1::Expr::Select*, - const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, const SourcePosition*) {} // Select node handler // Invoked after child nodes are processed. - virtual void PostVisitSelect(const google::api::expr::v1alpha1::Expr::Select*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. - virtual void PreVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after all child nodes are processed. - virtual void PostVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) = 0; + // Invoked after target node is processed. + // Expr is the call expression. + virtual void PostVisitTarget(const cel::expr::Expr*, + const SourcePosition*) = 0; + // Invoked before all child nodes are processed. virtual void PreVisitComprehension( - const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, const SourcePosition*) = 0; + const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) = 0; + + // Invoked before comprehension child node is processed. + virtual void PreVisitComprehensionSubexpression( + const cel::expr::Expr* subexpr, + const cel::expr::Expr::Comprehension* compr, + ComprehensionArg comprehension_arg, const SourcePosition*) {} + + // Invoked after comprehension child node is processed. + virtual void PostVisitComprehensionSubexpression( + const cel::expr::Expr* subexpr, + const cel::expr::Expr::Comprehension* compr, + ComprehensionArg comprehension_arg, const SourcePosition*) {} // Invoked after all child nodes are processed. virtual void PostVisitComprehension( - const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, const SourcePosition*) = 0; + const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. - virtual void PostVisitArg(int arg_num, const google::api::expr::v1alpha1::Expr*, + // Expr is the call expression. + virtual void PostVisitArg(int arg_num, const cel::expr::Expr*, const SourcePosition*) = 0; + // CreateList node handler + // Invoked before child nodes are processed. + // TODO(issues/22): this method is not pure virtual to avoid dependencies + // breakage. Change it in subsequent CLs. + virtual void PreVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, + const SourcePosition*) {} + // CreateList node handler // Invoked after child nodes are processed. - virtual void PostVisitCreateList(const google::api::expr::v1alpha1::Expr::CreateList*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, const SourcePosition*) = 0; + // CreateStruct node handler + // Invoked before child nodes are processed. + // TODO(issues/22): this method is not pure virtual to avoid dependencies + // breakage. Change it in subsequent CLs. + virtual void PreVisitCreateStruct( + const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, const SourcePosition*) {} + // CreateStruct node handler // Invoked after child nodes are processed. virtual void PostVisitCreateStruct( - const google::api::expr::v1alpha1::Expr::CreateStruct*, - const google::api::expr::v1alpha1::Expr*, const SourcePosition*) = 0; + const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, const SourcePosition*) = 0; }; } // namespace runtime diff --git a/eval/public/ast_visitor_base.h b/eval/public/ast_visitor_base.h index d41458e15..df8d8a926 100644 --- a/eval/public/ast_visitor_base.h +++ b/eval/public/ast_visitor_base.h @@ -18,7 +18,7 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ #include "eval/public/ast_visitor.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" namespace google { namespace api { @@ -38,61 +38,66 @@ class AstVisitorBase : public AstVisitor { // Const node handler. // Invoked after child nodes are processed. - void PostVisitConst(const google::api::expr::v1alpha1::Constant*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, const SourcePosition*) override {} // Ident node handler. // Invoked after child nodes are processed. - void PostVisitIdent(const google::api::expr::v1alpha1::Expr::Ident*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, const SourcePosition*) override {} // Select node handler // Invoked after child nodes are processed. - void PostVisitSelect(const google::api::expr::v1alpha1::Expr::Select*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, const SourcePosition*) override {} // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. - void PreVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + void PreVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after all child nodes are processed. - void PostVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked before all child nodes are processed. - void PreVisitComprehension(const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, + void PreVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after all child nodes are processed. - void PostVisitComprehension(const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. - void PostVisitArg(int, const google::api::expr::v1alpha1::Expr*, + // Expr is the call expression. + void PostVisitArg(int, const cel::expr::Expr*, const SourcePosition*) override {} + // Invoked after target node processed. + void PostVisitTarget(const cel::expr::Expr*, + const SourcePosition*) override {} + // CreateList node handler // Invoked after child nodes are processed. - void PostVisitCreateList(const google::api::expr::v1alpha1::Expr::CreateList*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, const SourcePosition*) override {} // CreateStruct node handler // Invoked after child nodes are processed. - void PostVisitCreateStruct(const google::api::expr::v1alpha1::Expr::CreateStruct*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCreateStruct(const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, const SourcePosition*) override {} }; diff --git a/eval/public/base_activation.h b/eval/public/base_activation.h new file mode 100644 index 000000000..7d9e0a51c --- /dev/null +++ b/eval/public/base_activation.h @@ -0,0 +1,75 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BASE_ACTIVATION_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BASE_ACTIVATION_H_ + +#include + +#include "google/protobuf/field_mask.pb.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_value.h" +#include "runtime/internal/attribute_matcher.h" + +namespace cel::runtime_internal { +class ActivationAttributeMatcherAccess; +} + +namespace google::api::expr::runtime { + +// Base class for an activation. +class BaseActivation { + public: + BaseActivation() = default; + + // Non-copyable/non-assignable + BaseActivation(const BaseActivation&) = delete; + BaseActivation& operator=(const BaseActivation&) = delete; + + // Move-constructible/move-assignable + BaseActivation(BaseActivation&& other) = default; + BaseActivation& operator=(BaseActivation&& other) = default; + + // Return a list of function overloads for the given name. + virtual std::vector FindFunctionOverloads( + absl::string_view) const = 0; + + // Provide the value that is bound to the name, if found. + // arena parameter is provided to support the case when we want to pass the + // ownership of returned object ( Message/List/Map ) to Evaluator. + virtual absl::optional FindValue(absl::string_view, + google::protobuf::Arena*) const = 0; + + // Return the collection of attribute patterns that determine missing + // attributes. + virtual const std::vector& missing_attribute_patterns() + const { + static const std::vector* empty = + new std::vector({}); + return *empty; + } + + // Return the collection of attribute patterns that determine "unknown" + // values. + virtual const std::vector& unknown_attribute_patterns() + const { + static const std::vector* empty = + new std::vector({}); + return *empty; + } + + virtual ~BaseActivation() = default; + + private: + friend class cel::runtime_internal::ActivationAttributeMatcherAccess; + + // Internal getter for overriding the attribute matching behavior. + virtual const cel::runtime_internal::AttributeMatcher* absl_nullable + GetAttributeMatcher() const { + return nullptr; + } +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BASE_ACTIVATION_H_ diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 8b8becb14..52bb07c01 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -1,1442 +1,65 @@ -#include "eval/public/builtin_func_registrar.h" +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include -#include +#include "eval/public/builtin_func_registrar.h" -#include "google/protobuf/util/time_util.h" #include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "eval/public/cel_builtins.h" -#include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" -#include "eval/public/containers/container_backed_list_impl.h" -#include "re2/re2.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -using google::protobuf::Arena; - -namespace { - -// Comparison template functions -template -CelValue Inequal(Arena*, Type t1, Type t2) { - return CelValue::CreateBool(t1 != t2); -} - -template -CelValue Equal(Arena*, Type t1, Type t2) { - return CelValue::CreateBool(t1 == t2); -} - -// Forward declaration of the generic equality operator -template <> -CelValue Equal(Arena*, CelValue t1, CelValue t2); - -template -bool LessThan(Arena*, Type t1, Type t2) { - return (t1 < t2); -} - -template -bool LessThanOrEqual(Arena*, Type t1, Type t2) { - return (t1 <= t2); -} - -template -bool GreaterThan(Arena* arena, Type t1, Type t2) { - return LessThan(arena, t2, t1); -} - -template -bool GreaterThanOrEqual(Arena* arena, Type t1, Type t2) { - return LessThanOrEqual(arena, t2, t1); -} - -// Duration comparison specializations -template <> -CelValue Inequal(Arena*, absl::Duration t1, absl::Duration t2) { - return CelValue::CreateBool(operator!=(t1, t2)); -} - -template <> -CelValue Equal(Arena*, absl::Duration t1, absl::Duration t2) { - return CelValue::CreateBool(operator==(t1, t2)); -} - -template <> -bool LessThan(Arena*, absl::Duration t1, absl::Duration t2) { - return operator<(t1, t2); -} - -template <> -bool LessThanOrEqual(Arena*, absl::Duration t1, absl::Duration t2) { - return operator<=(t1, t2); -} - -template <> -bool GreaterThan(Arena*, absl::Duration t1, absl::Duration t2) { - return operator>(t1, t2); -} - -template <> -bool GreaterThanOrEqual(Arena*, absl::Duration t1, absl::Duration t2) { - return operator>=(t1, t2); -} - -// Timestamp comparison specializations -template <> -CelValue Inequal(Arena*, absl::Time t1, absl::Time t2) { - return CelValue::CreateBool(operator!=(t1, t2)); -} - -template <> -CelValue Equal(Arena*, absl::Time t1, absl::Time t2) { - return CelValue::CreateBool(operator==(t1, t2)); -} - -template <> -bool LessThan(Arena*, absl::Time t1, absl::Time t2) { - return operator<(t1, t2); -} - -template <> -bool LessThanOrEqual(Arena*, absl::Time t1, absl::Time t2) { - return operator<=(t1, t2); -} - -template <> -bool GreaterThan(Arena*, absl::Time t1, absl::Time t2) { - return operator>(t1, t2); -} - -template <> -bool GreaterThanOrEqual(Arena*, absl::Time t1, absl::Time t2) { - return operator>=(t1, t2); -} - -// Message specializations -template <> -CelValue Inequal(Arena* arena, const google::protobuf::Message* t1, - const google::protobuf::Message* t2) { - if (t1 == nullptr) { - return CelValue::CreateBool(t2 != nullptr); - } - if (t2 == nullptr) { - return CelValue::CreateBool(true); - } - return CreateNoMatchingOverloadError(arena, builtin::kInequal); -} - -template <> -CelValue Equal(Arena* arena, const google::protobuf::Message* t1, - const google::protobuf::Message* t2) { - if (t1 == nullptr) { - return CelValue::CreateBool(t2 == nullptr); - } - if (t2 == nullptr) { - return CelValue::CreateBool(false); - } - return CreateNoMatchingOverloadError(arena, builtin::kEqual); -} - -// Equality specialization for lists -template <> -CelValue Equal(Arena* arena, const CelList* t1, const CelList* t2) { - int index_size = t1->size(); - if (t2->size() != index_size) { - return CelValue::CreateBool(false); - } - - for (int i = 0; i < index_size; i++) { - CelValue e1 = (*t1)[i]; - CelValue e2 = (*t2)[i]; - const CelValue eq = Equal(arena, e1, e2); - if (eq.IsBool()) { - if (!eq.BoolOrDie()) { - return CelValue::CreateBool(false); - } - } else { - // propagate errors - return eq; - } - } - - return CelValue::CreateBool(true); -} - -template <> -CelValue Inequal(Arena* arena, const CelList* t1, const CelList* t2) { - const CelValue eq = Equal(arena, t1, t2); - if (eq.IsBool()) { - return CelValue::CreateBool(!eq.BoolOrDie()); - } - return eq; -} - -// Equality specialization for maps -template <> -CelValue Equal(Arena* arena, const CelMap* t1, const CelMap* t2) { - if (t1->size() != t2->size()) { - return CelValue::CreateBool(false); - } - - const CelList* keys = t1->ListKeys(); - for (int i = 0; i < keys->size(); i++) { - CelValue key = (*keys)[i]; - CelValue v1 = (*t1)[key].value(); - absl::optional v2 = (*t2)[key]; - if (!v2.has_value()) { - return CelValue::CreateBool(false); - } - const CelValue eq = Equal(arena, v1, v2.value()); - if (eq.IsBool()) { - if (!eq.BoolOrDie()) { - return CelValue::CreateBool(false); - } - } else { - // propagate errors - return eq; - } - } - - return CelValue::CreateBool(true); -} - -template <> -CelValue Inequal(Arena* arena, const CelMap* t1, const CelMap* t2) { - const CelValue eq = Equal(arena, t1, t2); - if (eq.IsBool()) { - return CelValue::CreateBool(!eq.BoolOrDie()); - } - return eq; -} - -// Generic equality for CEL values -template <> -CelValue Equal(Arena* arena, CelValue t1, CelValue t2) { - if (t1.type() != t2.type()) { - // This is used to implement inequal for some types so we can't determine - // the function. - return CreateNoMatchingOverloadError(arena); - } - switch (t1.type()) { - case CelValue::Type::kBool: - return Equal(arena, t1.BoolOrDie(), t2.BoolOrDie()); - case CelValue::Type::kInt64: - return Equal(arena, t1.Int64OrDie(), t2.Int64OrDie()); - case CelValue::Type::kUint64: - return Equal(arena, t1.Uint64OrDie(), t2.Uint64OrDie()); - case CelValue::Type::kDouble: - return Equal(arena, t1.DoubleOrDie(), t2.DoubleOrDie()); - case CelValue::Type::kString: - return Equal(arena, t1.StringOrDie(), - t2.StringOrDie()); - case CelValue::Type::kBytes: - return Equal(arena, t1.BytesOrDie(), - t2.BytesOrDie()); - case CelValue::Type::kMessage: - return Equal(arena, t1.MessageOrDie(), - t2.MessageOrDie()); - case CelValue::Type::kDuration: - return Equal(arena, t1.DurationOrDie(), - t2.DurationOrDie()); - case CelValue::Type::kTimestamp: - return Equal(arena, t1.TimestampOrDie(), t2.TimestampOrDie()); - case CelValue::Type::kList: - return Equal(arena, t1.ListOrDie(), t2.ListOrDie()); - case CelValue::Type::kMap: - return Equal(arena, t1.MapOrDie(), t2.MapOrDie()); - default: - break; - } - return CreateNoMatchingOverloadError(arena); -} - -// Helper method -// -// Registers all equality functions for template parameters type. -template -absl::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registry) { - // Inequality - absl::Status status = - FunctionAdapter::CreateAndRegister( - builtin::kInequal, false, Inequal, registry); - if (!status.ok()) return status; - - // Equality - status = FunctionAdapter::CreateAndRegister( - builtin::kEqual, false, Equal, registry); - return status; -} - -// Registers all comparison functions for template parameter type. -template -absl::Status RegisterComparisonFunctionsForType(CelFunctionRegistry* registry) { - absl::Status status = RegisterEqualityFunctionsForType(registry); - if (!status.ok()) return status; - - // Less than - status = FunctionAdapter::CreateAndRegister( - builtin::kLess, false, LessThan, registry); - if (!status.ok()) return status; - - // Less than or Equal - status = FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, false, LessThanOrEqual, registry); - if (!status.ok()) return status; - - // Greater than - status = FunctionAdapter::CreateAndRegister( - builtin::kGreater, false, GreaterThan, registry); - if (!status.ok()) return status; - - // Greater than or Equal - status = FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, false, GreaterThanOrEqual, registry); - if (!status.ok()) return status; - - return absl::OkStatus(); -} - -// Template functions providing arithmetic operations -template -Type Add(Arena*, Type v0, Type v1) { - return v0 + v1; -} - -template -Type Sub(Arena*, Type v0, Type v1) { - return v0 - v1; -} - -template -Type Mul(Arena*, Type v0, Type v1) { - return v0 * v1; -} - -template -CelValue Div(Arena* arena, Type v0, Type v1); - -// Division operations for integer types should check for -// division by 0 -template <> -CelValue Div(Arena* arena, int64_t v0, int64_t v1) { - // For integral types, zero check is essential, to avoid - // floating pointer exception. - if (v1 == 0) { - // TODO(issues/25) Which code? - return CreateErrorValue(arena, "Division by 0"); - } - return CelValue::CreateInt64(v0 / v1); -} - -// Division operations for integer types should check for -// division by 0 -template <> -CelValue Div(Arena* arena, uint64_t v0, uint64_t v1) { - // For integral types, zero check is essential, to avoid - // floating pointer exception. - if (v1 == 0) { - // TODO(issues/25) Which code? - return CreateErrorValue(arena, "Division by 0"); - } - return CelValue::CreateUint64(v0 / v1); -} - -template <> -CelValue Div(Arena*, double v0, double v1) { - // For double, division will result in +/- inf - return CelValue::CreateDouble(v0 / v1); -} - -// Modulo operation -template -CelValue Modulo(Arena* arena, Type value, Type value2); - -// Modulo operations for integer types should check for -// division by 0 -template <> -CelValue Modulo(Arena* arena, int64_t value, int64_t value2) { - if (value2 == 0) { - return CreateErrorValue(arena, "Modulo by 0"); - } - - return CelValue::CreateInt64(value % value2); -} - -template <> -CelValue Modulo(Arena* arena, uint64_t value, uint64_t value2) { - if (value2 == 0) { - return CreateErrorValue(arena, "Modulo by 0"); - } - - return CelValue::CreateUint64(value % value2); -} - -// Helper method -// Registers all arithmetic functions for template parameter type. -template -absl::Status RegisterArithmeticFunctionsForType(CelFunctionRegistry* registry) { - absl::Status status = FunctionAdapter::CreateAndRegister( - builtin::kAdd, false, Add, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kSubtract, false, Sub, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMultiply, false, Mul, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDivide, false, Div, registry); - return status; -} - -template -bool ValueEquals(const CelValue& value, T other); - -template <> -bool ValueEquals(const CelValue& value, bool other) { - return value.IsBool() && (value.BoolOrDie() == other); -} - -template <> -bool ValueEquals(const CelValue& value, int64_t other) { - return value.IsInt64() && (value.Int64OrDie() == other); -} - -template <> -bool ValueEquals(const CelValue& value, uint64_t other) { - return value.IsUint64() && (value.Uint64OrDie() == other); -} - -template <> -bool ValueEquals(const CelValue& value, double other) { - return value.IsDouble() && (value.DoubleOrDie() == other); -} - -template <> -bool ValueEquals(const CelValue& value, CelValue::StringHolder other) { - return value.IsString() && (value.StringOrDie() == other); -} - -template <> -bool ValueEquals(const CelValue& value, CelValue::BytesHolder other) { - return value.IsBytes() && (value.BytesOrDie() == other); -} - -// Template function implementing CEL in() function -template -bool In(Arena*, T value, const CelList* list) { - int index_size = list->size(); - - for (int i = 0; i < index_size; i++) { - CelValue element = (*list)[i]; - - if (ValueEquals(element, value)) { - return true; - } - } - - return false; -} - -// Concatenation for StringHolder type. -CelValue::StringHolder ConcatString(Arena* arena, CelValue::StringHolder value1, - CelValue::StringHolder value2) { - auto concatenated = Arena::Create( - arena, absl::StrCat(value1.value(), value2.value())); - return CelValue::StringHolder(concatenated); -} - -// Concatenation for BytesHolder type. -CelValue::BytesHolder ConcatBytes(Arena* arena, CelValue::BytesHolder value1, - CelValue::BytesHolder value2) { - auto concatenated = Arena::Create( - arena, absl::StrCat(value1.value(), value2.value())); - return CelValue::BytesHolder(concatenated); -} - -// Concatenation for CelList type. -const CelList* ConcatList(Arena* arena, const CelList* value1, - const CelList* value2) { - std::vector joined_values; - - int size1 = value1->size(); - int size2 = value2->size(); - joined_values.reserve(size1 + size2); - - for (int i = 0; i < size1; i++) { - joined_values.push_back((*value1)[i]); - } - for (int i = 0; i < size2; i++) { - joined_values.push_back((*value2)[i]); - } - - auto concatenated = - Arena::Create(arena, joined_values); - return concatenated; -} - -// Timestamp -const absl::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, - absl::TimeZone::CivilInfo* breakdown) { - absl::TimeZone time_zone; - - if (!tz.empty()) { - bool found = absl::LoadTimeZone(std::string(tz), &time_zone); - if (!found) { - return absl::InvalidArgumentError("Invalid timezone"); - } - } - - *breakdown = time_zone.At(timestamp); - return absl::OkStatus(); -} - -CelValue GetTimeBreakdownPart( - Arena* arena, absl::Time timestamp, absl::string_view tz, - const std::function& - extractor_func) { - absl::TimeZone::CivilInfo breakdown; - auto status = FindTimeBreakdown(timestamp, tz, &breakdown); - - if (!status.ok()) { - return CreateErrorValue(arena, status.message()); - } - - return extractor_func(breakdown); -} - -CelValue CreateTimestampFromString(Arena* arena, - CelValue::StringHolder time_str) { - absl::Time ts; - if (!absl::ParseTime(absl::RFC3339_full, std::string(time_str.value()), &ts, - nullptr)) { - return CreateErrorValue(arena, "String to Timestamp conversion failed", - absl::StatusCode::kInvalidArgument); - } - return CelValue::CreateTimestamp(ts); -} - -CelValue GetFullYear(Arena* arena, absl::Time timestamp, absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64(breakdown.cs.year()); - }); -} - -CelValue GetMonth(Arena* arena, absl::Time timestamp, absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64(breakdown.cs.month() - 1); - }); -} - -CelValue GetDayOfYear(Arena* arena, absl::Time timestamp, - absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64( - absl::GetYearDay(absl::CivilDay(breakdown.cs)) - 1); - }); -} - -CelValue GetDayOfMonth(Arena* arena, absl::Time timestamp, - absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64(breakdown.cs.day() - 1); - }); -} - -CelValue GetDate(Arena* arena, absl::Time timestamp, absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64(breakdown.cs.day()); - }); -} - -CelValue GetDayOfWeek(Arena* arena, absl::Time timestamp, - absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - absl::Weekday weekday = absl::GetWeekday(breakdown.cs); - - // get day of week from the date in UTC, zero-based, zero for Sunday, - // based on GetDayOfWeek CEL function definition. - int weekday_num = static_cast(weekday); - weekday_num = (weekday_num == 6) ? 0 : weekday_num + 1; - return CelValue::CreateInt64(weekday_num); - }); -} - -CelValue GetHours(Arena* arena, absl::Time timestamp, absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64(breakdown.cs.hour()); - }); -} - -CelValue GetMinutes(Arena* arena, absl::Time timestamp, absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64(breakdown.cs.minute()); - }); -} - -CelValue GetSeconds(Arena* arena, absl::Time timestamp, absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64(breakdown.cs.second()); - }); -} - -CelValue GetMilliseconds(Arena* arena, absl::Time timestamp, - absl::string_view tz) { - return GetTimeBreakdownPart( - arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - return CelValue::CreateInt64( - absl::ToInt64Milliseconds(breakdown.subsecond)); - }); -} - -CelValue CreateDurationFromString(Arena* arena, - CelValue::StringHolder dur_str) { - absl::Duration d; - if (!absl::ParseDuration(std::string(dur_str.value()), &d)) { - return CreateErrorValue(arena, "String to Duration conversion failed", - absl::StatusCode::kInvalidArgument); - } - - return CelValue::CreateDuration(d); -} - -CelValue GetHours(Arena*, absl::Duration duration) { - return CelValue::CreateInt64(absl::ToInt64Hours(duration)); -} - -CelValue GetMinutes(Arena*, absl::Duration duration) { - return CelValue::CreateInt64(absl::ToInt64Minutes(duration)); -} - -CelValue GetSeconds(Arena*, absl::Duration duration) { - return CelValue::CreateInt64(absl::ToInt64Seconds(duration)); -} - -CelValue GetMilliseconds(Arena*, absl::Duration duration) { - int64_t millis_per_second = 1000L; - return CelValue::CreateInt64(absl::ToInt64Milliseconds(duration) % - millis_per_second); -} - -bool StringContains(Arena*, CelValue::StringHolder value, - CelValue::StringHolder substr) { - return absl::StrContains(value.value(), substr.value()); -} - -bool StringEndsWith(Arena*, CelValue::StringHolder value, - CelValue::StringHolder suffix) { - return absl::EndsWith(value.value(), suffix.value()); -} - -bool StringStartsWith(Arena*, CelValue::StringHolder value, - CelValue::StringHolder prefix) { - return absl::StartsWith(value.value(), prefix.value()); -} - -absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, - const InterpreterOptions& options) { - auto status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterEqualityFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterEqualityFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterEqualityFunctionsForType(registry); - if (!status.ok()) return status; - - return absl::OkStatus(); -} - -absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, - const InterpreterOptions& options) { - auto status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, false, StringContains, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, true, StringContains, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, false, StringEndsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, true, StringEndsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, false, StringStartsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, true, StringStartsWith, - registry); - if (!status.ok()) return status; - - return absl::OkStatus(); -} - -absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, - const InterpreterOptions& options) { - // Timestamp - // - // timestamp() conversion from string.. - auto status = - FunctionAdapter::CreateAndRegister( - builtin::kTimestamp, false, CreateTimestampFromString, registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kFullYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetFullYear(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kFullYear, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetFullYear(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMonth(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMonth, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMonth(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfYear(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfYear, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfYear(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfMonth(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfMonth, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfMonth(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDate, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDate(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDate, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDate(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfWeek, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfWeek(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfWeek, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfWeek(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetHours(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetHours(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMinutes(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMinutes(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetSeconds(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetSeconds(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMilliseconds(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMilliseconds(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - return absl::OkStatus(); -} - -} // namespace +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/standard/arithmetic_functions.h" +#include "runtime/standard/comparison_functions.h" +#include "runtime/standard/container_functions.h" +#include "runtime/standard/container_membership_functions.h" +#include "runtime/standard/equality_functions.h" +#include "runtime/standard/logical_functions.h" +#include "runtime/standard/regex_functions.h" +#include "runtime/standard/string_functions.h" +#include "runtime/standard/time_functions.h" +#include "runtime/standard/type_conversion_functions.h" + +namespace google::api::expr::runtime { absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { - // logical NOT - absl::Status status = FunctionAdapter::CreateAndRegister( - builtin::kNot, false, [](Arena*, bool value) -> bool { return !value; }, - registry); - if (!status.ok()) return status; - - // Negation group - status = FunctionAdapter::CreateAndRegister( - builtin::kNeg, false, [](Arena*, int64_t value) -> int64_t { return -value; }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kNeg, false, - [](Arena*, double value) -> double { return -value; }, registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctions(registry, options); - if (!status.ok()) return status; - - // Strictness - status = FunctionAdapter::CreateAndRegister( - builtin::kNotStrictlyFalse, false, - [](Arena*, bool value) -> bool { return value; }, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kNotStrictlyFalse, false, - [](Arena*, const CelError*) -> bool { return true; }, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kNotStrictlyFalse, false, - [](Arena*, const UnknownSet*) -> bool { return true; }, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kNotStrictlyFalseDeprecated, false, - [](Arena*, bool value) -> bool { return value; }, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kNotStrictlyFalseDeprecated, false, - [](Arena*, const CelError*) -> bool { return true; }, registry); - if (!status.ok()) return status; - - // String size - auto string_size_func = [](Arena*, CelValue::StringHolder value) -> int64_t { - return value.value().size(); - }; - // receiver style = true/false - // Support global and receiver style size() operations on strings. - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, string_size_func, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, string_size_func, registry); - if (!status.ok()) return status; - - // Bytes size - auto bytes_size_func = [](Arena*, CelValue::BytesHolder value) -> int64_t { - return value.value().size(); - }; - // receiver style = true/false - // Support global and receiver style size() operations on bytes. - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, bytes_size_func, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, bytes_size_func, registry); - if (!status.ok()) return status; - - // List size - auto list_size_func = [](Arena*, const CelList* cel_list) -> int64_t { - return (*cel_list).size(); - }; - // receiver style = true/false - // Support both the global and receiver style size() for lists. - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, list_size_func, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, list_size_func, registry); - if (!status.ok()) return status; - - // List in operator: @in - if (options.enable_list_contains) { - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kIn, false, In, - registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kIn, false, In, - registry); - if (!status.ok()) return status; - - // List in operator: _in_ (deprecated) - // Bindings preserved for backward compatibility with stored expressions. - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInDeprecated, false, - In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInDeprecated, false, - In, registry); - if (!status.ok()) return status; - - // List in() function (deprecated) - // Bindings preserved for backward compatibility with stored expressions. - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInFunction, false, - In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInFunction, false, - In, registry); - if (!status.ok()) return status; - } - - // Map size - auto map_size_func = [](Arena*, const CelMap* cel_map) -> int64_t { - return (*cel_map).size(); - }; - // receiver style = true/false - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, map_size_func, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, map_size_func, registry); - if (!status.ok()) return status; - - // Map in operator: @in - status = FunctionAdapter:: - CreateAndRegister( - builtin::kIn, false, - [](Arena*, CelValue::StringHolder key, - const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateString(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, - [](Arena*, int64_t key, const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateInt64(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, - [](Arena*, uint64_t key, const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateUint64(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - // Map in operators: _in_ (deprecated). - // Bindings preserved for backward compatibility with stored expressions. - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, - [](Arena*, int64_t key, const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateInt64(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, - [](Arena*, uint64_t key, const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateUint64(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kInDeprecated, false, - [](Arena*, CelValue::StringHolder key, - const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateString(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - // Map in() function (deprecated) - status = FunctionAdapter:: - CreateAndRegister( - builtin::kInFunction, false, - [](Arena*, CelValue::StringHolder key, - const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateString(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, - [](Arena*, int64_t key, const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateInt64(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, - [](Arena*, uint64_t key, const CelMap* cel_map) -> bool { - return (*cel_map)[CelValue::CreateUint64(key)].has_value(); - }, - registry); - if (!status.ok()) return status; - - // basic Arithmetic functions for numeric types - status = RegisterArithmeticFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterArithmeticFunctionsForType(registry); - if (!status.ok()) return status; - - status = RegisterArithmeticFunctionsForType(registry); - if (!status.ok()) return status; - - // Special arithmetic operators for Timestamp and Duration - status = - FunctionAdapter::CreateAndRegister( - builtin::kAdd, false, - [](Arena*, absl::Time t1, absl::Duration d2) -> CelValue { - return CelValue::CreateTimestamp(t1 + d2); - }, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter::CreateAndRegister( - builtin::kAdd, false, - [](Arena*, absl::Duration d2, absl::Time t1) -> CelValue { - return CelValue::CreateTimestamp(t1 + d2); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kAdd, false, - [](Arena*, absl::Duration d1, absl::Duration d2) -> CelValue { - return CelValue::CreateDuration(d1 + d2); - }, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter::CreateAndRegister( - builtin::kSubtract, false, - [](Arena*, absl::Time t1, absl::Duration d2) -> CelValue { - return CelValue::CreateTimestamp(t1 - d2); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kSubtract, false, - [](Arena*, absl::Time t1, absl::Time t2) -> CelValue { - return CelValue::CreateDuration(t1 - t2); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kSubtract, false, - [](Arena*, absl::Duration d1, absl::Duration d2) -> CelValue { - return CelValue::CreateDuration(d1 - d2); - }, - registry); - if (!status.ok()) return status; - - // Concat group - if (options.enable_string_concat) { - status = FunctionAdapter< - CelValue::StringHolder, CelValue::StringHolder, - CelValue::StringHolder>::CreateAndRegister(builtin::kAdd, false, - ConcatString, registry); - if (!status.ok()) return status; - - status = - FunctionAdapter::CreateAndRegister(builtin::kAdd, - false, - ConcatBytes, - registry); - if (!status.ok()) return status; - } - - if (options.enable_list_concat) { - status = - FunctionAdapter::CreateAndRegister(builtin::kAdd, false, - ConcatList, - registry); - if (!status.ok()) return status; - } - - // Global matches function. - if (options.enable_regex) { - auto regex_matches = [max_size = options.regex_max_program_size]( - Arena* arena, CelValue::StringHolder target, - CelValue::StringHolder regex) -> CelValue { - RE2 re2(regex.value().data()); - if (max_size > 0 && re2.ProgramSize() > max_size) { - return CreateErrorValue(arena, "exceeded RE2 max program size", - absl::StatusCode::kInvalidArgument); - } - if (!re2.ok()) { - return CreateErrorValue(arena, "invalid_argument", - absl::StatusCode::kInvalidArgument); - } - return CelValue::CreateBool(RE2::PartialMatch(re2::StringPiece(target.value().data(), target.value().size()), re2)); - }; - - status = FunctionAdapter< - CelValue, CelValue::StringHolder, - CelValue::StringHolder>::CreateAndRegister(builtin::kRegexMatch, false, - regex_matches, registry); - if (!status.ok()) return status; - - // Receiver-style matches function. - status = FunctionAdapter< - CelValue, CelValue::StringHolder, - CelValue::StringHolder>::CreateAndRegister(builtin::kRegexMatch, true, - regex_matches, registry); - if (!status.ok()) return status; - } - - status = RegisterStringFunctions(registry, options); - if (!status.ok()) return status; - - // Modulo - status = FunctionAdapter::CreateAndRegister( - builtin::kModulo, false, Modulo, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kModulo, false, Modulo, registry); - if (!status.ok()) return status; - - status = RegisterTimestampFunctions(registry, options); - if (!status.ok()) return status; - - // type conversion to int - // TODO(issues/26): To return errors on loss of precision - // (overflow/underflow) by returning StatusOr. - status = FunctionAdapter::CreateAndRegister( - builtin::kInt, false, - [](Arena*, absl::Time t) { return absl::ToUnixSeconds(t); }, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kInt, false, - [](Arena* arena, double v) { - if ((v > (double)std::numeric_limits::max()) || - (v < (double)std::numeric_limits::min())) { - return CreateErrorValue(arena, "double out of int range", - absl::StatusCode::kInvalidArgument); - } - return CelValue::CreateInt64((int64_t)v); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kInt, false, [](Arena*, bool v) { return (int64_t)v; }, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kInt, false, - [](Arena* arena, uint64_t v) { - if (v > (uint64_t)std::numeric_limits::max()) { - return CreateErrorValue(arena, "uint out of int range", - absl::StatusCode::kInvalidArgument); - } - return CelValue::CreateInt64((int64_t)v); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kInt, false, - [](Arena* arena, CelValue::StringHolder s) { - int64_t result; - if (absl::SimpleAtoi(s.value(), &result)) { - return CelValue::CreateInt64(result); - } else { - return CreateErrorValue(arena, "doesn't convert to a string", - absl::StatusCode::kInvalidArgument); - } - }, - registry); - if (!status.ok()) return status; - - // duration - - // duration() conversion from string.. - status = FunctionAdapter::CreateAndRegister( - builtin::kDuration, false, CreateDurationFromString, registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Duration d) -> CelValue { - return GetHours(arena, d); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Duration d) -> CelValue { - return GetMinutes(arena, d); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Duration d) -> CelValue { - return GetSeconds(arena, d); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Duration d) -> CelValue { - return GetMilliseconds(arena, d); - }, - registry); - if (!status.ok()) return status; - - if (options.enable_string_conversion) { - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, int64_t value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, uint64_t value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, double value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kString, false, - [](Arena* arena, - CelValue::BytesHolder value) -> CelValue::StringHolder { - return CelValue::StringHolder(Arena::Create( - arena, std::string(value.value()))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kString, false, - [](Arena*, CelValue::StringHolder value) -> CelValue::StringHolder { - return value; - }, - registry); - if (!status.ok()) return status; - } + cel::FunctionRegistry& modern_registry = registry->InternalGetRegistry(); + cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); + + CEL_RETURN_IF_ERROR( + cel::RegisterLogicalFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterComparisonFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterContainerFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR(cel::RegisterContainerMembershipFunctions( + modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterTypeConversionFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterArithmeticFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterTimeFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterStringFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterRegexFunctions(modern_registry, runtime_options)); + CEL_RETURN_IF_ERROR( + cel::RegisterEqualityFunctions(modern_registry, runtime_options)); return absl::OkStatus(); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/builtin_func_registrar.h b/eval/public/builtin_func_registrar.h index 2cf906857..afa9d12fe 100644 --- a/eval/public/builtin_func_registrar.h +++ b/eval/public/builtin_func_registrar.h @@ -1,22 +1,30 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ -#include "eval/public/cel_function.h" +#include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { absl::Status RegisterBuiltinFunctions( CelFunctionRegistry* registry, const InterpreterOptions& options = InterpreterOptions()); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ diff --git a/eval/public/builtin_func_registrar_test.cc b/eval/public/builtin_func_registrar_test.cc new file mode 100644 index 000000000..a11676a48 --- /dev/null +++ b/eval/public/builtin_func_registrar_test.cc @@ -0,0 +1,278 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/builtin_func_registrar.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "internal/time.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using cel::expr::Expr; +using cel::expr::SourceInfo; + +using ::absl_testing::StatusIs; +using ::cel::internal::MaxDuration; +using ::cel::internal::MinDuration; +using ::testing::HasSubstr; + +struct TestCase { + std::string test_name; + std::string expr; + absl::flat_hash_map vars; + absl::StatusOr result = CelValue::CreateBool(true); + InterpreterOptions options; +}; + +InterpreterOptions OverflowChecksEnabled() { + static InterpreterOptions options; + options.enable_timestamp_duration_overflow_errors = true; + return options; +} + +void ExpectResult(const TestCase& test_case) { + auto parsed_expr = parser::Parse(test_case.expr); + ASSERT_OK(parsed_expr); + const Expr& expr_ast = parsed_expr->expr(); + const SourceInfo& source_info = parsed_expr->source_info(); + + std::unique_ptr builder = + CreateCelExpressionBuilder(test_case.options); + ASSERT_OK( + RegisterBuiltinFunctions(builder->GetRegistry(), test_case.options)); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr_ast, &source_info)); + + Activation activation; + for (auto var : test_case.vars) { + activation.InsertValue(var.first, var.second); + } + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena)); + if (!test_case.result.ok()) { + EXPECT_TRUE(value.IsError()) << value.DebugString(); + EXPECT_THAT(*value.ErrorOrDie(), + StatusIs(test_case.result.status().code(), + HasSubstr(test_case.result.status().message()))); + return; + } + EXPECT_THAT(value, test::EqualsCelValue(*test_case.result)); +} + +using BuiltinFuncParamsTest = testing::TestWithParam; +TEST_P(BuiltinFuncParamsTest, StandardFunctions) { ExpectResult(GetParam()); } + +INSTANTIATE_TEST_SUITE_P( + BuiltinFuncParamsTest, BuiltinFuncParamsTest, + testing::ValuesIn({ + // Legacy duration and timestamp arithmetic tests. + {"TimeSubTimeLegacy", + "t0 - t1 == duration('90s90ns')", + { + {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + + absl::Nanoseconds(100))}, + {"t1", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10))}, + }}, + + {"TimeSubDurationLegacy", + "t0 - duration('90s90ns')", + { + {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + + absl::Nanoseconds(100))}, + }, + CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10))}, + + {"TimeAddDurationLegacy", + "t + duration('90s90ns')", + {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10))}}, + CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + + absl::Nanoseconds(100))}, + {"DurationAddTimeLegacy", + "duration('90s90ns') + t == t + duration('90s90ns')", + {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10))}}}, + + {"DurationAddDurationLegacy", + "duration('80s80ns') + duration('10s10ns') == duration('90s90ns')"}, + + {"DurationSubDurationLegacy", + "duration('90s90ns') - duration('80s80ns') == duration('10s10ns')"}, + + {"MinDurationSubDurationLegacy", + "min - duration('1ns') < duration('-87660000h')", + {{"min", CelValue::CreateDuration(MinDuration())}}}, + + {"MaxDurationAddDurationLegacy", + "max + duration('1ns') > duration('87660000h')", + {{"max", CelValue::CreateDuration(MaxDuration())}}}, + + {"TimestampConversionFromStringLegacy", + "timestamp('10000-01-02T00:00:00Z') > " + "timestamp('9999-01-01T00:00:00Z')"}, + + {"TimestampFromUnixEpochSeconds", + "timestamp(123) > timestamp('1970-01-01T00:02:02.999999999Z') && " + "timestamp(123) == timestamp('1970-01-01T00:02:03Z') && " + "timestamp(123) < timestamp('1970-01-01T00:02:03.000000001Z')"}, + + // Timestamp duration tests with fixes enabled for overflow checking. + {"TimeSubTime", + "t0 - t1 == duration('90s90ns')", + { + {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + + absl::Nanoseconds(100))}, + {"t1", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10))}, + }, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + + {"TimeSubDuration", + "t0 - duration('90s90ns')", + { + {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + + absl::Nanoseconds(100))}, + }, + CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10)), + OverflowChecksEnabled()}, + + {"TimeSubtractDurationOverflow", + "timestamp('0001-01-01T00:00:00Z') - duration('1ns')", + {}, + absl::OutOfRangeError("timestamp overflow"), + OverflowChecksEnabled()}, + + {"TimeAddDuration", + "t + duration('90s90ns')", + {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10))}}, + CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + + absl::Nanoseconds(100)), + OverflowChecksEnabled()}, + + {"TimeAddDurationOverflow", + "timestamp('9999-12-31T23:59:59.999999999Z') + duration('1ns')", + {}, + absl::OutOfRangeError("timestamp overflow"), + OverflowChecksEnabled()}, + + {"DurationAddTime", + "duration('90s90ns') + t == t + duration('90s90ns')", + {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + + absl::Nanoseconds(10))}}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + + {"DurationAddTimeOverflow", + "duration('1ns') + timestamp('9999-12-31T23:59:59.999999999Z')", + {}, + absl::OutOfRangeError("timestamp overflow"), + OverflowChecksEnabled()}, + + {"DurationAddDuration", + "duration('80s80ns') + duration('10s10ns') == duration('90s90ns')", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + + {"DurationSubDuration", + "duration('90s90ns') - duration('80s80ns') == duration('10s10ns')", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + + {"MinDurationSubDuration", + "min - duration('1ns')", + {{"min", CelValue::CreateDuration(MinDuration())}}, + absl::OutOfRangeError("overflow"), + OverflowChecksEnabled()}, + + {"MaxDurationAddDuration", + "max + duration('1ns')", + {{"max", CelValue::CreateDuration(MaxDuration())}}, + absl::OutOfRangeError("overflow"), + OverflowChecksEnabled()}, + + // Timestamp conversion overflow checks. + {"TimestampConversionFromStringOverflow", + "timestamp('10000-01-02T00:00:00Z')", + {}, + absl::OutOfRangeError("timestamp overflow"), + OverflowChecksEnabled()}, + + {"TimestampConversionFromStringUnderflow", + "timestamp('0000-01-01T00:00:00Z')", + {}, + absl::OutOfRangeError("timestamp overflow"), + OverflowChecksEnabled()}, + + // List concatenation tests. + {"ListConcatEmptyInputs", + "[] + [] == []", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + {"ListConcatRightEmpty", + "[1] + [] == [1]", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + {"ListConcatLeftEmpty", + "[] + [1] == [1]", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + {"ListConcat", + "[2] + [1] == [2, 1]", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + {"StringToBool", + "string(true) + string(false)", + {}, + CelValue::CreateStringView("truefalse"), + OverflowChecksEnabled()}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index b4b27400d..b73a2dc55 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -1,50 +1,78 @@ -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/util/time_util.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/time.h" +namespace google::api::expr::runtime { namespace { using google::protobuf::Duration; using google::protobuf::Timestamp; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::Expr; +using cel::expr::SourceInfo; using google::protobuf::Arena; -using google::protobuf::util::TimeUtil; -using testing::Eq; +using ::cel::internal::MaxDuration; +using ::cel::internal::MinDuration; +using ::cel::internal::MinTimestamp; +using ::testing::Eq; class BuiltinsTest : public ::testing::Test { protected: BuiltinsTest() {} - void SetUp() override { ASSERT_OK(RegisterBuiltinFunctions(®istry_)); } + // Helper method. Looks up in registry and tests comparison operation. + void PerformRun(absl::string_view operation, absl::optional target, + const std::vector& values, CelValue* result) { + PerformRun(operation, target, values, result, options_); + } // Helper method. Looks up in registry and tests comparison operation. void PerformRun(absl::string_view operation, absl::optional target, const std::vector& values, CelValue* result, - const InterpreterOptions& options = InterpreterOptions()) { + const InterpreterOptions& options) { Activation activation; Expr expr; SourceInfo source_info; auto call = expr.mutable_call_expr(); - call->set_function(operation.data()); + call->set_function(operation); if (target.has_value()) { std::string param_name = "target"; @@ -73,17 +101,11 @@ class BuiltinsTest : public ::testing::Test { ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Create CelExpression from AST (Expr object). - auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - - ASSERT_OK(cel_expression_status); - - auto cel_expression = std::move(cel_expression_status.value()); - - auto eval_status = cel_expression->Evaluate(activation, &arena_); - - ASSERT_OK(eval_status); - - *result = eval_status.value(); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr, &source_info)); + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena_)); + *result = value; } // Helper method. Looks up in registry and tests comparison operation. @@ -94,17 +116,21 @@ class BuiltinsTest : public ::testing::Test { ASSERT_NO_FATAL_FAILURE( PerformRun(operation, {}, {ref, other}, &result_value)); - ASSERT_EQ(result_value.IsBool(), true); + ASSERT_EQ(result_value.IsBool(), true) + << absl::StrCat(CelValue::TypeName(ref.type()), " ", operation, " ", + CelValue::TypeName(other.type())); ASSERT_EQ(result_value.BoolOrDie(), result) - << operation << " for " << CelValue::TypeName(ref.type()); + << operation << " for " << ref.DebugString() << " with " + << other.DebugString(); } // Helper method. Looks up in registry and tests for no matching equality // overload. void TestNoMatchingEqualOverload(const CelValue& ref, const CelValue& other) { + options_.enable_heterogeneous_equality = false; CelValue eq_value; ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kEqual, {}, {ref, other}, &eq_value)); + PerformRun(builtin::kEqual, {}, {ref, other}, &eq_value, options_)); ASSERT_TRUE(eq_value.IsError()) << " for " << CelValue::TypeName(ref.type()) << " and " << CelValue::TypeName(other.type()); @@ -112,7 +138,7 @@ class BuiltinsTest : public ::testing::Test { CelValue ineq_value; ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kInequal, {}, {ref, other}, &ineq_value)); + PerformRun(builtin::kInequal, {}, {ref, other}, &ineq_value, options_)); ASSERT_TRUE(ineq_value.IsError()) << " for " << CelValue::TypeName(ref.type()) << " and " << CelValue::TypeName(other.type()); @@ -120,6 +146,40 @@ class BuiltinsTest : public ::testing::Test { } // Helper method. Looks up in registry and tests Type conversions. + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + CelValue::BytesHolder result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsBytes(), true); + ASSERT_EQ(result_value.BytesOrDie(), result) + << operation << " for " << CelValue::TypeName(ref.type()); + } + + // Helper method. Looks up in registry and tests Type conversions. + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + CelValue::StringHolder result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsString(), true); + ASSERT_EQ(result_value.StringOrDie().value(), result.value()) + << operation << " for " << CelValue::TypeName(ref.type()); + } + + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + double result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsDouble(), true); + ASSERT_EQ(result_value.DoubleOrDie(), result) + << operation << " for " << CelValue::TypeName(ref.type()); + } + void TestTypeConverts(absl::string_view operation, const CelValue& ref, int64_t result) { CelValue result_value; @@ -131,6 +191,29 @@ class BuiltinsTest : public ::testing::Test { << operation << " for " << CelValue::TypeName(ref.type()); } + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + uint64_t result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsUint64(), true); + ASSERT_EQ(result_value.Uint64OrDie(), result) + << operation << " for " << CelValue::TypeName(ref.type()); + } + + void TestTypeConverts(absl::string_view operation, const CelValue& ref, + Duration& result) { + CelValue result_value; + + ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); + + ASSERT_EQ(result_value.IsDuration(), true); + ASSERT_EQ(result_value.DurationOrDie(), + CelProtoWrapper::CreateDuration(&result).DurationOrDie()) + << operation << " for " << CelValue::TypeName(ref.type()); + } + // Helper method. Attempts to perform a type conversion and expects an error // as the result. void TestTypeConversionError(absl::string_view operation, @@ -139,7 +222,7 @@ class BuiltinsTest : public ::testing::Test { ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); - ASSERT_EQ(result_value.IsError(), true); + ASSERT_EQ(result_value.IsError(), true) << result_value.DebugString(); } // Helper method. Looks up in registry and tests functions without params. @@ -221,6 +304,17 @@ class BuiltinsTest : public ::testing::Test { ASSERT_EQ(result_value.Int64OrDie(), result) << operation; } + // Helper for testing invalid signed integer arithmetic operations. + void TestArithmeticalErrorInt64(absl::string_view operation, int64_t v1, + int64_t v2, absl::StatusCode code) { + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun( + operation, {}, {CelValue::CreateInt64(v1), CelValue::CreateInt64(v2)}, + &result_value)); + ASSERT_EQ(result_value.IsError(), true); + ASSERT_EQ(result_value.ErrorOrDie()->code(), code) << operation; + } + // Helper method to test arithmetical operations for Uint64 void TestArithmeticalOperationUint64(absl::string_view operation, uint64_t v1, uint64_t v2, uint64_t result) { @@ -232,6 +326,17 @@ class BuiltinsTest : public ::testing::Test { ASSERT_EQ(result_value.Uint64OrDie(), result) << operation; } + // Helper for testing invalid unsigned integer arithmetic operations. + void TestArithmeticalErrorUint64(absl::string_view operation, uint64_t v1, + uint64_t v2, absl::StatusCode code) { + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun( + operation, {}, {CelValue::CreateUint64(v1), CelValue::CreateUint64(v2)}, + &result_value)); + ASSERT_EQ(result_value.IsError(), true); + ASSERT_EQ(result_value.ErrorOrDie()->code(), code) << operation; + } + // Helper method to test arithmetical operations for Double void TestArithmeticalOperationDouble(absl::string_view operation, double v1, double v2, double result) { @@ -250,7 +355,8 @@ class BuiltinsTest : public ::testing::Test { {value, CelValue::CreateList(cel_list)}, &result_value)); - ASSERT_EQ(result_value.IsBool(), true); + ASSERT_EQ(result_value.IsBool(), true) + << result_value.DebugString() << " argument: " << value.DebugString(); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << CelValue::TypeName(value.type()); } @@ -283,11 +389,11 @@ class BuiltinsTest : public ::testing::Test { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kIn, {}, {value, CelValue::CreateMap(cel_map)}, - &result_value)); + &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) - << " for " << CelValue::TypeName(value.type()); + << " for " << value.DebugString(); } void TestInDeprecatedMap(const CelMap* cel_map, const CelValue& value, @@ -295,7 +401,7 @@ class BuiltinsTest : public ::testing::Test { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInDeprecated, {}, {value, CelValue::CreateMap(cel_map)}, - &result_value)); + &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) @@ -307,20 +413,24 @@ class BuiltinsTest : public ::testing::Test { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInFunction, {}, {value, CelValue::CreateMap(cel_map)}, - &result_value)); + &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << CelValue::TypeName(value.type()); } - // Function registry object - CelFunctionRegistry registry_; + InterpreterOptions options_; // Arena Arena arena_; }; +class HeterogeneousEqualityTest : public BuiltinsTest { + public: + HeterogeneousEqualityTest() { options_.enable_heterogeneous_equality = true; } +}; + // Test Not() operation for Bool TEST_F(BuiltinsTest, TestNotOp) { CelValue result; @@ -331,6 +441,30 @@ TEST_F(BuiltinsTest, TestNotOp) { EXPECT_EQ(result.BoolOrDie(), false); } +// Test negation operation for numeric types. +TEST_F(BuiltinsTest, TestNegOp) { + CelValue result; + ASSERT_NO_FATAL_FAILURE( + PerformRun(builtin::kNeg, {}, {CelValue::CreateInt64(-1)}, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_EQ(result.Int64OrDie(), 1); + + ASSERT_NO_FATAL_FAILURE( + PerformRun(builtin::kNeg, {}, {CelValue::CreateDouble(-1.0)}, &result)); + ASSERT_TRUE(result.IsDouble()); + EXPECT_EQ(result.DoubleOrDie(), 1.0); +} + +// Test integer negation overflow. +TEST_F(BuiltinsTest, TestNegIntOverflow) { + CelValue result; + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kNeg, {}, + {CelValue::CreateInt64(std::numeric_limits::lowest())}, + &result)); + ASSERT_TRUE(result.IsError()); +} + // Test Equality/Non-Equality operation for Bool TEST_F(BuiltinsTest, TestBoolEqual) { CelValue ref = CelValue::CreateBool(true); @@ -393,9 +527,8 @@ TEST_F(BuiltinsTest, TestDurationComparisons) { // Test Equality/Non-Equality operation for messages TEST_F(BuiltinsTest, TestNullMessageEqual) { CelValue ref = CelValue::CreateNull(); - Expr call; - call.mutable_call_expr()->set_function("test"); - CelValue value = CelProtoWrapper::CreateMessage(&call, &arena_); + Expr dummy; + CelValue value = CelProtoWrapper::CreateMessage(&dummy, &arena_); TestComparison(builtin::kEqual, ref, ref, true); TestComparison(builtin::kInequal, ref, ref, false); TestComparison(builtin::kEqual, value, ref, false); @@ -404,73 +537,6 @@ TEST_F(BuiltinsTest, TestNullMessageEqual) { TestComparison(builtin::kInequal, ref, value, true); } -// Test Arithmetical operations for Timestamp and Duration -TEST_F(BuiltinsTest, TestTimestampDurationArithmeticalOperation) { - CelValue result_value, cel_ts0, cel_ts1, cel_d0, cel_d1, cel_d2; - Timestamp ts0, ts1; - Duration d0, d1, d2; - - ts0.set_seconds(100); - ts0.set_nanos(100); - ts1.set_seconds(10); - ts1.set_nanos(10); - - d0.set_seconds(90); - d0.set_nanos(90); - d1.set_seconds(80); - d1.set_nanos(80); - d2.set_seconds(10); - d2.set_nanos(10); - - cel_d0 = CelProtoWrapper::CreateDuration(&d0); - cel_d1 = CelProtoWrapper::CreateDuration(&d1); - cel_d2 = CelProtoWrapper::CreateDuration(&d2); - cel_ts0 = CelProtoWrapper::CreateTimestamp(&ts0); - cel_ts1 = CelProtoWrapper::CreateTimestamp(&ts1); - - // ts0 - ts1 = d0 - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kSubtract, {}, {cel_ts0, cel_ts1}, &result_value)); - ASSERT_EQ(result_value.IsDuration(), true); - ASSERT_EQ(absl::ToInt64Nanoseconds(result_value.DurationOrDie()), - TimeUtil::DurationToNanoseconds(d0)); - - // ts0 - d0 = ts1 - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kSubtract, {}, {cel_ts0, cel_d0}, &result_value)); - ASSERT_EQ(result_value.IsTimestamp(), true); - ASSERT_EQ(absl::ToUnixNanos(result_value.TimestampOrDie()), - TimeUtil::TimestampToNanoseconds(ts1)); - - // ts1 + d0 = ts0 - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kAdd, {}, {cel_ts1, cel_d0}, &result_value)); - ASSERT_EQ(result_value.IsTimestamp(), true); - ASSERT_EQ(absl::ToUnixNanos(result_value.TimestampOrDie()), - TimeUtil::TimestampToNanoseconds(ts0)); - - // d0 + ts1 = ts0 - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kAdd, {}, {cel_d0, cel_ts1}, &result_value)); - ASSERT_EQ(result_value.IsTimestamp(), true); - ASSERT_EQ(absl::ToUnixNanos(result_value.TimestampOrDie()), - TimeUtil::TimestampToNanoseconds(ts0)); - - // d0 - d1 = d2 - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kSubtract, {}, {cel_d0, cel_d1}, &result_value)); - ASSERT_EQ(result_value.IsDuration(), true); - ASSERT_EQ(absl::ToInt64Nanoseconds(result_value.DurationOrDie()), - TimeUtil::DurationToNanoseconds(d2)); - - // d1 + d2 = d0 - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kAdd, {}, {cel_d2, cel_d1}, &result_value)); - ASSERT_EQ(result_value.IsDuration(), true); - ASSERT_EQ(absl::ToInt64Nanoseconds(result_value.DurationOrDie()), - TimeUtil::DurationToNanoseconds(d0)); -} - // Test functions for Duration TEST_F(BuiltinsTest, TestDurationFunctions) { Duration ref; @@ -478,24 +544,49 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { ref.set_seconds(93541L); ref.set_nanos(11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), 25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - 1559L); + int64_t{1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - 93541L); + int64_t{93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - 11L); + int64_t{11L}); + + std::string result = "93541.011s"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), + CelValue::StringHolder(&result)); + TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); ref.set_seconds(-93541L); ref.set_nanos(-11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), -25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{-25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - -1559L); + int64_t{-1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - -93541L); + int64_t{-93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - -11L); + int64_t{-11L}); + + result = "-93541.011s"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), + CelValue::StringHolder(&result)); + TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); + + absl::Duration d = MinDuration() + absl::Seconds(-1); + result = absl::FormatDuration(d); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); + + d = MaxDuration() + absl::Seconds(1); + result = absl::FormatDuration(d); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); + + std::string inf = "inf"; + std::string ninf = "-inf"; + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&inf)); + TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&ninf)); } // Test functions for Timestamp @@ -506,25 +597,48 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { ref.set_seconds(1L); ref.set_nanos(11000000L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1970L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 0L); + int64_t{1970L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 0L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 1L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 1L); + int64_t{0L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), - 11L); + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); +} + +TEST_F(BuiltinsTest, TestTimestampConversionToString) { + Timestamp ref; + ref.set_seconds(1L); + ref.set_nanos(11000000L); + std::string result = "1970-01-01T00:00:01.011Z"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), + CelValue::StringHolder(&result)); - // Test timestamp functions w/ timezone + ref.set_seconds(259200L); + ref.set_nanos(0L); + result = "1970-01-04T00:00:00Z"; + TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), + CelValue::StringHolder(&result)); +} + +TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { + // Test timestamp functions w/ IANA timezone + Timestamp ref; ref.set_seconds(1L); ref.set_nanos(11000000L); std::vector params; @@ -533,79 +647,313 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, - 1969L); + int64_t{1969L}); + TestFunctionsWithParams(builtin::kMonth, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); + TestFunctionsWithParams(builtin::kDayOfYear, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); + TestFunctionsWithParams(builtin::kDayOfMonth, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); + TestFunctionsWithParams(builtin::kDate, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); + TestFunctionsWithParams(builtin::kHours, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); + TestFunctionsWithParams(builtin::kMinutes, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); + TestFunctionsWithParams(builtin::kSeconds, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); + TestFunctionsWithParams(builtin::kMilliseconds, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); + + ref.set_seconds(259200L); + ref.set_nanos(0L); + TestFunctionsWithParams(builtin::kDayOfWeek, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); + + // Test timestamp functions with negative value + ref.set_seconds(-1L); + ref.set_nanos(0L); + + TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); + TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{364L}); + TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{3L}); + + // Test timestamp functions w/ fixed timezone + ref.set_seconds(1L); + ref.set_nanos(11000000L); + const std::string fixedzone = "-08:00"; + params.clear(); + params.push_back(CelValue::CreateString(&fixedzone)); + + TestFunctionsWithParams(builtin::kFullYear, + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, - CelProtoWrapper::CreateTimestamp(&ref), params, 364L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 30L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); TestFunctionsWithParams(builtin::kDate, - CelProtoWrapper::CreateTimestamp(&ref), params, 31L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); TestFunctionsWithParams(builtin::kHours, - CelProtoWrapper::CreateTimestamp(&ref), params, 16L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, - CelProtoWrapper::CreateTimestamp(&ref), params, 0L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 1L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, - CelProtoWrapper::CreateTimestamp(&ref), params, 6L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1969L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 11L); + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 364L); + int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 30L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 31L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 23L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 59L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 3L); + int64_t{3L}); + + TestTypeConversionError( + builtin::kString, + CelValue::CreateTimestamp(MinTimestamp() + absl::Seconds(-1))); +} + +TEST_F(BuiltinsTest, TestBytesConversions_bytes) { + std::string txt = "hello"; + CelValue::BytesHolder result = CelValue::BytesHolder(&txt); + TestTypeConverts(builtin::kBytes, CelValue::CreateBytes(&txt), result); +} + +TEST_F(BuiltinsTest, TestBytesConversions_string) { + std::string txt = "hello"; + CelValue::BytesHolder result = CelValue::BytesHolder(&txt); + TestTypeConverts(builtin::kBytes, CelValue::CreateString(&txt), result); +} + +TEST_F(BuiltinsTest, TestDoubleConversions_double) { + double ref = 100.1; + TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), + double{100.1}); } -TEST_F(BuiltinsTest, TestTypeConversions_Timestamp) { +TEST_F(BuiltinsTest, TestDoubleConversions_int) { + int64_t ref = 100L; + TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), double{100.0}); +} + +TEST_F(BuiltinsTest, TestDoubleConversions_string) { + std::string ref = "-100.1"; + TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), + double{-100.1}); +} + +TEST_F(BuiltinsTest, TestDoubleConversions_uint) { + uint64_t ref = 100UL; + TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), + double{100.0}); +} + +TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { + std::string invalid = "-100e-10.0"; + TestTypeConversionError(builtin::kDouble, CelValue::CreateString(&invalid)); +} + +TEST_F(BuiltinsTest, TestDynConversions) { + TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), double{100.1}); + TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), int64_t{100L}); + TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), + uint64_t{100UL}); +} + +TEST_F(BuiltinsTest, TestIntConversions_int) { + TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), int64_t{100L}); +} + +TEST_F(BuiltinsTest, TestIntConversions_Timestamp) { Timestamp ref; ref.set_seconds(100); - TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), 100L); + TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{100L}); } -TEST_F(BuiltinsTest, TestTypeConversions_double) { +TEST_F(BuiltinsTest, TestIntConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), int64_t{100L}); +} + +TEST_F(BuiltinsTest, TestIntConversions_string) { + std::string ref = "100"; + TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), int64_t{100L}); } -TEST_F(BuiltinsTest, TestTypeConversions_uint64) { +TEST_F(BuiltinsTest, TestIntConversions_uint) { uint64_t ref = 100; - TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), int64_t{100L}); +} + +TEST_F(BuiltinsTest, TestIntConversions_doubleIntMin) { + // Converting int64_t min may or may not roundtrip properly without overflow + // depending on compiler flags, so the conservative approach is to treat this + // case as overflow. + double range = std::numeric_limits::lowest(); + TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); +} + +TEST_F(BuiltinsTest, TestIntConversions_doubleIntMinMinus1024) { + // Converting values between [int64_t::lowest(), (int64_t::lowest() - 1024)] + // will result in an int64_t representable value, in some cases, but not all + // as the conversion depends on + double range = std::numeric_limits::lowest(); + range -= 1024L; + TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); +} + +TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus512) { + // Converting int64_t max - 512 to a double will not roundtrip to the original + // value, but it will roundtrip to a valid 64-bit integer. + double range = std::numeric_limits::max() - 512; + TestTypeConverts(builtin::kInt, CelValue::CreateDouble(range), + int64_t{std::numeric_limits::max() - 1023}); } -TEST_F(BuiltinsTest, TestTypeConversionError_doubleNegRange) { +TEST_F(BuiltinsTest, TestIntConversionError_doubleNegRange) { double range = -1.0e99; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } -TEST_F(BuiltinsTest, TestTypeConversionError_doublePosRange) { +TEST_F(BuiltinsTest, TestIntConversionError_doublePosRange) { double range = 1.0e99; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } -TEST_F(BuiltinsTest, TestTypeConversionError_uint64Range) { +TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMax) { + // Converting int64_t max to a double results in a double value of int64_t max + // + 1 which should cause the overflow testing to trip. + double range = std::numeric_limits::max(); + TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); +} +TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus1) { + // Converting values between int64_t::max() and int64_t::max() - 511 will + // result in overflow errors during round-tripping. + double range = std::numeric_limits::max() - 1; + TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); +} + +TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus511) { + // Converting values between int64_t::max() and int64_t::max() - 511 will + // result in overflow errors during round-tripping. + double range = std::numeric_limits::max() - 511; + TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); +} + +TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMinMinus1025) { + // Converting double values lower than int64_t::lowest() - 1025 will result in + // an overflow error. + double range = std::numeric_limits::lowest(); + range -= 1025L; + TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); +} + +TEST_F(BuiltinsTest, TestIntConversionError_uintRange) { uint64_t range = 18446744073709551615UL; TestTypeConversionError(builtin::kInt, CelValue::CreateUint64(range)); } +TEST_F(BuiltinsTest, TestUintConversions_double) { + double ref = 100.1; + TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), + uint64_t{100UL}); +} + +TEST_F(BuiltinsTest, TestUintConversions_int) { + int64_t ref = 100L; + TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), uint64_t{100UL}); +} + +TEST_F(BuiltinsTest, TestUintConversions_string) { + std::string ref = "100"; + TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), + uint64_t{100UL}); +} + +TEST_F(BuiltinsTest, TestUintConversions_uint) { + TestTypeConverts(builtin::kUint, CelValue::CreateUint64(uint64_t{100UL}), + uint64_t{100UL}); +} + +TEST_F(BuiltinsTest, TestUintConversionError_doubleNegRange) { + double range = -1.0e99; + TestTypeConversionError(builtin::kUint, CelValue::CreateDouble(range)); +} + +TEST_F(BuiltinsTest, TestUintConversionError_doublePosRange) { + double range = 1.0e99; + TestTypeConversionError(builtin::kUint, CelValue::CreateDouble(range)); +} + +TEST_F(BuiltinsTest, TestUintConversionError_intRange) { + int64_t range = -1L; + TestTypeConversionError(builtin::kUint, CelValue::CreateInt64(range)); +} + +TEST_F(BuiltinsTest, TestUintConversionError_stringInvalid) { + std::string invalid = "-100"; + TestTypeConversionError(builtin::kUint, CelValue::CreateString(&invalid)); +} + TEST_F(BuiltinsTest, TestTimestampComparisons) { Timestamp ref; Timestamp lesser; @@ -628,7 +976,7 @@ TEST_F(BuiltinsTest, TestLogicalOr) { TestLogicalOperation(op_name, true, false, true); TestLogicalOperation(op_name, false, false, false); - CelError error; + CelError error = absl::CancelledError(); // Test special cases - mix of bool and error // true || error CelValue result; @@ -679,7 +1027,7 @@ TEST_F(BuiltinsTest, TestLogicalAnd) { TestLogicalOperation(op_name, true, false, false); TestLogicalOperation(op_name, false, false, false); - CelError error; + CelError error = absl::CancelledError(); // Test special cases - mix of bool and error // true && error CelValue result; @@ -736,7 +1084,7 @@ TEST_F(BuiltinsTest, TestTernary) { } TEST_F(BuiltinsTest, TestTernaryErrorAsCondition) { - CelError cel_error; + CelError cel_error = absl::CancelledError(); std::vector args = {CelValue::CreateError(&cel_error), CelValue::CreateInt64(1), CelValue::CreateInt64(2)}; @@ -746,7 +1094,7 @@ TEST_F(BuiltinsTest, TestTernaryErrorAsCondition) { PerformRun(builtin::kTernary, {}, args, &result_value)); ASSERT_EQ(result_value.IsError(), true); - ASSERT_EQ(result_value.ErrorOrDie(), &cel_error); + ASSERT_EQ(*result_value.ErrorOrDie(), cel_error); } TEST_F(BuiltinsTest, TestTernaryStringAsCondition) { @@ -773,6 +1121,25 @@ class FakeList : public CelList { std::vector values_; }; +class FakeErrorMap : public CelMap { + public: + FakeErrorMap() {} + + int size() const override { return 0; } + + absl::StatusOr Has(const CelValue& key) const override { + return absl::InvalidArgumentError("bad key type"); + } + + absl::optional operator[](CelValue key) const override { + return absl::nullopt; + } + + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } +}; + template class FakeMap : public CelMap { public: @@ -785,7 +1152,7 @@ class FakeMap : public CelMap { for (auto kv : data) { keys.push_back(create_cel_value(kv.first)); } - keys_ = absl::make_unique(keys); + keys_ = std::make_unique(keys); } int size() const override { return data_.size(); } @@ -802,7 +1169,9 @@ class FakeMap : public CelMap { return it->second; } - const CelList* ListKeys() const override { return keys_.get(); } + absl::StatusOr ListKeys() const override { + return keys_.get(); + } private: std::map data_; @@ -810,6 +1179,18 @@ class FakeMap : public CelMap { std::function(CelValue)> get_cel_value_; }; +class FakeBoolMap : public FakeMap { + public: + explicit FakeBoolMap(const std::map& data) + : FakeMap(data, CelValue::CreateBool, + [](CelValue v) -> absl::optional { + if (!v.IsBool()) { + return absl::nullopt; + } + return v.BoolOrDie(); + }) {} +}; + class FakeInt64Map : public FakeMap { public: explicit FakeInt64Map(const std::map& data) @@ -848,18 +1229,6 @@ class FakeStringMap : public FakeMap { }) {} }; -class FakeBoolMap : public FakeMap { - public: - explicit FakeBoolMap(const std::map& data) - : FakeMap(data, CelValue::CreateBool, - [](CelValue v) -> absl::optional { - if (!v.IsBool()) { - return absl::nullopt; - } - return v.BoolOrDie(); - }) {} -}; - // Test list index access function TEST_F(BuiltinsTest, ListIndex) { constexpr int64_t kValues[] = {3, 4, 5, 6}; @@ -1107,8 +1476,16 @@ TEST_F(BuiltinsTest, StringSize) { builtin::kSize, {}, {CelValue::CreateString(&test)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 9); +} - ASSERT_EQ(result_value.Int64OrDie(), test.size()); +TEST_F(BuiltinsTest, StringUnicodeSize) { + std::string test = "πέντε"; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kSize, {}, {CelValue::CreateString(&test)}, &result_value)); + ASSERT_EQ(result_value.IsInt64(), true); + ASSERT_EQ(result_value.Int64OrDie(), 5); } TEST_F(BuiltinsTest, BytesSize) { @@ -1155,8 +1532,6 @@ TEST_F(BuiltinsTest, MapSize) { } TEST_F(BuiltinsTest, TestBoolListIn) { - std::vector values; - FakeList cel_list({CelValue::CreateBool(false), CelValue::CreateBool(false)}); TestInList(&cel_list, CelValue::CreateBool(false), true); @@ -1164,8 +1539,6 @@ TEST_F(BuiltinsTest, TestBoolListIn) { } TEST_F(BuiltinsTest, TestInt64ListIn) { - std::vector values; - FakeList cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); TestInList(&cel_list, CelValue::CreateInt64(2), true); @@ -1173,8 +1546,6 @@ TEST_F(BuiltinsTest, TestInt64ListIn) { } TEST_F(BuiltinsTest, TestUint64ListIn) { - std::vector values; - FakeList cel_list({CelValue::CreateUint64(1), CelValue::CreateUint64(2)}); TestInList(&cel_list, CelValue::CreateUint64(2), true); @@ -1182,8 +1553,6 @@ TEST_F(BuiltinsTest, TestUint64ListIn) { } TEST_F(BuiltinsTest, TestDoubleListIn) { - std::vector values; - FakeList cel_list({CelValue::CreateDouble(1), CelValue::CreateDouble(2)}); TestInList(&cel_list, CelValue::CreateDouble(2), true); @@ -1191,8 +1560,6 @@ TEST_F(BuiltinsTest, TestDoubleListIn) { } TEST_F(BuiltinsTest, TestStringListIn) { - std::vector values; - std::string v0 = "test0"; std::string v1 = "test1"; std::string v2 = "test2"; @@ -1216,6 +1583,85 @@ TEST_F(BuiltinsTest, TestBytesListIn) { TestInList(&cel_list, CelValue::CreateBytes(&v2), false); } +TEST_F(HeterogeneousEqualityTest, MixedTypes) { + FakeList cel_list({CelValue::CreateDuration(absl::Seconds(1)), + CelValue::CreateNull(), CelValue::CreateInt64(1)}); + + ASSERT_NO_FATAL_FAILURE( + TestInList(&cel_list, CelValue::CreateDuration(absl::Seconds(1)), true)); + ASSERT_NO_FATAL_FAILURE( + TestInList(&cel_list, CelValue::CreateInt64(1), true)); + + ASSERT_NO_FATAL_FAILURE( + TestInList(&cel_list, CelValue::CreateUint64(1), true)); + + ASSERT_NO_FATAL_FAILURE( + TestInList(&cel_list, CelValue::CreateInt64(2), false)); + ASSERT_NO_FATAL_FAILURE( + TestInList(&cel_list, CelValue::CreateStringView("abc"), false)); +} + +TEST_F(HeterogeneousEqualityTest, NullIn) { + FakeList cel_list({CelValue::CreateInt64(0), CelValue::CreateNull(), + CelValue::CreateInt64(1)}); + + ASSERT_NO_FATAL_FAILURE( + TestInList(&cel_list, CelValue::CreateInt64(1), true)); + ASSERT_NO_FATAL_FAILURE(TestInList(&cel_list, CelValue::CreateNull(), true)); + ASSERT_NO_FATAL_FAILURE( + TestInList(&cel_list, CelValue::CreateInt64(2), false)); +} + +TEST_F(HeterogeneousEqualityTest, NullNotIn) { + FakeList cel_list({CelValue::CreateInt64(0), CelValue::CreateInt64(1)}); + + ASSERT_NO_FATAL_FAILURE(TestInList(&cel_list, CelValue::CreateNull(), false)); +} + +TEST_F(BuiltinsTest, TestMapInError) { + FakeErrorMap cel_map; + std::vector kValues = { + CelValue::CreateBool(true), + CelValue::CreateInt64(1), + CelValue::CreateStringView("hello"), + CelValue::CreateUint64(2), + }; + + options_.enable_heterogeneous_equality = true; + for (auto key : kValues) { + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); + ASSERT_TRUE(result_value.IsBool()) + << key.DebugString() << " : " << result_value.DebugString(); + EXPECT_FALSE(result_value.BoolOrDie()); + } + + options_.enable_heterogeneous_equality = false; + for (auto key : kValues) { + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); + + EXPECT_TRUE(result_value.IsError()); + EXPECT_EQ(result_value.ErrorOrDie()->message(), "bad key type"); + EXPECT_EQ(result_value.ErrorOrDie()->code(), + absl::StatusCode::kInvalidArgument); + } +} + +TEST_F(BuiltinsTest, TestBoolMapIn) { + constexpr bool kValues[] = {true, true}; + std::map data; + for (auto value : kValues) { + data[value] = CelValue::CreateBool(value); + } + FakeBoolMap cel_map(data); + TestInMap(&cel_map, CelValue::CreateBool(true), true); + TestInMap(&cel_map, CelValue::CreateBool(false), false); + TestInMap(&cel_map, CelValue::CreateUint64(3), false); +} + TEST_F(BuiltinsTest, TestInt64MapIn) { constexpr int64_t kValues[] = {3, -4, 5, -6}; std::map data; @@ -1223,9 +1669,21 @@ TEST_F(BuiltinsTest, TestInt64MapIn) { data[value] = CelValue::CreateInt64(value * value); } FakeInt64Map cel_map(data); + options_.enable_heterogeneous_equality = false; TestInMap(&cel_map, CelValue::CreateInt64(-4), true); TestInMap(&cel_map, CelValue::CreateInt64(4), false); TestInMap(&cel_map, CelValue::CreateUint64(3), false); + TestInMap(&cel_map, CelValue::CreateUint64(4), false); + + options_.enable_heterogeneous_equality = true; + TestInMap(&cel_map, CelValue::CreateUint64(3), true); + TestInMap(&cel_map, CelValue::CreateUint64(4), false); + TestInMap(&cel_map, CelValue::CreateDouble(NAN), false); + TestInMap(&cel_map, CelValue::CreateDouble(-4.0), true); + TestInMap(&cel_map, CelValue::CreateDouble(-4.1), false); + TestInMap(&cel_map, + CelValue::CreateDouble(std::numeric_limits::max()), + false); } TEST_F(BuiltinsTest, TestUint64MapIn) { @@ -1235,9 +1693,17 @@ TEST_F(BuiltinsTest, TestUint64MapIn) { data[value] = CelValue::CreateUint64(value * value); } FakeUint64Map cel_map(data); + options_.enable_heterogeneous_equality = false; TestInMap(&cel_map, CelValue::CreateUint64(4), true); TestInMap(&cel_map, CelValue::CreateUint64(44), false); TestInMap(&cel_map, CelValue::CreateInt64(4), false); + + options_.enable_heterogeneous_equality = true; + TestInMap(&cel_map, CelValue::CreateInt64(-1), false); + TestInMap(&cel_map, CelValue::CreateInt64(4), true); + TestInMap(&cel_map, CelValue::CreateDouble(4.0), true); + TestInMap(&cel_map, CelValue::CreateDouble(-4.0), false); + TestInMap(&cel_map, CelValue::CreateDouble(7.0), false); } TEST_F(BuiltinsTest, TestStringMapIn) { @@ -1259,32 +1725,41 @@ TEST_F(BuiltinsTest, TestInt64Arithmetics) { TestArithmeticalOperationInt64(builtin::kDivide, 10, 5, 2); } -TEST_F(BuiltinsTest, TestInt64DivisionByZero) { - CelValue result_value; - - ASSERT_NO_FATAL_FAILURE(PerformRun( - builtin::kDivide, {}, - {CelValue::CreateInt64(1), CelValue::CreateInt64(0)}, &result_value)); - - ASSERT_TRUE(result_value.IsError()); +TEST_F(BuiltinsTest, TestInt64ArithmeticOverflow) { + int64_t min = std::numeric_limits::lowest(); + int64_t max = std::numeric_limits::max(); + TestArithmeticalErrorInt64(builtin::kAdd, max, 1, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorInt64(builtin::kSubtract, min, max, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorInt64(builtin::kMultiply, max, 2, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorInt64(builtin::kModulo, min, -1, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorInt64(builtin::kDivide, min, -1, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorInt64(builtin::kDivide, min, 0, + absl::StatusCode::kInvalidArgument); } TEST_F(BuiltinsTest, TestUint64Arithmetics) { TestArithmeticalOperationUint64(builtin::kAdd, 2, 3, 5); - TestArithmeticalOperationUint64(builtin::kSubtract, 2, 3, - static_cast(-1)); + TestArithmeticalOperationUint64(builtin::kSubtract, 3, 2, 1); TestArithmeticalOperationUint64(builtin::kMultiply, 2, 3, 6); TestArithmeticalOperationUint64(builtin::kDivide, 10, 5, 2); } -TEST_F(BuiltinsTest, TestUint64DivisionByZero) { +TEST_F(BuiltinsTest, TestUint64ArithmeticOverflow) { CelValue result_value; - - ASSERT_NO_FATAL_FAILURE(PerformRun( - builtin::kDivide, {}, - {CelValue::CreateUint64(1), CelValue::CreateUint64(0)}, &result_value)); - - ASSERT_TRUE(result_value.IsError()); + uint64_t max = std::numeric_limits::max(); + TestArithmeticalErrorUint64(builtin::kAdd, max, 1, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorUint64(builtin::kSubtract, 2, 3, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorUint64(builtin::kMultiply, max, 2, + absl::StatusCode::kOutOfRange); + TestArithmeticalErrorUint64(builtin::kDivide, 1, 0, + absl::StatusCode::kInvalidArgument); } TEST_F(BuiltinsTest, TestDoubleArithmetics) { @@ -1412,15 +1887,6 @@ TEST_F(BuiltinsTest, MatchesMaxSize) { EXPECT_TRUE(result_value.IsError()); } -TEST_F(BuiltinsTest, StringToInt) { - std::string target = "-42"; - std::vector args = {CelValue::CreateString(&target)}; - CelValue result_value; - ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInt, {}, args, &result_value)); - ASSERT_TRUE(result_value.IsInt64()); - EXPECT_EQ(result_value.Int64OrDie(), -42); -} - TEST_F(BuiltinsTest, StringToIntNonInt) { std::string target = "not_a_number"; std::vector args = {CelValue::CreateString(&target)}; @@ -1466,6 +1932,15 @@ TEST_F(BuiltinsTest, BytesToString) { EXPECT_EQ(result_value.StringOrDie().value(), "abcd"); } +TEST_F(BuiltinsTest, BytesToStringInvalid) { + std::string input = "\xFF"; + std::vector args = {CelValue::CreateBytes(&input)}; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE( + PerformRun(builtin::kString, {}, args, &result_value)); + ASSERT_TRUE(result_value.IsError()); +} + TEST_F(BuiltinsTest, StringToString) { std::string input = "abcd"; std::vector args = {CelValue::CreateString(&input)}; @@ -1476,9 +1951,58 @@ TEST_F(BuiltinsTest, StringToString) { EXPECT_EQ(result_value.StringOrDie().value(), "abcd"); } -} // namespace +// Type operations +TEST_F(BuiltinsTest, TypeComparisons) { + std::vector> paired_values; + + paired_values.push_back( + {CelValue::CreateBool(false), CelValue::CreateBool(true)}); + paired_values.push_back( + {CelValue::CreateInt64(-1), CelValue::CreateInt64(1)}); + paired_values.push_back( + {CelValue::CreateUint64(1), CelValue::CreateUint64(2)}); + paired_values.push_back( + {CelValue::CreateDouble(1.), CelValue::CreateDouble(2.)}); + + std::string str1 = "test1"; + std::string str2 = "test2"; + paired_values.push_back( + {CelValue::CreateString(&str1), CelValue::CreateString(&str2)}); + paired_values.push_back( + {CelValue::CreateBytes(&str1), CelValue::CreateBytes(&str2)}); + + FakeList cel_list1({CelValue::CreateBool(false)}); + FakeList cel_list2({CelValue::CreateBool(true)}); + paired_values.push_back( + {CelValue::CreateList(&cel_list1), CelValue::CreateList(&cel_list2)}); + + std::map data1; + std::map data2; + FakeInt64Map cel_map1(data1); + FakeInt64Map cel_map2(data2); + paired_values.push_back( + {CelValue::CreateMap(&cel_map1), CelValue::CreateMap(&cel_map2)}); + + for (size_t i = 0; i < paired_values.size(); i++) { + for (size_t j = 0; j < paired_values.size(); j++) { + CelValue result1; + CelValue result2; + + PerformRun(builtin::kType, {}, {paired_values[i].first}, &result1); + PerformRun(builtin::kType, {}, {paired_values[j].second}, &result2); + + ASSERT_TRUE(result1.IsCelType()) << "Unexpected result for value at index" + << i << ":" << result1.DebugString(); + ASSERT_TRUE(result2.IsCelType()) << "Unexpected result for value at index" + << j << ":" << result2.DebugString(); + if (i == j) { + ASSERT_EQ(result1.CelTypeOrDie(), result2.CelTypeOrDie()); + } else { + ASSERT_TRUE(result1.CelTypeOrDie() != result2.CelTypeOrDie()); + } + } + } +} -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 3be2885a4..70525a04d 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -1,34 +1,37 @@ #include "eval/public/cel_attribute.h" #include +#include +#include +#include +#include #include "absl/strings/string_view.h" -#include "absl/types/variant.h" +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { namespace { +// Visitation for attribute qualifier kinds struct QualifierVisitor { CelAttributeQualifierPattern operator()(absl::string_view v) { if (v == "*") { return CelAttributeQualifierPattern::CreateWildcard(); } - return CelAttributeQualifierPattern::Create(CelValue::CreateStringView(v)); + return CelAttributeQualifierPattern::OfString(std::string(v)); } CelAttributeQualifierPattern operator()(int64_t v) { - return CelAttributeQualifierPattern::Create(CelValue::CreateInt64(v)); + return CelAttributeQualifierPattern::OfInt(v); } CelAttributeQualifierPattern operator()(uint64_t v) { - return CelAttributeQualifierPattern::Create(CelValue::CreateUint64(v)); + return CelAttributeQualifierPattern::OfUint(v); } CelAttributeQualifierPattern operator()(bool v) { - return CelAttributeQualifierPattern::Create(CelValue::CreateBool(v)); + return CelAttributeQualifierPattern::OfBool(v); } CelAttributeQualifierPattern operator()(CelAttributeQualifierPattern v) { @@ -38,10 +41,43 @@ struct QualifierVisitor { } // namespace +CelAttributeQualifierPattern CreateCelAttributeQualifierPattern( + const CelValue& value) { + switch (value.type()) { + case cel::Kind::kInt64: + return CelAttributeQualifierPattern::OfInt(value.Int64OrDie()); + case cel::Kind::kUint64: + return CelAttributeQualifierPattern::OfUint(value.Uint64OrDie()); + case cel::Kind::kString: + return CelAttributeQualifierPattern::OfString( + std::string(value.StringOrDie().value())); + case cel::Kind::kBool: + return CelAttributeQualifierPattern::OfBool(value.BoolOrDie()); + default: + return CelAttributeQualifierPattern(CelAttributeQualifier()); + } +} + +CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value) { + switch (value.type()) { + case cel::Kind::kInt64: + return CelAttributeQualifier::OfInt(value.Int64OrDie()); + case cel::Kind::kUint64: + return CelAttributeQualifier::OfUint(value.Uint64OrDie()); + case cel::Kind::kString: + return CelAttributeQualifier::OfString( + std::string(value.StringOrDie().value())); + case cel::Kind::kBool: + return CelAttributeQualifier::OfBool(value.BoolOrDie()); + default: + return CelAttributeQualifier(); + } +} + CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec) { std::vector path; path.reserve(path_spec.size()); @@ -51,7 +87,4 @@ CelAttributePattern CreateCelAttributePattern( return CelAttributePattern(std::string(variable), std::move(path)); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_attribute.h b/eval/public/cel_attribute.h index bc3f064e7..959fff75e 100644 --- a/eval/public/cel_attribute.h +++ b/eval/public/cel_attribute.h @@ -1,267 +1,64 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ +#include + #include #include +#include #include - -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "base/attribute.h" #include "eval/public/cel_value.h" -#include "eval/public/cel_value_internal.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // CelAttributeQualifier represents a segment in // attribute resolutuion path. A segment can be qualified by values of -// following types: string/int64_t/uint64/bool. -class CelAttributeQualifier { - private: - // Helper class, used to implement CelAttributeQualifier::operator==. - class EqualVisitor { - public: - template - class NestedEqualVisitor { - public: - explicit NestedEqualVisitor(const T& arg) : arg_(arg) {} - - template - bool operator()(const U&) const { - return false; - } - - bool operator()(const T& other) const { return other == arg_; } - - private: - const T& arg_; - }; - - explicit EqualVisitor(const CelValue& other) : other_(other) {} - - template - bool operator()(const Type& arg) { - return other_.template Visit(NestedEqualVisitor(arg)); - } - - private: - const CelValue& other_; - }; - - CelValue value_; - - explicit CelAttributeQualifier(CelValue value) : value_(value) {} - - public: - // Factory method. - static CelAttributeQualifier Create(CelValue value) { - return CelAttributeQualifier(value); - } +// following types: string/int64_t/uint64_t/bool. +using CelAttributeQualifier = ::cel::AttributeQualifier; - template - T Visit(Op&& operation) const { - return value_.Visit(operation); - } - - // Family of Get... methods. Return values if requested type matches the - // stored one. - absl::optional GetInt64Key() const { - return (value_.IsInt64()) ? absl::optional(value_.Int64OrDie()) - : absl::nullopt; - } - - absl::optional GetUint64Key() const { - return (value_.IsUint64()) ? absl::optional(value_.Uint64OrDie()) - : absl::nullopt; - } - - absl::optional GetStringKey() const { - return (value_.IsString()) - ? absl::optional(value_.StringOrDie().value()) - : absl::nullopt; - } - - absl::optional GetBoolKey() const { - return (value_.IsBool()) ? absl::optional(value_.BoolOrDie()) - : absl::nullopt; - } - - bool operator==(const CelAttributeQualifier& other) const { - return IsMatch(other.value_); - } - - bool IsMatch(const CelValue& cel_value) const { - return value_.template Visit(EqualVisitor(cel_value)); - } - - bool IsMatch(absl::string_view other_key) { - absl::optional key = GetStringKey(); - return (key.has_value() && key.value() == other_key); - } -}; +// CelAttribute represents resolved attribute path. +using CelAttribute = ::cel::Attribute; // CelAttributeQualifierPattern matches a segment in // attribute resolutuion path. CelAttributeQualifierPattern is capable of -// matching path elements of types string/int64_t/uint64/bool. -class CelAttributeQualifierPattern { - private: - // Qualifier value. If not set, treated as wildcard. - absl::optional value_; - - CelAttributeQualifierPattern(absl::optional value) - : value_(value) {} - - public: - // Factory method. - static CelAttributeQualifierPattern Create(CelValue value) { - return CelAttributeQualifierPattern(CelAttributeQualifier::Create(value)); - } - - static CelAttributeQualifierPattern CreateWildcard() { - return CelAttributeQualifierPattern(absl::nullopt); - } - - bool IsWildcard() const { return !value_.has_value(); } - - bool IsMatch(const CelAttributeQualifier& qualifier) const { - if (IsWildcard()) return true; - return value_.value() == qualifier; - } - - bool IsMatch(const CelValue& cel_value) const { - if (!value_.has_value()) { - switch (cel_value.type()) { - case CelValue::Type::kInt64: - case CelValue::Type::kUint64: - case CelValue::Type::kString: - case CelValue::Type::kBool: { - return true; - } - default: { - return false; - } - } - } - return value_.value().IsMatch(cel_value); - } - - bool IsMatch(absl::string_view other_key) { - if (!value_.has_value()) return true; - return value_.value().IsMatch(other_key); - } -}; - -// CelAttribute represents resolved attribute path. -class CelAttribute { - public: - CelAttribute(google::api::expr::v1alpha1::Expr variable, - std::vector qualifier_path) - : variable_(std::move(variable)), - qualifier_path_(std::move(qualifier_path)) {} - - const google::api::expr::v1alpha1::Expr& variable() const { return variable_; } - - const std::vector& qualifier_path() const { - return qualifier_path_; - } - - bool operator==(const CelAttribute& other) const { - // TODO(issues/41) we only support Ident-rooted attributes at the moment. - if (!variable().has_ident_expr() || !other.variable().has_ident_expr()) { - return false; - } - - if (variable().ident_expr().name() != - other.variable().ident_expr().name()) { - return false; - } - - if (qualifier_path().size() != other.qualifier_path().size()) { - return false; - } - - for (size_t i = 0; i < qualifier_path().size(); i++) { - if (!(qualifier_path()[i] == other.qualifier_path()[i])) { - return false; - } - } - - return true; - } - - private: - google::api::expr::v1alpha1::Expr variable_; - std::vector qualifier_path_; -}; +// matching path elements of types string/int64_t/uint64_t/bool. +using CelAttributeQualifierPattern = ::cel::AttributeQualifierPattern; // CelAttributePattern is a fully-qualified absolute attribute path pattern. // Supported segments steps in the path are: // - field selection; // - map lookup by key; // - list access by index. -class CelAttributePattern { - public: - // MatchType enum specifies how closely pattern is matching the attribute: - enum class MatchType { - NONE, // Pattern does not match attribute itself nor its children - PARTIAL, // Pattern matches an entity nested within attribute; - FULL // Pattern matches an attribute itself. - }; - - CelAttributePattern(std::string variable, - std::vector qualifier_path) - : variable_(std::move(variable)), - qualifier_path_(std::move(qualifier_path)) {} - - absl::string_view variable() const { return variable_; } - - const std::vector& qualifier_path() const { - return qualifier_path_; - } - - // Matches the pattern to an attribute. - // Distinguishes between no-match, partial match and full match cases. - MatchType IsMatch(const CelAttribute& attribute) const { - MatchType result = MatchType::NONE; - if (attribute.variable().ident_expr().name() != variable_) { - return result; - } - - auto max_index = qualifier_path().size(); - result = MatchType::FULL; - if (qualifier_path().size() > attribute.qualifier_path().size()) { - max_index = attribute.qualifier_path().size(); - result = MatchType::PARTIAL; - } +using CelAttributePattern = ::cel::AttributePattern; - for (size_t i = 0; i < max_index; i++) { - if (!(qualifier_path()[i].IsMatch(attribute.qualifier_path()[i]))) { - return MatchType::NONE; - } - } - return result; - } +CelAttributeQualifierPattern CreateCelAttributeQualifierPattern( + const CelValue& value); - private: - std::string variable_; - std::vector qualifier_path_; -}; +CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value); // Short-hand helper for creating |CelAttributePattern|s. string_view arguments // must outlive the returned pattern. CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec = {}); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ diff --git a/eval/public/cel_attribute_test.cc b/eval/public/cel_attribute_test.cc index 0bba473b6..b72189332 100644 --- a/eval/public/cel_attribute_test.cc +++ b/eval/public/cel_attribute_test.cc @@ -1,32 +1,36 @@ #include "eval/public/cel_attribute.h" -#include "google/protobuf/arena.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { +namespace { -using ::google::protobuf::Duration; -using ::google::protobuf::Timestamp; +using cel::expr::Expr; -using testing::Eq; -using testing::IsEmpty; -using testing::SizeIs; - -namespace { +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; class DummyMap : public CelMap { public: absl::optional operator[](CelValue value) const override { return CelValue::CreateNull(); } - const CelList* ListKeys() const override { return nullptr; } + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } int size() const override { return 0; } }; @@ -41,28 +45,30 @@ class DummyList : public CelList { }; TEST(CelAttributeQualifierTest, TestBoolAccess) { - auto qualifier = CelAttributeQualifier::Create(CelValue::CreateBool(true)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateBool(true)); EXPECT_FALSE(qualifier.GetStringKey().has_value()); EXPECT_FALSE(qualifier.GetInt64Key().has_value()); EXPECT_FALSE(qualifier.GetUint64Key().has_value()); EXPECT_TRUE(qualifier.GetBoolKey().has_value()); EXPECT_THAT(qualifier.GetBoolKey().value(), Eq(true)); + EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("true")); } TEST(CelAttributeQualifierTest, TestInt64Access) { - auto qualifier = CelAttributeQualifier::Create(CelValue::CreateInt64(1)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateInt64(-1)); EXPECT_FALSE(qualifier.GetBoolKey().has_value()); EXPECT_FALSE(qualifier.GetStringKey().has_value()); EXPECT_FALSE(qualifier.GetUint64Key().has_value()); EXPECT_TRUE(qualifier.GetInt64Key().has_value()); - EXPECT_THAT(qualifier.GetInt64Key().value(), Eq(1)); + EXPECT_THAT(qualifier.GetInt64Key().value(), Eq(-1)); + EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("-1")); } TEST(CelAttributeQualifierTest, TestUint64Access) { - auto qualifier = CelAttributeQualifier::Create(CelValue::CreateUint64(1)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateUint64(1)); EXPECT_FALSE(qualifier.GetBoolKey().has_value()); EXPECT_FALSE(qualifier.GetStringKey().has_value()); @@ -70,11 +76,12 @@ TEST(CelAttributeQualifierTest, TestUint64Access) { EXPECT_TRUE(qualifier.GetUint64Key().has_value()); EXPECT_THAT(qualifier.GetUint64Key().value(), Eq(1UL)); + EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("1")); } TEST(CelAttributeQualifierTest, TestStringAccess) { const std::string test = "test"; - auto qualifier = CelAttributeQualifier::Create(CelValue::CreateString(&test)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateString(&test)); EXPECT_FALSE(qualifier.GetBoolKey().has_value()); EXPECT_FALSE(qualifier.GetInt64Key().has_value()); @@ -82,201 +89,122 @@ TEST(CelAttributeQualifierTest, TestStringAccess) { EXPECT_TRUE(qualifier.GetStringKey().has_value()); EXPECT_THAT(qualifier.GetStringKey().value(), Eq("test")); + EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("test")); } void TestAllInequalities(const CelAttributeQualifier& qualifier) { EXPECT_FALSE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateBool(false))); + CreateCelAttributeQualifier(CelValue::CreateBool(false))); EXPECT_FALSE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateInt64(0))); + CreateCelAttributeQualifier(CelValue::CreateInt64(0))); EXPECT_FALSE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateUint64(0))); + CreateCelAttributeQualifier(CelValue::CreateUint64(0))); const std::string test = "Those are not the droids you are looking for."; EXPECT_FALSE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateString(&test))); + CreateCelAttributeQualifier(CelValue::CreateString(&test))); } TEST(CelAttributeQualifierTest, TestBoolComparison) { - auto qualifier = CelAttributeQualifier::Create(CelValue::CreateBool(true)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateBool(true)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateBool(true))); + CreateCelAttributeQualifier(CelValue::CreateBool(true))); } TEST(CelAttributeQualifierTest, TestInt64Comparison) { - auto qualifier = CelAttributeQualifier::Create(CelValue::CreateInt64(true)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateInt64(true)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateInt64(true))); + CreateCelAttributeQualifier(CelValue::CreateInt64(true))); } TEST(CelAttributeQualifierTest, TestUint64Comparison) { - auto qualifier = CelAttributeQualifier::Create(CelValue::CreateUint64(true)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateUint64(true)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateUint64(true))); + CreateCelAttributeQualifier(CelValue::CreateUint64(true))); } TEST(CelAttributeQualifierTest, TestStringComparison) { const std::string kTest = "test"; - auto qualifier = - CelAttributeQualifier::Create(CelValue::CreateString(&kTest)); + auto qualifier = CreateCelAttributeQualifier(CelValue::CreateString(&kTest)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == - CelAttributeQualifier::Create(CelValue::CreateString(&kTest))); -} - -void TestAllCelValueMismatches(const CelAttributeQualifierPattern& qualifier) { - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateNull())); - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateBool(false))); - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateInt64(0))); - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateUint64(0))); - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateDouble(0.))); - - const std::string kStr = "Those are not the droids you are looking for."; - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateString(&kStr))); - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateBytes(&kStr))); - - Duration msg_duration; - msg_duration.set_seconds(0); - msg_duration.set_nanos(0); - EXPECT_FALSE( - qualifier.IsMatch(CelProtoWrapper::CreateDuration(&msg_duration))); - - Timestamp msg_timestamp; - msg_timestamp.set_seconds(0); - msg_timestamp.set_nanos(0); - EXPECT_FALSE( - qualifier.IsMatch(CelProtoWrapper::CreateTimestamp(&msg_timestamp))); - - DummyList dummy_list; - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateList(&dummy_list))); - - DummyMap dummy_map; - EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateMap(&dummy_map))); - - google::protobuf::Arena arena; - EXPECT_FALSE(qualifier.IsMatch(CreateErrorValue(&arena, kStr))); + CreateCelAttributeQualifier(CelValue::CreateString(&kTest))); } void TestAllQualifierMismatches(const CelAttributeQualifierPattern& qualifier) { const std::string test = "Those are not the droids you are looking for."; EXPECT_FALSE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateBool(false)))); - EXPECT_FALSE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateInt64(0)))); + CreateCelAttributeQualifier(CelValue::CreateBool(false)))); + EXPECT_FALSE( + qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(0)))); EXPECT_FALSE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateUint64(0)))); + CreateCelAttributeQualifier(CelValue::CreateUint64(0)))); EXPECT_FALSE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateString(&test)))); -} - -TEST(CelAttributeQualifierPatternTest, TestCelValueBoolMatch) { - auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateBool(true)); - - TestAllCelValueMismatches(qualifier); - - CelValue value_match = CelValue::CreateBool(true); - - EXPECT_TRUE(qualifier.IsMatch(value_match)); -} - -TEST(CelAttributeQualifierPatternTest, TestCelValueInt64Match) { - auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)); - - TestAllCelValueMismatches(qualifier); - - CelValue value_match = CelValue::CreateInt64(1); - - EXPECT_TRUE(qualifier.IsMatch(value_match)); -} - -TEST(CelAttributeQualifierPatternTest, TestCelValueUint64Match) { - auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateUint64(1)); - - TestAllCelValueMismatches(qualifier); - - CelValue value_match = CelValue::CreateUint64(1); - - EXPECT_TRUE(qualifier.IsMatch(value_match)); -} - -TEST(CelAttributeQualifierPatternTest, TestCelValueStringMatch) { - std::string kTest = "test"; - auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateString(&kTest)); - - TestAllCelValueMismatches(qualifier); - - CelValue value_match = CelValue::CreateString(&kTest); - - EXPECT_TRUE(qualifier.IsMatch(value_match)); + CreateCelAttributeQualifier(CelValue::CreateString(&test)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierBoolMatch) { auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateBool(true)); + CreateCelAttributeQualifierPattern(CelValue::CreateBool(true)); TestAllQualifierMismatches(qualifier); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateBool(true)))); + CreateCelAttributeQualifier(CelValue::CreateBool(true)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierInt64Match) { - auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)); + auto qualifier = CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)); TestAllQualifierMismatches(qualifier); - EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateInt64(1)))); + EXPECT_TRUE( + qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(1)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierUint64Match) { auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateUint64(1)); + CreateCelAttributeQualifierPattern(CelValue::CreateUint64(1)); TestAllQualifierMismatches(qualifier); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateUint64(1)))); + CreateCelAttributeQualifier(CelValue::CreateUint64(1)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierStringMatch) { const std::string test = "test"; auto qualifier = - CelAttributeQualifierPattern::Create(CelValue::CreateString(&test)); + CreateCelAttributeQualifierPattern(CelValue::CreateString(&test)); TestAllQualifierMismatches(qualifier); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateString(&test)))); + CreateCelAttributeQualifier(CelValue::CreateString(&test)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierWildcardMatch) { auto qualifier = CelAttributeQualifierPattern::CreateWildcard(); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateBool(false)))); + CreateCelAttributeQualifier(CelValue::CreateBool(false)))); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateBool(true)))); - EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateInt64(0)))); - EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateInt64(1)))); + CreateCelAttributeQualifier(CelValue::CreateBool(true)))); + EXPECT_TRUE( + qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(0)))); + EXPECT_TRUE( + qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(1)))); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateUint64(0)))); + CreateCelAttributeQualifier(CelValue::CreateUint64(0)))); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateUint64(1)))); + CreateCelAttributeQualifier(CelValue::CreateUint64(1)))); const std::string kTest1 = "test1"; const std::string kTest2 = "test2"; EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateString(&kTest1)))); + CreateCelAttributeQualifier(CelValue::CreateString(&kTest1)))); EXPECT_TRUE(qualifier.IsMatch( - CelAttributeQualifier::Create(CelValue::CreateString(&kTest2)))); + CreateCelAttributeQualifier(CelValue::CreateString(&kTest2)))); } TEST(CreateCelAttributePattern, Basic) { @@ -287,11 +215,6 @@ TEST(CreateCelAttributePattern, Basic) { EXPECT_THAT(pattern.variable(), Eq("abc")); ASSERT_THAT(pattern.qualifier_path(), SizeIs(5)); - EXPECT_TRUE( - pattern.qualifier_path()[0].IsMatch(CelValue::CreateStringView(kTest))); - EXPECT_TRUE(pattern.qualifier_path()[1].IsMatch(CelValue::CreateUint64(1))); - EXPECT_TRUE(pattern.qualifier_path()[2].IsMatch(CelValue::CreateInt64(-1))); - EXPECT_TRUE(pattern.qualifier_path()[3].IsMatch(CelValue::CreateBool(false))); EXPECT_TRUE(pattern.qualifier_path()[4].IsWildcard()); } @@ -314,9 +237,82 @@ TEST(CreateCelAttributePattern, Wildcards) { EXPECT_TRUE(pattern.qualifier_path()[2].IsWildcard()); } -} // namespace +TEST(CelAttribute, AsStringBasic) { + CelAttribute attr( + "var", + { + CreateCelAttributeQualifier(CelValue::CreateStringView("qual1")), + CreateCelAttributeQualifier(CelValue::CreateStringView("qual2")), + CreateCelAttributeQualifier(CelValue::CreateStringView("qual3")), + }); + + ASSERT_OK_AND_ASSIGN(std::string string_format, attr.AsString()); + + EXPECT_EQ(string_format, "var.qual1.qual2.qual3"); +} -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +TEST(CelAttribute, AsStringInvalidRoot) { + CelAttribute attr( + "", { + CreateCelAttributeQualifier(CelValue::CreateStringView("qual1")), + CreateCelAttributeQualifier(CelValue::CreateStringView("qual2")), + CreateCelAttributeQualifier(CelValue::CreateStringView("qual3")), + }); + + EXPECT_EQ(attr.AsString().status().code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(CelAttribute, InvalidQualifiers) { + Expr expr; + expr.mutable_ident_expr()->set_name("var"); + google::protobuf::Arena arena; + + CelAttribute attr1("var", { + CreateCelAttributeQualifier( + CelValue::CreateDuration(absl::Minutes(2))), + }); + CelAttribute attr2("var", + { + CreateCelAttributeQualifier( + CelProtoWrapper::CreateMessage(&expr, &arena)), + }); + CelAttribute attr3( + "var", { + CreateCelAttributeQualifier(CelValue::CreateBool(false)), + }); + + // Implementation detail: Messages as attribute qualifiers are unsupported, + // so the implementation treats them inequal to any other. This is included + // for coverage. + EXPECT_FALSE(attr1 == attr2); + EXPECT_FALSE(attr2 == attr1); + EXPECT_FALSE(attr2 == attr2); + EXPECT_FALSE(attr1 == attr3); + EXPECT_FALSE(attr3 == attr1); + EXPECT_FALSE(attr2 == attr3); + EXPECT_FALSE(attr3 == attr2); + + // If the attribute includes an unsupported qualifier, return invalid argument + // error. + EXPECT_THAT(attr1.AsString(), StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(attr2.AsString(), StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CelAttribute, AsStringQualiferTypes) { + CelAttribute attr( + "var", + { + CreateCelAttributeQualifier(CelValue::CreateStringView("qual1")), + CreateCelAttributeQualifier(CelValue::CreateUint64(1)), + CreateCelAttributeQualifier(CelValue::CreateInt64(-1)), + CreateCelAttributeQualifier(CelValue::CreateBool(false)), + }); + + ASSERT_OK_AND_ASSIGN(std::string string_format, attr.AsString()); + + EXPECT_EQ(string_format, "var.qual1[1][-1][false]"); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_builtins.h b/eval/public/cel_builtins.h index 259c801f7..f03e02f8c 100644 --- a/eval/public/cel_builtins.h +++ b/eval/public/cel_builtins.h @@ -1,80 +1,15 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_BUILTINS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_BUILTINS_H_ +#include "base/builtins.h" + namespace google { namespace api { namespace expr { namespace runtime { -// Constants specifying names for CEL builtins. -namespace builtin { - -// Comparison -constexpr char kEqual[] = "_==_"; -constexpr char kInequal[] = "_!=_"; -constexpr char kLess[] = "_<_"; -constexpr char kLessOrEqual[] = "_<=_"; -constexpr char kGreater[] = "_>_"; -constexpr char kGreaterOrEqual[] = "_>=_"; - -// Logical -constexpr char kAnd[] = "_&&_"; -constexpr char kOr[] = "_||_"; -constexpr char kNot[] = "!_"; - -// Strictness -constexpr char kNotStrictlyFalse[] = "@not_strictly_false"; -// Deprecated '__not_strictly_false__' function. Preserved for backwards -// compatibility with stored expressions. -constexpr char kNotStrictlyFalseDeprecated[] = "__not_strictly_false__"; - -// Arithmetical -constexpr char kAdd[] = "_+_"; -constexpr char kSubtract[] = "_-_"; -constexpr char kNeg[] = "-_"; -constexpr char kMultiply[] = "_*_"; -constexpr char kDivide[] = "_/_"; -constexpr char kModulo[] = "_%_"; - -// String operations -constexpr char kRegexMatch[] = "matches"; -constexpr char kStringContains[] = "contains"; -constexpr char kStringEndsWith[] = "endsWith"; -constexpr char kStringStartsWith[] = "startsWith"; - -// Container operations -constexpr char kIn[] = "@in"; -// Deprecated '_in_' operator. Preserved for backwards compatibility with stored -// expressions. -constexpr char kInDeprecated[] = "_in_"; -// Deprecated 'in()' function. Preserved for backwards compatibility with stored -// expressions. -constexpr char kInFunction[] = "in"; -constexpr char kIndex[] = "_[_]"; -constexpr char kSize[] = "size"; - -constexpr char kTernary[] = "_?_:_"; - -// Timestamp and Duration -constexpr char kDuration[] = "duration"; -constexpr char kTimestamp[] = "timestamp"; -constexpr char kFullYear[] = "getFullYear"; -constexpr char kMonth[] = "getMonth"; -constexpr char kDayOfYear[] = "getDayOfYear"; -constexpr char kDayOfMonth[] = "getDayOfMonth"; -constexpr char kDate[] = "getDate"; -constexpr char kDayOfWeek[] = "getDayOfWeek"; -constexpr char kHours[] = "getHours"; -constexpr char kMinutes[] = "getMinutes"; -constexpr char kSeconds[] = "getSeconds"; -constexpr char kMilliseconds[] = "getMilliseconds"; - -// Type conversions -// TODO(issues/23): Add other type conversion methods. -constexpr char kInt[] = "int"; -constexpr char kString[] = "string"; - -} // namespace builtin +// Alias new namespace until external CEL users can be updated. +namespace builtin = cel::builtin; } // namespace runtime } // namespace expr diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 5e23df1ca..a56c450b0 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -1,43 +1,146 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include "eval/public/cel_expr_builder_factory.h" +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "common/kind.h" +#include "common/memory.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" +#include "eval/compiler/comprehension_vulnerability_check.h" +#include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/qualified_reference_resolver.h" +#include "eval/compiler/regex_precompilation_optimization.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" #include "eval/public/cel_options.h" +#include "extensions/select_optimization.h" +#include "internal/noop_delete.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::MemoryManagerRef; +using ::cel::extensions::CreateSelectOptimizationProgramOptimizer; +using ::cel::extensions::kCelAttribute; +using ::cel::extensions::kCelHasField; +using ::cel::extensions::SelectOptimizationAstUpdater; +using ::cel::runtime_internal::CreateConstantFoldingOptimizer; +using ::cel::runtime_internal::RuntimeEnv; -namespace google { -namespace api { -namespace expr { -namespace runtime { +} // namespace std::unique_ptr CreateCelExpressionBuilder( + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options) { - auto builder = absl::make_unique(); - builder->set_shortcircuiting(options.short_circuiting); - builder->set_constant_folding(options.constant_folding, - options.constant_arena); - builder->set_enable_comprehension(options.enable_comprehension); - builder->set_comprehension_max_iterations( - options.comprehension_max_iterations); - builder->set_fail_on_warnings(options.fail_on_warnings); - - switch (options.unknown_processing) { - case UnknownProcessingOptions::kAttributeAndFunction: - builder->set_enable_unknown_function_results(true); - builder->set_enable_unknowns(true); - break; - case UnknownProcessingOptions::kAttributeOnly: - builder->set_enable_unknowns(true); - break; - case UnknownProcessingOptions::kDisabled: - break; + if (descriptor_pool == nullptr) { + ABSL_LOG(ERROR) << "Cannot pass nullptr as descriptor pool to " + "CreateCelExpressionBuilder"; + return nullptr; } - builder->set_enable_missing_attribute_errors( - options.enable_missing_attribute_errors); + cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); + absl_nullable std::shared_ptr shared_message_factory; + if (message_factory != nullptr) { + shared_message_factory = std::shared_ptr( + message_factory, + cel::internal::NoopDeleteFor()); + } + auto env = std::make_shared( + std::shared_ptr( + descriptor_pool, + cel::internal::NoopDeleteFor()), + shared_message_factory); + if (auto status = env->Initialize(); !status.ok()) { + ABSL_LOG(ERROR) << "Failed to validate standard message types: " + << status.ToString(); // NOLINT: OSS compatibility + return nullptr; + } + auto builder = std::make_unique( + std::move(env), runtime_options); + + FlatExprBuilder& flat_expr_builder = builder->flat_expr_builder(); + + flat_expr_builder.AddAstTransform(NewReferenceResolverExtension( + (options.enable_qualified_identifier_rewrites) + ? ReferenceResolverOption::kAlways + : ReferenceResolverOption::kCheckedOnly)); + + if (options.enable_comprehension_vulnerability_check) { + builder->flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + } + + if (options.constant_folding) { + std::shared_ptr shared_arena; + if (options.constant_arena != nullptr) { + shared_arena = std::shared_ptr( + options.constant_arena, + cel::internal::NoopDeleteFor()); + } + builder->flat_expr_builder().AddProgramOptimizer( + CreateConstantFoldingOptimizer(std::move(shared_arena), + std::move(shared_message_factory))); + } + + if (options.enable_regex_precompilation) { + flat_expr_builder.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options.regex_max_program_size)); + } + + if (options.enable_select_optimization) { + // Add AST transform to update select branches on a stored + // CheckedExpression. This may already be performed by a type checker. + flat_expr_builder.AddAstTransform( + std::make_unique()); + // Add overloads for select optimization signature. + // These are never bound, only used to prevent the builder from failing on + // the overloads check. + absl::Status status = + builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( + kCelAttribute, false, {cel::Kind::kAny, cel::Kind::kList})); + if (!status.ok()) { + ABSL_LOG(ERROR) << "Failed to register " << kCelAttribute << ": " + << status; + } + status = builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( + kCelHasField, false, {cel::Kind::kAny, cel::Kind::kList})); + if (!status.ok()) { + ABSL_LOG(ERROR) << "Failed to register " << kCelHasField << ": " + << status; + } + // Add runtime implementation. + flat_expr_builder.AddProgramOptimizer( + CreateSelectOptimizationProgramOptimizer()); + } - return std::move(builder); + return builder; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_expr_builder_factory.h b/eval/public/cel_expr_builder_factory.h index f3f08d991..61450069f 100644 --- a/eval/public/cel_expr_builder_factory.h +++ b/eval/public/cel_expr_builder_factory.h @@ -1,8 +1,13 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ +#include + +#include "absl/base/attributes.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google { namespace api { @@ -11,8 +16,25 @@ namespace runtime { // Factory creates CelExpressionBuilder implementation for public use. std::unique_ptr CreateCelExpressionBuilder( + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options = InterpreterOptions()); +ABSL_DEPRECATED( + "This overload uses the generated descriptor pool, which allows " + "expressions to create any messages linked into the binary. This is not " + "hermetic and potentially dangerous, you should select the descriptor pool " + "carefully. Use the other overload and explicitly pass your descriptor " + "pool. It can still be the generated descriptor pool, but the choice " + "should be explicit. If you do not need struct creation, use " + "`cel::GetMinimalDescriptorPool()`.") +inline std::unique_ptr CreateCelExpressionBuilder( + const InterpreterOptions& options = InterpreterOptions()) { + return CreateCelExpressionBuilder(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + options); +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index cf189b281..4cf029e89 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -1,22 +1,24 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ +#include #include +#include +#include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "eval/public/activation.h" -#include "eval/public/cel_function.h" +#include "eval/public/base_activation.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // CelEvaluationListener is the callback that is passed to (and called by) -// CelEvaluation::Trace. It gets an expression node ID from the original +// CelExpression::Trace. It gets an expression node ID from the original // expression, its value and the arena object. If an expression node // is evaluated multiple times (e.g. as a part of Comprehension.loop_step) // then the order of the callback invocations is guaranteed to correspond @@ -25,7 +27,7 @@ namespace runtime { using CelEvaluationListener = std::function; -// An opaque state used for evaluation of a cell expression. +// An opaque state used for evaluation of a CEL expression. class CelEvaluationState { public: virtual ~CelEvaluationState() = default; @@ -73,59 +75,70 @@ class CelExpression { // it built. class CelExpressionBuilder { public: - CelExpressionBuilder() - : registry_(absl::make_unique()) {} + CelExpressionBuilder() = default; - virtual ~CelExpressionBuilder() {} + virtual ~CelExpressionBuilder() = default; // Creates CelExpression object from AST tree. - // expr specifies root of AST tree + // expr specifies root of AST tree. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info) const = 0; + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info) const = 0; // Creates CelExpression object from AST tree. // expr specifies root of AST tree. // non-fatal build warnings are written to warnings if encountered. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, std::vector* warnings) const = 0; - // CelFunction registry. Extension function should be registered with it - // prior to expression creation. - CelFunctionRegistry* GetRegistry() const { return registry_.get(); } - - // Enums registered with the builder. - const std::set& resolvable_enums() const { - return resolvable_enums_; + // Creates CelExpression object from a checked expression. + // This includes an AST, source info, type hints and ident hints. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. + virtual absl::StatusOr> CreateExpression( + const cel::expr::CheckedExpr* checked_expr) const { + // Default implementation just passes through the expr and source info. + return CreateExpression(&checked_expr->expr(), + &checked_expr->source_info()); } - // Add Enum to the list of resolvable by the builder. - void AddResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { - resolvable_enums_.emplace(enum_descriptor); + // Creates CelExpression object from a checked expression. + // This includes an AST, source info, type hints and ident hints. + // non-fatal build warnings are written to warnings if encountered. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. + virtual absl::StatusOr> CreateExpression( + const cel::expr::CheckedExpr* checked_expr, + std::vector* warnings) const { + // Default implementation just passes through the expr and source_info. + return CreateExpression(&checked_expr->expr(), &checked_expr->source_info(), + warnings); } - // Remove Enum from the list of resolvable by the builder. - void RemoveResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { - resolvable_enums_.erase(enum_descriptor); - } + // CelFunction registry. Extension function should be registered with it + // prior to expression creation. + virtual CelFunctionRegistry* GetRegistry() const = 0; - void set_container(std::string container) { - container_ = std::move(container); - } + // CEL Type registry. Provides a means to resolve the CEL built-in types to + // CelValue instances, and to extend the set of types and enums known to + // expressions by registering them ahead of time. + virtual CelTypeRegistry* GetTypeRegistry() const = 0; - absl::string_view container() const { return container_; } + virtual void set_container(std::string container) = 0; - private: - std::unique_ptr registry_; - std::set resolvable_enums_; - std::string container_; + virtual absl::string_view container() const = 0; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ diff --git a/eval/public/cel_function.cc b/eval/public/cel_function.cc index d97fd8ce5..9b760d1ec 100644 --- a/eval/public/cel_function.cc +++ b/eval/public/cel_function.cc @@ -1,42 +1,48 @@ #include "eval/public/cel_function.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { - -bool CelFunctionDescriptor::ShapeMatches( - bool receiver_style, const std::vector& types) const { - if (receiver_style_ != receiver_style) { - return false; - } +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/value.h" +#include "eval/internal/interop.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "runtime/function.h" + +namespace google::api::expr::runtime { - if (types_.size() != types.size()) { +using ::cel::Value; +using ::cel::interop_internal::ToLegacyValue; + +bool CelFunction::MatchArguments(absl::Span arguments) const { + auto types_size = descriptor().types().size(); + + if (types_size != arguments.size()) { return false; } - - for (size_t i = 0; i < types_.size(); i++) { - CelValue::Type this_type = types_[i]; - CelValue::Type other_type = types[i]; - if (this_type != CelValue::Type::kAny && - other_type != CelValue::Type::kAny && this_type != other_type) { + for (size_t i = 0; i < types_size; i++) { + const auto& value = arguments[i]; + CelValue::Type arg_type = descriptor().types()[i]; + if (value.type() != arg_type && arg_type != CelValue::Type::kAny) { return false; } } + return true; } -bool CelFunction::MatchArguments(absl::Span arguments) const { +bool CelFunction::MatchArguments(absl::Span arguments) const { auto types_size = descriptor().types().size(); if (types_size != arguments.size()) { return false; } - for (size_t i = 0; i < types_size; i++) { const auto& value = arguments[i]; CelValue::Type arg_type = descriptor().types()[i]; - if (value.type() != arg_type && arg_type != CelValue::Type::kAny) { + if (value->kind() != arg_type && arg_type != CelValue::Type::kAny) { return false; } } @@ -44,7 +50,27 @@ bool CelFunction::MatchArguments(absl::Span arguments) const { return true; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +absl::StatusOr CelFunction::Invoke( + absl::Span arguments, + const cel::Function::InvokeContext& context) const { + std::vector legacy_args; + legacy_args.reserve(arguments.size()); + + // Users shouldn't be able to create expressions that call registered + // functions with unconvertible types, but it's possible to create an AST that + // can trigger this by making an unexpected call on a value that the + // interpreter expects to only be used with internal program steps. + for (const auto& arg : arguments) { + CEL_ASSIGN_OR_RETURN(legacy_args.emplace_back(), + ToLegacyValue(context.arena(), arg, true)); + } + + CelValue legacy_result; + + CEL_RETURN_IF_ERROR(Evaluate(legacy_args, &legacy_result, context.arena())); + + return cel::interop_internal::LegacyValueToModernValueOrDie( + context.arena(), legacy_result, /*unchecked=*/true); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index 7e1dc0275..6c9ff2e7a 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -1,47 +1,22 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ -#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" +#include "common/function_descriptor.h" +#include "common/value.h" #include "eval/public/cel_value.h" +#include "runtime/function.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // Type that describes CelFunction. // This complex structure is needed for overloads support. -class CelFunctionDescriptor { - public: - CelFunctionDescriptor(const std::string& name, bool receiver_style, - const std::vector types) - : name_(name), receiver_style_(receiver_style), types_(types) {} - - // Function name. - const std::string& name() const { return name_; } - - // Whether function is receiver style i.e. true means arg0.name(args[1:]...). - bool receiver_style() const { return receiver_style_; } - - // The argmument types the function accepts. - const std::vector& types() const { return types_; } - - // Helper for matching a descriptor. This tests that the shape is the same -- - // |other| accepts the same number and types of arguments and is the same call - // style). - bool ShapeMatches(const CelFunctionDescriptor& other) const { - return ShapeMatches(other.receiver_style(), other.types()); - } - bool ShapeMatches(bool receiver_style, - const std::vector& types) const; - - private: - std::string name_; - bool receiver_style_; - std::vector types_; -}; +using CelFunctionDescriptor = ::cel::FunctionDescriptor; // CelFunction is a handler that represents single // CEL function. @@ -53,17 +28,17 @@ class CelFunctionDescriptor { // - amount of arguments and their types. // Function overloads are resolved based on their arguments and // receiver style. -class CelFunction { +class CelFunction : public ::cel::Function { public: // Build CelFunction from descriptor - explicit CelFunction(const CelFunctionDescriptor& descriptor) - : descriptor_(descriptor) {} + explicit CelFunction(CelFunctionDescriptor descriptor) + : descriptor_(std::move(descriptor)) {} // Non-copyable CelFunction(const CelFunction& other) = delete; CelFunction& operator=(const CelFunction& other) = delete; - virtual ~CelFunction() {} + ~CelFunction() override = default; // Evaluates CelValue based on arguments supplied. // If result content is to be allocated (e.g. string concatenation), @@ -84,6 +59,15 @@ class CelFunction { // Method is called during runtime. bool MatchArguments(absl::Span arguments) const; + bool MatchArguments(absl::Span arguments) const; + + // Implements cel::Function. + using cel::Function::Invoke; + + absl::StatusOr Invoke( + absl::Span arguments, + const cel::Function::InvokeContext& context) const final; + // CelFunction descriptor const CelFunctionDescriptor& descriptor() const { return descriptor_; } @@ -91,9 +75,6 @@ class CelFunction { CelFunctionDescriptor descriptor_; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ diff --git a/eval/public/cel_function_adapter.cc b/eval/public/cel_function_adapter.cc deleted file mode 100644 index ee82673c8..000000000 --- a/eval/public/cel_function_adapter.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "eval/public/cel_function_adapter.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -namespace internal { - -template <> -absl::optional TypeCodeMatch() { - return CelValue::Type::kAny; -} - - -} // namespace internal - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index 74ab5848b..01b07045d 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -1,302 +1,121 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ + +#include #include +#include +#include +#include #include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_function_adapter_impl.h" +#include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "absl/status/statusor.h" +#include "google/protobuf/message.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace internal { -// TypeCodeMatch template function family -// Used for CEL type deduction based on C++ native -// type. -template -absl::optional TypeCodeMatch() { - int index = CelValue::IndexOf::value; - if (index < 0) return {}; - CelValue::Type arg_type = static_cast(index); - if (arg_type >= CelValue::Type::kAny) { - return {}; +// A type code matcher that adds support for google::protobuf::Message. +struct ProtoAdapterTypeCodeMatcher { + template + constexpr static std::optional type_code() { + if constexpr (std::is_same_v) { + return CelValue::Type::kMessage; + } else { + return internal::TypeCodeMatcher().type_code(); + } } - return arg_type; -} - -// A bit of a trick - to pass Any kind of value, we use generic -// CelValue parameters. -template <> -absl::optional TypeCodeMatch(); +}; -template -bool AddType(std::vector*) { - return true; -} +// A value converter that handles wrapping google::protobuf::Messages as CelValues. +struct ProtoAdapterValueConverter + : public internal::ValueConverterBase { + using BaseType = internal::ValueConverterBase; + using BaseType::NativeToValue; + using BaseType::ValueToNative; -// AddType template method -// Appends CEL type constant deduced from C++ type Type to descriptor -template -bool AddType(std::vector* arg_types) { - auto kind = TypeCodeMatch(); - if (!kind) { - return false; + absl::Status NativeToValue(const ::google::protobuf::Message* value, + ::google::protobuf::Arena* arena, CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null Message pointer returned"); + } + *result = CelProtoWrapper::CreateMessage(value, arena); + return absl::OkStatus(); } - - arg_types->push_back(kind.value()); - - return AddType(arg_types); - - return true; -} - +}; } // namespace internal // FunctionAdapter is a helper class that simplifies creation of CelFunction // implementations. -// It accepts method implementations as std::function, allowing -// them to be lambdas/regular C++ functions. CEL method descriptors are -// deduced based on C++ function signatures. +// +// The static Create member function accepts CelFunction::Evalaute method +// implementations as std::function, allowing them to be lambdas/regular C++ +// functions. CEL method descriptors ddeduced based on C++ function signatures. +// +// The adapted CelFunction::Evaluate implementation will set result to the +// value returned by the handler. To handle errors, choose CelValue as the +// return type, and use the CreateError/Create* helpers in cel_value.h. +// +// The wrapped std::function may return absl::StatusOr. If the wrapped +// function returns the absl::Status variant, the generated CelFunction +// implementation will return a non-ok status code, rather than a CelError +// wrapping an absl::Status value. A returned non-ok status indicates a hard +// error, meaning the interpreter cannot reasonably continue evaluation (e.g. +// data corruption or broken invariant). To create a CelError that follows +// logical pruning rules, the extension function implementation should return a +// CelError or an error-typed CelValue. +// +// FunctionAdapter +// ReturnType: the C++ return type of the function implementation +// Arguments: the C++ Argument type of the function implementation +// +// Static Methods: +// +// Create(absl::string_view function_name, bool receiver_style, +// FunctionType func) -> absl::StatusOr> +// +// Usage example: +// +// auto func = [](::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { +// return i < j; +// }; +// +// CEL_ASSIGN_OR_RETURN(auto cel_func, +// FunctionAdapter::Create("<", false, func)); +// +// CreateAndRegister(absl::string_view function_name, bool receiver_style, +// FunctionType func, CelFunctionRegisry registry) +// -> absl::Status // // Usage example: // -// auto func = [](google::protobuf::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { +// auto func = [](::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { // return i < j; // }; // -// auto func_status = -// FunctionAdapter::Create("<", false, func); +// CEL_RETURN_IF_ERROR(( +// FunctionAdapter::CreateAndRegister("<", false, +// func, cel_expression_builder->GetRegistry())); // -// if(func_status.ok()) { -// auto func = func_status.value(); -// } template -class FunctionAdapter : public CelFunction { - public: - using FuncType = std::function; - - FunctionAdapter(const CelFunctionDescriptor& descriptor, FuncType handler) - : CelFunction(descriptor), handler_(std::move(handler)) {} - - static absl::StatusOr> Create( - absl::string_view name, bool receiver_type, - std::function handler) { - std::vector arg_types; - - if (!internal::AddType<0, Arguments...>(&arg_types)) { - return absl::Status( - absl::StatusCode::kInternal, - absl::StrCat("Failed to create adapter for ", name, - ": failed to determine input parameter type")); - } - - std::unique_ptr cel_func = absl::make_unique( - CelFunctionDescriptor(std::string(name), receiver_type, arg_types), - std::move(handler)); - return std::move(cel_func); - } - - // Creates function handler and attempts to register it with - // supplied function registry. - static absl::Status CreateAndRegister( - absl::string_view name, bool receiver_type, - std::function handler, - CelFunctionRegistry* registry) { - auto status = Create(name, receiver_type, std::move(handler)); - if (!status.ok()) { - return status.status(); - } - - return registry->Register(std::move(status.value())); - } - -#if defined(__clang_major_version__) && __clang_major_version__ >= 8 && !defined(__APPLE__) - template - inline absl::Status RunWrap(absl::Span arguments, - std::tuple<::google::protobuf::Arena*, Arguments...> input, - CelValue* result, ::google::protobuf::Arena* arena) const { - if (!ConvertFromValue(arguments[arg_index], - &std::get(input))) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Type conversion failed"); - } - return RunWrap(arguments, input, result, arena); - } - - template <> - inline absl::Status RunWrap( - absl::Span, - std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, - ::google::protobuf::Arena* arena) const { - return CreateReturnValue(absl::apply(handler_, input), arena, result); - } -#else - inline absl::Status RunWrap(std::function func, - const absl::Span argset, - ::google::protobuf::Arena* arena, CelValue* result, - int arg_index) const { - return CreateReturnValue(func(), arena, result); - } - - template - inline absl::Status RunWrap(std::function func, - const absl::Span argset, - ::google::protobuf::Arena* arena, CelValue* result, - int arg_index) const { - Arg argument; - if (!ConvertFromValue(argset[arg_index], &argument)) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Type conversion failed"); - } - - std::function wrapped_func = - [func, argument](Args... args) -> ReturnType { - return func(argument, args...); - }; - - return RunWrap(std::move(wrapped_func), argset, arena, result, - arg_index + 1); - } -#endif - - absl::Status Evaluate(absl::Span arguments, CelValue* result, - ::google::protobuf::Arena* arena) const override { - if (arguments.size() != sizeof...(Arguments)) { - return absl::Status(absl::StatusCode::kInternal, - "Argument number mismatch"); - } - -#if defined(__clang_major_version__) && __clang_major_version__ >= 8 && !defined(__APPLE__) - std::tuple<::google::protobuf::Arena*, Arguments...> input; - std::get<0>(input) = arena; - return RunWrap<0>(arguments, input, result, arena); -#else - const auto* handler = &handler_; - std::function wrapped_handler = - [handler, arena](Arguments... args) -> ReturnType { - return (*handler)(arena, args...); - }; - return RunWrap(std::move(wrapped_handler), arguments, arena, result, 0); -#endif - } - - private: - template - static bool ConvertFromValue(CelValue value, ArgType* result) { - return value.GetValue(result); - } - - // Special conversion - from CelValue to CelValue - plain copy - static bool ConvertFromValue(CelValue value, CelValue* result) { - *result = std::move(value); - return true; - } - - // CreateReturnValue method wraps evaluation result with CelValue. - static absl::Status CreateReturnValue(bool value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateBool(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(int64_t value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateInt64(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(uint64_t value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateUint64(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(double value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateDouble(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(CelValue::StringHolder value, - ::google::protobuf::Arena*, CelValue* result) { - *result = CelValue::CreateString(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(CelValue::BytesHolder value, - ::google::protobuf::Arena*, CelValue* result) { - *result = CelValue::CreateBytes(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const ::google::protobuf::Message* value, - ::google::protobuf::Arena* arena, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null Message pointer returned"); - } - *result = CelProtoWrapper::CreateMessage(value, arena); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelList* value, ::google::protobuf::Arena*, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null CelList pointer returned"); - } - *result = CelValue::CreateList(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelMap* value, ::google::protobuf::Arena*, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null CelMap pointer returned"); - } - *result = CelValue::CreateMap(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelError* value, ::google::protobuf::Arena*, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null CelError pointer returned"); - } - *result = CelValue::CreateError(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelValue& value, ::google::protobuf::Arena*, - CelValue* result) { - *result = value; - return absl::OkStatus(); - } - - template - static absl::Status CreateReturnValue(const absl::StatusOr& value, - ::google::protobuf::Arena*, CelValue*) { - if (!value) { - return value.status(); - } - return CreateReturnValue(value.value()); - } - - FuncType handler_; -}; - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +using FunctionAdapter = + internal::FunctionAdapterImpl:: + FunctionAdapter; + +template +using UnaryFunctionAdapter = internal::FunctionAdapterImpl< + internal::ProtoAdapterTypeCodeMatcher, + internal::ProtoAdapterValueConverter>::UnaryFunction; + +template +using BinaryFunctionAdapter = internal::FunctionAdapterImpl< + internal::ProtoAdapterTypeCodeMatcher, + internal::ProtoAdapterValueConverter>::BinaryFunction; + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ diff --git a/eval/public/cel_function_adapter_impl.h b/eval/public/cel_function_adapter_impl.h new file mode 100644 index 000000000..6cd661c10 --- /dev/null +++ b/eval/public/cel_function_adapter_impl.h @@ -0,0 +1,407 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" + +#if defined(__clang__) || !defined(__GNUC__) +// Do not disable. +#else +#define CEL_CPP_DISABLE_PARTIAL_SPECIALIZATION 1 +#endif + +namespace google::api::expr::runtime { + +namespace internal { +// TypeCodeMatch template helper. +// Used for CEL type deduction based on C++ native type. +struct TypeCodeMatcher { + template + constexpr static std::optional type_code() { + if constexpr (std::is_same_v) { + // A bit of a trick - to pass Any kind of value, we use generic CelValue + // parameters. + return CelValue::Type::kAny; + } else { + int index = CelValue::IndexOf::value; + if (index < 0) return {}; + CelValue::Type arg_type = static_cast(index); + if (arg_type >= CelValue::Type::kAny) { + return {}; + } + return arg_type; + } + } +}; + +// Template helper to construct an argument list for a CelFunctionDescriptor. +template +struct TypeAdder { + template + bool AddType(std::vector* arg_types) const { + auto kind = TypeCodeMatcher().template type_code(); + if (!kind) { + return false; + } + + arg_types->push_back(*kind); + + return AddType(arg_types); + + return true; + } + + template + bool AddType(std::vector* arg_types) const { + return true; + } +}; + +// Template helper for C++ types to CEL conversions. +// Uses CRTP to dispatch to derived class overloads in the StatusOr helper. +template +struct ValueConverterBase { + // Value to native uwraps a CelValue to a native type. + template + bool ValueToNative(CelValue value, T* result) { + if constexpr (std::is_same_v) { + *result = std::move(value); + return true; + } else { + return value.GetValue(result); + } + } + + // Native to value wraps a native return type to a CelValue. + absl::Status NativeToValue(bool value, ::google::protobuf::Arena*, CelValue* result) { + *result = CelValue::CreateBool(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(int64_t value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateInt64(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(uint64_t value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateUint64(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(double value, ::google::protobuf::Arena*, CelValue* result) { + *result = CelValue::CreateDouble(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(CelValue::StringHolder value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateString(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(CelValue::BytesHolder value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateBytes(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(const CelList* value, ::google::protobuf::Arena*, + CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null CelList pointer returned"); + } + *result = CelValue::CreateList(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(const CelMap* value, ::google::protobuf::Arena*, + CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null CelMap pointer returned"); + } + *result = CelValue::CreateMap(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(CelValue::CelTypeHolder value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateCelType(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(const CelError* value, ::google::protobuf::Arena*, + CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null CelError pointer returned"); + } + *result = CelValue::CreateError(value); + return absl::OkStatus(); + } + + // Special case -- just forward a CelValue. + absl::Status NativeToValue(const CelValue& value, ::google::protobuf::Arena*, + CelValue* result) { + *result = value; + return absl::OkStatus(); + } + + template + absl::Status NativeToValue(absl::StatusOr value, ::google::protobuf::Arena* arena, + CelValue* result) { + CEL_ASSIGN_OR_RETURN(auto held_value, value); + return Derived().NativeToValue(held_value, arena, result); + } +}; + +struct ValueConverter : public ValueConverterBase {}; + +// Generalized implementation for function adapters. See comments on +// instantiated versions for details on usage. +// +// TypeCodeMatcher provides the mapping from C++ type to CEL type. +// ValueConverter provides value conversions from native to CEL and vice versa. +// ReturnType and Arguments types are instantiated for the particular shape of +// the adapted functions. +template +class FunctionAdapterImpl { + public: + // Implementations for the common cases of unary and binary functions. + // This reduces the binary size substantially over the generic templated + // versions. + template + class BinaryFunction : public CelFunction { + public: + using FuncType = std::function; + + static std::unique_ptr Create(absl::string_view name, + bool receiver_style, + FuncType handler) { + constexpr auto arg1_type = TypeCodeMatcher::template type_code(); + static_assert(arg1_type.has_value(), "T does not map to a CEL type."); + constexpr auto arg2_type = TypeCodeMatcher::template type_code(); + static_assert(arg2_type.has_value(), "U does not map to a CEL type."); + std::vector arg_types{*arg1_type, *arg2_type}; + + return absl::WrapUnique(new BinaryFunction( + CelFunctionDescriptor(name, receiver_style, std::move(arg_types)), + std::move(handler))); + } + + absl::Status Evaluate(absl::Span arguments, + CelValue* result, + google::protobuf::Arena* arena) const override { + if (arguments.size() != 2) { + return absl::InternalError("Argument number mismatch, expected 2"); + } + T arg; + if (!ValueConverter().ValueToNative(arguments[0], &arg)) { + return absl::InternalError("C++ to CEL type conversion failed"); + } + U arg2; + if (!ValueConverter().ValueToNative(arguments[1], &arg2)) { + return absl::InternalError("C++ to CEL type conversion failed"); + } + ReturnType handlerResult = handler_(arena, arg, arg2); + return ValueConverter().NativeToValue(handlerResult, arena, result); + } + + private: + BinaryFunction(CelFunctionDescriptor descriptor, FuncType handler) + : CelFunction(descriptor), handler_(std::move(handler)) {} + + FuncType handler_; + }; + + template + class UnaryFunction : public CelFunction { + public: + using FuncType = std::function; + + static std::unique_ptr Create(absl::string_view name, + bool receiver_style, + FuncType handler) { + constexpr auto arg_type = TypeCodeMatcher::template type_code(); + static_assert(arg_type.has_value(), "T does not map to a CEL type."); + std::vector arg_types{*arg_type}; + + return absl::WrapUnique(new UnaryFunction( + CelFunctionDescriptor(name, receiver_style, std::move(arg_types)), + std::move(handler))); + } + + absl::Status Evaluate(absl::Span arguments, + CelValue* result, + google::protobuf::Arena* arena) const override { + if (arguments.size() != 1) { + return absl::InternalError("Argument number mismatch, expected 1"); + } + T arg; + if (!ValueConverter().ValueToNative(arguments[0], &arg)) { + return absl::InternalError("C++ to CEL type conversion failed"); + } + ReturnType handlerResult = handler_(arena, arg); + return ValueConverter().NativeToValue(handlerResult, arena, result); + } + + private: + UnaryFunction(CelFunctionDescriptor descriptor, FuncType handler) + : CelFunction(descriptor), handler_(std::move(handler)) {} + + FuncType handler_; + }; + + // Generalized implementation. + template + class FunctionAdapter : public CelFunction { + public: + using FuncType = std::function; + using TypeAdder = internal::TypeAdder; + + FunctionAdapter(CelFunctionDescriptor descriptor, FuncType handler) + : CelFunction(std::move(descriptor)), handler_(std::move(handler)) {} + + static absl::StatusOr> Create( + absl::string_view name, bool receiver_type, + std::function handler) { + std::vector arg_types; + arg_types.reserve(sizeof...(Arguments)); + + if (!TypeAdder().template AddType<0, Arguments...>(&arg_types)) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrCat("Failed to create adapter for ", name, + ": failed to determine input parameter type")); + } + + return std::make_unique( + CelFunctionDescriptor(name, receiver_type, std::move(arg_types)), + std::move(handler)); + } + + // Creates function handler and attempts to register it with + // supplied function registry. + static absl::Status CreateAndRegister( + absl::string_view name, bool receiver_type, + std::function handler, + CelFunctionRegistry* registry) { + CEL_ASSIGN_OR_RETURN(auto cel_function, + Create(name, receiver_type, std::move(handler))); + + return registry->Register(std::move(cel_function)); + } + +#if !defined(CEL_CPP_DISABLE_PARTIAL_SPECIALIZATION) + template + inline absl::Status RunWrap( + absl::Span arguments, + std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, + ::google::protobuf::Arena* arena) const { + if (!ValueConverter().ValueToNative(arguments[arg_index], + &std::get(input))) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Type conversion failed"); + } + return RunWrap(arguments, input, result, arena); + } + + template <> + inline absl::Status RunWrap( + absl::Span, + std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, + ::google::protobuf::Arena* arena) const { + return ValueConverter().NativeToValue(absl::apply(handler_, input), arena, + result); + } +#else + inline absl::Status RunWrap( + std::function func, + ABSL_ATTRIBUTE_UNUSED const absl::Span argset, + ::google::protobuf::Arena* arena, CelValue* result, + ABSL_ATTRIBUTE_UNUSED int arg_index) const { + return ValueConverter().NativeToValue(func(), arena, result); + } + + template + inline absl::Status RunWrap(std::function func, + const absl::Span argset, + ::google::protobuf::Arena* arena, CelValue* result, + int arg_index) const { + Arg argument; + if (!ValueConverter().ValueToNative(argset[arg_index], &argument)) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Type conversion failed"); + } + + std::function wrapped_func = + [func, argument](Args... args) -> ReturnType { + return func(argument, args...); + }; + + return RunWrap(std::move(wrapped_func), argset, arena, result, + arg_index + 1); + } +#endif + + absl::Status Evaluate(absl::Span arguments, + CelValue* result, + ::google::protobuf::Arena* arena) const override { + if (arguments.size() != sizeof...(Arguments)) { + return absl::Status(absl::StatusCode::kInternal, + "Argument number mismatch"); + } + +#if !defined(CEL_CPP_DISABLE_PARTIAL_SPECIALIZATION) + std::tuple<::google::protobuf::Arena*, Arguments...> input; + std::get<0>(input) = arena; + return RunWrap<0>(arguments, input, result, arena); +#else + const auto* handler = &handler_; + std::function wrapped_handler = + [handler, arena](Arguments... args) -> ReturnType { + return (*handler)(arena, args...); + }; + return RunWrap(std::move(wrapped_handler), arguments, arena, result, 0); +#endif + } + + private: + FuncType handler_; + }; +}; + +} // namespace internal + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ diff --git a/eval/public/cel_function_adapter_test.cc b/eval/public/cel_function_adapter_test.cc index 102b171de..29d27e5af 100644 --- a/eval/public/cel_function_adapter_test.cc +++ b/eval/public/cel_function_adapter_test.cc @@ -1,8 +1,12 @@ #include "eval/public/cel_function_adapter.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "base/status_macros.h" +#include +#include +#include +#include + +#include "internal/status_macros.h" +#include "internal/testing.h" namespace google { namespace api { @@ -13,34 +17,22 @@ namespace { TEST(CelFunctionAdapterTest, TestAdapterNoArg) { auto func = [](google::protobuf::Arena*) -> int64_t { return 100; }; - - auto func_status = FunctionAdapter::Create("const", false, func); - - ASSERT_OK(func_status); - - auto cel_func = std::move(func_status.value()); + ASSERT_OK_AND_ASSIGN( + auto cel_func, (FunctionAdapter::Create("const", false, func))); absl::Span args; - CelValue result = CelValue::CreateNull(); google::protobuf::Arena arena; - auto eval_status = cel_func->Evaluate(args, &result, &arena); - - ASSERT_OK(eval_status); - - ASSERT_TRUE( - result.IsInt64()); // Obvious failure, for educational purposes only. + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); + ASSERT_TRUE(result.IsInt64()); } TEST(CelFunctionAdapterTest, TestAdapterOneArg) { std::function func = [](google::protobuf::Arena* arena, int64_t i) -> int64_t { return i + 1; }; - - auto func_status = FunctionAdapter::Create("_++_", false, func); - - ASSERT_OK(func_status); - - auto cel_func = std::move(func_status.value()); + ASSERT_OK_AND_ASSIGN( + auto cel_func, + (FunctionAdapter::Create("_++_", false, func))); std::vector args_vec; args_vec.push_back(CelValue::CreateInt64(99)); @@ -49,11 +41,7 @@ TEST(CelFunctionAdapterTest, TestAdapterOneArg) { google::protobuf::Arena arena; absl::Span args(&args_vec[0], args_vec.size()); - - auto eval_status = cel_func->Evaluate(args, &result, &arena); - - ASSERT_OK(eval_status); - + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 100); } @@ -62,13 +50,9 @@ TEST(CelFunctionAdapterTest, TestAdapterTwoArgs) { auto func = [](google::protobuf::Arena* arena, int64_t i, int64_t j) -> int64_t { return i + j; }; - - auto func_status = - FunctionAdapter::Create("_++_", false, func); - - ASSERT_OK(func_status); - - auto cel_func = std::move(func_status.value()); + ASSERT_OK_AND_ASSIGN(auto cel_func, + (FunctionAdapter::Create( + "_++_", false, func))); std::vector args_vec; args_vec.push_back(CelValue::CreateInt64(20)); @@ -78,11 +62,7 @@ TEST(CelFunctionAdapterTest, TestAdapterTwoArgs) { google::protobuf::Arena arena; absl::Span args(&args_vec[0], args_vec.size()); - - auto eval_status = cel_func->Evaluate(args, &result, &arena); - - ASSERT_OK(eval_status); - + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 42); } @@ -97,14 +77,10 @@ TEST(CelFunctionAdapterTest, TestAdapterThreeArgs) { return StringHolder( google::protobuf::Arena::Create(arena, std::move(value))); }; - - auto func_status = - FunctionAdapter::Create("concat", false, func); - - ASSERT_OK(func_status); - - auto cel_func = std::move(func_status.value()); + ASSERT_OK_AND_ASSIGN( + auto cel_func, + (FunctionAdapter::Create("concat", false, func))); std::string test1 = "1"; std::string test2 = "2"; @@ -119,11 +95,7 @@ TEST(CelFunctionAdapterTest, TestAdapterThreeArgs) { google::protobuf::Arena arena; absl::Span args(&args_vec[0], args_vec.size()); - - auto eval_status = cel_func->Evaluate(args, &result, &arena); - - ASSERT_OK(eval_status); - + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "123"); } @@ -134,17 +106,13 @@ TEST(CelFunctionAdapterTest, TestTypeDeductionForCelValueBasicTypes) { const google::protobuf::Message*, absl::Duration, absl::Time, const CelList*, const CelMap*, const CelError*) -> bool { return false; }; - - auto func_status = - FunctionAdapter::Create("dummy_func", false, func); - - ASSERT_OK(func_status); - - auto cel_func = std::move(func_status.value()); - + ASSERT_OK_AND_ASSIGN( + auto cel_func, + (FunctionAdapter::Create("dummy_func", false, func))); auto descriptor = cel_func->descriptor(); EXPECT_EQ(descriptor.receiver_style(), false); @@ -165,6 +133,27 @@ TEST(CelFunctionAdapterTest, TestTypeDeductionForCelValueBasicTypes) { ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kError); } +TEST(CelFunctionAdapterTest, TestAdapterStatusOrMessage) { + auto func = + [](google::protobuf::Arena* arena) -> absl::StatusOr { + auto* ret = google::protobuf::Arena::Create(arena); + ret->set_seconds(123); + return ret; + }; + ASSERT_OK_AND_ASSIGN( + auto cel_func, + (FunctionAdapter>::Create( + "const", false, func))); + + absl::Span args; + + CelValue result = CelValue::CreateNull(); + google::protobuf::Arena arena; + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); + ASSERT_TRUE(result.IsTimestamp()); + EXPECT_EQ(result.TimestampOrDie(), absl::FromUnixSeconds(123)); +} + } // namespace } // namespace runtime diff --git a/eval/public/cel_function_provider.cc b/eval/public/cel_function_provider.cc deleted file mode 100644 index 135a5b8f7..000000000 --- a/eval/public/cel_function_provider.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "eval/public/cel_function_provider.h" - -#include "absl/status/statusor.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -namespace { -// Impl for simple provider that looks up functions in an activation function -// registry. -class ActivationFunctionProviderImpl : public CelFunctionProvider { - public: - ActivationFunctionProviderImpl() {} - absl::StatusOr GetFunction( - const CelFunctionDescriptor& descriptor, - const BaseActivation& activation) const override { - std::vector overloads = - activation.FindFunctionOverloads(descriptor.name()); - - const CelFunction* matching_overload = nullptr; - - for (const CelFunction* overload : overloads) { - if (overload->descriptor().ShapeMatches(descriptor)) { - if (matching_overload != nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Couldn't resolve function."); - } - matching_overload = overload; - } - } - - return matching_overload; - } -}; - -} // namespace - -std::unique_ptr CreateActivationFunctionProvider() { - return std::make_unique(); -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/cel_function_provider.h b/eval/public/cel_function_provider.h deleted file mode 100644 index 3f13d2b33..000000000 --- a/eval/public/cel_function_provider.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_PROVIDER_H_ -#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_PROVIDER_H_ - -#include "eval/public/activation.h" -#include "eval/public/cel_function.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// CelFunctionProvider is an interface for providers of lazy CelFunctions (i.e. -// implementation isn't available until evaluation time based on the -// activation). -class CelFunctionProvider { - public: - // Returns a ptr to a |CelFunction| based on the provided |Activation|. Given - // the same activation, this should return the same CelFunction. The - // CelFunction ptr is assumed to be stable for the life of the Activation. - // nullptr is interpreted as no funtion overload matches the descriptor. - virtual absl::StatusOr GetFunction( - const CelFunctionDescriptor& descriptor, - const BaseActivation& activation) const = 0; - virtual ~CelFunctionProvider() {} -}; - -// Create a CelFunctionProvider that just looks up the functions inserted in the -// Activation. This is a convenience implementation for a simple, common -// use-case. -std::unique_ptr CreateActivationFunctionProvider(); - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_PROVIDER_H_ diff --git a/eval/public/cel_function_provider_test.cc b/eval/public/cel_function_provider_test.cc deleted file mode 100644 index 8a6ff3e42..000000000 --- a/eval/public/cel_function_provider_test.cc +++ /dev/null @@ -1,78 +0,0 @@ -#include "eval/public/cel_function_provider.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "base/status_macros.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace { - -using testing::_; -using testing::Eq; -using testing::HasSubstr; -using testing::Ne; - -class ConstCelFunction : public CelFunction { - public: - ConstCelFunction() : CelFunction({"ConstFunction", false, {}}) {} - explicit ConstCelFunction(const CelFunctionDescriptor& desc) - : CelFunction(desc) {} - absl::Status Evaluate(absl::Span args, CelValue* output, - google::protobuf::Arena* arena) const override { - return absl::Status(absl::StatusCode::kUnimplemented, "Not Implemented"); - } -}; - -TEST(CreateActivationFunctionProviderTest, NoOverloadFound) { - Activation activation; - auto provider = CreateActivationFunctionProvider(); - - auto func = provider->GetFunction({"LazyFunc", false, {}}, activation); - - ASSERT_OK(func.status()); - EXPECT_THAT(func.value(), Eq(nullptr)); -} - -TEST(CreateActivationFunctionProviderTest, OverloadFound) { - Activation activation; - CelFunctionDescriptor desc{"LazyFunc", false, {}}; - auto provider = CreateActivationFunctionProvider(); - - auto status = - activation.InsertFunction(std::make_unique(desc)); - EXPECT_OK(status); - - auto func = provider->GetFunction(desc, activation); - - ASSERT_OK(func.status()); - EXPECT_THAT(func.value(), Ne(nullptr)); -} - -TEST(CreateActivationFunctionProviderTest, AmbiguousLookup) { - Activation activation; - CelFunctionDescriptor desc1{"LazyFunc", false, {CelValue::Type::kInt64}}; - CelFunctionDescriptor desc2{"LazyFunc", false, {CelValue::Type::kUint64}}; - CelFunctionDescriptor match_desc{"LazyFunc", false, {CelValue::Type::kAny}}; - - auto provider = CreateActivationFunctionProvider(); - - auto status = - activation.InsertFunction(std::make_unique(desc1)); - EXPECT_OK(status); - status = activation.InsertFunction(std::make_unique(desc2)); - EXPECT_OK(status); - - auto func = provider->GetFunction(match_desc, activation); - - EXPECT_THAT(std::string(func.status().message()), - HasSubstr("Couldn't resolve function")); -} - -} // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/cel_function_registry.cc b/eval/public/cel_function_registry.cc index 34202afe4..d96510ab6 100644 --- a/eval/public/cel_function_registry.cc +++ b/eval/public/cel_function_registry.cc @@ -1,112 +1,122 @@ #include "eval/public/cel_function_registry.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { - -absl::Status CelFunctionRegistry::Register( - std::unique_ptr function) { - const CelFunctionDescriptor& descriptor = function->descriptor(); - - if (DescriptorRegistered(descriptor)) { - return absl::Status( - absl::StatusCode::kAlreadyExists, - "CelFunction with specified parameters already registered"); +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "common/function_descriptor.h" +#include "common/value.h" +#include "eval/internal/interop.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { +namespace { + +// Legacy cel function that proxies to the modern cel::Function interface. +// +// This is used to wrap new-style cel::Functions for clients consuming +// legacy CelFunction-based APIs. The evaluate implementation on this class +// should not be called by the CEL evaluator, but a sensible result is returned +// for unit tests that haven't been migrated to the new APIs yet. +class ProxyToModernCelFunction : public CelFunction { + public: + ProxyToModernCelFunction(const cel::FunctionDescriptor& descriptor, + const cel::Function& implementation) + : CelFunction(descriptor), implementation_(&implementation) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + // This is only safe for use during interop where the MemoryManager is + // assumed to always be backed by a google::protobuf::Arena instance. After all + // dependencies on legacy CelFunction are removed, we can remove this + // implementation. + + std::vector modern_args = + cel::interop_internal::LegacyValueToModernValueOrDie(arena, args); + + CEL_ASSIGN_OR_RETURN( + auto modern_result, + implementation_->Invoke( + modern_args, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), arena)); + + *result = cel::interop_internal::ModernValueToLegacyValueOrDie( + arena, modern_result); + + return absl::OkStatus(); } - auto& overloads = functions_[descriptor.name()]; - overloads.static_overloads.push_back(std::move(function)); - return absl::OkStatus(); -} + private: + // owned by the registry + const cel::Function* implementation_; +}; -absl::Status CelFunctionRegistry::RegisterLazyFunction( - const CelFunctionDescriptor& descriptor, - std::unique_ptr factory) { - if (DescriptorRegistered(descriptor)) { - return absl::Status( - absl::StatusCode::kAlreadyExists, - "CelFunction with specified parameters already registered"); - } - auto& overloads = functions_[descriptor.name()]; - LazyFunctionEntry entry = std::make_unique( - descriptor, std::move(factory)); - overloads.lazy_overloads.push_back(std::move(entry)); +} // namespace +absl::Status CelFunctionRegistry::RegisterAll( + std::initializer_list registrars, + const InterpreterOptions& opts) { + for (Registrar registrar : registrars) { + CEL_RETURN_IF_ERROR(registrar(this, opts)); + } return absl::OkStatus(); } std::vector CelFunctionRegistry::FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const { - std::vector matched_funcs; - - auto overloads = functions_.find(std::string(name)); - if (overloads == functions_.end()) { - return matched_funcs; - } - - for (const auto& func_ptr : overloads->second.static_overloads) { - if (func_ptr->descriptor().ShapeMatches(receiver_style, types)) { - matched_funcs.push_back(func_ptr.get()); + std::vector matched_funcs = + modern_registry_.FindStaticOverloads(name, receiver_style, types); + + // For backwards compatibility, lazily initialize a legacy CEL function + // if required. + // The registry should remain add-only until migration to the new type is + // complete, so this should work whether the function was introduced via + // the modern registry or the old registry wrapping a modern instance. + std::vector results; + results.reserve(matched_funcs.size()); + + { + absl::MutexLock lock(mu_); + for (cel::FunctionOverloadReference entry : matched_funcs) { + std::unique_ptr& legacy_impl = + functions_[&entry.implementation]; + + if (legacy_impl == nullptr) { + legacy_impl = std::make_unique( + entry.descriptor, entry.implementation); + } + results.push_back(legacy_impl.get()); } } - - return matched_funcs; + return results; } -std::vector CelFunctionRegistry::FindLazyOverloads( +std::vector +CelFunctionRegistry::FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const { - std::vector matched_funcs; - - auto overloads = functions_.find(std::string(name)); - if (overloads == functions_.end()) { - return matched_funcs; - } + std::vector lazy_overloads = + modern_registry_.FindLazyOverloads(name, receiver_style, types); + std::vector result; + result.reserve(lazy_overloads.size()); - for (const LazyFunctionEntry& entry : overloads->second.lazy_overloads) { - if (entry->first.ShapeMatches(receiver_style, types)) { - matched_funcs.push_back(entry->second.get()); - } - } - - return matched_funcs; -} - -absl::node_hash_map> -CelFunctionRegistry::ListFunctions() const { - absl::node_hash_map> - descriptor_map; - - for (const auto& entry : functions_) { - std::vector descriptors; - const RegistryEntry& function_entry = entry.second; - descriptors.reserve(function_entry.static_overloads.size() + - function_entry.lazy_overloads.size()); - for (const auto& func : function_entry.static_overloads) { - descriptors.push_back(&func->descriptor()); - } - for (const LazyFunctionEntry& func : function_entry.lazy_overloads) { - descriptors.push_back(&func->first); - } - descriptor_map[entry.first] = std::move(descriptors); + for (const LazyOverload& overload : lazy_overloads) { + result.push_back(&overload.descriptor); } - - return descriptor_map; -} - -bool CelFunctionRegistry::DescriptorRegistered( - const CelFunctionDescriptor& descriptor) const { - return !(FindOverloads(descriptor.name(), descriptor.receiver_style(), - descriptor.types()) - .empty()) || - !(FindLazyOverloads(descriptor.name(), descriptor.receiver_style(), - descriptor.types()) - .empty()); + return result; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_function_registry.h b/eval/public/cel_function_registry.h index de1d64bcc..d2274d83d 100644 --- a/eval/public/cel_function_registry.h +++ b/eval/public/cel_function_registry.h @@ -1,57 +1,89 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_REGISTRY_H_ +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" -#include "absl/types/span.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "common/function_descriptor.h" +#include "common/kind.h" #include "eval/public/cel_function.h" -#include "eval/public/cel_function_provider.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_registry.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // CelFunctionRegistry class allows to register builtin or custom // CelFunction handlers with it and look them up when creating // CelExpression objects from Expr ASTs. class CelFunctionRegistry { public: - CelFunctionRegistry() {} + // Represents a single overload for a lazily provided function. + using LazyOverload = cel::FunctionRegistry::LazyOverload; + + CelFunctionRegistry() = default; + + ~CelFunctionRegistry() = default; - ~CelFunctionRegistry() {} + using Registrar = absl::Status (*)(CelFunctionRegistry*, + const InterpreterOptions&); // Register CelFunction object. Object ownership is // passed to registry. // Function registration should be performed prior to // CelExpression creation. - absl::Status Register(std::unique_ptr function); + absl::Status Register(std::unique_ptr function) { + // We need to copy the descriptor, otherwise there is no guarantee that the + // lvalue reference to the descriptor is valid as function may be destroyed. + auto descriptor = function->descriptor(); + return Register(descriptor, std::move(function)); + } + + absl::Status Register(const cel::FunctionDescriptor& descriptor, + std::unique_ptr implementation) { + return modern_registry_.Register(descriptor, std::move(implementation)); + } - // Register a lazily provided function. CelFunctionProvider is used to get - // a CelFunction ptr at evaluation time. The registry takes ownership of the - // factory. - absl::Status RegisterLazyFunction( - const CelFunctionDescriptor& descriptor, - std::unique_ptr factory); + absl::Status RegisterAll(std::initializer_list registrars, + const InterpreterOptions& opts); // Register a lazily provided function. This overload uses a default provider // that delegates to the activation at evaluation time. absl::Status RegisterLazyFunction(const CelFunctionDescriptor& descriptor) { - return RegisterLazyFunction(descriptor, CreateActivationFunctionProvider()); + return modern_registry_.RegisterLazyFunction(descriptor); } - // Find subset of CelFunction that match overload conditions + // Find a subset of CelFunction that match overload conditions // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. // name - the name of CelFunction; // receiver_style - indicates whether function has receiver style; // types - argument types. If type is not known during compilation, // DYN value should be passed. + // + // Results refer to underlying registry entries by pointer. Results are + // invalid after the registry is deleted. std::vector FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const; + std::vector FindStaticOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types) const { + return modern_registry_.FindStaticOverloads(name, receiver_style, types); + } + // Find subset of CelFunction providers that match overload conditions // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. @@ -59,32 +91,56 @@ class CelFunctionRegistry { // receiver_style - indicates whether function has receiver style; // types - argument types. If type is not known during compilation, // DYN value should be passed. - std::vector FindLazyOverloads( + std::vector FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const; + // Find subset of CelFunction providers that match overload conditions + // As types may not be available during expression compilation, + // further narrowing of this subset will happen at evaluation stage. + // name - the name of CelFunction; + // receiver_style - indicates whether function has receiver style; + // types - argument types. If type is not known during compilation, + // DYN value should be passed. + std::vector ModernFindLazyOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types) const { + return modern_registry_.FindLazyOverloads(name, receiver_style, types); + } + // Retrieve list of registered function descriptors. This includes both // static and lazy functions. - absl::node_hash_map> - ListFunctions() const; + absl::node_hash_map> + ListFunctions() const { + return modern_registry_.ListFunctions(); + } + + // cel internal accessor for returning backing modern registry. + // + // This is intended to allow migrating the CEL evaluator internals while + // maintaining the existing CelRegistry API. + // + // CEL users should not use this. + const cel::FunctionRegistry& InternalGetRegistry() const { + return modern_registry_; + } + + cel::FunctionRegistry& InternalGetRegistry() { return modern_registry_; } private: - // Returns whether the descriptor is registered in either as a lazy funtion or - // in the static functions. - bool DescriptorRegistered(const CelFunctionDescriptor& descriptor) const; - using StaticFunctionEntry = std::unique_ptr; - using LazyFunctionEntry = std::unique_ptr< - std::pair>>; - struct RegistryEntry { - std::vector static_overloads; - std::vector lazy_overloads; - }; - absl::node_hash_map functions_; + cel::FunctionRegistry modern_registry_; + + // Maintain backwards compatibility for callers expecting CelFunction + // interface. + // This is not used internally, but some client tests check that a specific + // CelFunction overload is used. + // Lazily initialized. + mutable absl::Mutex mu_; + mutable absl::flat_hash_map> + functions_ ABSL_GUARDED_BY(mu_); }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_REGISTRY_H_ diff --git a/eval/public/cel_function_registry_test.cc b/eval/public/cel_function_registry_test.cc index 323588213..75963cda7 100644 --- a/eval/public/cel_function_registry_test.cc +++ b/eval/public/cel_function_registry_test.cc @@ -1,34 +1,29 @@ #include "eval/public/cel_function_registry.h" #include +#include +#include -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "common/kind.h" +#include "eval/internal/adapter_activation_impl.h" +#include "eval/public/activation.h" #include "eval/public/cel_function.h" -#include "eval/public/cel_function_provider.h" -#include "base/status_macros.h" +#include "internal/testing.h" +#include "runtime/function_overload_reference.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { namespace { -using testing::Eq; -using testing::Property; -using testing::SizeIs; - -class NullLazyFunctionProvider : public virtual CelFunctionProvider { - public: - NullLazyFunctionProvider() {} - // Just return nullptr indicating no matching function. - absl::StatusOr GetFunction( - const CelFunctionDescriptor& desc, - const BaseActivation& activation) const override { - return nullptr; - } -}; +using ::absl_testing::StatusIs; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Property; +using ::testing::SizeIs; +using ::testing::Truly; class ConstCelFunction : public CelFunction { public: @@ -52,15 +47,11 @@ TEST(CelFunctionRegistryTest, InsertAndRetrieveLazyFunction) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; Activation activation; - auto register_status = registry.RegisterLazyFunction( - lazy_function_desc, std::make_unique()); - EXPECT_OK(register_status); + ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); - const auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); - EXPECT_THAT(providers, testing::SizeIs(1)); - auto func = providers[0]->GetFunction(lazy_function_desc, activation); - ASSERT_OK(func.status()); - EXPECT_THAT(func.value(), Eq(nullptr)); + const auto descriptors = + registry.FindLazyOverloads("LazyFunction", false, {}); + EXPECT_THAT(descriptors, testing::SizeIs(1)); } // Confirm that lazy and static functions share the same descriptor space: @@ -69,22 +60,39 @@ TEST(CelFunctionRegistryTest, InsertAndRetrieveLazyFunction) { TEST(CelFunctionRegistryTest, LazyAndStaticFunctionShareDescriptorSpace) { CelFunctionRegistry registry; CelFunctionDescriptor desc = ConstCelFunction::MakeDescriptor(); - auto register_status = registry.RegisterLazyFunction( - desc, std::make_unique()); - EXPECT_OK(register_status); + ASSERT_OK(registry.RegisterLazyFunction(desc)); - absl::Status status = registry.Register(std::make_unique()); + absl::Status status = registry.Register(ConstCelFunction::MakeDescriptor(), + std::make_unique()); EXPECT_FALSE(status.ok()); } +// Confirm that lazy and static functions share the same descriptor space: +// i.e. you can't insert both a lazy function and a static function for the same +// descriptors. +TEST(CelFunctionRegistryTest, FindStaticOverloadsReturns) { + CelFunctionRegistry registry; + CelFunctionDescriptor desc = ConstCelFunction::MakeDescriptor(); + ASSERT_OK(registry.Register(desc, std::make_unique(desc))); + + std::vector overloads = + registry.FindStaticOverloads(desc.name(), false, {}); + + EXPECT_THAT(overloads, + ElementsAre(Truly( + [](const cel::FunctionOverloadReference& overload) -> bool { + return overload.descriptor.name() == "ConstFunction"; + }))) + << "Expected single ConstFunction()"; +} + TEST(CelFunctionRegistryTest, ListFunctions) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; - auto register_status = registry.RegisterLazyFunction( - lazy_function_desc, std::make_unique()); - EXPECT_OK(register_status); - EXPECT_OK(registry.Register(std::make_unique())); + ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); + EXPECT_OK(registry.Register(ConstCelFunction::MakeDescriptor(), + std::make_unique())); auto registered_functions = registry.ListFunctions(); @@ -93,26 +101,203 @@ TEST(CelFunctionRegistryTest, ListFunctions) { EXPECT_THAT(registered_functions["ConstFunction"], SizeIs(1)); } +TEST(CelFunctionRegistryTest, LegacyFindLazyOverloads) { + CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; + CelFunctionRegistry registry; + + ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); + ASSERT_OK(registry.Register(ConstCelFunction::MakeDescriptor(), + std::make_unique())); + + EXPECT_THAT(registry.FindLazyOverloads("LazyFunction", false, {}), + ElementsAre(Truly([](const CelFunctionDescriptor* descriptor) { + return descriptor->name() == "LazyFunction"; + }))) + << "Expected single lazy overload for LazyFunction()"; +} + TEST(CelFunctionRegistryTest, DefaultLazyProvider) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; Activation activation; + cel::interop_internal::AdapterActivationImpl modern_activation(activation); EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); - auto insert_status = activation.InsertFunction( - std::make_unique(lazy_function_desc)); - EXPECT_OK(insert_status); + EXPECT_OK(activation.InsertFunction( + std::make_unique(lazy_function_desc))); - const auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); + auto providers = registry.ModernFindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(providers, testing::SizeIs(1)); - auto func = providers[0]->GetFunction(lazy_function_desc, activation); + ASSERT_OK_AND_ASSIGN(auto func, providers[0].provider.GetFunction( + lazy_function_desc, modern_activation)); + ASSERT_TRUE(func.has_value()); + EXPECT_THAT(func->descriptor, + Property(&cel::FunctionDescriptor::name, Eq("LazyFunction"))); +} + +TEST(CelFunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { + CelFunctionRegistry registry; + Activation legacy_activation; + cel::interop_internal::AdapterActivationImpl activation(legacy_activation); + CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; + EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); + EXPECT_OK(legacy_activation.InsertFunction( + std::make_unique(lazy_function_desc))); + + const auto providers = + registry.ModernFindLazyOverloads("LazyFunction", false, {}); + ASSERT_THAT(providers, testing::SizeIs(1)); + const auto& provider = providers[0].provider; + auto func = provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, + activation); + ASSERT_OK(func.status()); - EXPECT_THAT(func.value(), Property(&CelFunction::descriptor, - Property(&CelFunctionDescriptor::name, - Eq("LazyFunction")))); + EXPECT_EQ(*func, absl::nullopt); +} + +TEST(CelFunctionRegistryTest, DefaultLazyProviderAmbiguousLookup) { + CelFunctionRegistry registry; + Activation legacy_activation; + cel::interop_internal::AdapterActivationImpl activation(legacy_activation); + CelFunctionDescriptor desc1{"LazyFunc", false, {CelValue::Type::kInt64}}; + CelFunctionDescriptor desc2{"LazyFunc", false, {CelValue::Type::kUint64}}; + CelFunctionDescriptor match_desc{"LazyFunc", false, {CelValue::Type::kAny}}; + ASSERT_OK(registry.RegisterLazyFunction(match_desc)); + ASSERT_OK(legacy_activation.InsertFunction( + std::make_unique(desc1))); + ASSERT_OK(legacy_activation.InsertFunction( + std::make_unique(desc2))); + + auto providers = + registry.ModernFindLazyOverloads("LazyFunc", false, {cel::Kind::kAny}); + ASSERT_THAT(providers, testing::SizeIs(1)); + const auto& provider = providers[0].provider; + auto func = provider.GetFunction(match_desc, activation); + + EXPECT_THAT(std::string(func.status().message()), + HasSubstr("Couldn't resolve function")); +} + +TEST(CelFunctionRegistryTest, CanRegisterNonStrictFunction) { + { + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("NonStrictFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/false); + ASSERT_OK(registry.Register( + descriptor, std::make_unique(descriptor))); + EXPECT_THAT(registry.FindStaticOverloads("NonStrictFunction", false, + {CelValue::Type::kAny}), + SizeIs(1)); + } + { + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("NonStrictLazyFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/false); + EXPECT_OK(registry.RegisterLazyFunction(descriptor)); + EXPECT_THAT(registry.FindLazyOverloads("NonStrictLazyFunction", false, + {CelValue::Type::kAny}), + SizeIs(1)); + } +} + +using NonStrictTestCase = std::tuple; +using NonStrictRegistrationFailTest = testing::TestWithParam; + +TEST_P(NonStrictRegistrationFailTest, + IfOtherOverloadExistsRegisteringNonStrictFails) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/true); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK(registry.Register( + descriptor, std::make_unique(descriptor))); + } + CelFunctionDescriptor new_descriptor( + "OverloadedFunction", + /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, + /*is_strict=*/false); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = registry.Register( + new_descriptor, std::make_unique(new_descriptor)); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("Only one overload"))); +} + +TEST_P(NonStrictRegistrationFailTest, + IfOtherNonStrictExistsRegisteringStrictFails) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/false); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK(registry.Register( + descriptor, std::make_unique(descriptor))); + } + CelFunctionDescriptor new_descriptor( + "OverloadedFunction", + /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, + /*is_strict=*/true); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = registry.Register( + new_descriptor, std::make_unique(new_descriptor)); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("Only one overload"))); } +TEST_P(NonStrictRegistrationFailTest, CanRegisterStrictFunctionsWithoutLimit) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/true); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK(registry.Register( + descriptor, std::make_unique(descriptor))); + } + CelFunctionDescriptor new_descriptor( + "OverloadedFunction", + /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, + /*is_strict=*/true); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = registry.Register( + new_descriptor, std::make_unique(new_descriptor)); + } + EXPECT_OK(status); +} + +INSTANTIATE_TEST_SUITE_P(NonStrictRegistrationFailTest, + NonStrictRegistrationFailTest, + testing::Combine(testing::Bool(), testing::Bool())); + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google + +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_number.cc b/eval/public/cel_number.cc new file mode 100644 index 000000000..e08afb6a3 --- /dev/null +++ b/eval/public/cel_number.cc @@ -0,0 +1,31 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_number.h" + +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime { + +absl::optional GetNumberFromCelValue(const CelValue& value) { + if (int64_t val; value.GetValue(&val)) { + return CelNumber(val); + } else if (uint64_t val; value.GetValue(&val)) { + return CelNumber(val); + } else if (double val; value.GetValue(&val)) { + return CelNumber(val); + } + return absl::nullopt; +} +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_number.h b/eval/public/cel_number.h new file mode 100644 index 000000000..1f66ce4f2 --- /dev/null +++ b/eval/public/cel_number.h @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "internal/number.h" + +namespace google::api::expr::runtime { + +using CelNumber = cel::internal::Number; + +// Return a CelNumber if the value holds a numeric type, otherwise return +// nullopt. +absl::optional GetNumberFromCelValue(const CelValue& value); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ diff --git a/eval/public/cel_number_test.cc b/eval/public/cel_number_test.cc new file mode 100644 index 000000000..3c6f36e9b --- /dev/null +++ b/eval/public/cel_number_test.cc @@ -0,0 +1,45 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_number.h" + +#include +#include + +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +using ::testing::Optional; + + +TEST(CelNumber, GetNumberFromCelValue) { + EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateDouble(1.1)), + Optional(CelNumber::FromDouble(1.1))); + EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateInt64(1)), + Optional(CelNumber::FromDouble(1.0))); + EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateUint64(1)), + Optional(CelNumber::FromDouble(1.0))); + + EXPECT_EQ(GetNumberFromCelValue(CelValue::CreateDuration(absl::Seconds(1))), + absl::nullopt); +} + + + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_options.cc b/eval/public/cel_options.cc new file mode 100644 index 000000000..938b5e96f --- /dev/null +++ b/eval/public/cel_options.cc @@ -0,0 +1,47 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_options.h" + +#include "runtime/runtime_options.h" + +namespace google::api::expr::runtime { + +cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) { + return cel::RuntimeOptions{/*.container=*/"", + options.unknown_processing, + options.enable_missing_attribute_errors, + options.enable_timestamp_duration_overflow_errors, + options.short_circuiting, + options.enable_comprehension, + options.comprehension_max_iterations, + options.enable_comprehension_list_append, + options.enable_comprehension_mutable_map, + options.enable_regex, + options.regex_max_program_size, + options.enable_string_conversion, + options.enable_string_concat, + options.enable_list_concat, + options.enable_list_contains, + options.fail_on_warnings, + options.enable_qualified_type_identifiers, + options.enable_heterogeneous_equality, + options.enable_empty_wrapper_null_unboxing, + options.enable_lazy_bind_initialization, + options.max_recursion_depth, + options.enable_recursive_tracing, + options.enable_fast_builtins}; +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index de10230cb..4d81eb8a7 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -1,24 +1,33 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ +#include "absl/base/attributes.h" +#include "runtime/runtime_options.h" #include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// Options for unknown processing. -enum class UnknownProcessingOptions { - // No unknown processing. - kDisabled, - // Only attributes supported. - kAttributeOnly, - // Attributes and functions supported. Function results are dependent on the - // logic for handling unknown_attributes, so clients must opt in to both. - kAttributeAndFunction -}; +namespace google::api::expr::runtime { + +using UnknownProcessingOptions = cel::UnknownProcessingOptions; +using ProtoWrapperTypeOptions = cel::ProtoWrapperTypeOptions; + +// LINT.IfChange // Interpreter options for controlling evaluation and builtin functions. struct InterpreterOptions { // Level of unknown support enabled. @@ -27,19 +36,28 @@ struct InterpreterOptions { bool enable_missing_attribute_errors = false; + // Enable timestamp duration overflow checks. + // + // The CEL-Spec indicates that overflow should occur outside the range of + // string-representable timestamps, and at the limit of durations which can be + // expressed with a single int64 value. + bool enable_timestamp_duration_overflow_errors = false; + // Enable short-circuiting of the logical operator evaluation. If enabled, // AND, OR, and TERNARY do not evaluate the entire expression once the the // resulting value is known from the left-hand side. bool short_circuiting = true; - // DEPRECATED. This option has no effect. - bool partial_string_match = true; - - // Enable constant folding during the expression creation. If enabled, - // an arena must be provided for constant generation. - // Note that expression tracing applies a modified expression if this option - // is enabled. + // Enable constant folding during the expression creation. + // + // Note that expression tracing will apply to a modified expression if this + // option is enabled. bool constant_folding = false; + + // Optionally specified arena for constant folding. If not specified, the + // builder will create one as needed per expression built. Any arena created + // by the builder will be destroyed when the corresponding expression is + // destroyed. google::protobuf::Arena* constant_arena = nullptr; // Enable comprehension expressions (e.g. exists, all) @@ -48,7 +66,15 @@ struct InterpreterOptions { // Set maximum number of iterations in the comprehension expressions if // comprehensions are enabled. The limit applies globally per an evaluation, // including the nested loops as well. Use value 0 to disable the upper bound. - int comprehension_max_iterations = 0; + int comprehension_max_iterations = 10000; + + // Enable list append within comprehensions. Note, this option is not safe + // with hand-rolled ASTs. + bool enable_comprehension_list_append = false; + + // Enable mutable map construction within comprehensions. Note, this option is + // not safe with hand-rolled ASTs. + bool enable_comprehension_mutable_map = false; // Enable RE2 match() overload. bool enable_regex = true; @@ -72,11 +98,131 @@ struct InterpreterOptions { // Treat builder warnings as fatal errors. bool fail_on_warnings = true; + + // Enable the resolution of qualified type identifiers as type values instead + // of field selections. + // + // This toggle may cause certain identifiers which overlap with CEL built-in + // type or with protobuf message types linked into the binary to be resolved + // as static type values rather than as per-eval variables. + bool enable_qualified_type_identifiers = false; + + // Enable a check for memory vulnerabilities within comprehension + // sub-expressions. + // + // Note: This flag is not necessary if you are only using Core CEL macros. + // + // Consider enabling this feature when using custom comprehensions, and + // absolutely enable the feature when using hand-written ASTs for + // comprehension expressions. + bool enable_comprehension_vulnerability_check = false; + + // Enable heterogeneous comparisons (e.g. support for cross-type comparisons). + ABSL_DEPRECATED( + "The ability to disable heterogeneous equality is being removed in the " + "near future") + bool enable_heterogeneous_equality = true; + + // Enables unwrapping proto wrapper types to null if unset. e.g. if an + // expression access a field of type google.protobuf.Int64Value that is unset, + // that will result in a Null cel value, as opposed to returning the + // cel representation of the proto defined default int64: 0. + bool enable_empty_wrapper_null_unboxing = false; + + // Enables expression rewrites to disambiguate namespace qualified identifiers + // from container access for variables and receiver-style calls for functions. + // + // Note: This makes an implicit copy of the input expression for lifetime + // safety. + bool enable_qualified_identifier_rewrites = false; + + // Historically regular expressions were compiled on each invocation to + // `matches` and not re-used, even if the regular expression is a constant. + // Enabling this option causes constant regular expressions to be compiled + // ahead-of-time and re-used for each invocation to `matches`. A side effect + // of this is that invalid regular expressions will result in errors when + // building an expression. + // + // It is recommended that this option be enabled in conjunction with + // enable_constant_folding. + // + // Note: In most cases enabling this option is safe, however to perform this + // optimization overloads are not consulted for applicable calls. If you have + // overridden the default `matches` function you should not enable this + // option. + bool enable_regex_precompilation = false; + + // Enable select optimization, replacing long select chains with a single + // operation. + // + // This assumes that the type information at check time agrees with the + // configured types at runtime. + // + // Important: The select optimization follows spec behavior for traversals. + // - `enable_empty_wrapper_null_unboxing` is ignored and optimized traversals + // always operates as though it is `true`. + // - `enable_heterogeneous_equality` is ignored and optimized traversals + // always operate as though it is `true`. + bool enable_select_optimization = false; + + // Enable lazy cel.bind alias initialization. + // + // This is now always enabled. Setting this option has no effect. It will be + // removed in a later update. + bool enable_lazy_bind_initialization = true; + + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. + // + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the + // expression. + // + // This does not account for re-entrant evaluation in a client's extension + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. + // + // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. + int max_recursion_depth = 0; + + // Enable tracing support for recursively planned programs. + // + // Unlike the stack machine implementation, supporting tracing can affect + // performance whether or not tracing is requested for a given evaluation. + bool enable_recursive_tracing = false; + + // Enable fast implementations for some CEL standard functions. + // + // Uses a custom implementation for some functions in the CEL standard, + // bypassing normal dispatching logic and safety checks for functions. + // + // This prevents extending or disabling these functions in most cases. The + // expression planner will make a best effort attempt to check if custom + // overloads have been added for these functions, and will attempt to use them + // if they exist. + // + // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in + bool enable_fast_builtins = true; + + // When enabled, string(double) will format the double with enough precision + // to ensure that the original double value can be recovered exactly. + // + // If available, will use the `std::to_chars` standard library function to + // perform the conversion to generate the shortest representation. + // + // Otherwise, will fall back to formatting with the worst-case required + // precision. + bool enable_precision_preserving_double_format = true; }; +// LINT.ThenChange(//depot/google3/runtime/runtime_options.h) + +cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc new file mode 100644 index 000000000..639a348dd --- /dev/null +++ b/eval/public/cel_type_registry.cc @@ -0,0 +1,64 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_type_registry.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "google/protobuf/descriptor.h" + +namespace google::api::expr::runtime { + +namespace { + +void AddEnumFromDescriptor(const google::protobuf::EnumDescriptor* desc, + CelTypeRegistry& registry) { + std::vector enumerators; + enumerators.reserve(desc->value_count()); + for (int i = 0; i < desc->value_count(); i++) { + enumerators.push_back( + {std::string(desc->value(i)->name()), desc->value(i)->number()}); + } + registry.RegisterEnum(desc->full_name(), std::move(enumerators)); +} + +} // namespace + +void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { + AddEnumFromDescriptor(enum_descriptor, *this); +} + +void CelTypeRegistry::RegisterEnum(absl::string_view enum_name, + std::vector enumerators) { + modern_type_registry_.RegisterEnum(enum_name, std::move(enumerators)); +} + +// Find a type's CelValue instance by its fully qualified name. +absl::optional CelTypeRegistry::FindTypeAdapter( + absl::string_view fully_qualified_type_name) const { + auto maybe_adapter = + GetFirstTypeProvider()->ProvideLegacyType(fully_qualified_type_name); + if (maybe_adapter.has_value()) { + return maybe_adapter; + } + return absl::nullopt; +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h new file mode 100644 index 000000000..3fb80bcea --- /dev/null +++ b/eval/public/cel_type_registry.h @@ -0,0 +1,145 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/type_provider.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_provider.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +// CelTypeRegistry manages the set of registered types available for use within +// object literal construction, enum comparisons, and type testing. +// +// The CelTypeRegistry is intended to live for the duration of all CelExpression +// values created by a given CelExpressionBuilder and one is created by default +// within the standard CelExpressionBuilder. +// +// By default, all core CEL types and all linked protobuf message types are +// implicitly registered by way of the generated descriptor pool. A descriptor +// pool can be given to avoid accidentally exposing linked protobuf types to CEL +// which were intended to remain internal or to operate on hermetic descriptor +// pools. +class CelTypeRegistry { + public: + // Representation of an enum constant. + using Enumerator = cel::TypeRegistry::Enumerator; + + // Representation of an enum. + using Enumeration = cel::TypeRegistry::Enumeration; + + CelTypeRegistry() + : CelTypeRegistry(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()) {} + + CelTypeRegistry(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nullable message_factory) + : modern_type_registry_(descriptor_pool, message_factory) {} + + ~CelTypeRegistry() = default; + + // Register an enum whose values may be used within CEL expressions. + // + // Enum registration must be performed prior to CelExpression creation. + void Register(const google::protobuf::EnumDescriptor* enum_descriptor); + + // Register an enum whose values may be used within CEL expressions. + // + // Enum registration must be performed prior to CelExpression creation. + void RegisterEnum(absl::string_view name, + std::vector enumerators); + + // Get the first registered type provider. + std::shared_ptr GetFirstTypeProvider() const { + return cel::runtime_internal::GetLegacyRuntimeTypeProvider( + modern_type_registry_); + } + + // Returns the effective type provider that has been configured with the + // registry. + // + // This is a composited type provider that should check in order: + // - builtins + // - custom enumerations + // - registered extension type providers in the order registered. + const cel::TypeProvider& GetTypeProvider() const { + return modern_type_registry_.GetComposedTypeProvider(); + } + + // Find a type adapter given a fully qualified type name. + // Adapter provides a generic interface for the reflection operations the + // interpreter needs to provide. + absl::optional FindTypeAdapter( + absl::string_view fully_qualified_type_name) const; + + // Return the registered enums configured within the type registry in the + // internal format that can be identified as int constants at plan time. + const absl::flat_hash_map& resolveable_enums() + const { + return modern_type_registry_.resolveable_enums(); + } + + // Return the registered enums configured within the type registry. + // + // This is provided for validating registry setup, it should not be used + // internally. + // + // Invalidated whenever registered enums are updated. + absl::flat_hash_set ListResolveableEnums() const { + const auto& enums = resolveable_enums(); + absl::flat_hash_set result; + result.reserve(enums.size()); + + for (const auto& entry : enums) { + result.insert(entry.first); + } + + return result; + } + + // Accessor for underlying modern registry. + // + // This is exposed for migrating runtime internals, CEL users should not call + // this. + cel::TypeRegistry& InternalGetModernRegistry() { + return modern_type_registry_; + } + + const cel::TypeRegistry& InternalGetModernRegistry() const { + return modern_type_registry_; + } + + private: + // Internal modern registry. + cel::TypeRegistry modern_type_registry_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ diff --git a/eval/public/cel_type_registry_protobuf_reflection_test.cc b/eval/public/cel_type_registry_protobuf_reflection_test.cc new file mode 100644 index 000000000..85d05f95a --- /dev/null +++ b/eval/public/cel_type_registry_protobuf_reflection_test.cc @@ -0,0 +1,109 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/protobuf/struct.pb.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "common/type.h" +#include "eval/public/cel_type_registry.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::MemoryManagerRef; +using ::cel::StructType; +using ::cel::Type; +using ::google::protobuf::Struct; +using ::testing::AllOf; +using ::testing::Contains; +using ::testing::Eq; +using ::testing::Optional; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +MATCHER_P(TypeNameIs, name, "") { + const Type& type = arg; + *result_listener << "got typename: " << type.name(); + return type.name() == name; +} + +MATCHER_P(MatchesEnumDescriptor, desc, "") { + const auto& enum_type = arg; + + if (enum_type.enumerators.size() != desc->value_count()) { + return false; + } + + for (int i = 0; i < desc->value_count(); i++) { + const auto& constant = enum_type.enumerators[i]; + + const auto* value_desc = desc->value(i); + + if (value_desc->name() != constant.name) { + return false; + } + if (value_desc->number() != constant.number) { + return false; + } + } + return true; +} + +TEST(CelTypeRegistryTest, RegisterEnumDescriptor) { + CelTypeRegistry registry; + registry.Register(google::protobuf::GetEnumDescriptor()); + + EXPECT_THAT( + registry.ListResolveableEnums(), + UnorderedElementsAre("google.protobuf.NullValue", + "google.api.expr.runtime.TestMessage.TestEnum")); + + EXPECT_THAT( + registry.resolveable_enums(), + AllOf(Contains(Pair( + "google.protobuf.NullValue", + MatchesEnumDescriptor( + google::protobuf::GetEnumDescriptor()))), + Contains(Pair( + "google.api.expr.runtime.TestMessage.TestEnum", + MatchesEnumDescriptor( + google::protobuf::GetEnumDescriptor()))))); +} + +TEST(CelTypeRegistryTypeProviderTest, StructTypes) { + CelTypeRegistry registry; + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + + ASSERT_OK_AND_ASSIGN(absl::optional struct_message_type, + registry.GetTypeProvider().FindType( + "google.api.expr.runtime.TestMessage")); + ASSERT_TRUE(struct_message_type.has_value()); + ASSERT_TRUE((*struct_message_type).Is()) + << (*struct_message_type).DebugString(); + EXPECT_THAT(struct_message_type->As()->name(), + Eq("google.api.expr.runtime.TestMessage")); + + // Can't override builtins. + ASSERT_OK_AND_ASSIGN( + absl::optional struct_type, + registry.GetTypeProvider().FindType("google.protobuf.Struct")); + EXPECT_THAT(struct_type, Optional(TypeNameIs("map"))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc new file mode 100644 index 000000000..9f3fde9be --- /dev/null +++ b/eval/public/cel_type_registry_test.cc @@ -0,0 +1,137 @@ +#include "eval/public/cel_type_registry.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/type_provider.h" +#include "common/memory.h" +#include "common/type.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_provider.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::cel::MemoryManagerRef; +using ::cel::Type; +using ::cel::TypeProvider; +using ::testing::Contains; +using ::testing::Key; +using ::testing::Optional; + +class TestTypeProvider : public LegacyTypeProvider { + public: + explicit TestTypeProvider(std::vector types) + : types_(std::move(types)) {} + + // Return a type adapter for an opaque type + // (no reflection operations supported). + absl::optional ProvideLegacyType( + absl::string_view name) const override { + for (const auto& type : types_) { + if (name == type) { + return LegacyTypeAdapter(/*access=*/nullptr, /*mutation=*/nullptr); + } + } + return absl::nullopt; + } + + private: + std::vector types_; +}; + +TEST(CelTypeRegistryTest, RegisterEnum) { + CelTypeRegistry registry; + registry.RegisterEnum("google.api.expr.runtime.TestMessage.TestEnum", + { + {"TEST_ENUM_UNSPECIFIED", 0}, + {"TEST_ENUM_1", 10}, + {"TEST_ENUM_2", 20}, + {"TEST_ENUM_3", 30}, + }); + + EXPECT_THAT(registry.resolveable_enums(), + Contains(Key("google.api.expr.runtime.TestMessage.TestEnum"))); +} + +TEST(CelTypeRegistryTest, TestRegisterBuiltInEnum) { + CelTypeRegistry registry; + + ASSERT_THAT(registry.resolveable_enums(), + Contains(Key("google.protobuf.NullValue"))); +} + +TEST(CelTypeRegistryTest, TestGetFirstTypeProviderSuccess) { + CelTypeRegistry registry; + auto type_provider = registry.GetFirstTypeProvider(); + ASSERT_NE(type_provider, nullptr); + ASSERT_FALSE( + type_provider->ProvideLegacyType("google.protobuf.Int64").has_value()); + ASSERT_TRUE( + type_provider->ProvideLegacyType("google.protobuf.Any").has_value()); +} + +TEST(CelTypeRegistryTest, TestFindTypeAdapterFound) { + CelTypeRegistry registry; + auto desc = registry.FindTypeAdapter("google.protobuf.Any"); + ASSERT_TRUE(desc.has_value()); +} + +TEST(CelTypeRegistryTest, TestFindTypeAdapterFoundMultipleProviders) { + CelTypeRegistry registry; + auto desc = registry.FindTypeAdapter("google.protobuf.Any"); + ASSERT_TRUE(desc.has_value()); +} + +TEST(CelTypeRegistryTest, TestFindTypeAdapterNotFound) { + CelTypeRegistry registry; + auto desc = registry.FindTypeAdapter("missing.MessageType"); + EXPECT_FALSE(desc.has_value()); +} + +MATCHER_P(TypeNameIs, name, "") { + const Type& type = arg; + *result_listener << "got typename: " << type.name(); + return type.name() == name; +} + +TEST(CelTypeRegistryTypeProviderTest, Builtins) { + CelTypeRegistry registry; + + // simple + ASSERT_OK_AND_ASSIGN(absl::optional bool_type, + registry.GetTypeProvider().FindType("bool")); + EXPECT_THAT(bool_type, Optional(TypeNameIs("bool"))); + // opaque + ASSERT_OK_AND_ASSIGN( + absl::optional timestamp_type, + registry.GetTypeProvider().FindType("google.protobuf.Timestamp")); + EXPECT_THAT(timestamp_type, + Optional(TypeNameIs("google.protobuf.Timestamp"))); + // wrapper + ASSERT_OK_AND_ASSIGN( + absl::optional int_wrapper_type, + registry.GetTypeProvider().FindType("google.protobuf.Int64Value")); + EXPECT_THAT(int_wrapper_type, + Optional(TypeNameIs("google.protobuf.Int64Value"))); + // json + ASSERT_OK_AND_ASSIGN( + absl::optional json_struct_type, + registry.GetTypeProvider().FindType("google.protobuf.Struct")); + EXPECT_THAT(json_struct_type, Optional(TypeNameIs("map"))); + // special + ASSERT_OK_AND_ASSIGN( + absl::optional any_type, + registry.GetTypeProvider().FindType("google.protobuf.Any")); + EXPECT_THAT(any_type, Optional(TypeNameIs("google.protobuf.Any"))); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 414ace71e..25da7fe75 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -1,31 +1,31 @@ #include "eval/public/cel_value.h" +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "eval/internal/errors.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "extensions/protobuf/memory_manager.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::protobuf::Arena; - -constexpr char kErrNoMatchingOverload[] = "No matching overloads found"; -constexpr char kErrNoSuchKey[] = "Key not found in map"; -constexpr absl::string_view kErrUnknownValue = "Unknown value "; -// Error name for MissingAttributeError indicating that evaluation has -// accessed an attribute whose value is undefined. go/terminal-unknown -constexpr absl::string_view kErrMissingAttribute = "MissingAttributeError: "; -constexpr absl::string_view kPayloadUrlUnknownPath = "unknown_path"; -constexpr absl::string_view kPayloadUrlMissingAttributePath = - "missing_attribute_path"; -constexpr absl::string_view kPayloadUrlUnknownFunctionResult = - "cel_is_unknown_function_result"; +using ::google::protobuf::Arena; +namespace interop = ::cel::interop_internal; constexpr absl::string_view kNullTypeName = "null_type"; constexpr absl::string_view kBoolTypeName = "bool"; @@ -41,10 +41,96 @@ constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; +struct DebugStringVisitor { + google::protobuf::Arena* const arena; + + std::string operator()(bool arg) { return absl::StrFormat("%d", arg); } + std::string operator()(int64_t arg) { return absl::StrFormat("%lld", arg); } + std::string operator()(uint64_t arg) { return absl::StrFormat("%llu", arg); } + std::string operator()(double arg) { return absl::StrFormat("%f", arg); } + std::string operator()(CelValue::NullType) { return "null"; } + + std::string operator()(CelValue::StringHolder arg) { + return absl::StrFormat("%s", arg.value()); + } + + std::string operator()(CelValue::BytesHolder arg) { + return absl::StrFormat("%s", arg.value()); + } + + std::string operator()(const MessageWrapper& arg) { + return arg.message_ptr() == nullptr + ? "NULL" + : arg.legacy_type_info()->DebugString(arg); + } + + std::string operator()(absl::Duration arg) { + return absl::FormatDuration(arg); + } + + std::string operator()(absl::Time arg) { + return absl::FormatTime(arg, absl::UTCTimeZone()); + } + + std::string operator()(const CelList* arg) { + std::vector elements; + elements.reserve(arg->size()); + for (int i = 0; i < arg->size(); i++) { + elements.push_back(arg->Get(arena, i).DebugString()); + } + return absl::StrCat("[", absl::StrJoin(elements, ", "), "]"); + } + + std::string operator()(const CelMap* arg) { + auto keys_or_error = arg->ListKeys(arena); + if (!keys_or_error.status().ok()) { + return "invalid list keys"; + } + const CelList* keys = std::move(keys_or_error.value()); + std::vector elements; + elements.reserve(keys->size()); + for (int i = 0; i < keys->size(); i++) { + const auto& key = (*keys).Get(arena, i); + const auto& optional_value = arg->Get(arena, key); + elements.push_back(absl::StrCat("<", key.DebugString(), ">: <", + optional_value.has_value() + ? optional_value->DebugString() + : "nullopt", + ">")); + } + return absl::StrCat("{", absl::StrJoin(elements, ", "), "}"); + } + + std::string operator()(const UnknownSet* arg) { + return "?"; // Not implemented. + } + + std::string operator()(CelValue::CelTypeHolder arg) { + return absl::StrCat(arg.value()); + } + + std::string operator()(const CelError* arg) { return arg->ToString(); } +}; + } // namespace +ABSL_CONST_INIT const absl::string_view kPayloadUrlMissingAttributePath = + cel::runtime_internal::kPayloadUrlMissingAttributePath; + +CelValue CelValue::CreateDuration(absl::Duration value) { + if (value >= cel::runtime_internal::kDurationHigh || + value <= cel::runtime_internal::kDurationLow) { + return CelValue(cel::runtime_internal::DurationOverflowError()); + } + return CreateUncheckedDuration(value); +} + +// TODO(issues/136): These don't match the CEL runtime typenames. They should +// be updated where possible for consistency. std::string CelValue::TypeName(Type value_type) { switch (value_type) { + case Type::kNullType: + return "null_type"; case Type::kBool: return "bool"; case Type::kInt64: @@ -75,11 +161,28 @@ std::string CelValue::TypeName(Type value_type) { return "CelError"; case Type::kAny: return "Any type"; + default: + return "unknown"; + } +} + +absl::Status CelValue::CheckMapKeyType(const CelValue& key) { + switch (key.type()) { + case CelValue::Type::kString: + case CelValue::Type::kInt64: + case CelValue::Type::kUint64: + case CelValue::Type::kBool: + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Invalid map key type: '", CelValue::TypeName(key.type()), "'")); } } CelValue CelValue::ObtainCelType() const { switch (type()) { + case Type::kNullType: + return CreateCelType(CelTypeHolder(kNullTypeName)); case Type::kBool: return CreateCelType(CelTypeHolder(kBoolTypeName)); case Type::kInt64: @@ -93,13 +196,15 @@ CelValue CelValue::ObtainCelType() const { case Type::kBytes: return CreateCelType(CelTypeHolder(kBytesTypeName)); case Type::kMessage: { - auto msg = MessageOrDie(); - if (msg == nullptr) { + MessageWrapper wrapper; + CelValue::GetValue(&wrapper); + if (wrapper.message_ptr() == nullptr) { return CreateCelType(CelTypeHolder(kNullTypeName)); } // Descritptor::full_name() returns const reference, so using pointer // should be safe. - return CreateCelType(CelTypeHolder(msg->GetDescriptor()->full_name())); + return CreateCelType( + CelTypeHolder(wrapper.legacy_type_info()->GetTypename(wrapper))); } case Type::kDuration: return CreateCelType(CelTypeHolder(kDurationTypeName)); @@ -115,7 +220,7 @@ CelValue CelValue::ObtainCelType() const { return *this; case Type::kError: return *this; - case Type::kAny: { + default: { static const CelError* invalid_type_error = new CelError(absl::InvalidArgumentError("Unsupported CelValue type")); return CreateError(invalid_type_error); @@ -125,158 +230,183 @@ CelValue CelValue::ObtainCelType() const { // Returns debug string describing a value const std::string CelValue::DebugString() const { - switch (type()) { - case Type::kBool: - return absl::StrFormat("bool: %d", BoolOrDie()); - case Type::kInt64: - return absl::StrFormat("int64: %lld", Int64OrDie()); - case Type::kUint64: - return absl::StrFormat("uint64: %llu", Uint64OrDie()); - case Type::kDouble: - return absl::StrFormat("double: %f", DoubleOrDie()); - case Type::kString: - return absl::StrFormat("string: %s", StringOrDie().value()); - case Type::kBytes: - return absl::StrFormat("bytes: %s", BytesOrDie().value()); - case Type::kMessage: - return absl::StrFormat( - "Message: %s", - IsNull() ? "NULL" : MessageOrDie()->ShortDebugString()); - case Type::kDuration: - return absl::StrFormat("Duration: %s", - absl::FormatDuration(DurationOrDie())); - case Type::kTimestamp: - return absl::StrFormat( - "Time: %s", absl::FormatTime(TimestampOrDie(), absl::UTCTimeZone())); - case Type::kList: - return absl::StrFormat("List, size: %lld", ListOrDie()->size()); - case Type::kMap: - return absl::StrFormat("Map, size: %lld", MapOrDie()->size()); - case Type::kUnknownSet: - return "UnknownSet"; - case Type::kCelType: - return absl::StrFormat("CelType, %s", CelTypeOrDie().value()); - break; - case Type::kError: - return absl::StrFormat("Error: %s", ErrorOrDie()->ToString()); - case Type::kAny: - return "Any"; - default: - return "unknown_type"; + google::protobuf::Arena arena; + return absl::StrCat(CelValue::TypeName(type()), ": ", + InternalVisit(DebugStringVisitor{&arena})); +} + +namespace { + +class EmptyCelList final : public CelList { + public: + static const EmptyCelList* Get() { + static const absl::NoDestructor instance; + return &*instance; + } + + CelValue operator[](int index) const override { + static const CelError* invalid_argument = + new CelError(absl::InvalidArgumentError("index out of bounds")); + return CelValue::CreateError(invalid_argument); + } + + int size() const override { return 0; } + + bool empty() const override { return true; } +}; + +class EmptyCelMap final : public CelMap { + public: + static const EmptyCelMap* Get() { + static const absl::NoDestructor instance; + return &*instance; + } + + absl::optional operator[](CelValue key) const override { + return absl::nullopt; } + + absl::StatusOr Has(const CelValue& key) const override { + CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); + return false; + } + + int size() const override { return 0; } + + bool empty() const override { return true; } + + absl::StatusOr ListKeys() const override { + return EmptyCelList::Get(); + } +}; + +} // namespace + +CelValue CelValue::CreateList() { return CreateList(EmptyCelList::Get()); } + +CelValue CelValue::CreateMap() { return CreateMap(EmptyCelMap::Get()); } + +CelValue CreateErrorValue(cel::MemoryManagerRef manager, + absl::string_view message, + absl::StatusCode error_code) { + // TODO(uncreated-issue/1): assume arena-style allocator while migrating to new + // value type. + Arena* arena = cel::extensions::ProtoMemoryManagerArena(manager); + return CreateErrorValue(arena, message, error_code); +} + +CelValue CreateErrorValue(cel::MemoryManagerRef manager, + const absl::Status& status) { + // TODO(uncreated-issue/1): assume arena-style allocator while migrating to new + // value type. + Arena* arena = cel::extensions::ProtoMemoryManagerArena(manager); + return CreateErrorValue(arena, status); } CelValue CreateErrorValue(Arena* arena, absl::string_view message, - absl::StatusCode error_code, int) { + absl::StatusCode error_code) { CelError* error = Arena::Create(arena, error_code, message); return CelValue::CreateError(error); } -CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena) { - return CreateErrorValue(arena, kErrNoMatchingOverload, - absl::StatusCode::kUnknown); +CelValue CreateErrorValue(Arena* arena, const absl::Status& status) { + CelError* error = Arena::Create(arena, status); + return CelValue::CreateError(error); +} + +CelValue CreateNoMatchingOverloadError(cel::MemoryManagerRef manager, + absl::string_view fn) { + return CelValue::CreateError(interop::CreateNoMatchingOverloadError( + cel::extensions::ProtoMemoryManagerArena(manager), fn)); } CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena, absl::string_view fn) { - return CreateErrorValue(arena, absl::StrCat(kErrNoMatchingOverload, " ", fn), - absl::StatusCode::kUnknown); + return CelValue::CreateError( + interop::CreateNoMatchingOverloadError(arena, fn)); } bool CheckNoMatchingOverloadError(CelValue value) { return value.IsError() && value.ErrorOrDie()->code() == absl::StatusCode::kUnknown && absl::StrContains(value.ErrorOrDie()->message(), - kErrNoMatchingOverload); + cel::runtime_internal::kErrNoMatchingOverload); } -CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena) { - return CreateErrorValue(arena, "no_such_field", absl::StatusCode::kNotFound); +CelValue CreateNoSuchFieldError(cel::MemoryManagerRef manager, + absl::string_view field) { + return CelValue::CreateError(interop::CreateNoSuchFieldError( + cel::extensions::ProtoMemoryManagerArena(manager), field)); } -CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key) { - return CreateErrorValue(arena, absl::StrCat(kErrNoSuchKey, " : ", key), - absl::StatusCode::kNotFound); +CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena, absl::string_view field) { + return CelValue::CreateError(interop::CreateNoSuchFieldError(arena, field)); } -bool CheckNoSuchKeyError(CelValue value) { - return value.IsError() && - absl::StartsWith(value.ErrorOrDie()->message(), kErrNoSuchKey); +CelValue CreateNoSuchKeyError(cel::MemoryManagerRef manager, + absl::string_view key) { + return CelValue::CreateError(interop::CreateNoSuchKeyError( + cel::extensions::ProtoMemoryManagerArena(manager), key)); } -CelValue CreateUnknownValueError(google::protobuf::Arena* arena, - absl::string_view unknown_path) { - CelError* error = - Arena::Create(arena, absl::StatusCode::kUnavailable, - absl::StrCat(kErrUnknownValue, unknown_path)); - error->SetPayload(kPayloadUrlUnknownPath, absl::Cord(unknown_path)); - return CelValue::CreateError(error); +CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key) { + return CelValue::CreateError(interop::CreateNoSuchKeyError(arena, key)); } -bool IsUnknownValueError(const CelValue& value) { - // TODO(issues/41): replace with the implementation of go/cel-known-unknowns - if (!value.IsError()) return false; - const CelError* error = value.ErrorOrDie(); - if (error && error->code() == absl::StatusCode::kUnavailable) { - auto path = error->GetPayload(kPayloadUrlUnknownPath); - return path.has_value(); - } - return false; +bool CheckNoSuchKeyError(CelValue value) { + return value.IsError() && + absl::StartsWith(value.ErrorOrDie()->message(), + cel::runtime_internal::kErrNoSuchKey); } CelValue CreateMissingAttributeError(google::protobuf::Arena* arena, absl::string_view missing_attribute_path) { - CelError* error = Arena::Create( - arena, absl::StatusCode::kInvalidArgument, - absl::StrCat(kErrMissingAttribute, missing_attribute_path)); - error->SetPayload(kPayloadUrlMissingAttributePath, - absl::Cord(missing_attribute_path)); - return CelValue::CreateError(error); + return CelValue::CreateError( + interop::CreateMissingAttributeError(arena, missing_attribute_path)); +} + +CelValue CreateMissingAttributeError(cel::MemoryManagerRef manager, + absl::string_view missing_attribute_path) { + // TODO(uncreated-issue/1): assume arena-style allocator while migrating + // to new value type. + return CelValue::CreateError(interop::CreateMissingAttributeError( + cel::extensions::ProtoMemoryManagerArena(manager), + missing_attribute_path)); } bool IsMissingAttributeError(const CelValue& value) { - if (!value.IsError()) return false; - const CelError* error = value.ErrorOrDie(); // Crash ok + const CelError* error; + if (!value.GetValue(&error)) return false; if (error && error->code() == absl::StatusCode::kInvalidArgument) { - auto path = error->GetPayload(kPayloadUrlMissingAttributePath); + auto path = error->GetPayload( + cel::runtime_internal::kPayloadUrlMissingAttributePath); return path.has_value(); } return false; } -std::set GetUnknownPathsSetOrDie(const CelValue& value) { - // TODO(issues/41): replace with the implementation of go/cel-known-unknowns - const CelError* error = value.ErrorOrDie(); - if (error && error->code() == absl::StatusCode::kUnavailable) { - auto path = error->GetPayload(kPayloadUrlUnknownPath); - if (path.has_value()) return {std::string(path.value())}; - } - GOOGLE_LOG(FATAL) << "The value is not an unknown path error."; // Crash ok - return {}; +CelValue CreateUnknownFunctionResultError(cel::MemoryManagerRef manager, + absl::string_view help_message) { + return CelValue::CreateError(interop::CreateUnknownFunctionResultError( + cel::extensions::ProtoMemoryManagerArena(manager), help_message)); } CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, absl::string_view help_message) { - CelError* error = Arena::Create( - arena, absl::StatusCode::kUnavailable, - absl::StrCat("Unknown function result: ", help_message)); - error->SetPayload(kPayloadUrlUnknownFunctionResult, absl::Cord("true")); - return CelValue::CreateError(error); + return CelValue::CreateError( + interop::CreateUnknownFunctionResultError(arena, help_message)); } bool IsUnknownFunctionResult(const CelValue& value) { - if (!value.IsError()) { - return false; - } - const CelError* error = value.ErrorOrDie(); + const CelError* error; + if (!value.GetValue(&error)) return false; + if (error == nullptr || error->code() != absl::StatusCode::kUnavailable) { return false; } - auto payload = error->GetPayload(kPayloadUrlUnknownFunctionResult); + auto payload = error->GetPayload( + cel::runtime_internal::kPayloadUrlUnknownFunctionResult); return payload.has_value() && payload.value() == "true"; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 0c7057e7d..76b4d09bb 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -16,28 +16,46 @@ // string* msg = google::protobuf::Arena::Create(arena,"test"); // CelValue value = CelValue::CreateString(msg); // (c) For messages: -// const MyMessage * msg = google::protobuf::Arena::CreateMessage(arena); +// const MyMessage * msg = google::protobuf::Arena::Create(arena); // CelValue value = CelProtoWrapper::CreateMessage(msg, &arena); -#include "google/protobuf/message.h" +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/kind.h" +#include "common/memory.h" +#include "common/native_type.h" #include "eval/public/cel_value_internal.h" -#include "absl/status/statusor.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/unknown_set.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "internal/utf8.h" +#include "google/protobuf/message.h" + +namespace cel::interop_internal { +struct CelListAccess; +struct CelMapAccess; +} // namespace cel::interop_internal -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { using CelError = absl::Status; -// Break cyclic depdendencies for container types. +// Break cyclic dependencies for container types. class CelList; class CelMap; -class UnknownSet; +class LegacyTypeAdapter; class CelValue { public: @@ -49,14 +67,14 @@ class CelValue { public: StringHolderBase() : value_(absl::string_view()) {} - StringHolderBase(const StringHolderBase &) = default; - StringHolderBase &operator=(const StringHolderBase &) = default; + StringHolderBase(const StringHolderBase&) = default; + StringHolderBase& operator=(const StringHolderBase&) = default; // string parameter is passed through pointer to ensure string_view is not // initialized with string rvalue. Also, according to Google style guide, // passing pointers conveys the message that the reference to string is kept // in the constructed holder object. - explicit StringHolderBase(const std::string *str) : value_(*str) {} + explicit StringHolderBase(const std::string* str) : value_(*str) {} absl::string_view value() const { return value_; } @@ -102,13 +120,20 @@ class CelValue { // Helper structure for CelType datatype. using CelTypeHolder = StringHolderBase<2>; + // Type for CEL Null values. Implemented as a monostate to behave well in + // absl::variant. + using NullType = absl::monostate; + + // GCC: fully qualified to avoid change of meaning error. + using MessageWrapper = google::api::expr::runtime::MessageWrapper; + private: // CelError MUST BE the last in the declaration - it is a ceiling for Type // enum using ValueHolder = internal::ValueHolder< - bool, int64_t, uint64_t, double, StringHolder, BytesHolder, - const google::protobuf::Message *, absl::Duration, absl::Time, const CelList *, - const CelMap *, const UnknownSet *, CelTypeHolder, const CelError *>; + NullType, bool, int64_t, uint64_t, double, StringHolder, BytesHolder, + MessageWrapper, absl::Duration, absl::Time, const CelList*, const CelMap*, + const UnknownSet*, CelTypeHolder, const CelError*>; public: // Metafunction providing positions corresponding to specific @@ -117,30 +142,36 @@ class CelValue { using IndexOf = ValueHolder::IndexOf; // Enum for types supported. - enum class Type { + // This is not recommended for use in exhaustive switches in client code. + // Types may be updated over time. + using Type = ::cel::Kind; + + // Legacy enumeration that is here for testing purposes. Do not use. + enum class LegacyType { + kNullType = IndexOf::value, kBool = IndexOf::value, kInt64 = IndexOf::value, kUint64 = IndexOf::value, kDouble = IndexOf::value, kString = IndexOf::value, kBytes = IndexOf::value, - kMessage = IndexOf::value, + kMessage = IndexOf::value, kDuration = IndexOf::value, kTimestamp = IndexOf::value, - kList = IndexOf::value, - kMap = IndexOf::value, - kUnknownSet = IndexOf::value, + kList = IndexOf::value, + kMap = IndexOf::value, + kUnknownSet = IndexOf::value, kCelType = IndexOf::value, - kError = IndexOf::value, + kError = IndexOf::value, kAny // Special value. Used in function descriptors. }; // Default constructor. // Creates CelValue with null data type. - CelValue() : CelValue(static_cast(nullptr)) {} + CelValue() : CelValue(NullType()) {} // Returns Type that describes the type of value stored. - Type type() const { return Type(value_.index()); } + Type type() const { return static_cast(value_.index()); } // Returns debug string describing a value const std::string DebugString() const; @@ -149,9 +180,10 @@ class CelValue { // The reason for this is the high risk of implicit type conversions // between bool/int/pointer types. // We rely on copy elision to avoid extra copying. - static CelValue CreateNull() { - return CelValue(static_cast(nullptr)); - } + static CelValue CreateNull() { return CelValue(NullType()); } + + // Transitional factory for migrating to null types. + static CelValue CreateNullTypedValue() { return CelValue(NullType()); } static CelValue CreateBool(bool value) { return CelValue(value); } @@ -161,13 +193,19 @@ class CelValue { static CelValue CreateDouble(double value) { return CelValue(value); } - static CelValue CreateString(StringHolder holder) { return CelValue(holder); } + static CelValue CreateString(StringHolder holder) { + ABSL_ASSERT(::cel::internal::Utf8IsValid(holder.value())); + return CelValue(holder); + } + // Returns a string value from a string_view. Warning: the caller is + // responsible for the lifecycle of the backing string. Prefer CreateString + // instead. static CelValue CreateStringView(absl::string_view value) { return CelValue(StringHolder(value)); } - static CelValue CreateString(const std::string *str) { + static CelValue CreateString(const std::string* str) { return CelValue(StringHolder(str)); } @@ -177,27 +215,35 @@ class CelValue { return CelValue(BytesHolder(value)); } - static CelValue CreateBytes(const std::string *str) { + static CelValue CreateBytes(const std::string* str) { return CelValue(BytesHolder(str)); } - static CelValue CreateDuration(absl::Duration value) { + static CelValue CreateDuration(absl::Duration value); + + static CelValue CreateUncheckedDuration(absl::Duration value) { return CelValue(value); } static CelValue CreateTimestamp(absl::Time value) { return CelValue(value); } - static CelValue CreateList(const CelList *value) { + static CelValue CreateList(const CelList* value) { CheckNullPointer(value, Type::kList); return CelValue(value); } - static CelValue CreateMap(const CelMap *value) { + // Creates a CelValue backed by an empty immutable list. + static CelValue CreateList(); + + static CelValue CreateMap(const CelMap* value) { CheckNullPointer(value, Type::kMap); return CelValue(value); } - static CelValue CreateUnknownSet(const UnknownSet *value) { + // Creates a CelValue backed by an empty immutable map. + static CelValue CreateMap(); + + static CelValue CreateUnknownSet(const UnknownSet* value) { CheckNullPointer(value, Type::kUnknownSet); return CelValue(value); } @@ -206,11 +252,23 @@ class CelValue { return CelValue(holder); } - static CelValue CreateError(const CelError *value) { + static CelValue CreateCelTypeView(absl::string_view value) { + // This factory method is used for dealing with string references which + // come from protobuf objects or other containers which promise pointer + // stability. In general, this is a risky method to use and should not + // be invoked outside the core CEL library. + return CelValue(CelTypeHolder(value)); + } + + static CelValue CreateError(const CelError* value) { CheckNullPointer(value, Type::kError); return CelValue(value); } + // Returns an absl::OkStatus() when the key is a valid protobuf map type, + // meaning it is a scalar value that is neither floating point nor bytes. + static absl::Status CheckMapKeyType(const CelValue& key); + // Obtain the CelType of the value. CelValue ObtainCelType() const; @@ -223,20 +281,22 @@ class CelValue { // Fails if stored value type is not boolean. bool BoolOrDie() const { return GetValueOrDie(Type::kBool); } - // Returns stored int64_t value. - // Fails if stored value type is not int64_t. + // Returns stored int64 value. + // Fails if stored value type is not int64. int64_t Int64OrDie() const { return GetValueOrDie(Type::kInt64); } - // Returns stored uint64_t value. - // Fails if stored value type is not uint64_t. - uint64_t Uint64OrDie() const { return GetValueOrDie(Type::kUint64); } + // Returns stored uint64 value. + // Fails if stored value type is not uint64. + uint64_t Uint64OrDie() const { + return GetValueOrDie(Type::kUint64); + } // Returns stored double value. // Fails if stored value type is not double. double DoubleOrDie() const { return GetValueOrDie(Type::kDouble); } - // Returns stored const string * value. - // Fails if stored value type is not const string *. + // Returns stored const string* value. + // Fails if stored value type is not const string*. StringHolder StringOrDie() const { return GetValueOrDie(Type::kString); } @@ -245,10 +305,17 @@ class CelValue { return GetValueOrDie(Type::kBytes); } - // Returns stored const Message * value. - // Fails if stored value type is not const Message *. - const google::protobuf::Message *MessageOrDie() const { - return GetValueOrDie(Type::kMessage); + // Returns stored const Message* value. + // Fails if stored value type is not const Message*. + const google::protobuf::Message* MessageOrDie() const { + MessageWrapper wrapped = MessageWrapperOrDie(); + ABSL_ASSERT(wrapped.HasFullProto()); + return static_cast(wrapped.message_ptr()); + } + + ABSL_DEPRECATED("Use MessageOrDie") + MessageWrapper MessageWrapperOrDie() const { + return GetValueOrDie(Type::kMessage); } // Returns stored duration value. @@ -263,16 +330,16 @@ class CelValue { return GetValueOrDie(Type::kTimestamp); } - // Returns stored const CelList * value. - // Fails if stored value type is not const CelList *. - const CelList *ListOrDie() const { - return GetValueOrDie(Type::kList); + // Returns stored const CelList* value. + // Fails if stored value type is not const CelList*. + const CelList* ListOrDie() const { + return GetValueOrDie(Type::kList); } // Returns stored const CelMap * value. // Fails if stored value type is not const CelMap *. - const CelMap *MapOrDie() const { - return GetValueOrDie(Type::kMap); + const CelMap* MapOrDie() const { + return GetValueOrDie(Type::kMap); } // Returns stored const CelTypeHolder value. @@ -283,14 +350,14 @@ class CelValue { // Returns stored const UnknownAttributeSet * value. // Fails if stored value type is not const UnknownAttributeSet *. - const UnknownSet *UnknownSetOrDie() const { - return GetValueOrDie(Type::kUnknownSet); + const UnknownSet* UnknownSetOrDie() const { + return GetValueOrDie(Type::kUnknownSet); } // Returns stored const CelError * value. // Fails if stored value type is not const CelError *. - const CelError *ErrorOrDie() const { - return GetValueOrDie(Type::kError); + const CelError* ErrorOrDie() const { + return GetValueOrDie(Type::kError); } bool IsNull() const { return value_.template Visit(NullCheckOp()); } @@ -307,66 +374,118 @@ class CelValue { bool IsBytes() const { return value_.is(); } - bool IsMessage() const { return value_.is(); } + bool IsMessage() const { return value_.is(); } bool IsDuration() const { return value_.is(); } bool IsTimestamp() const { return value_.is(); } - bool IsList() const { return value_.is(); } + bool IsList() const { return value_.is(); } - bool IsMap() const { return value_.is(); } + bool IsMap() const { return value_.is(); } - bool IsUnknownSet() const { return value_.is(); } + bool IsUnknownSet() const { return value_.is(); } bool IsCelType() const { return value_.is(); } - bool IsError() const { return value_.is(); } + bool IsError() const { return value_.is(); } // Invokes op() with the active value, and returns the result. // All overloads of op() must have the same return type. + // Note: this depends on the internals of CelValue, so use with caution. template - ReturnType Visit(Op &&op) const { - return value_.template Visit(op); + ReturnType InternalVisit(Op&& op) const { + return value_.template Visit(std::forward(op)); + } + + // Invokes op() with the active value, and returns the result. + // All overloads of op() must have the same return type. + // TODO(uncreated-issue/2): Move to CelProtoWrapper to retain the assumed + // google::protobuf::Message variant version behavior for client code. + template + ReturnType Visit(Op&& op) const { + return value_.template Visit( + internal::MessageVisitAdapter(std::forward(op))); } // Template-style getter. // Returns true, if assignment successful template - bool GetValue(Arg *value) const { - return this->template Visit(AssignerOp(value)); + bool GetValue(Arg* value) const { + return this->template InternalVisit(AssignerOp(value)); } // Provides type names for internal logging. static std::string TypeName(Type value_type); + // Factory for message wrapper. This should only be used by internal + // libraries. + // TODO(uncreated-issue/2): exposed for testing while wiring adapter APIs. Should + // make private visibility after refactors are done. + ABSL_DEPRECATED("Use CelProtoWrapper::CreateMessage") + static CelValue CreateMessageWrapper(MessageWrapper value) { + CheckNullPointer(value.message_ptr(), Type::kMessage); + CheckNullPointer(value.legacy_type_info(), Type::kMessage); + return CelValue(value); + } + private: ValueHolder value_; - template + template struct AssignerOp { - explicit AssignerOp(T *val) : value(val) {} + explicit AssignerOp(T* val) : value(val) {} template - bool operator()(const U &) { + bool operator()(const U&) { return false; } - bool operator()(const T &arg) { + bool operator()(const T& arg) { *value = arg; return true; } - T *value; + T* value; + }; + + // Specialization for MessageWrapper to support legacy behavior while + // migrating off hard dependency on google::protobuf::Message. + // TODO(uncreated-issue/2): Move to CelProtoWrapper. + template + struct AssignerOp< + T, std::enable_if_t>> { + explicit AssignerOp(const google::protobuf::Message** val) : value(val) {} + + template + bool operator()(const U&) { + return false; + } + + bool operator()(const MessageWrapper& held_value) { + if (!held_value.HasFullProto()) { + return false; + } + + *value = static_cast(held_value.message_ptr()); + return true; + } + + const google::protobuf::Message** value; }; struct NullCheckOp { template - bool operator()(const T &) const { + bool operator()(const T&) const { return false; } - bool operator()(const google::protobuf::Message *arg) const { return arg == nullptr; } + bool operator()(NullType) const { return true; } + // Note: this is not typically possible, but is supported for allowing + // function resolution for null ptrs as Messages. + bool operator()(const MessageWrapper& arg) const { + return arg.message_ptr() == nullptr; + } }; // Constructs CelValue wrapping value supplied as argument. @@ -374,27 +493,43 @@ class CelValue { template explicit CelValue(T value) : value_(value) {} + // Crashes with a null pointer error. + static void CrashNullPointer(Type type) ABSL_ATTRIBUTE_COLD { + ABSL_LOG(FATAL) << "Null pointer supplied for " + << TypeName(type); // Crash ok + } + // Null pointer checker for pointer-based types. - static void CheckNullPointer(const void *ptr, Type type) { - if (ptr == nullptr) { - GOOGLE_LOG(FATAL) << "Null pointer supplied for " << TypeName(type); // Crash ok + static void CheckNullPointer(const void* ptr, Type type) { + if (ABSL_PREDICT_FALSE(ptr == nullptr)) { + CrashNullPointer(type); } } + // Crashes with a type mismatch error. + static void CrashTypeMismatch(Type requested_type, + Type actual_type) ABSL_ATTRIBUTE_COLD { + ABSL_LOG(FATAL) << "Type mismatch" // Crash ok + << ": expected " << TypeName(requested_type) // Crash ok + << ", encountered " << TypeName(actual_type); // Crash ok + } + // Gets value of type specified template T GetValueOrDie(Type requested_type) const { auto value_ptr = value_.get(); - if (value_ptr == nullptr) { - GOOGLE_LOG(FATAL) << "Type mismatch" // Crash ok - << ": expected " << TypeName(requested_type) // Crash ok - << ", encountered " << TypeName(type()); // Crash ok + if (ABSL_PREDICT_FALSE(value_ptr == nullptr)) { + CrashTypeMismatch(requested_type, type()); } return *value_ptr; } friend class CelProtoWrapper; + friend class ProtoMessageTypeAdapter; + friend class EvaluatorStack; + friend class TestOnly_FactoryAccessor; }; + static_assert(absl::is_trivially_destructible::value, "Non-trivially-destructible CelValue impacts " "performance"); @@ -402,25 +537,89 @@ static_assert(absl::is_trivially_destructible::value, // CelList is a base class for list adapting classes. class CelList { public: + ABSL_DEPRECATED( + "Unless you are sure of the underlying CelList implementation, call Get " + "and pass an arena instead") virtual CelValue operator[](int index) const = 0; + // Like `operator[](int)` above, but also accepts an arena. Prefer calling + // this variant if the arena is known. + virtual CelValue Get(google::protobuf::Arena* arena, int index) const { + static_cast(arena); + return (*this)[index]; + } + // List size virtual int size() const = 0; // Default empty check. Can be overridden in subclass for performance. virtual bool empty() const { return size() == 0; } virtual ~CelList() {} + + private: + friend struct cel::interop_internal::CelListAccess; + friend struct cel::NativeTypeTraits; + + virtual cel::NativeTypeId GetNativeTypeId() const { + return cel::NativeTypeId(); + } }; // CelMap is a base class for map accessors. class CelMap { public: - // Map lookup. If value found, - // returns CelValue in return type. - // Per Protobuffer specification, acceptable key types are - // int64_t,uint64,string. + // Map lookup. If value found, returns CelValue in return type. + // + // Per the protobuf specification, acceptable key types are bool, int64, + // uint64, string. Any key type that is not supported should result in valued + // response containing an absl::StatusCode::kInvalidArgument wrapped as a + // CelError. + // + // Type specializations are permitted since CEL supports such distinctions + // at type-check time. For example, the expression `1 in map_str` where the + // variable `map_str` is of type map(string, string) will yield a type-check + // error. To be consistent, the runtime should also yield an invalid argument + // error if the type does not agree with the expected key types held by the + // container. + // TODO(issues/122): Make this method const correct. + ABSL_DEPRECATED( + "Unless you are sure of the underlying CelMap implementation, call Get " + "and pass an arena instead") virtual absl::optional operator[](CelValue key) const = 0; + // Like `operator[](CelValue)` above, but also accepts an arena. Prefer + // calling this variant if the arena is known. + virtual absl::optional Get(google::protobuf::Arena* arena, + CelValue key) const { + static_cast(arena); + return (*this)[key]; + } + + // Return whether the key is present within the map. + // + // Typically, key resolution will be a simple boolean result; however, there + // are scenarios where the conversion of the input key to the underlying + // key-type will produce an absl::StatusCode::kInvalidArgument. + // + // Evaluators are responsible for handling non-OK results by propagating the + // error, as appropriate, up the evaluation stack either as a `StatusOr` or + // as a `CelError` value, depending on the context. + virtual absl::StatusOr Has(const CelValue& key) const { + // This check safeguards against issues with invalid key types such as NaN. + CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); + google::protobuf::Arena arena; + auto value = (*this).Get(&arena, key); + if (!value.has_value()) { + return false; + } + // This protects from issues that may occur when looking up a key value, + // such as a failure to convert an int64 to an int32 map key. + if (value->IsError()) { + return *value->ErrorOrDie(); + } + return true; + } + // Map size virtual int size() const = 0; // Default empty check. Can be overridden in subclass for performance. @@ -428,68 +627,144 @@ class CelMap { // Return list of keys. CelList is owned by Arena, so no // ownership is passed. - virtual const CelList *ListKeys() const = 0; + ABSL_DEPRECATED( + "Unless you are sure of the underlying CelMap implementation, call " + "ListKeys and pass an arena instead") + virtual absl::StatusOr ListKeys() const = 0; + + // Like `ListKeys()` above, but also accepts an arena. Prefer calling this + // variant if the arena is known. + virtual absl::StatusOr ListKeys(google::protobuf::Arena* arena) const { + static_cast(arena); + return ListKeys(); + } virtual ~CelMap() {} + + private: + friend struct cel::interop_internal::CelMapAccess; + friend struct cel::NativeTypeTraits; + + virtual cel::NativeTypeId GetNativeTypeId() const { + return cel::NativeTypeId(); + } }; // Utility method that generates CelValue containing CelError. // message an error message // error_code error code -// position location of the error source in CEL expression string the Expr was -// parsed from. -1, if the position can not be determined. CelValue CreateErrorValue( - google::protobuf::Arena *arena, absl::string_view message, - absl::StatusCode error_code = absl::StatusCode::kUnknown, - int position = -1); - -CelValue CreateNoMatchingOverloadError(google::protobuf::Arena *arena); -CelValue CreateNoMatchingOverloadError(google::protobuf::Arena *arena, - absl::string_view fn); + cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view message, + absl::StatusCode error_code = absl::StatusCode::kUnknown); +CelValue CreateErrorValue( + google::protobuf::Arena* arena, absl::string_view message, + absl::StatusCode error_code = absl::StatusCode::kUnknown); + +// Utility method for generating a CelValue from an absl::Status. +CelValue CreateErrorValue(cel::MemoryManagerRef manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const absl::Status& status); + +// Utility method for generating a CelValue from an absl::Status. +CelValue CreateErrorValue(google::protobuf::Arena* arena, const absl::Status& status); + +// Create an error for failed overload resolution, optionally including the name +// of the function. +CelValue CreateNoMatchingOverloadError(cel::MemoryManagerRef manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view fn = ""); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") +CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena, + absl::string_view fn = ""); bool CheckNoMatchingOverloadError(CelValue value); -CelValue CreateNoSuchFieldError(google::protobuf::Arena *arena); +CelValue CreateNoSuchFieldError(cel::MemoryManagerRef manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view field = ""); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") +CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena, + absl::string_view field = ""); -CelValue CreateNoSuchKeyError(google::protobuf::Arena *arena, absl::string_view key); -bool CheckNoSuchKeyError(CelValue value); - -// Returns the error indicating that evaluation encountered a value marked -// as unknown, was included in Activation unknown_paths. -CelValue CreateUnknownValueError(google::protobuf::Arena *arena, - absl::string_view unknown_path); +CelValue CreateNoSuchKeyError(cel::MemoryManagerRef manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view key); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") +CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key); -// Returns true if this is unknown value error indicating that evaluation -// encountered a value marked as unknown in Activation unknown_paths. -bool IsUnknownValueError(const CelValue &value); +bool CheckNoSuchKeyError(CelValue value); // Returns an error indicating that evaluation has accessed an attribute whose // value is undefined. For example, this may represent a field in a proto // message bound to the activation whose value can't be determined by the // hosting application. -CelValue CreateMissingAttributeError(google::protobuf::Arena *arena, +CelValue CreateMissingAttributeError(cel::MemoryManagerRef manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view missing_attribute_path); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") +CelValue CreateMissingAttributeError(google::protobuf::Arena* arena, absl::string_view missing_attribute_path); -bool IsMissingAttributeError(const CelValue &value); +ABSL_CONST_INIT extern const absl::string_view kPayloadUrlMissingAttributePath; +bool IsMissingAttributeError(const CelValue& value); // Returns error indicating the result of the function is unknown. This is used // as a signal to create an unknown set if unknown function handling is opted // into. -CelValue CreateUnknownFunctionResultError(google::protobuf::Arena *arena, +CelValue CreateUnknownFunctionResultError(cel::MemoryManagerRef manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view help_message); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") +CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, absl::string_view help_message); // Returns true if this is unknown value error indicating that evaluation // called an extension function whose value is unknown for the given args. // This is used as a signal to convert to an UnknownSet if the behavior is opted // into. -bool IsUnknownFunctionResult(const CelValue &value); +bool IsUnknownFunctionResult(const CelValue& value); + +} // namespace google::api::expr::runtime + +namespace cel { -// Returns set of unknown paths for unknown value error. The value must be -// unknown error, see IsUnknownValueError() above, or it dies. -std::set GetUnknownPathsSetOrDie(const CelValue &value); +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const google::api::expr::runtime::CelList& cel_list) { + return cel_list.GetNativeTypeId(); + } +}; + +template +struct NativeTypeTraits< + T, + std::enable_if_t, + std::negation>>>> + final { + static NativeTypeId Id(const google::api::expr::runtime::CelList& cel_list) { + return NativeTypeTraits::Id(cel_list); + } +}; + +template <> +struct NativeTypeTraits final { + static NativeTypeId Id(const google::api::expr::runtime::CelMap& cel_map) { + return cel_map.GetNativeTypeId(); + } +}; + +template +struct NativeTypeTraits< + T, std::enable_if_t, + std::negation>>>> + final { + static NativeTypeId Id(const google::api::expr::runtime::CelMap& cel_map) { + return NativeTypeTraits::Id(cel_map); + } +}; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_H_ diff --git a/eval/public/cel_value_internal.h b/eval/public/cel_value_internal.h index e46f2b874..64b895ad7 100644 --- a/eval/public/cel_value_internal.h +++ b/eval/public/cel_value_internal.h @@ -17,14 +17,14 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ -#include "absl/types/variant.h" +#include -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "absl/base/macros.h" +#include "absl/types/variant.h" +#include "eval/public/message_wrapper.h" +#include "google/protobuf/message.h" -namespace internal { +namespace google::api::expr::runtime::internal { // Helper classes needed for IndexOf metafunction implementation. template @@ -60,7 +60,7 @@ class ValueHolder { using IndexOf = TypeIndexer<0, sizeof...(Args), T, Args...>; template - const T *get() const { + const T* get() const { return absl::get_if(&value_); } @@ -72,7 +72,7 @@ class ValueHolder { int index() const { return value_.index(); } template - ReturnType Visit(Op &&op) const { + ReturnType Visit(Op&& op) const { return absl::visit(std::forward(op), value_); } @@ -80,11 +80,24 @@ class ValueHolder { absl::variant value_; }; -} // namespace internal +// Adapter for visitor clients that depend on google::protobuf::Message as a variant type. +template +struct MessageVisitAdapter { + explicit MessageVisitAdapter(Op&& op) : op(std::forward(op)) {} + + template + T operator()(const ArgT& arg) { + return op(arg); + } + + T operator()(const MessageWrapper& wrapper) { + ABSL_ASSERT(wrapper.HasFullProto()); + return op(static_cast(wrapper.message_ptr())); + } + + Op op; +}; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ diff --git a/eval/public/cel_value_producer.h b/eval/public/cel_value_producer.h index bbecd1b55..88ef185d4 100644 --- a/eval/public/cel_value_producer.h +++ b/eval/public/cel_value_producer.h @@ -3,10 +3,7 @@ #include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // CelValueProducer produces CelValue during CEL Expression evaluation. // It is intended to be used with Activation, to provide on-demand CelValue @@ -23,9 +20,6 @@ class CelValueProducer { virtual CelValue Produce(google::protobuf::Arena* arena) = 0; }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_PRODUCER_H_ diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index f9538965e..0af6eb9e7 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -1,28 +1,44 @@ #include "eval/public/cel_value.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include + #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "eval/public/unknown_attribute_set.h" +#include "absl/types/optional.h" +#include "common/memory.h" +#include "eval/internal/errors.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/public/testing/matchers.h" #include "eval/public/unknown_set.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -using testing::Eq; -using testing::UnorderedPointwise; +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::runtime_internal::kDurationHigh; +using ::cel::runtime_internal::kDurationLow; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::NotNull; class DummyMap : public CelMap { public: absl::optional operator[](CelValue value) const override { return CelValue::CreateNull(); } - const CelList* ListKeys() const override { return nullptr; } + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } int size() const override { return 0; } }; @@ -39,6 +55,9 @@ class DummyList : public CelList { TEST(CelValueTest, TestType) { ::google::protobuf::Arena arena; + CelValue value_null = CelValue::CreateNull(); + EXPECT_THAT(value_null.type(), Eq(CelValue::Type::kNullType)); + CelValue value_bool = CelValue::CreateBool(false); EXPECT_THAT(value_bool.type(), Eq(CelValue::Type::kBool)); @@ -124,7 +143,7 @@ TEST(CelValueTest, TestBool) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } -// This test verifies CelValue support of int64_t type. +// This test verifies CelValue support of int64 type. TEST(CelValueTest, TestInt64) { int64_t v = 1; CelValue value = CelValue::CreateInt64(v); @@ -138,7 +157,7 @@ TEST(CelValueTest, TestInt64) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } -// This test verifies CelValue support of uint64_t type. +// This test verifies CelValue support of uint64 type. TEST(CelValueTest, TestUint64) { uint64_t v = 1; CelValue value = CelValue::CreateUint64(v); @@ -152,7 +171,7 @@ TEST(CelValueTest, TestUint64) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } -// This test verifies CelValue support of int64_t type. +// This test verifies CelValue support of int64 type. TEST(CelValueTest, TestDouble) { double v0 = 1.; CelValue value = CelValue::CreateDouble(v0); @@ -166,6 +185,23 @@ TEST(CelValueTest, TestDouble) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } +TEST(CelValueTest, TestDurationRangeCheck) { + EXPECT_THAT(CelValue::CreateDuration(absl::Seconds(1)), + test::IsCelDuration(absl::Seconds(1))); + + EXPECT_THAT( + CelValue::CreateDuration(kDurationHigh), + test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Duration is out of range")))); + EXPECT_THAT( + CelValue::CreateDuration(kDurationLow), + test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Duration is out of range")))); + + EXPECT_THAT(CelValue::CreateDuration(kDurationLow + absl::Seconds(1)), + test::IsCelDuration(kDurationLow + absl::Seconds(1))); +} + // This test verifies CelValue support of string type. TEST(CelValueTest, TestString) { constexpr char kTestStr0[] = "test0"; @@ -217,6 +253,20 @@ TEST(CelValueTest, TestList) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } +TEST(CelValueTest, TestEmptyList) { + ::google::protobuf::Arena arena; + + CelValue value = CelValue::CreateList(); + EXPECT_TRUE(value.IsList()); + + const CelList* value2; + EXPECT_TRUE(value.GetValue(&value2)); + EXPECT_TRUE(value2->empty()); + EXPECT_EQ(value2->size(), 0); + EXPECT_THAT(value2->Get(&arena, 0), + test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument))); +} + // This test verifies CelValue support of Map type. TEST(CelValueTest, TestMap) { DummyMap dummy_map; @@ -232,9 +282,27 @@ TEST(CelValueTest, TestMap) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } -TEST(CelValueTest, TestCelType) { +TEST(CelValueTest, TestEmptyMap) { ::google::protobuf::Arena arena; + CelValue value = CelValue::CreateMap(); + EXPECT_TRUE(value.IsMap()); + + const CelMap* value2; + EXPECT_TRUE(value.GetValue(&value2)); + EXPECT_TRUE(value2->empty()); + EXPECT_EQ(value2->size(), 0); + EXPECT_THAT(value2->Has(CelValue::CreateBool(false)), IsOkAndHolds(false)); + EXPECT_THAT(value2->Get(&arena, CelValue::CreateBool(false)), + Eq(absl::nullopt)); + EXPECT_THAT(value2->ListKeys(&arena), IsOkAndHolds(NotNull())); +} + +TEST(CelValueTest, TestCelType) { + CelValue value_null = CelValue::CreateNullTypedValue(); + EXPECT_THAT(value_null.ObtainCelType().CelTypeOrDie().value(), + Eq("null_type")); + CelValue value_bool = CelValue::CreateBool(false); EXPECT_THAT(value_bool.ObtainCelType().CelTypeOrDie().value(), Eq("bool")); @@ -257,15 +325,17 @@ TEST(CelValueTest, TestCelType) { EXPECT_THAT(value_bytes.type(), Eq(CelValue::Type::kBytes)); EXPECT_THAT(value_bytes.ObtainCelType().CelTypeOrDie().value(), Eq("bytes")); + std::string msg_type_str = "google.api.expr.runtime.TestMessage"; + CelValue msg_type = CelValue::CreateCelTypeView(msg_type_str); + EXPECT_TRUE(msg_type.IsCelType()); + EXPECT_THAT(msg_type.CelTypeOrDie().value(), + Eq("google.api.expr.runtime.TestMessage")); + EXPECT_THAT(msg_type.type(), Eq(CelValue::Type::kCelType)); + UnknownSet unknown_set; CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); EXPECT_TRUE(value_unknown.ObtainCelType().IsUnknownSet()); - - CelValue missing_attribute_error = - CreateMissingAttributeError(&arena, "destination.ip"); - EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); - EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); } // This test verifies CelValue support of Unknown type. @@ -283,15 +353,69 @@ TEST(CelValueTest, TestUnknownSet) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } -TEST(CelValueTest, UnknownFunctionResultErrors) { - ::google::protobuf::Arena arena; +TEST(CelValueTest, SpecialErrorFactories) { + google::protobuf::Arena arena; + auto manager = ProtoMemoryManagerRef(&arena); + + CelValue error = CreateNoSuchKeyError(manager, "key"); + EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); + EXPECT_TRUE(CheckNoSuchKeyError(error)); + + error = CreateNoSuchFieldError(manager, "field"); + EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); + + error = CreateNoMatchingOverloadError(manager, "function"); + EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kUnknown))); + EXPECT_TRUE(CheckNoMatchingOverloadError(error)); + + absl::Status error_status = absl::InternalError("internal error"); + error_status.SetPayload("CreateErrorValuePreservesFullStatusMessage", + absl::Cord("more information")); + error = CreateErrorValue(manager, error_status); + EXPECT_THAT(error, test::IsCelError(error_status)); + + error = CreateErrorValue(&arena, error_status); + EXPECT_THAT(error, test::IsCelError(error_status)); +} + +TEST(CelValueTest, MissingAttributeErrorsDeprecated) { + google::protobuf::Arena arena; + + CelValue missing_attribute_error = + CreateMissingAttributeError(&arena, "destination.ip"); + EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); + EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); +} + +TEST(CelValueTest, MissingAttributeErrors) { + google::protobuf::Arena arena; + auto manager = ProtoMemoryManagerRef(&arena); + + CelValue missing_attribute_error = + CreateMissingAttributeError(manager, "destination.ip"); + EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); + EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); +} + +TEST(CelValueTest, UnknownFunctionResultErrorsDeprecated) { + google::protobuf::Arena arena; CelValue value = CreateUnknownFunctionResultError(&arena, "message"); EXPECT_TRUE(value.IsError()); EXPECT_TRUE(IsUnknownFunctionResult(value)); } +TEST(CelValueTest, UnknownFunctionResultErrors) { + google::protobuf::Arena arena; + auto manager = ProtoMemoryManagerRef(&arena); + + CelValue value = CreateUnknownFunctionResultError(manager, "message"); + EXPECT_TRUE(value.IsError()); + EXPECT_TRUE(IsUnknownFunctionResult(value)); +} + TEST(CelValueTest, DebugString) { + EXPECT_EQ(CelValue::CreateNull().DebugString(), "null_type: null"); EXPECT_EQ(CelValue::CreateBool(true).DebugString(), "bool: 1"); EXPECT_EQ(CelValue::CreateInt64(-12345).DebugString(), "int64: -12345"); EXPECT_EQ(CelValue::CreateUint64(12345).DebugString(), "uint64: 12345"); @@ -306,18 +430,53 @@ TEST(CelValueTest, DebugString) { EXPECT_EQ( CelValue::CreateTimestamp(absl::FromUnixSeconds(86400)).DebugString(), - "Time: 1970-01-02T00:00:00+00:00"); + "Timestamp: 1970-01-02T00:00:00+00:00"); UnknownSet unknown_set; EXPECT_EQ(CelValue::CreateUnknownSet(&unknown_set).DebugString(), - "UnknownSet"); + "UnknownSet: ?"); absl::Status error = absl::InternalError("Blah..."); EXPECT_EQ(CelValue::CreateError(&error).DebugString(), - "Error: INTERNAL: Blah..."); + "CelError: INTERNAL: Blah..."); + + // List and map DebugString() test coverage is in cel_proto_wrapper_test.cc. +} + +TEST(CelValueTest, Message) { + TestMessage message; + auto value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, TrivialTypeInfo::GetInstance())); + EXPECT_TRUE(value.IsMessage()); + CelValue::MessageWrapper held; + ASSERT_TRUE(value.GetValue(&held)); + EXPECT_TRUE(held.HasFullProto()); + EXPECT_EQ(held.message_ptr(), + static_cast(&message)); + EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); + // TrivialTypeInfo doesn't provide any details about the specific message. + EXPECT_EQ(value.ObtainCelType().CelTypeOrDie().value(), "opaque"); + EXPECT_EQ(value.DebugString(), "Message: opaque"); +} + +TEST(CelValueTest, MessageLite) { + TestMessage message; + // Upcast to message lite. + const google::protobuf::MessageLite* ptr = &message; + auto value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(ptr, TrivialTypeInfo::GetInstance())); + EXPECT_TRUE(value.IsMessage()); + CelValue::MessageWrapper held; + ASSERT_TRUE(value.GetValue(&held)); + EXPECT_FALSE(held.HasFullProto()); + EXPECT_EQ(held.message_ptr(), &message); + EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); + EXPECT_EQ(value.ObtainCelType().CelTypeOrDie().value(), "opaque"); + EXPECT_EQ(value.DebugString(), "Message: opaque"); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +TEST(CelValueTest, Size) { + // CelValue performance degrades when it becomes larger. + static_assert(sizeof(CelValue) <= 3 * sizeof(uintptr_t)); +} +} // namespace google::api::expr::runtime diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc new file mode 100644 index 000000000..ec282704c --- /dev/null +++ b/eval/public/comparison_functions.cc @@ -0,0 +1,33 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/comparison_functions.h" + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/standard/comparison_functions.h" + +namespace google::api::expr::runtime { + +absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + cel::RuntimeOptions modern_options = ConvertToRuntimeOptions(options); + cel::FunctionRegistry& modern_registry = registry->InternalGetRegistry(); + return cel::RegisterComparisonFunctions(modern_registry, modern_options); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/comparison_functions.h b/eval/public/comparison_functions.h new file mode 100644 index 000000000..61df888ac --- /dev/null +++ b/eval/public/comparison_functions.h @@ -0,0 +1,36 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" + +namespace google::api::expr::runtime { + +// Register built in comparison functions (<, <=, >, >=). +// +// Most users should prefer to use RegisterBuiltinFunctions. +// +// This is call is included in RegisterBuiltinFunctions -- calling both +// RegisterBuiltinFunctions and RegisterComparisonFunctions directly on the same +// registry will result in an error. +absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc new file mode 100644 index 000000000..78f347ec8 --- /dev/null +++ b/eval/public/comparison_functions_test.cc @@ -0,0 +1,248 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/comparison_functions.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/testing/matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::expr::ParsedExpr; +using ::google::rpc::context::AttributeContext; +using ::testing::Combine; +using ::testing::ValuesIn; + +MATCHER_P2(DefinesHomogenousOverload, name, argument_type, + absl::StrCat(name, " for ", CelValue::TypeName(argument_type))) { + const CelFunctionRegistry& registry = arg; + return !registry + .FindOverloads(name, /*receiver_style=*/false, + {argument_type, argument_type}) + .empty(); + return false; +} + +struct ComparisonTestCase { + absl::string_view expr; + bool result; + CelValue lhs = CelValue::CreateNull(); + CelValue rhs = CelValue::CreateNull(); +}; + +class ComparisonFunctionTest + : public testing::TestWithParam> { + public: + ComparisonFunctionTest() { + options_.enable_heterogeneous_equality = std::get<1>(GetParam()); + options_.enable_empty_wrapper_null_unboxing = true; + builder_ = CreateCelExpressionBuilder(options_); + } + + CelFunctionRegistry& registry() { return *builder_->GetRegistry(); } + + absl::StatusOr Evaluate(absl::string_view expr, const CelValue& lhs, + const CelValue& rhs) { + CEL_ASSIGN_OR_RETURN(ParsedExpr parsed_expr, parser::Parse(expr)); + Activation activation; + activation.InsertValue("lhs", lhs); + activation.InsertValue("rhs", rhs); + + CEL_ASSIGN_OR_RETURN(auto expression, + builder_->CreateExpression( + &parsed_expr.expr(), &parsed_expr.source_info())); + + return expression->Evaluate(activation, &arena_); + } + + protected: + std::unique_ptr builder_; + InterpreterOptions options_; + google::protobuf::Arena arena_; +}; + +TEST_P(ComparisonFunctionTest, SmokeTest) { + ComparisonTestCase test_case = std::get<0>(GetParam()); + google::protobuf::LinkMessageReflection(); + + ASSERT_OK(RegisterComparisonFunctions(®istry(), options_)); + ASSERT_OK_AND_ASSIGN(auto result, + Evaluate(test_case.expr, test_case.lhs, test_case.rhs)); + + EXPECT_THAT(result, test::IsCelBool(test_case.result)); +} + +INSTANTIATE_TEST_SUITE_P( + LessThan, ComparisonFunctionTest, + Combine(ValuesIn( + {// less than + {"false < true", true}, + {"1 < 2", true}, + {"-2 < -1", true}, + {"1.1 < 1.2", true}, + {"'a' < 'b'", true}, + {"lhs < rhs", true, CelValue::CreateBytesView("a"), + CelValue::CreateBytesView("b")}, + {"lhs < rhs", true, CelValue::CreateDuration(absl::Seconds(1)), + CelValue::CreateDuration(absl::Seconds(2))}, + {"lhs < rhs", true, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), + CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), + // heterogeneous equality enabled + testing::Bool())); + +INSTANTIATE_TEST_SUITE_P( + GreaterThan, ComparisonFunctionTest, + testing::Combine( + testing::ValuesIn( + {{"false > true", false}, + {"1 > 2", false}, + {"-2 > -1", false}, + {"1.1 > 1.2", false}, + {"'a' > 'b'", false}, + {"lhs > rhs", false, CelValue::CreateBytesView("a"), + CelValue::CreateBytesView("b")}, + {"lhs > rhs", false, CelValue::CreateDuration(absl::Seconds(1)), + CelValue::CreateDuration(absl::Seconds(2))}, + {"lhs > rhs", false, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), + CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), + // heterogeneous equality enabled + testing::Bool())); + +INSTANTIATE_TEST_SUITE_P( + GreaterOrEqual, ComparisonFunctionTest, + Combine(ValuesIn( + {{"false >= true", false}, + {"1 >= 2", false}, + {"-2 >= -1", false}, + {"1.1 >= 1.2", false}, + {"'a' >= 'b'", false}, + {"lhs >= rhs", false, CelValue::CreateBytesView("a"), + CelValue::CreateBytesView("b")}, + {"lhs >= rhs", false, + CelValue::CreateDuration(absl::Seconds(1)), + CelValue::CreateDuration(absl::Seconds(2))}, + {"lhs >= rhs", false, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), + CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), + // heterogeneous equality enabled + testing::Bool())); + +INSTANTIATE_TEST_SUITE_P( + LessOrEqual, ComparisonFunctionTest, + Combine(testing::ValuesIn( + {{"false <= true", true}, + {"1 <= 2", true}, + {"-2 <= -1", true}, + {"1.1 <= 1.2", true}, + {"'a' <= 'b'", true}, + {"lhs <= rhs", true, CelValue::CreateBytesView("a"), + CelValue::CreateBytesView("b")}, + {"lhs <= rhs", true, + CelValue::CreateDuration(absl::Seconds(1)), + CelValue::CreateDuration(absl::Seconds(2))}, + {"lhs <= rhs", true, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), + CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), + // heterogeneous equality enabled + testing::Bool())); + +INSTANTIATE_TEST_SUITE_P(HeterogeneousNumericComparisons, + ComparisonFunctionTest, + Combine(testing::ValuesIn( + { // less than + {"1 < 2u", true}, // int < uint + {"2 < 1u", false}, + {"1 < 2.1", true}, // int < double + {"3 < 2.1", false}, + {"1u < 2", true}, // uint < int + {"2u < 1", false}, + {"1u < -1.1", false}, // uint < double + {"1u < 2.1", true}, + {"1.1 < 2", true}, // double < int + {"1.1 < 1", false}, + {"1.0 < 1u", false}, // double < uint + {"1.0 < 3u", true}, + + // less than or equal + {"1 <= 2u", true}, // int <= uint + {"2 <= 1u", false}, + {"1 <= 2.1", true}, // int <= double + {"3 <= 2.1", false}, + {"1u <= 2", true}, // uint <= int + {"1u <= 0", false}, + {"1u <= -1.1", false}, // uint <= double + {"2u <= 1.0", false}, + {"1.1 <= 2", true}, // double <= int + {"2.1 <= 2", false}, + {"1.0 <= 1u", true}, // double <= uint + {"1.1 <= 1u", false}, + + // greater than + {"3 > 2u", true}, // int > uint + {"3 > 4u", false}, + {"3 > 2.1", true}, // int > double + {"3 > 4.1", false}, + {"3u > 2", true}, // uint > int + {"3u > 4", false}, + {"3u > -1.1", true}, // uint > double + {"3u > 4.1", false}, + {"3.1 > 2", true}, // double > int + {"3.1 > 4", false}, + {"3.0 > 1u", true}, // double > uint + {"3.0 > 4u", false}, + + // greater than or equal + {"3 >= 2u", true}, // int >= uint + {"3 >= 4u", false}, + {"3 >= 2.1", true}, // int >= double + {"3 >= 4.1", false}, + {"3u >= 2", true}, // uint >= int + {"3u >= 4", false}, + {"3u >= -1.1", true}, // uint >= double + {"3u >= 4.1", false}, + {"3.1 >= 2", true}, // double >= int + {"3.1 >= 4", false}, + {"3.0 >= 1u", true}, // double >= uint + {"3.0 >= 4u", false}, + {"1u >= -1", true}, + {"1 >= 4u", false}, + + // edge cases + {"-1 < 1u", true}, + {"1 < 9223372036854775808u", true}}), + testing::Values(true))); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/container_function_registrar.cc b/eval/public/container_function_registrar.cc new file mode 100644 index 000000000..c61aa93c9 --- /dev/null +++ b/eval/public/container_function_registrar.cc @@ -0,0 +1,31 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/container_function_registrar.h" + +#include "eval/public/cel_options.h" +#include "runtime/runtime_options.h" +#include "runtime/standard/container_functions.h" + +namespace google::api::expr::runtime { + +absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); + + return cel::RegisterContainerFunctions(registry->InternalGetRegistry(), + runtime_options); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/container_function_registrar.h b/eval/public/container_function_registrar.h new file mode 100644 index 000000000..9ce268439 --- /dev/null +++ b/eval/public/container_function_registrar.h @@ -0,0 +1,36 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" + +namespace google::api::expr::runtime { + +// Register built in container functions. +// +// Most users should prefer to use RegisterBuiltinFunctions. +// +// This call is included in RegisterBuiltinFunctions -- calling both +// RegisterBuiltinFunctions and RegisterContainerFunctions directly on the same +// registry will result in an error. +absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ diff --git a/eval/public/container_function_registrar_test.cc b/eval/public/container_function_registrar_test.cc new file mode 100644 index 000000000..e6d5f93d8 --- /dev/null +++ b/eval/public/container_function_registrar_test.cc @@ -0,0 +1,95 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/container_function_registrar.h" + +#include +#include + +#include "eval/public/activation.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/equality_function_registrar.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "parser/parser.h" + +namespace google::api::expr::runtime { +namespace { + +using cel::expr::Expr; +using cel::expr::SourceInfo; +using ::testing::ValuesIn; + +struct TestCase { + std::string test_name; + std::string expr; + absl::StatusOr result = CelValue::CreateBool(true); +}; + +const CelList& CelNumberListExample() { + static ContainerBackedListImpl* example = + new ContainerBackedListImpl({CelValue::CreateInt64(1)}); + return *example; +} + +void ExpectResult(const TestCase& test_case) { + auto parsed_expr = parser::Parse(test_case.expr); + ASSERT_OK(parsed_expr); + const Expr& expr_ast = parsed_expr->expr(); + const SourceInfo& source_info = parsed_expr->source_info(); + InterpreterOptions options; + options.enable_timestamp_duration_overflow_errors = true; + options.enable_comprehension_list_append = true; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterContainerFunctions(builder->GetRegistry(), options)); + // Needed to avoid error - No overloads provided for FunctionStep creation. + ASSERT_OK(RegisterEqualityFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr_ast, &source_info)); + + Activation activation; + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena)); + EXPECT_THAT(value, test::EqualsCelValue(*test_case.result)); +} + +using ContainerFunctionParamsTest = testing::TestWithParam; +TEST_P(ContainerFunctionParamsTest, StandardFunctions) { + ExpectResult(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + ContainerFunctionParamsTest, ContainerFunctionParamsTest, + ValuesIn( + {{"FilterNumbers", "[1, 2, 3].filter(num, num == 1)", + CelValue::CreateList(&CelNumberListExample())}, + {"ListConcatEmptyInputs", "[] + [] == []", CelValue::CreateBool(true)}, + {"ListConcatRightEmpty", "[1] + [] == [1]", + CelValue::CreateBool(true)}, + {"ListConcatLeftEmpty", "[] + [1] == [1]", CelValue::CreateBool(true)}, + {"ListConcat", "[2] + [1] == [2, 1]", CelValue::CreateBool(true)}, + {"ListSize", "[1, 2, 3].size() == 3", CelValue::CreateBool(true)}, + {"MapSize", "{1: 2, 2: 4}.size() == 2", CelValue::CreateBool(true)}, + {"EmptyListSize", "size({}) == 0", CelValue::CreateBool(true)}}), + [](const testing::TestParamInfo& + info) { return info.param.test_name; }); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index 04e76b4e2..18ad48734 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -1,11 +1,30 @@ -# Container type implementations for use in the c++ CEL evaluator. +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -package(default_visibility = ["//visibility:public"]) +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") -licenses(["notice"]) # Apache 2.0 +package(default_visibility = ["//visibility:public"]) +licenses(["notice"]) # TODO(issues/69): Expose this in a public API. +package_group( + name = "cel_internal", + packages = ["//eval/..."], +) + cc_library( name = "field_access", srcs = [ @@ -15,11 +34,12 @@ cc_library( "field_access.h", ], deps = [ + "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", - "//internal:proto_util", + "//eval/public/structs:field_access_impl", + "//internal:status_macros", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) @@ -33,8 +53,7 @@ cc_library( ], deps = [ "//eval/public:cel_value", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", ], ) @@ -49,37 +68,36 @@ cc_library( deps = [ "//eval/public:cel_value", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_library( name = "field_backed_list_impl", - srcs = [ - "field_backed_list_impl.cc", - ], hdrs = [ "field_backed_list_impl.h", ], deps = [ - ":field_access", + ":internal_field_backed_list_impl", "//eval/public:cel_value", - "@com_google_absl//absl/strings", + "//eval/public/structs:cel_proto_wrapper", ], ) cc_library( name = "field_backed_map_impl", - srcs = [ - "field_backed_map_impl.cc", - ], hdrs = [ "field_backed_map_impl.h", ], deps = [ - ":field_access", + ":internal_field_backed_map_impl", "//eval/public:cel_value", - "@com_google_absl//absl/strings", + "//eval/public/structs:cel_proto_wrapper", + "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) @@ -93,8 +111,8 @@ cc_test( deps = [ ":container_backed_map_impl", "//eval/public:cel_value", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "//internal:testing", + "@com_google_absl//absl/status", ], ) @@ -106,10 +124,9 @@ cc_test( ], deps = [ ":field_backed_list_impl", - "//eval/eval:evaluator_core", "//eval/testutil:test_message_cc_proto", + "//internal:testing", "//testutil:util", - "@com_google_googletest//:gtest_main", ], ) @@ -121,9 +138,95 @@ cc_test( ], deps = [ ":field_backed_map_impl", - "//eval/eval:evaluator_core", "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "field_access_test", + srcs = ["field_access_test.cc"], + deps = [ + ":field_access", + "//eval/public:cel_value", + "//eval/public/structs:cel_proto_wrapper", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "//internal:time", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "internal_field_backed_list_impl", + srcs = [ + "internal_field_backed_list_impl.cc", + ], + hdrs = [ + "internal_field_backed_list_impl.h", + ], + deps = [ + "//eval/public:cel_value", + "//eval/public/structs:field_access_impl", + "//eval/public/structs:protobuf_value_factory", + ], +) + +cc_test( + name = "internal_field_backed_list_impl_test", + size = "small", + srcs = [ + "internal_field_backed_list_impl_test.cc", + ], + deps = [ + ":internal_field_backed_list_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "//testutil:util", + ], +) + +cc_library( + name = "internal_field_backed_map_impl", + srcs = [ + "internal_field_backed_map_impl.cc", + ], + hdrs = [ + "internal_field_backed_map_impl.h", + ], + deps = [ + "//eval/public:cel_value", + "//eval/public/structs:field_access_impl", + "//eval/public/structs:protobuf_value_factory", + "//extensions/protobuf/internal:map_reflection", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "internal_field_backed_map_impl_test", + size = "small", + srcs = [ + "internal_field_backed_map_impl_test.cc", + ], + visibility = ["//visibility:private"], + deps = [ + ":internal_field_backed_map_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", ], ) diff --git a/eval/public/containers/container_backed_list_impl.h b/eval/public/containers/container_backed_list_impl.h index 2e195051a..c0480c651 100644 --- a/eval/public/containers/container_backed_list_impl.h +++ b/eval/public/containers/container_backed_list_impl.h @@ -1,8 +1,11 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_LIST_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_LIST_IMPL_H_ +#include +#include + #include "eval/public/cel_value.h" -#include "absl/types/span.h" +#include "google/protobuf/arena.h" namespace google { namespace api { @@ -24,6 +27,11 @@ class ContainerBackedListImpl : public CelList { // List element access operator. CelValue operator[](int index) const override { return values_[index]; } + // List element access operator. + CelValue Get(google::protobuf::Arena*, int index) const override { + return values_[index]; + } + private: std::vector values_; }; diff --git a/eval/public/containers/container_backed_map_impl.cc b/eval/public/containers/container_backed_map_impl.cc index 5e05abbde..5ac08af92 100644 --- a/eval/public/containers/container_backed_map_impl.cc +++ b/eval/public/containers/container_backed_map_impl.cc @@ -1,8 +1,13 @@ - - #include "eval/public/containers/container_backed_map_impl.h" +#include +#include + #include "absl/container/node_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "eval/public/cel_value.h" @@ -77,92 +82,47 @@ class CelValueEq { const CelValue& other_; }; -// CelValue hasher functor. -class Hasher { - public: - size_t operator()(const CelValue& key) const { - return key.template Visit(HasherOp()); - } -}; - -// CelValue equality functor. -class Equal { - public: - // - bool operator()(const CelValue& key1, const CelValue& key2) const { - if (key1.type() != key2.type()) { - return false; - } - return key1.template Visit(CelValueEq(key2)); - } -}; +} // namespace -// CelMap implementation that uses STL map container as backing storage. -// KeyType is the type of key values stored in CelValue, InnerKeyType is the -// type of key in STL map. -class ContainerBackedMapImpl : public CelMap { - public: - static std::unique_ptr Create( - absl::Span> key_values) { - auto cel_map = absl::WrapUnique(new ContainerBackedMapImpl()); - - if (!cel_map->AddItems(key_values)) { - return nullptr; - } - return std::move(cel_map); +// Map element access operator. +absl::optional CelMapBuilder::operator[](CelValue cel_key) const { + auto item = values_map_.find(cel_key); + if (item == values_map_.end()) { + return absl::nullopt; } + return item->second; +} - // Map size. - int size() const override { return values_map_.size(); } +absl::Status CelMapBuilder::Add(CelValue key, CelValue value) { + auto [unused, inserted] = values_map_.emplace(key, value); - // Map element access operator. - absl::optional operator[](CelValue cel_key) const override { - auto item = values_map_.find(cel_key); - if (item == values_map_.end()) { - return {}; - } - return item->second; + if (!inserted) { + return absl::InvalidArgumentError("duplicate map keys"); } + key_list_.Add(key); + return absl::OkStatus(); +} - const CelList* ListKeys() const override { return &key_list_; } - - private: - class KeyList : public CelList { - public: - int size() const override { return keys_.size(); } - - CelValue operator[](int index) const override { return keys_[index]; } - - void Add(const CelValue& key) { keys_.push_back(key); } - - private: - std::vector keys_; - }; - - ContainerBackedMapImpl() = default; - - bool AddItems(absl::Span> key_values) { - for (const auto& item : key_values) { - auto result = values_map_.emplace(item.first, item.second); +// CelValue hasher functor. +size_t CelMapBuilder::Hasher::operator()(const CelValue& key) const { + return key.template Visit(HasherOp()); +} - // Failed to insert pair into map - addition failed. - if (!result.second) { - return false; - } - key_list_.Add(item.first); - } - return true; +bool CelMapBuilder::Equal::operator()(const CelValue& key1, + const CelValue& key2) const { + if (key1.type() != key2.type()) { + return false; } + return key1.template Visit(CelValueEq(key2)); +} - absl::node_hash_map values_map_; - KeyList key_list_; -}; - -} // namespace - -std::unique_ptr CreateContainerBackedMap( - absl::Span> key_values) { - return ContainerBackedMapImpl::Create(key_values); +absl::StatusOr> CreateContainerBackedMap( + absl::Span> key_values) { + auto map = std::make_unique(); + for (const auto& key_value : key_values) { + CEL_RETURN_IF_ERROR(map->Add(key_value.first, key_value.second)); + } + return map; } } // namespace runtime diff --git a/eval/public/containers/container_backed_map_impl.h b/eval/public/containers/container_backed_map_impl.h index c89fd3d21..6092eefcf 100644 --- a/eval/public/containers/container_backed_map_impl.h +++ b/eval/public/containers/container_backed_map_impl.h @@ -1,21 +1,70 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_MAP_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_MAP_IMPL_H_ -#include "eval/public/cel_value.h" +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime { + +// CelMap implementation that uses STL map container as backing storage. +// KeyType is the type of key values stored in CelValue. +// After building, upcast to CelMap to prevent further additions. +class CelMapBuilder : public CelMap { + public: + CelMapBuilder() {} + + // Try to insert a key value pair into the map. Returns a status if key + // already exists. + absl::Status Add(CelValue key, CelValue value); + + int size() const override { return values_map_.size(); } + + absl::optional operator[](CelValue cel_key) const override; + + absl::StatusOr Has(const CelValue& cel_key) const override { + return values_map_.contains(cel_key); + } + + absl::StatusOr ListKeys() const override { + return &key_list_; + } + + private: + // Custom CelList implementation for maintaining key list. + class KeyList : public CelList { + public: + KeyList() {} + + int size() const override { return keys_.size(); } + + CelValue operator[](int index) const override { return keys_[index]; } + + void Add(const CelValue& key) { keys_.push_back(key); } + + private: + std::vector keys_; + }; + + struct Hasher { + size_t operator()(const CelValue& key) const; + }; + struct Equal { + bool operator()(const CelValue& key1, const CelValue& key2) const; + }; -namespace google { -namespace api { -namespace expr { -namespace runtime { + absl::node_hash_map values_map_; + KeyList key_list_; +}; -// Template factory method creating container-backed CelMap. -std::unique_ptr CreateContainerBackedMap( - absl::Span> key_values); +// Factory method creating container-backed CelMap. +absl::StatusOr> CreateContainerBackedMap( + absl::Span> key_values); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_MAP_IMPL_H_ diff --git a/eval/public/containers/container_backed_map_impl_test.cc b/eval/public/containers/container_backed_map_impl_test.cc index d017ca893..59d38d235 100644 --- a/eval/public/containers/container_backed_map_impl_test.cc +++ b/eval/public/containers/container_backed_map_impl_test.cc @@ -4,28 +4,28 @@ #include #include -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/status/status.h" #include "eval/public/cel_value.h" +#include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using testing::Eq; -using testing::Not; -using testing::IsNull; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::Not; TEST(ContainerBackedMapImplTest, TestMapInt64) { std::vector> args = { {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(3)}}; - auto cel_map = CreateContainerBackedMap( - absl::Span>(args.data(), args.size())); + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); ASSERT_THAT(cel_map, Not(IsNull())); @@ -56,8 +56,10 @@ TEST(ContainerBackedMapImplTest, TestMapUint64) { std::vector> args = { {CelValue::CreateUint64(1), CelValue::CreateInt64(2)}, {CelValue::CreateUint64(2), CelValue::CreateInt64(3)}}; - auto cel_map = CreateContainerBackedMap( - absl::Span>(args.data(), args.size())); + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); ASSERT_THAT(cel_map, Not(IsNull())); @@ -92,8 +94,10 @@ TEST(ContainerBackedMapImplTest, TestMapString) { std::vector> args = { {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; - auto cel_map = CreateContainerBackedMap( - absl::Span>(args.data(), args.size())); + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); ASSERT_THAT(cel_map, Not(IsNull())); @@ -120,9 +124,64 @@ TEST(ContainerBackedMapImplTest, TestMapString) { ASSERT_FALSE(lookup3); } +TEST(CelMapBuilder, TestMapString) { + const std::string kKey1 = "1"; + const std::string kKey2 = "2"; + const std::string kKey3 = "3"; + + std::vector> args = { + {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, + {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; + CelMapBuilder builder; + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey1), CelValue::CreateInt64(2))); + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3))); + + CelMap* cel_map = &builder; + + ASSERT_THAT(cel_map, Not(IsNull())); + + EXPECT_THAT(cel_map->size(), Eq(2)); + + // Test lookup with key == 1 ( should succeed ) + auto lookup1 = (*cel_map)[CelValue::CreateString(&kKey1)]; + + ASSERT_TRUE(lookup1); + + CelValue cel_value = lookup1.value(); + + ASSERT_TRUE(cel_value.IsInt64()); + EXPECT_THAT(cel_value.Int64OrDie(), 2); + + // Test lookup with different type ( should fail ) + auto lookup2 = (*cel_map)[CelValue::CreateInt64(1)]; + + ASSERT_FALSE(lookup2); + + // Test lookup with key3 ( should fail ) + auto lookup3 = (*cel_map)[CelValue::CreateString(&kKey3)]; + + ASSERT_FALSE(lookup3); +} + +TEST(CelMapBuilder, RepeatKeysFail) { + const std::string kKey1 = "1"; + const std::string kKey2 = "2"; + + std::vector> args = { + {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, + {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; + CelMapBuilder builder; + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey1), CelValue::CreateInt64(2))); + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3))); + EXPECT_THAT( + builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)), + StatusIs(absl::StatusCode::kInvalidArgument, "duplicate map keys")); +} + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index eab074b83..a3da18e40 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -1,774 +1,85 @@ -#include "eval/public/containers/field_access.h" +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include +#include "eval/public/containers/field_access.h" -#include "google/protobuf/any.pb.h" -#include "google/protobuf/map_field.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/substitute.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/proto_util.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { +#include "eval/public/structs/field_access_impl.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/map_field.h" -namespace { +namespace google::api::expr::runtime { using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; -using ::google::protobuf::MapValueRef; +using ::google::protobuf::MapValueConstRef; using ::google::protobuf::Message; -using ::google::protobuf::Reflection; - -// Well-known type protobuf type names which require special get / set behavior. -constexpr const char kProtobufDuration[] = "google.protobuf.Duration"; -constexpr const char kProtobufTimestamp[] = "google.protobuf.Timestamp"; -constexpr const char kProtobufAny[] = "google.protobuf.Any"; - -const char kTypeGoogleApisComPrefix[] = "type.googleapis.com/"; - -// Singular message fields and repeated message fields have similar access model -// To provide common approach, we implement accessor classes, based on CRTP. -// FieldAccessor is CRTP base class, specifying Get.. method family. -template -class FieldAccessor { - public: - bool GetBool() const { return static_cast(this)->GetBool(); } - - int64_t GetInt32() const { - return static_cast(this)->GetInt32(); - } - - uint64_t GetUInt32() const { - return static_cast(this)->GetUInt32(); - } - - int64_t GetInt64() const { - return static_cast(this)->GetInt64(); - } - - uint64_t GetUInt64() const { - return static_cast(this)->GetUInt64(); - } - - double GetFloat() const { - return static_cast(this)->GetFloat(); - } - - double GetDouble() const { - return static_cast(this)->GetDouble(); - } - - const std::string* GetString(std::string* buffer) const { - return static_cast(this)->GetString(buffer); - } - - const Message* GetMessage() const { - return static_cast(this)->GetMessage(); - } - - int64_t GetEnumValue() const { - return static_cast(this)->GetEnumValue(); - } - - // This method provides message field content, wrapped in CelValue. - // If value provided successfully, returns Ok. - // arena Arena to use for allocations if needed. - // result pointer to object to store value in. - absl::Status CreateValueFromFieldAccessor(Arena* arena, CelValue* result) { - switch (field_desc_->cpp_type()) { - case FieldDescriptor::CPPTYPE_BOOL: { - bool value = GetBool(); - *result = CelValue::CreateBool(value); - break; - } - case FieldDescriptor::CPPTYPE_INT32: { - int64_t value = GetInt32(); - *result = CelValue::CreateInt64(value); - break; - } - case FieldDescriptor::CPPTYPE_INT64: { - int64_t value = GetInt64(); - *result = CelValue::CreateInt64(value); - break; - } - case FieldDescriptor::CPPTYPE_UINT32: { - uint64_t value = GetUInt32(); - *result = CelValue::CreateUint64(value); - break; - } - case FieldDescriptor::CPPTYPE_UINT64: { - uint64_t value = GetUInt64(); - *result = CelValue::CreateUint64(value); - break; - } - case FieldDescriptor::CPPTYPE_FLOAT: { - double value = GetFloat(); - *result = CelValue::CreateDouble(value); - break; - } - case FieldDescriptor::CPPTYPE_DOUBLE: { - double value = GetDouble(); - *result = CelValue::CreateDouble(value); - break; - } - case FieldDescriptor::CPPTYPE_STRING: { - std::string buffer; - const std::string* value = GetString(&buffer); - if (value == &buffer) { - value = google::protobuf::Arena::Create(arena, std::move(buffer)); - } - switch (field_desc_->type()) { - case FieldDescriptor::TYPE_STRING: - *result = CelValue::CreateString(value); - break; - case FieldDescriptor::TYPE_BYTES: - *result = CelValue::CreateBytes(value); - break; - default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Error handling C++ string conversion"); - } - break; - } - case FieldDescriptor::CPPTYPE_MESSAGE: { - const google::protobuf::Message* msg_value = GetMessage(); - *result = CelProtoWrapper::CreateMessage(msg_value, arena); - break; - } - case FieldDescriptor::CPPTYPE_ENUM: { - int enum_value = GetEnumValue(); - *result = CelValue::CreateInt64(enum_value); - break; - } - default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Unhandled C++ type conversion"); - } - - return absl::OkStatus(); - } - - protected: - FieldAccessor(const Message* msg, const FieldDescriptor* field_desc) - : msg_(msg), field_desc_(field_desc) {} - - const Message* msg_; - const FieldDescriptor* field_desc_; -}; - -// Accessor class, to work with singular fields -class ScalarFieldAccessor : public FieldAccessor { - public: - ScalarFieldAccessor(const Message* msg, const FieldDescriptor* field_desc) - : FieldAccessor(msg, field_desc) {} - - bool GetBool() const { return GetReflection()->GetBool(*msg_, field_desc_); } - - int64_t GetInt32() const { - return GetReflection()->GetInt32(*msg_, field_desc_); - } - - uint64_t GetUInt32() const { - return GetReflection()->GetUInt32(*msg_, field_desc_); - } - - int64_t GetInt64() const { - return GetReflection()->GetInt64(*msg_, field_desc_); - } - - uint64_t GetUInt64() const { - return GetReflection()->GetUInt64(*msg_, field_desc_); - } - - double GetFloat() const { - return GetReflection()->GetFloat(*msg_, field_desc_); - } - - double GetDouble() const { - return GetReflection()->GetDouble(*msg_, field_desc_); - } - - const std::string* GetString(std::string* buffer) const { - return &GetReflection()->GetStringReference(*msg_, field_desc_, buffer); - } - - const Message* GetMessage() const { - return &GetReflection()->GetMessage(*msg_, field_desc_); - } - - int64_t GetEnumValue() const { - return GetReflection()->GetEnumValue(*msg_, field_desc_); - } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } -}; - -// Accessor class, to work with repeated fields. -class RepeatedFieldAccessor : public FieldAccessor { - public: - RepeatedFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, - int index) - : FieldAccessor(msg, field_desc), index_(index) {} - - bool GetBool() const { - return GetReflection()->GetRepeatedBool(*msg_, field_desc_, index_); - } - - int64_t GetInt32() const { - return GetReflection()->GetRepeatedInt32(*msg_, field_desc_, index_); - } - - uint64_t GetUInt32() const { - return GetReflection()->GetRepeatedUInt32(*msg_, field_desc_, index_); - } - - int64_t GetInt64() const { - return GetReflection()->GetRepeatedInt64(*msg_, field_desc_, index_); - } - - uint64_t GetUInt64() const { - return GetReflection()->GetRepeatedUInt64(*msg_, field_desc_, index_); - } - - double GetFloat() const { - return GetReflection()->GetRepeatedFloat(*msg_, field_desc_, index_); - } - - double GetDouble() const { - return GetReflection()->GetRepeatedDouble(*msg_, field_desc_, index_); - } - - const std::string* GetString(std::string* buffer) const { - return &GetReflection()->GetRepeatedStringReference(*msg_, field_desc_, - index_, buffer); - } - - const Message* GetMessage() const { - return &GetReflection()->GetRepeatedMessage(*msg_, field_desc_, index_); - } - - int64_t GetEnumValue() const { - return GetReflection()->GetRepeatedEnumValue(*msg_, field_desc_, index_); - } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } - - private: - int index_; -}; - -// Accessor class, to work with map values -class MapValueAccessor : public FieldAccessor { - public: - MapValueAccessor(const Message* msg, const FieldDescriptor* field_desc, - const MapValueRef* value_ref) - : FieldAccessor(msg, field_desc), value_ref_(value_ref) {} - - bool GetBool() const { return value_ref_->GetBoolValue(); } - - int64_t GetInt32() const { return value_ref_->GetInt32Value(); } - - uint64_t GetUInt32() const { return value_ref_->GetUInt32Value(); } - - int64_t GetInt64() const { return value_ref_->GetInt64Value(); } - - uint64_t GetUInt64() const { return value_ref_->GetUInt64Value(); } - - double GetFloat() const { return value_ref_->GetFloatValue(); } - - double GetDouble() const { return value_ref_->GetDoubleValue(); } - - const std::string* GetString(std::string* /*buffer*/) const { - return &value_ref_->GetStringValue(); - } - - const Message* GetMessage() const { return &value_ref_->GetMessageValue(); } - - int64_t GetEnumValue() const { return value_ref_->GetEnumValue(); } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } - - private: - const MapValueRef* value_ref_; -}; - -// Helper classes that should retrieve values from CelValue, -// when CelValue content inherits from Message. -template -class MessageRetriever { - public: - absl::optional operator()(const T&) const { return {}; } -}; - -// Partial specialization, valid when T is assignable to message -// -template -class MessageRetriever { - public: - absl::optional operator()(const T& arg) const { - const Message* msg = arg; - return msg; - } -}; - -class MessageRetrieverOp { - public: - template - absl::optional operator()(const T& arg) { - // Metaprogramming hacks... - return MessageRetriever::value>()( - arg); - } -}; -} // namespace +absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, + const FieldDescriptor* desc, + google::protobuf::Arena* arena, + CelValue* result) { + return CreateValueFromSingleField( + msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, arena, result); +} absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const FieldDescriptor* desc, + ProtoWrapperTypeOptions options, google::protobuf::Arena* arena, CelValue* result) { - ScalarFieldAccessor accessor(msg, desc); - return accessor.CreateValueFromFieldAccessor(arena, result); + CEL_ASSIGN_OR_RETURN( + *result, + internal::CreateValueFromSingleField( + msg, desc, options, &CelProtoWrapper::InternalWrapMessage, arena)); + return absl::OkStatus(); } absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result) { - RepeatedFieldAccessor accessor(msg, desc, index); - return accessor.CreateValueFromFieldAccessor(arena, result); + CEL_ASSIGN_OR_RETURN( + *result, + internal::CreateValueFromRepeatedField( + msg, desc, index, &CelProtoWrapper::InternalWrapMessage, arena)); + return absl::OkStatus(); } absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const FieldDescriptor* desc, - const MapValueRef* value_ref, + const MapValueConstRef* value_ref, google::protobuf::Arena* arena, CelValue* result) { - MapValueAccessor accessor(msg, desc, value_ref); - return accessor.CreateValueFromFieldAccessor(arena, result); + CEL_ASSIGN_OR_RETURN( + *result, + internal::CreateValueFromMapValue( + msg, desc, value_ref, &CelProtoWrapper::InternalWrapMessage, arena)); + return absl::OkStatus(); } -// Singular message fields and repeated message fields have similar access model -// To provide common approach, we implement field setter classes, based on CRTP. -// FieldAccessor is CRTP base class, specifying Get.. method family. -template -class FieldSetter { - public: - bool AssignBool(const CelValue& cel_value) const { - bool value; - - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetBool(value); - return true; - } - - bool AssignInt32(const CelValue& cel_value) const { - int64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetInt32(value); - return true; - } - - bool AssignUInt32(const CelValue& cel_value) const { - uint64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetUInt32(value); - return true; - } - - bool AssignInt64(const CelValue& cel_value) const { - int64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetInt64(value); - return true; - } - - bool AssignUInt64(const CelValue& cel_value) const { - uint64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetUInt64(value); - return true; - } - - bool AssignFloat(const CelValue& cel_value) const { - double value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetFloat(value); - return true; - } - - bool AssignDouble(const CelValue& cel_value) const { - double value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetDouble(value); - return true; - } - - bool AssignString(const CelValue& cel_value) const { - CelValue::StringHolder value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetString(value); - return true; - } - - bool AssignBytes(const CelValue& cel_value) const { - CelValue::BytesHolder value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetBytes(value); - return true; - } - - bool AssignEnum(const CelValue& cel_value) const { - int64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetEnum(value); - return true; - } - - bool AssignMessage(const CelValue& cel_value) const { - // We attempt to retrieve value if it derives from google::protobuf::Message. - // That includes both generic Protobuf message types and specific - // message types stored in CelValue as separate entities. - auto value = cel_value.template Visit>( - MessageRetrieverOp()); - - if (!value.has_value()) { - GOOGLE_LOG(ERROR) << "Has No Value"; - return false; - } - - static_cast(this)->SetMessage(value.value()); - return true; - } - - bool AssignDuration(const CelValue& cel_value) const { - absl::Duration d; - if (!cel_value.GetValue(&d)) { - GOOGLE_LOG(ERROR) << "Unable to retrieve duration"; - return false; - } - google::protobuf::Duration duration; - google::api::expr::internal::EncodeDuration(d, &duration); - static_cast(this)->SetMessage(&duration); - return true; - } - - bool AssignTimestamp(const CelValue& cel_value) const { - absl::Time t; - if (!cel_value.GetValue(&t)) { - GOOGLE_LOG(ERROR) << "Unable to retrieve timestamp"; - return false; - } - google::protobuf::Timestamp timestamp; - google::api::expr::internal::EncodeTime(t, ×tamp); - static_cast(this)->SetMessage(×tamp); - return true; - } - - // This method provides message field content, wrapped in CelValue. - // If value provided successfully, returns Ok. - // arena Arena to use for allocations if needed. - // result pointer to object to store value in. - bool SetFieldFromCelValue(const CelValue& value) { - switch (field_desc_->cpp_type()) { - case FieldDescriptor::CPPTYPE_BOOL: { - return AssignBool(value); - } - case FieldDescriptor::CPPTYPE_INT32: { - return AssignInt32(value); - } - case FieldDescriptor::CPPTYPE_INT64: { - return AssignInt64(value); - } - case FieldDescriptor::CPPTYPE_UINT32: { - return AssignUInt32(value); - } - case FieldDescriptor::CPPTYPE_UINT64: { - return AssignUInt64(value); - } - case FieldDescriptor::CPPTYPE_FLOAT: { - return AssignFloat(value); - } - case FieldDescriptor::CPPTYPE_DOUBLE: { - return AssignDouble(value); - } - case FieldDescriptor::CPPTYPE_STRING: { - switch (field_desc_->type()) { - case FieldDescriptor::TYPE_STRING: - - return AssignString(value); - case FieldDescriptor::TYPE_BYTES: - return AssignBytes(value); - default: - return false; - } - break; - } - case FieldDescriptor::CPPTYPE_MESSAGE: { - const std::string& type_name = field_desc_->message_type()->full_name(); - // When the field is a message, it might be a well-known type with a - // non-proto representation that requires special handling before it - // can be set on the field. - if (type_name == kProtobufTimestamp) { - return AssignTimestamp(value); - } else if (type_name == kProtobufDuration) { - return AssignDuration(value); - } - return AssignMessage(value); - } - case FieldDescriptor::CPPTYPE_ENUM: { - return AssignEnum(value); - } - default: - return false; - } - - return true; - } - - protected: - FieldSetter(Message* msg, const FieldDescriptor* field_desc) - : msg_(msg), field_desc_(field_desc) {} - - Message* msg_; - const FieldDescriptor* field_desc_; -}; - -// Accessor class, to work with singular fields -class ScalarFieldSetter : public FieldSetter { - public: - ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc) - : FieldSetter(msg, field_desc) {} - - bool SetBool(bool value) const { - GetReflection()->SetBool(msg_, field_desc_, value); - return true; - } - - bool SetInt32(int32_t value) const { - GetReflection()->SetInt32(msg_, field_desc_, value); - return true; - } - - bool SetUInt32(uint32_t value) const { - GetReflection()->SetUInt32(msg_, field_desc_, value); - return true; - } - - bool SetInt64(int64_t value) const { - GetReflection()->SetInt64(msg_, field_desc_, value); - return true; - } - - bool SetUInt64(uint64_t value) const { - GetReflection()->SetUInt64(msg_, field_desc_, value); - return true; - } - - bool SetFloat(float value) const { - GetReflection()->SetFloat(msg_, field_desc_, value); - return true; - } - - bool SetDouble(double value) const { - GetReflection()->SetDouble(msg_, field_desc_, value); - return true; - } - - bool SetString(CelValue::StringHolder value) const { - GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetBytes(CelValue::BytesHolder value) const { - GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetMessage(const Message* value) const { - if (!value) { - GOOGLE_LOG(ERROR) << "Message is NULL"; - return true; - } - - if (value->GetDescriptor()->full_name() == - field_desc_->message_type()->full_name()) { - GetReflection()->MutableMessage(msg_, field_desc_)->MergeFrom(*value); - return true; - - } else if (field_desc_->message_type()->full_name() == kProtobufAny) { - auto any_msg = google::protobuf::DynamicCastToGenerated( - GetReflection()->MutableMessage(msg_, field_desc_)); - if (any_msg == nullptr) { - // TODO(issues/68): This is probably a dynamic message. We should - // implement this once we add support for dynamic protobuf types. - return false; - } - any_msg->set_type_url(absl::StrCat(kTypeGoogleApisComPrefix, - value->GetDescriptor()->full_name())); - return value->SerializeToString(any_msg->mutable_value()); - } - return false; - } - - bool SetEnum(const int64_t value) const { - GetReflection()->SetEnumValue(msg_, field_desc_, value); - return true; - } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } -}; - -// Appender class, to work with repeated fields -class RepeatedFieldSetter : public FieldSetter { - public: - RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc) - : FieldSetter(msg, field_desc) {} - - bool SetBool(bool value) const { - GetReflection()->AddBool(msg_, field_desc_, value); - return true; - } - - bool SetInt32(int32_t value) const { - GetReflection()->AddInt32(msg_, field_desc_, value); - return true; - } - - bool SetUInt32(uint32_t value) const { - GetReflection()->AddUInt32(msg_, field_desc_, value); - return true; - } - - bool SetInt64(int64_t value) const { - GetReflection()->AddInt64(msg_, field_desc_, value); - return true; - } - - bool SetUInt64(uint64_t value) const { - GetReflection()->AddUInt64(msg_, field_desc_, value); - return true; - } - - bool SetFloat(float value) const { - GetReflection()->AddFloat(msg_, field_desc_, value); - return true; - } - - bool SetDouble(double value) const { - GetReflection()->AddDouble(msg_, field_desc_, value); - return true; - } - - bool SetString(CelValue::StringHolder value) const { - GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetBytes(CelValue::BytesHolder value) const { - GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetMessage(const Message* value) const { - if (!value) return true; - if (value->GetDescriptor()->full_name() != - field_desc_->message_type()->full_name()) { - return false; - } - - GetReflection()->AddMessage(msg_, field_desc_)->MergeFrom(*value); - return true; - } - - bool SetEnum(const int64_t value) const { - GetReflection()->AddEnumValue(msg_, field_desc_, value); - return true; - } - - private: - const Reflection* GetReflection() const { return msg_->GetReflection(); } -}; - -// This method sets message field -// If value provided successfully, returns Ok. -// arena Arena to use for allocations if needed. -// result pointer to object to store value in. absl::Status SetValueToSingleField(const CelValue& value, - const FieldDescriptor* desc, Message* msg) { - ScalarFieldSetter setter(msg, desc); - return (setter.SetFieldFromCelValue(value)) - ? absl::OkStatus() - : absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument to message \"$0\" field " - "\"$1\" of type $2: value was \"$3\"", - msg->GetDescriptor()->name(), desc->name(), - desc->type_name(), value.DebugString())); + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + return internal::SetValueToSingleField(value, desc, msg, arena); } absl::Status AddValueToRepeatedField(const CelValue& value, - const FieldDescriptor* desc, - Message* msg) { - RepeatedFieldSetter setter(msg, desc); - return (setter.SetFieldFromCelValue(value)) - ? absl::OkStatus() - : absl::InvalidArgumentError(absl::Substitute( - "Could not add supplied argument \"$2\" to message \"$0\" " - "field \"$1\".", - msg->GetDescriptor()->name(), desc->name(), - value.DebugString())); -} - -absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const FieldDescriptor* desc, Message* msg) { - auto entry_msg = msg->GetReflection()->AddMessage(msg, desc); - auto key_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(1); - auto value_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(2); - - ScalarFieldSetter key_setter(entry_msg, key_field_desc); - ScalarFieldSetter value_setter(entry_msg, value_field_desc); - - if (!key_setter.SetFieldFromCelValue(key)) { - return absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument \"$2\" to message " - "\"$0\" field \"$1\" map key.", - msg->GetDescriptor()->name(), desc->name(), key.DebugString())); - } - - if (!value_setter.SetFieldFromCelValue(value)) { - return absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument \"$2\" to message \"$0\" " - "field \"$1\" map value.", - msg->GetDescriptor()->name(), desc->name(), value.DebugString())); - } - - return absl::OkStatus(); + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + return internal::AddValueToRepeatedField(value, desc, msg, arena); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/containers/field_access.h b/eval/public/containers/field_access.h index 63ed38369..69d3191dd 100644 --- a/eval/public/containers/field_access.h +++ b/eval/public/containers/field_access.h @@ -1,23 +1,27 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // Creates CelValue from singular message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. +// options Option to enable treating unset wrapper type fields as null. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, google::protobuf::Arena* arena, CelValue* result); +absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, + const google::protobuf::FieldDescriptor* desc, + ProtoWrapperTypeOptions options, + google::protobuf::Arena* arena, CelValue* result); + // Creates CelValue from repeated message field. // Returns status of the operation. // msg Message containing the field. @@ -39,35 +43,28 @@ absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, // result pointer to CelValue to store the result in. absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, - const google::protobuf::MapValueRef* value_ref, + const google::protobuf::MapValueConstRef* value_ref, google::protobuf::Arena* arena, CelValue* result); // Assigns content of CelValue to singular message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when setting the field. absl::Status SetValueToSingleField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); + google::protobuf::Message* msg, google::protobuf::Arena* arena); + // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when adding the value. absl::Status AddValueToRepeatedField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); - -// Adds content of CelValue to repeated message field. -// Returns status of the operation. -// msg Message containing the field. -// desc Descriptor of the field to access. + google::protobuf::Message* msg, + google::protobuf::Arena* arena); -absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); +} // namespace google::api::expr::runtime -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ diff --git a/eval/public/containers/field_access_test.cc b/eval/public/containers/field_access_test.cc new file mode 100644 index 000000000..8c0bc0037 --- /dev/null +++ b/eval/public/containers/field_access_test.cc @@ -0,0 +1,284 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/containers/field_access.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "internal/time.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::internal::MaxDuration; +using ::cel::internal::MaxTimestamp; +using ::google::protobuf::Arena; +using ::google::protobuf::FieldDescriptor; +using ::testing::HasSubstr; + +TEST(FieldAccessTest, SetDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField(CelValue::CreateDuration(MaxDuration()), + field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetDurationBadDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField( + CelValue::CreateDuration(MaxDuration() + absl::Seconds(1)), field, &msg, + &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetDurationBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestamp) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField(CelValue::CreateTimestamp(MaxTimestamp()), + field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetTimestampBadTime) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField( + CelValue::CreateTimestamp(MaxTimestamp() + absl::Seconds(1)), field, &msg, + &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestampBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetInt32Overflow) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_int32"); + EXPECT_THAT( + SetValueToSingleField( + CelValue::CreateInt64(std::numeric_limits::max() + 1L), + field, &msg, &arena), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Could not assign"))); +} + +TEST(FieldAccessTest, SetUint32Overflow) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_uint32"); + EXPECT_THAT( + SetValueToSingleField( + CelValue::CreateUint64(std::numeric_limits::max() + 1L), + field, &msg, &arena), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Could not assign"))); +} + +TEST(FieldAccessTest, SetMessage) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + TestAllTypes::NestedMessage* nested_msg = + google::protobuf::Arena::Create(&arena); + nested_msg->set_bb(1); + auto status = SetValueToSingleField( + CelProtoWrapper::CreateMessage(nested_msg, &arena), field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetMessageWithNul) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + auto status = + SetValueToSingleField(CelValue::CreateNull(), field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +constexpr std::array kWrapperFieldNames = { + "single_bool_wrapper", "single_int64_wrapper", "single_int32_wrapper", + "single_uint64_wrapper", "single_uint32_wrapper", "single_double_wrapper", + "single_float_wrapper", "single_string_wrapper", "single_bytes_wrapper"}; + +// Unset wrapper type fields are treated as null if accessed after option +// enabled. +TEST(CreateValueFromFieldTest, UnsetWrapperTypesNullIfEnabled) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + for (const auto& field : kWrapperFieldNames) { + ASSERT_OK(CreateValueFromSingleField( + &test_message, TestAllTypes::GetDescriptor()->FindFieldByName(field), + ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)) + << field; + ASSERT_TRUE(result.IsNull()) << field << ": " << result.DebugString(); + } +} + +// Unset wrapper type fields are treated as proto default under old +// behavior. +TEST(CreateValueFromFieldTest, UnsetWrapperTypesDefaultValueIfDisabled) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + for (const auto& field : kWrapperFieldNames) { + ASSERT_OK(CreateValueFromSingleField( + &test_message, TestAllTypes::GetDescriptor()->FindFieldByName(field), + ProtoWrapperTypeOptions::kUnsetProtoDefault, &arena, &result)) + << field; + ASSERT_FALSE(result.IsNull()) << field << ": " << result.DebugString(); + } +} + +// If a wrapper type is set to default value, the corresponding CelValue is the +// proto default value. +TEST(CreateValueFromFieldTest, SetWrapperTypesDefaultValue) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + single_bool_wrapper {} + single_int64_wrapper {} + single_int32_wrapper {} + single_uint64_wrapper {} + single_uint32_wrapper {} + single_double_wrapper {} + single_float_wrapper {} + single_string_wrapper {} + single_bytes_wrapper {} + )pb", + &test_message)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_bool_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); + EXPECT_THAT(result, test::IsCelBool(false)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_int64_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); + EXPECT_THAT(result, test::IsCelInt64(0)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_int32_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); + EXPECT_THAT(result, test::IsCelInt64(0)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_uint64_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); + EXPECT_THAT(result, test::IsCelUint64(0)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_uint32_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); + EXPECT_THAT(result, test::IsCelUint64(0)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_double_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &arena, &result)); + EXPECT_THAT(result, test::IsCelDouble(0.0f)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_float_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &arena, &result)); + EXPECT_THAT(result, test::IsCelDouble(0.0f)); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_string_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &arena, &result)); + EXPECT_THAT(result, test::IsCelString("")); + + ASSERT_OK(CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_bytes_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &arena, &result)); + EXPECT_THAT(result, test::IsCelBytes("")); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/public/containers/field_backed_list_impl.cc b/eval/public/containers/field_backed_list_impl.cc deleted file mode 100644 index 2fa86c272..000000000 --- a/eval/public/containers/field_backed_list_impl.cc +++ /dev/null @@ -1,30 +0,0 @@ - -#include "eval/public/containers/field_backed_list_impl.h" - -#include "eval/public/cel_value.h" -#include "eval/public/containers/field_access.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -int FieldBackedListImpl::size() const { - return reflection_->FieldSize(*message_, descriptor_); -} - -CelValue FieldBackedListImpl::operator[](int index) const { - CelValue result = CelValue::CreateNull(); - auto status = CreateValueFromRepeatedField(message_, descriptor_, arena_, - index, &result); - if (!status.ok()) { - result = CreateErrorValue(arena_, status.ToString()); - } - - return result; -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/containers/field_backed_list_impl.h b/eval/public/containers/field_backed_list_impl.h index ac330850c..39f654764 100644 --- a/eval/public/containers/field_backed_list_impl.h +++ b/eval/public/containers/field_backed_list_impl.h @@ -2,6 +2,8 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_LIST_IMPL_H_ #include "eval/public/cel_value.h" +#include "eval/public/containers/internal_field_backed_list_impl.h" +#include "eval/public/structs/cel_proto_wrapper.h" namespace google { namespace api { @@ -10,29 +12,17 @@ namespace runtime { // CelList implementation that uses "repeated" message field // as backing storage. -class FieldBackedListImpl : public CelList { +class FieldBackedListImpl : public internal::FieldBackedListImpl { public: // message contains the "repeated" field // descriptor FieldDescriptor for the field + // arena is used for incidental allocations when unwrapping the field. FieldBackedListImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, google::protobuf::Arena* arena) - : message_(message), - descriptor_(descriptor), - reflection_(message_->GetReflection()), - arena_(arena) {} - - // List size. - int size() const override; - - // List element access operator. - CelValue operator[](int index) const override; - - private: - const google::protobuf::Message* message_; - const google::protobuf::FieldDescriptor* descriptor_; - const google::protobuf::Reflection* reflection_; - google::protobuf::Arena* arena_; + : internal::FieldBackedListImpl( + message, descriptor, &CelProtoWrapper::InternalWrapMessage, arena) { + } }; } // namespace runtime diff --git a/eval/public/containers/field_backed_list_impl_test.cc b/eval/public/containers/field_backed_list_impl_test.cc index 6ad711019..10caa45de 100644 --- a/eval/public/containers/field_backed_list_impl_test.cc +++ b/eval/public/containers/field_backed_list_impl_test.cc @@ -1,8 +1,10 @@ #include "eval/public/containers/field_backed_list_impl.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include + #include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" #include "testutil/util.h" namespace google { @@ -11,8 +13,8 @@ namespace expr { namespace runtime { namespace { -using testing::Eq; -using testing::DoubleEq; +using ::testing::Eq; +using ::testing::DoubleEq; using testutil::EqualsProto; @@ -23,7 +25,7 @@ std::unique_ptr CreateList(const TestMessage* message, const google::protobuf::FieldDescriptor* field_desc = message->GetDescriptor()->FindFieldByName(field); - return absl::make_unique(message, field_desc, arena); + return std::make_unique(message, field_desc, arena); } TEST(FieldBackedListImplTest, BoolDatatypeTest) { @@ -185,7 +187,6 @@ TEST(FieldBackedListImplTest, StringDatatypeTest) { EXPECT_EQ((*cel_list)[1].StringOrDie().value(), "2"); } - TEST(FieldBackedListImplTest, BytesDatatypeTest) { TestMessage message; message.add_bytes_list("1"); diff --git a/eval/public/containers/field_backed_map_impl.cc b/eval/public/containers/field_backed_map_impl.cc deleted file mode 100644 index 9cf6c5b12..000000000 --- a/eval/public/containers/field_backed_map_impl.cc +++ /dev/null @@ -1,238 +0,0 @@ -#include "eval/public/containers/field_backed_map_impl.h" - -#include "google/protobuf/map_field.h" -#include "eval/public/cel_value.h" -#include "eval/public/containers/field_access.h" - -#ifdef GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND - -namespace google { -namespace protobuf { -namespace expr { - -// CelMapReflectionFriend provides access to Reflection's private methods. The -// class is a friend of google::protobuf::Reflection. We do not add FieldBackedMapImpl as -// a friend directly, because it belongs to google:: namespace. The build of -// protobuf fails on MSVC if this namespace is used, probably because -// of macros usage. -class CelMapReflectionFriend { - public: - static bool ContainsMapKey(const Reflection* reflection, - const Message& message, - const FieldDescriptor* field, const MapKey& key) { - return reflection->ContainsMapKey(message, field, key); - } - - static bool InsertOrLookupMapValue(const Reflection* reflection, - Message* message, - const FieldDescriptor* field, - const MapKey& key, MapValueRef* val) { - return reflection->InsertOrLookupMapValue(message, field, key, val); - } -}; - -} // namespace expr -} // namespace protobuf -} // namespace google - -#endif // GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -namespace { -using google::protobuf::Arena; -using google::protobuf::Descriptor; -using google::protobuf::FieldDescriptor; -using google::protobuf::MapValueRef; -using google::protobuf::Message; - -// Map entries have two field tags -// 1 - for key -// 2 - for value -constexpr int kKeyTag = 1; -constexpr int kValueTag = 2; - -class KeyList : public CelList { - public: - // message contains the "repeated" field - // descriptor FieldDescriptor for the field - KeyList(const google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* descriptor, google::protobuf::Arena* arena) - : message_(message), - descriptor_(descriptor), - reflection_(message_->GetReflection()), - arena_(arena) {} - - // List size. - int size() const override { - return reflection_->FieldSize(*message_, descriptor_); - } - - // List element access operator. - CelValue operator[](int index) const override { - CelValue key = CelValue::CreateNull(); - const Message* entry = - &reflection_->GetRepeatedMessage(*message_, descriptor_, index); - - if (entry == nullptr) { - return CelValue::CreateNull(); - } - - const Descriptor* entry_descriptor = entry->GetDescriptor(); - // Key Tag == 1 - const FieldDescriptor* key_desc = - entry_descriptor->FindFieldByNumber(kKeyTag); - - auto status = CreateValueFromSingleField(entry, key_desc, arena_, &key); - if (!status.ok()) { - return CreateErrorValue(arena_, status.message()); - } - return key; - } - - private: - const google::protobuf::Message* message_; - const google::protobuf::FieldDescriptor* descriptor_; - const google::protobuf::Reflection* reflection_; - google::protobuf::Arena* arena_; -}; - -} // namespace - -FieldBackedMapImpl::FieldBackedMapImpl( - const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, - google::protobuf::Arena* arena) - : message_(message), - descriptor_(descriptor), - reflection_(message_->GetReflection()), - arena_(arena), - key_list_(absl::make_unique(message, descriptor, arena)) {} - -int FieldBackedMapImpl::size() const { - return reflection_->FieldSize(*message_, descriptor_); -} - -const CelList* FieldBackedMapImpl::ListKeys() const { return key_list_.get(); } - -absl::optional FieldBackedMapImpl::operator[](CelValue key) const { -#ifdef GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND - // Fast implementation. - google::protobuf::MapKey inner_key; - switch (key.type()) { - case CelValue::Type::kInt64: { - inner_key.SetInt64Value(key.Int64OrDie()); - break; - } - case CelValue::Type::kUint64: { - inner_key.SetUInt64Value(key.Uint64OrDie()); - break; - } - case CelValue::Type::kString: { - auto str = key.StringOrDie().value(); - inner_key.SetStringValue(std::string(str.begin(), str.end())); - break; - } - default: { return {}; } - } - // Performance issue. Currently the only way to do a lookup is - // InsertOrLookupMapValue. This function will modify the map if the key - // doesn't exist, that is why we have to call ContainsMapKey first, which - // results in hashing the key more than once. - if (!google::protobuf::expr::CelMapReflectionFriend::ContainsMapKey( - reflection_, *message_, descriptor_, inner_key)) { - return {}; - } - MapValueRef value_ref; - // InsertOrLookupMapValue is not marked as const (but it is const in this - // scenario when ContainsMapKey returns true), so we use const_cast. - if (google::protobuf::expr::CelMapReflectionFriend::InsertOrLookupMapValue( - reflection_, const_cast(message_), descriptor_, - inner_key, &value_ref)) { - GOOGLE_LOG(ERROR) << "The map was expected to have the key, but it didn't."; - } - // Get value descriptor treating it as a repeated field. - // All values in protobuf map have the same type. - // The map is not empty, because ContainsMapKey returned true. - const Message* entry = - &reflection_->GetRepeatedMessage(*message_, descriptor_, 0); - if (entry == nullptr) { - return {}; - } - const Descriptor* entry_descriptor = entry->GetDescriptor(); - const FieldDescriptor* value_desc = - entry_descriptor->FindFieldByNumber(kValueTag); - - CelValue result = CelValue::CreateNull(); - auto status = CreateValueFromMapValue(message_, value_desc, &value_ref, - arena_, &result); - if (!status.ok()) { - return CreateErrorValue(arena_, status.message()); - } - return result; -#else // GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND - // Slow implementation. - CelValue result = CelValue::CreateNull(); - CelValue inner_key = CelValue::CreateNull(); - - int map_size = size(); - for (int i = 0; i < map_size; i++) { - const Message* entry = - &reflection_->GetRepeatedMessage(*message_, descriptor_, i); - - if (entry == nullptr) continue; - - const Descriptor* entry_descriptor = entry->GetDescriptor(); - // Key Tag == 1 - const FieldDescriptor* key_desc = - entry_descriptor->FindFieldByNumber(kKeyTag); - - auto status = - CreateValueFromSingleField(entry, key_desc, arena_, &inner_key); - if (!status.ok()) { - return CreateErrorValue(arena_, status.ToString()); - } - - if (key.type() != inner_key.type()) { - continue; - } - - bool match = false; - switch (key.type()) { - case CelValue::Type::kInt64: - match = key.Int64OrDie() == inner_key.Int64OrDie(); - break; - case CelValue::Type::kUint64: - match = key.Uint64OrDie() == inner_key.Uint64OrDie(); - break; - case CelValue::Type::kString: - match = key.StringOrDie() == inner_key.StringOrDie(); - break; - default: - match = false; - } - - if (match) { - const FieldDescriptor* value_desc = - entry_descriptor->FindFieldByNumber(kValueTag); - - auto status = - CreateValueFromSingleField(entry, value_desc, arena_, &result); - if (!status.ok()) { - return CreateErrorValue(arena_, status.message()); - } - - return result; - } - } - - return {}; -#endif // GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/containers/field_backed_map_impl.h b/eval/public/containers/field_backed_map_impl.h index 2f1d7ad47..71452ef68 100644 --- a/eval/public/containers/field_backed_map_impl.h +++ b/eval/public/containers/field_backed_map_impl.h @@ -1,44 +1,35 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_MAP_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_MAP_IMPL_H_ +#include "absl/status/statusor.h" #include "eval/public/cel_value.h" +#include "eval/public/containers/internal_field_backed_map_impl.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // CelMap implementation that uses "map" message field // as backing storage. -class FieldBackedMapImpl : public CelMap { +// +// Trivial subclass of internal implementation to avoid API changes for clients +// that use this directly. +class FieldBackedMapImpl : public internal::FieldBackedMapImpl { public: // message contains the "map" field. Object stores the pointer // to the message, thus it is expected that message outlives the // object. // descriptor FieldDescriptor for the field + // arena is used for incidental allocations from unpacking the field. FieldBackedMapImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, - google::protobuf::Arena* arena); - - // Map size. - int size() const override; - - // Map element access operator. - absl::optional operator[](CelValue key) const override; - - const CelList* ListKeys() const override; - - private: - const google::protobuf::Message* message_; - const google::protobuf::FieldDescriptor* descriptor_; - const google::protobuf::Reflection* reflection_; - google::protobuf::Arena* arena_; - std::unique_ptr key_list_; + google::protobuf::Arena* arena) + : internal::FieldBackedMapImpl( + message, descriptor, &CelProtoWrapper::InternalWrapMessage, arena) { + } }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_MAP_IMPL_H_ diff --git a/eval/public/containers/field_backed_map_impl_test.cc b/eval/public/containers/field_backed_map_impl_test.cc index 3e2f3f08e..4c75149ce 100644 --- a/eval/public/containers/field_backed_map_impl_test.cc +++ b/eval/public/containers/field_backed_map_impl_test.cc @@ -1,58 +1,172 @@ #include "eval/public/containers/field_backed_map_impl.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using testing::Eq; -using testing::UnorderedPointwise; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::UnorderedPointwise; -// Helper method. Creates simple pipeline containing Select step and runs it. -std::unique_ptr CreateMap(const TestMessage* message, - const std::string& field, - google::protobuf::Arena* arena) { +// Test factory for FieldBackedMaps from message and field name. +std::unique_ptr CreateMap(const TestMessage* message, + const std::string& field, + google::protobuf::Arena* arena) { const google::protobuf::FieldDescriptor* field_desc = message->GetDescriptor()->FindFieldByName(field); - return absl::make_unique(message, field_desc, arena); + return std::make_unique(message, field_desc, arena); } -TEST(FieldBackedMapImplTest, IntKeyTest) { +TEST(FieldBackedMapImplTest, BadKeyTypeTest) { TestMessage message; - auto field_map = message.mutable_int64_int32_map(); + google::protobuf::Arena arena; + constexpr std::array map_types = { + "int64_int32_map", "uint64_int32_map", "string_int32_map", + "bool_int32_map", "int32_int32_map", "uint32_uint32_map", + }; + + for (auto map_type : map_types) { + auto cel_map = CreateMap(&message, std::string(map_type), &arena); + // Look up a boolean key. This should result in an error for both the + // presence test and the value lookup. + auto result = cel_map->Has(CelValue::CreateNull()); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); + + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); + + auto lookup = (*cel_map)[CelValue::CreateNull()]; + EXPECT_TRUE(lookup.has_value()); + EXPECT_TRUE(lookup->IsError()); + EXPECT_THAT(lookup->ErrorOrDie()->code(), + Eq(absl::StatusCode::kInvalidArgument)); + } +} + +TEST(FieldBackedMapImplTest, Int32KeyTest) { + TestMessage message; + auto field_map = message.mutable_int32_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int32_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_FALSE((*cel_map)[CelValue::CreateInt64(3)].has_value()); + EXPECT_FALSE(cel_map->Has(CelValue::CreateInt64(3)).value_or(true)); +} + +TEST(FieldBackedMapImplTest, Int32KeyOutOfRangeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int32_int32_map", &arena); + + // Look up keys out of int32 range + auto result = cel_map->Has( + CelValue::CreateInt64(std::numeric_limits::max() + 1L)); + EXPECT_THAT(result.status(), + StatusIs(absl::StatusCode::kOutOfRange, HasSubstr("overflow"))); + + result = cel_map->Has( + CelValue::CreateInt64(std::numeric_limits::lowest() - 1L)); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); +} + +TEST(FieldBackedMapImplTest, Int64KeyTest) { + TestMessage message; + auto field_map = message.mutable_int64_int32_map(); + (*field_map)[0] = 1; + (*field_map)[1] = 2; + google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "int64_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateInt64(3)].has_value(), false); } -TEST(FieldBackedMapImplTest, UintKeyTest) { +TEST(FieldBackedMapImplTest, BoolKeyTest) { + TestMessage message; + auto field_map = message.mutable_bool_int32_map(); + (*field_map)[false] = 1; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "bool_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateBool(false)]->Int64OrDie(), 1); + EXPECT_TRUE(cel_map->Has(CelValue::CreateBool(false)).value_or(false)); + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)].has_value(), false); + + (*field_map)[true] = 2; + EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)]->Int64OrDie(), 2); +} + +TEST(FieldBackedMapImplTest, Uint32KeyTest) { + TestMessage message; + auto field_map = message.mutable_uint32_uint32_map(); + (*field_map)[0] = 1u; + (*field_map)[1] = 2u; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Uint64OrDie(), 1UL); + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Uint64OrDie(), 2UL); + EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); + EXPECT_EQ(cel_map->Has(CelValue::CreateUint64(3)).value_or(true), false); +} + +TEST(FieldBackedMapImplTest, Uint32KeyOutOfRangeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); + + // Look up keys out of uint32 range + auto result = cel_map->Has( + CelValue::CreateUint64(std::numeric_limits::max() + 1UL)); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); +} + +TEST(FieldBackedMapImplTest, Uint64KeyTest) { TestMessage message; auto field_map = message.mutable_uint64_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; - auto cel_map = CreateMap(&message, "uint64_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); @@ -65,7 +179,6 @@ TEST(FieldBackedMapImplTest, StringKeyTest) { (*field_map)["test1"] = 2; google::protobuf::Arena arena; - auto cel_map = CreateMap(&message, "string_int32_map", &arena); std::string test0 = "test0"; @@ -74,6 +187,7 @@ TEST(FieldBackedMapImplTest, StringKeyTest) { EXPECT_EQ((*cel_map)[CelValue::CreateString(&test0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateString(&test1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateString(&test1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateString(&test_notfound)].has_value(), @@ -82,14 +196,8 @@ TEST(FieldBackedMapImplTest, StringKeyTest) { TEST(FieldBackedMapImplTest, EmptySizeTest) { TestMessage message; - google::protobuf::Arena arena; - auto cel_map = CreateMap(&message, "string_int32_map", &arena); - - std::string test0 = "test0"; - std::string test1 = "test1"; - EXPECT_EQ(cel_map->size(), 0); } @@ -101,7 +209,6 @@ TEST(FieldBackedMapImplTest, RepeatedAddTest) { (*field_map)["test0"] = 3; google::protobuf::Arena arena; - auto cel_map = CreateMap(&message, "string_int32_map", &arena); EXPECT_EQ(cel_map->size(), 2); @@ -118,10 +225,8 @@ TEST(FieldBackedMapImplTest, KeyListTest) { } google::protobuf::Arena arena; - auto cel_map = CreateMap(&message, "string_int32_map", &arena); - - const CelList* key_list = cel_map->ListKeys(); + const CelList* key_list = cel_map->ListKeys().value(); EXPECT_EQ(key_list->size(), 100); for (int i = 0; i < key_list->size(); i++) { @@ -132,7 +237,4 @@ TEST(FieldBackedMapImplTest, KeyListTest) { } } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/containers/internal_field_backed_list_impl.cc b/eval/public/containers/internal_field_backed_list_impl.cc new file mode 100644 index 000000000..6541db468 --- /dev/null +++ b/eval/public/containers/internal_field_backed_list_impl.cc @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/containers/internal_field_backed_list_impl.h" + +#include "eval/public/cel_value.h" +#include "eval/public/structs/field_access_impl.h" + +namespace google::api::expr::runtime::internal { + +int FieldBackedListImpl::size() const { + return reflection_->FieldSize(*message_, descriptor_); +} + +CelValue FieldBackedListImpl::operator[](int index) const { + auto result = CreateValueFromRepeatedField(message_, descriptor_, index, + factory_, arena_); + if (!result.ok()) { + CreateErrorValue(arena_, result.status().ToString()); + } + + return *result; +} + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/containers/internal_field_backed_list_impl.h b/eval/public/containers/internal_field_backed_list_impl.h new file mode 100644 index 000000000..95f8de425 --- /dev/null +++ b/eval/public/containers/internal_field_backed_list_impl.h @@ -0,0 +1,59 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ + +#include + +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" + +namespace google::api::expr::runtime::internal { + +// CelList implementation that uses "repeated" message field +// as backing storage. +// +// The internal implementation allows for interface updates without breaking +// clients that depend on this class for implementing custom CEL lists +class FieldBackedListImpl : public CelList { + public: + // message contains the "repeated" field + // descriptor FieldDescriptor for the field + FieldBackedListImpl(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + ProtobufValueFactory factory, google::protobuf::Arena* arena) + : message_(message), + descriptor_(descriptor), + reflection_(message_->GetReflection()), + factory_(std::move(factory)), + arena_(arena) {} + + // List size. + int size() const override; + + // List element access operator. + CelValue operator[](int index) const override; + + private: + const google::protobuf::Message* message_; + const google::protobuf::FieldDescriptor* descriptor_; + const google::protobuf::Reflection* reflection_; + ProtobufValueFactory factory_; + google::protobuf::Arena* arena_; +}; + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ diff --git a/eval/public/containers/internal_field_backed_list_impl_test.cc b/eval/public/containers/internal_field_backed_list_impl_test.cc new file mode 100644 index 000000000..409bad095 --- /dev/null +++ b/eval/public/containers/internal_field_backed_list_impl_test.cc @@ -0,0 +1,252 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/containers/internal_field_backed_list_impl.h" + +#include +#include + +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "testutil/util.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using ::google::api::expr::testutil::EqualsProto; +using ::testing::DoubleEq; +using ::testing::Eq; + +// Helper method. Creates simple pipeline containing Select step and runs it. +std::unique_ptr CreateList(const TestMessage* message, + const std::string& field, + google::protobuf::Arena* arena) { + const google::protobuf::FieldDescriptor* field_desc = + message->GetDescriptor()->FindFieldByName(field); + + return std::make_unique( + message, field_desc, &CelProtoWrapper::InternalWrapMessage, arena); +} + +TEST(FieldBackedListImplTest, BoolDatatypeTest) { + TestMessage message; + message.add_bool_list(true); + message.add_bool_list(false); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "bool_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].BoolOrDie(), true); + EXPECT_EQ((*cel_list)[1].BoolOrDie(), false); +} + +TEST(FieldBackedListImplTest, TestLength0) { + TestMessage message; + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), 0); +} + +TEST(FieldBackedListImplTest, TestLength1) { + TestMessage message; + message.add_int32_list(1); + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), 1); + EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); +} + +TEST(FieldBackedListImplTest, TestLength100000) { + TestMessage message; + + const int kLen = 100000; + + for (int i = 0; i < kLen; i++) { + message.add_int32_list(i); + } + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), kLen); + for (int i = 0; i < kLen; i++) { + EXPECT_EQ((*cel_list)[i].Int64OrDie(), i); + } +} + +TEST(FieldBackedListImplTest, Int32DatatypeTest) { + TestMessage message; + message.add_int32_list(1); + message.add_int32_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, Int64DatatypeTest) { + TestMessage message; + message.add_int64_list(1); + message.add_int64_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int64_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, Uint32DatatypeTest) { + TestMessage message; + message.add_uint32_list(1); + message.add_uint32_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "uint32_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, Uint64DatatypeTest) { + TestMessage message; + message.add_uint64_list(1); + message.add_uint64_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "uint64_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, FloatDatatypeTest) { + TestMessage message; + message.add_float_list(1); + message.add_float_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "float_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); + EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); +} + +TEST(FieldBackedListImplTest, DoubleDatatypeTest) { + TestMessage message; + message.add_double_list(1); + message.add_double_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "double_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); + EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); +} + +TEST(FieldBackedListImplTest, StringDatatypeTest) { + TestMessage message; + message.add_string_list("1"); + message.add_string_list("2"); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "string_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].StringOrDie().value(), "1"); + EXPECT_EQ((*cel_list)[1].StringOrDie().value(), "2"); +} + +TEST(FieldBackedListImplTest, BytesDatatypeTest) { + TestMessage message; + message.add_bytes_list("1"); + message.add_bytes_list("2"); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "bytes_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].BytesOrDie().value(), "1"); + EXPECT_EQ((*cel_list)[1].BytesOrDie().value(), "2"); +} + +TEST(FieldBackedListImplTest, MessageDatatypeTest) { + TestMessage message; + TestMessage* msg1 = message.add_message_list(); + TestMessage* msg2 = message.add_message_list(); + + msg1->set_string_value("1"); + msg2->set_string_value("2"); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "message_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT(*msg1, EqualsProto(*((*cel_list)[0].MessageOrDie()))); + EXPECT_THAT(*msg2, EqualsProto(*((*cel_list)[1].MessageOrDie()))); +} + +TEST(FieldBackedListImplTest, EnumDatatypeTest) { + TestMessage message; + + message.add_enum_list(TestMessage::TEST_ENUM_1); + message.add_enum_list(TestMessage::TEST_ENUM_2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "enum_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT((*cel_list)[0].Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); + EXPECT_THAT((*cel_list)[1].Int64OrDie(), Eq(TestMessage::TEST_ENUM_2)); +} + +} // namespace +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/containers/internal_field_backed_map_impl.cc b/eval/public/containers/internal_field_backed_map_impl.cc new file mode 100644 index 000000000..a879955d1 --- /dev/null +++ b/eval/public/containers/internal_field_backed_map_impl.cc @@ -0,0 +1,298 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/containers/internal_field_backed_map_impl.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/field_access_impl.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { + +namespace { +using google::protobuf::Descriptor; +using google::protobuf::FieldDescriptor; +using google::protobuf::MapValueConstRef; +using google::protobuf::Message; + +// Map entries have two field tags +// 1 - for key +// 2 - for value +constexpr int kKeyTag = 1; +constexpr int kValueTag = 2; + +class KeyList : public CelList { + public: + // message contains the "repeated" field + // descriptor FieldDescriptor for the field + KeyList(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena) + : message_(message), + descriptor_(descriptor), + reflection_(message_->GetReflection()), + factory_(factory), + arena_(arena) {} + + // List size. + int size() const override { + return reflection_->FieldSize(*message_, descriptor_); + } + + // List element access operator. + CelValue operator[](int index) const override { + const Message* entry = + &reflection_->GetRepeatedMessage(*message_, descriptor_, index); + + if (entry == nullptr) { + return CelValue::CreateNull(); + } + + const Descriptor* entry_descriptor = entry->GetDescriptor(); + // Key Tag == 1 + const FieldDescriptor* key_desc = + entry_descriptor->FindFieldByNumber(kKeyTag); + + absl::StatusOr key_value = CreateValueFromSingleField( + entry, key_desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, factory_, + arena_); + if (!key_value.ok()) { + return CreateErrorValue(arena_, key_value.status()); + } + return *key_value; + } + + private: + const google::protobuf::Message* message_; + const google::protobuf::FieldDescriptor* descriptor_; + const google::protobuf::Reflection* reflection_; + const ProtobufValueFactory& factory_; + google::protobuf::Arena* arena_; +}; + +bool MatchesMapKeyType(const FieldDescriptor* key_desc, const CelValue& key) { + switch (key_desc->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return key.IsBool(); + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + // fall through + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return key.IsInt64(); + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + // fall through + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return key.IsUint64(); + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return key.IsString(); + default: + return false; + } +} + +absl::Status InvalidMapKeyType(absl::string_view key_type) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid map key type: '", key_type, "'")); +} + +} // namespace + +FieldBackedMapImpl::FieldBackedMapImpl( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, + ProtobufValueFactory factory, google::protobuf::Arena* arena) + : message_(message), + descriptor_(descriptor), + key_desc_(descriptor_->message_type()->FindFieldByNumber(kKeyTag)), + value_desc_(descriptor_->message_type()->FindFieldByNumber(kValueTag)), + reflection_(message_->GetReflection()), + factory_(std::move(factory)), + arena_(arena), + key_list_( + std::make_unique(message, descriptor, factory_, arena)) {} + +int FieldBackedMapImpl::size() const { + return reflection_->FieldSize(*message_, descriptor_); +} + +absl::StatusOr FieldBackedMapImpl::ListKeys() const { + return key_list_.get(); +} + +absl::StatusOr FieldBackedMapImpl::Has(const CelValue& key) const { + MapValueConstRef value_ref; + return LookupMapValue(key, &value_ref); +} + +absl::optional FieldBackedMapImpl::operator[](CelValue key) const { + // Fast implementation which uses a friend method to do a hash-based key + // lookup. + MapValueConstRef value_ref; + auto lookup_result = LookupMapValue(key, &value_ref); + if (!lookup_result.ok()) { + return CreateErrorValue(arena_, lookup_result.status()); + } + if (!*lookup_result) { + return absl::nullopt; + } + + // Get value descriptor treating it as a repeated field. + // All values in protobuf map have the same type. + // The map is not empty, because LookupMapValue returned true. + absl::StatusOr result = CreateValueFromMapValue( + message_, value_desc_, &value_ref, factory_, arena_); + if (!result.ok()) { + return CreateErrorValue(arena_, result.status()); + } + return *result; +} + +absl::StatusOr FieldBackedMapImpl::LookupMapValue( + const CelValue& key, MapValueConstRef* value_ref) const { + if (!MatchesMapKeyType(key_desc_, key)) { + return InvalidMapKeyType(key_desc_->cpp_type_name()); + } + + std::string map_key_string; + google::protobuf::MapKey proto_key; + switch (key_desc_->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + bool key_value; + key.GetValue(&key_value); + proto_key.SetBoolValue(key_value); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + int64_t key_value; + key.GetValue(&key_value); + if (key_value > std::numeric_limits::max() || + key_value < std::numeric_limits::lowest()) { + return absl::OutOfRangeError("integer overflow"); + } + proto_key.SetInt32Value(key_value); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + int64_t key_value; + key.GetValue(&key_value); + proto_key.SetInt64Value(key_value); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + CelValue::StringHolder key_value; + key.GetValue(&key_value); + map_key_string.assign(key_value.value().data(), key_value.value().size()); + proto_key.SetStringValue(map_key_string); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + uint64_t key_value; + key.GetValue(&key_value); + if (key_value > std::numeric_limits::max()) { + return absl::OutOfRangeError("unsigned integer overlow"); + } + proto_key.SetUInt32Value(key_value); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + uint64_t key_value; + key.GetValue(&key_value); + proto_key.SetUInt64Value(key_value); + } break; + default: + return InvalidMapKeyType(key_desc_->cpp_type_name()); + } + // Look the value up + return cel::extensions::protobuf_internal::LookupMapValue( + *reflection_, *message_, *descriptor_, proto_key, value_ref); +} + +absl::StatusOr FieldBackedMapImpl::LegacyHasMapValue( + const CelValue& key) const { + auto lookup_result = LegacyLookupMapValue(key); + if (!lookup_result.has_value()) { + return false; + } + auto result = *lookup_result; + if (result.IsError()) { + return *(result.ErrorOrDie()); + } + return true; +} + +absl::optional FieldBackedMapImpl::LegacyLookupMapValue( + const CelValue& key) const { + // Ensure that the key matches the key type. + if (!MatchesMapKeyType(key_desc_, key)) { + return CreateErrorValue(arena_, + InvalidMapKeyType(key_desc_->cpp_type_name())); + } + + int map_size = size(); + for (int i = 0; i < map_size; i++) { + const Message* entry = + &reflection_->GetRepeatedMessage(*message_, descriptor_, i); + if (entry == nullptr) continue; + + // Key Tag == 1 + absl::StatusOr key_value = CreateValueFromSingleField( + entry, key_desc_, ProtoWrapperTypeOptions::kUnsetProtoDefault, factory_, + arena_); + if (!key_value.ok()) { + return CreateErrorValue(arena_, key_value.status()); + } + + bool match = false; + switch (key_desc_->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + match = key.BoolOrDie() == key_value->BoolOrDie(); + break; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + // fall through + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + match = key.Int64OrDie() == key_value->Int64OrDie(); + break; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + // fall through + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + match = key.Uint64OrDie() == key_value->Uint64OrDie(); + break; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + match = key.StringOrDie() == key_value->StringOrDie(); + break; + default: + // this would normally indicate a bad key type, which should not be + // possible based on the earlier test. + break; + } + + if (match) { + absl::StatusOr value_cel_value = CreateValueFromSingleField( + entry, value_desc_, ProtoWrapperTypeOptions::kUnsetProtoDefault, + factory_, arena_); + if (!value_cel_value.ok()) { + return CreateErrorValue(arena_, value_cel_value.status()); + } + return *value_cel_value; + } + } + return {}; +} + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/containers/internal_field_backed_map_impl.h b/eval/public/containers/internal_field_backed_map_impl.h new file mode 100644 index 000000000..596343b75 --- /dev/null +++ b/eval/public/containers/internal_field_backed_map_impl.h @@ -0,0 +1,77 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ + +#include "absl/status/statusor.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { +// CelMap implementation that uses "map" message field +// as backing storage. +class FieldBackedMapImpl : public CelMap { + public: + // message contains the "map" field. Object stores the pointer + // to the message, thus it is expected that message outlives the + // object. + // descriptor FieldDescriptor for the field + FieldBackedMapImpl(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + ProtobufValueFactory factory, google::protobuf::Arena* arena); + + // Map size. + int size() const override; + + // Map element access operator. + absl::optional operator[](CelValue key) const override; + + // Presence test function. + absl::StatusOr Has(const CelValue& key) const override; + + absl::StatusOr ListKeys() const override; + + // Include base class definitions to avoid GCC warnings about hidden virtual + // overloads. + using CelMap::ListKeys; + + protected: + // These methods are exposed as protected methods for testing purposes since + // whether one or the other is used depends on build time flags, but each + // should be tested accordingly. + + absl::StatusOr LookupMapValue( + const CelValue& key, google::protobuf::MapValueConstRef* value_ref) const; + + absl::StatusOr LegacyHasMapValue(const CelValue& key) const; + + absl::optional LegacyLookupMapValue(const CelValue& key) const; + + private: + const google::protobuf::Message* message_; + const google::protobuf::FieldDescriptor* descriptor_; + const google::protobuf::FieldDescriptor* key_desc_; + const google::protobuf::FieldDescriptor* value_desc_; + const google::protobuf::Reflection* reflection_; + ProtobufValueFactory factory_; + google::protobuf::Arena* arena_; + std::unique_ptr key_list_; +}; + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ diff --git a/eval/public/containers/internal_field_backed_map_impl_test.cc b/eval/public/containers/internal_field_backed_map_impl_test.cc new file mode 100644 index 000000000..7a666ef10 --- /dev/null +++ b/eval/public/containers/internal_field_backed_map_impl_test.cc @@ -0,0 +1,291 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/public/containers/internal_field_backed_map_impl.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::UnorderedPointwise; + +class FieldBackedMapTestImpl : public FieldBackedMapImpl { + public: + FieldBackedMapTestImpl(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + google::protobuf::Arena* arena) + : FieldBackedMapImpl(message, descriptor, + &CelProtoWrapper::InternalWrapMessage, arena) {} + + // For code coverage, expose fallback lookups used when not compiled with + // support for optimized versions. + using FieldBackedMapImpl::LegacyHasMapValue; + using FieldBackedMapImpl::LegacyLookupMapValue; +}; + +// Helper method. Creates simple pipeline containing Select step and runs it. +std::unique_ptr CreateMap(const TestMessage* message, + const std::string& field, + google::protobuf::Arena* arena) { + const google::protobuf::FieldDescriptor* field_desc = + message->GetDescriptor()->FindFieldByName(field); + + return std::make_unique(message, field_desc, arena); +} + +TEST(FieldBackedMapImplTest, BadKeyTypeTest) { + TestMessage message; + google::protobuf::Arena arena; + constexpr std::array map_types = { + "int64_int32_map", "uint64_int32_map", "string_int32_map", + "bool_int32_map", "int32_int32_map", "uint32_uint32_map", + }; + + for (auto map_type : map_types) { + auto cel_map = CreateMap(&message, std::string(map_type), &arena); + // Look up a boolean key. This should result in an error for both the + // presence test and the value lookup. + auto result = cel_map->Has(CelValue::CreateNull()); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); + + result = cel_map->LegacyHasMapValue(CelValue::CreateNull()); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); + + auto lookup = (*cel_map)[CelValue::CreateNull()]; + EXPECT_TRUE(lookup.has_value()); + EXPECT_TRUE(lookup->IsError()); + EXPECT_THAT(lookup->ErrorOrDie()->code(), + Eq(absl::StatusCode::kInvalidArgument)); + + lookup = cel_map->LegacyLookupMapValue(CelValue::CreateNull()); + EXPECT_TRUE(lookup.has_value()); + EXPECT_TRUE(lookup->IsError()); + EXPECT_THAT(lookup->ErrorOrDie()->code(), + Eq(absl::StatusCode::kInvalidArgument)); + } +} + +TEST(FieldBackedMapImplTest, Int32KeyTest) { + TestMessage message; + auto field_map = message.mutable_int32_int32_map(); + (*field_map)[0] = 1; + (*field_map)[1] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int32_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_FALSE((*cel_map)[CelValue::CreateInt64(3)].has_value()); + EXPECT_FALSE(cel_map->Has(CelValue::CreateInt64(3)).value_or(true)); + EXPECT_FALSE( + cel_map->LegacyHasMapValue(CelValue::CreateInt64(3)).value_or(true)); +} + +TEST(FieldBackedMapImplTest, Int32KeyOutOfRangeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int32_int32_map", &arena); + + // Look up keys out of int32 range + auto result = cel_map->Has( + CelValue::CreateInt64(std::numeric_limits::max() + 1L)); + EXPECT_THAT(result.status(), + StatusIs(absl::StatusCode::kOutOfRange, HasSubstr("overflow"))); + + result = cel_map->Has( + CelValue::CreateInt64(std::numeric_limits::lowest() - 1L)); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); +} + +TEST(FieldBackedMapImplTest, Int64KeyTest) { + TestMessage message; + auto field_map = message.mutable_int64_int32_map(); + (*field_map)[0] = 1; + (*field_map)[1] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int64_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); + EXPECT_EQ( + cel_map->LegacyLookupMapValue(CelValue::CreateInt64(1))->Int64OrDie(), 2); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(3)].has_value(), false); +} + +TEST(FieldBackedMapImplTest, BoolKeyTest) { + TestMessage message; + auto field_map = message.mutable_bool_int32_map(); + (*field_map)[false] = 1; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "bool_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateBool(false)]->Int64OrDie(), 1); + EXPECT_TRUE(cel_map->Has(CelValue::CreateBool(false)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateBool(false)).value_or(false)); + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)].has_value(), false); + + (*field_map)[true] = 2; + EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)]->Int64OrDie(), 2); +} + +TEST(FieldBackedMapImplTest, Uint32KeyTest) { + TestMessage message; + auto field_map = message.mutable_uint32_uint32_map(); + (*field_map)[0] = 1u; + (*field_map)[1] = 2u; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Uint64OrDie(), 1UL); + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Uint64OrDie(), 2UL); + EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); + EXPECT_EQ(cel_map->Has(CelValue::CreateUint64(3)).value_or(true), false); +} + +TEST(FieldBackedMapImplTest, Uint32KeyOutOfRangeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); + + // Look up keys out of uint32 range + auto result = cel_map->Has( + CelValue::CreateUint64(std::numeric_limits::max() + 1UL)); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); +} + +TEST(FieldBackedMapImplTest, Uint64KeyTest) { + TestMessage message; + auto field_map = message.mutable_uint64_int32_map(); + (*field_map)[0] = 1; + (*field_map)[1] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint64_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); +} + +TEST(FieldBackedMapImplTest, StringKeyTest) { + TestMessage message; + auto field_map = message.mutable_string_int32_map(); + (*field_map)["test0"] = 1; + (*field_map)["test1"] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + + std::string test0 = "test0"; + std::string test1 = "test1"; + std::string test_notfound = "test_notfound"; + + EXPECT_EQ((*cel_map)[CelValue::CreateString(&test0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateString(&test1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateString(&test1)).value_or(false)); + EXPECT_TRUE(cel_map->LegacyHasMapValue(CelValue::CreateString(&test1)) + .value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateString(&test_notfound)].has_value(), + false); +} + +TEST(FieldBackedMapImplTest, EmptySizeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + EXPECT_EQ(cel_map->size(), 0); +} + +TEST(FieldBackedMapImplTest, RepeatedAddTest) { + TestMessage message; + auto field_map = message.mutable_string_int32_map(); + (*field_map)["test0"] = 1; + (*field_map)["test1"] = 2; + (*field_map)["test0"] = 3; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + + EXPECT_EQ(cel_map->size(), 2); +} + +TEST(FieldBackedMapImplTest, KeyListTest) { + TestMessage message; + auto field_map = message.mutable_string_int32_map(); + std::vector keys; + std::vector keys1; + for (int i = 0; i < 100; i++) { + keys.push_back(absl::StrCat("test", i)); + (*field_map)[keys.back()] = i; + } + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + const CelList* key_list = cel_map->ListKeys().value(); + + EXPECT_EQ(key_list->size(), 100); + for (int i = 0; i < key_list->size(); i++) { + keys1.push_back(std::string((*key_list)[i].StringOrDie().value())); + } + + EXPECT_THAT(keys, UnorderedPointwise(Eq(), keys1)); +} + +} // namespace +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/equality_function_registrar.cc b/eval/public/equality_function_registrar.cc new file mode 100644 index 000000000..f2ae3f22b --- /dev/null +++ b/eval/public/equality_function_registrar.cc @@ -0,0 +1,32 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/equality_function_registrar.h" + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/runtime_options.h" +#include "runtime/standard/equality_functions.h" + +namespace google::api::expr::runtime { + +absl::Status RegisterEqualityFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); + return cel::RegisterEqualityFunctions(registry->InternalGetRegistry(), + runtime_options); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/equality_function_registrar.h b/eval/public/equality_function_registrar.h new file mode 100644 index 000000000..bb859b5a0 --- /dev/null +++ b/eval/public/equality_function_registrar.h @@ -0,0 +1,44 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EQUALITY_FUNCTION_REGISTRAR_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EQUALITY_FUNCTION_REGISTRAR_H_ + +#include "absl/status/status.h" +#include "eval/internal/cel_value_equal.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" + +namespace google::api::expr::runtime { + +// Implementation for general equality between CELValues. Exposed for +// consistent behavior in set membership functions. +// +// Returns nullopt if the comparison is undefined between differently typed +// values. +using cel::interop_internal::CelValueEqualImpl; + +// Register built in comparison functions (==, !=). +// +// Most users should prefer to use RegisterBuiltinFunctions. +// +// This call is included in RegisterBuiltinFunctions -- calling both +// RegisterBuiltinFunctions and RegisterComparisonFunctions directly on the same +// registry will result in an error. +absl::Status RegisterEqualityFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EQUALITY_FUNCTION_REGISTRAR_H_ diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc new file mode 100644 index 000000000..a77a92734 --- /dev/null +++ b/eval/public/equality_function_registrar_test.cc @@ -0,0 +1,933 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/public/equality_function_registrar.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/any.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "eval/public/activation.h" +#include "eval/public/cel_builtins.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" // IWYU pragma: keep +#include "internal/benchmark.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::expr::ParsedExpr; +using ::google::rpc::context::AttributeContext; +using ::testing::_; +using ::testing::Combine; +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::Values; +using ::testing::ValuesIn; + +MATCHER_P2(DefinesHomogenousOverload, name, argument_type, + absl::StrCat(name, " for ", CelValue::TypeName(argument_type))) { + const CelFunctionRegistry& registry = arg; + return !registry + .FindOverloads(name, /*receiver_style=*/false, + {argument_type, argument_type}) + .empty(); + return false; +} + +struct EqualityTestCase { + enum class ErrorKind { kMissingOverload, kMissingIdentifier }; + absl::string_view expr; + std::variant result; + CelValue lhs = CelValue::CreateNull(); + CelValue rhs = CelValue::CreateNull(); +}; + +bool IsNumeric(CelValue::Type type) { + return type == CelValue::Type::kDouble || type == CelValue::Type::kInt64 || + type == CelValue::Type::kUint64; +} + +const CelList& CelListExample1() { + static ContainerBackedListImpl* example = + new ContainerBackedListImpl({CelValue::CreateInt64(1)}); + return *example; +} + +const CelList& CelListExample2() { + static ContainerBackedListImpl* example = + new ContainerBackedListImpl({CelValue::CreateInt64(2)}); + return *example; +} + +const CelMap& CelMapExample1() { + static CelMap* example = []() { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; + // Implementation copies values into a hash map. + auto map = CreateContainerBackedMap(absl::MakeSpan(values)); + return map->release(); + }(); + return *example; +} + +const CelMap& CelMapExample2() { + static CelMap* example = []() { + std::vector> values{ + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; + auto map = CreateContainerBackedMap(absl::MakeSpan(values)); + return map->release(); + }(); + return *example; +} + +const std::vector& ValueExamples1() { + static std::vector* examples = []() { + google::protobuf::Arena arena; + auto result = std::make_unique>(); + + result->push_back(CelValue::CreateNull()); + result->push_back(CelValue::CreateBool(false)); + result->push_back(CelValue::CreateInt64(1)); + result->push_back(CelValue::CreateUint64(1)); + result->push_back(CelValue::CreateDouble(1.0)); + result->push_back(CelValue::CreateStringView("string")); + result->push_back(CelValue::CreateBytesView("bytes")); + // No arena allocs expected in this example. + result->push_back(CelProtoWrapper::CreateMessage( + std::make_unique().release(), &arena)); + result->push_back(CelValue::CreateDuration(absl::Seconds(1))); + result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(1))); + result->push_back(CelValue::CreateList(&CelListExample1())); + result->push_back(CelValue::CreateMap(&CelMapExample1())); + result->push_back(CelValue::CreateCelTypeView("type")); + + return result.release(); + }(); + return *examples; +} + +const std::vector& ValueExamples2() { + static std::vector* examples = []() { + google::protobuf::Arena arena; + auto result = std::make_unique>(); + auto message2 = std::make_unique(); + message2->set_int64_value(2); + + result->push_back(CelValue::CreateNull()); + result->push_back(CelValue::CreateBool(true)); + result->push_back(CelValue::CreateInt64(2)); + result->push_back(CelValue::CreateUint64(2)); + result->push_back(CelValue::CreateDouble(2.0)); + result->push_back(CelValue::CreateStringView("string2")); + result->push_back(CelValue::CreateBytesView("bytes2")); + // No arena allocs expected in this example. + result->push_back( + CelProtoWrapper::CreateMessage(message2.release(), &arena)); + result->push_back(CelValue::CreateDuration(absl::Seconds(2))); + result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(2))); + result->push_back(CelValue::CreateList(&CelListExample2())); + result->push_back(CelValue::CreateMap(&CelMapExample2())); + result->push_back(CelValue::CreateCelTypeView("type2")); + + return result.release(); + }(); + return *examples; +} + +class CelValueEqualImplTypesTest + : public testing::TestWithParam> { + public: + CelValueEqualImplTypesTest() = default; + + const CelValue& lhs() { return std::get<0>(GetParam()); } + + const CelValue& rhs() { return std::get<1>(GetParam()); } + + bool should_be_equal() { return std::get<2>(GetParam()); } +}; + +std::string CelValueEqualTestName( + const testing::TestParamInfo>& + test_case) { + return absl::StrCat(CelValue::TypeName(std::get<0>(test_case.param).type()), + CelValue::TypeName(std::get<1>(test_case.param).type()), + (std::get<2>(test_case.param)) ? "Equal" : "Inequal"); +} + +TEST_P(CelValueEqualImplTypesTest, Basic) { + std::optional result = CelValueEqualImpl(lhs(), rhs()); + + if (lhs().IsNull() || rhs().IsNull()) { + if (lhs().IsNull() && rhs().IsNull()) { + EXPECT_THAT(result, Optional(true)); + } else { + EXPECT_THAT(result, Optional(false)); + } + } else if (lhs().type() == rhs().type() || + (IsNumeric(lhs().type()) && IsNumeric(rhs().type()))) { + EXPECT_THAT(result, Optional(should_be_equal())); + } else { + EXPECT_THAT(result, Optional(false)); + } +} + +INSTANTIATE_TEST_SUITE_P(EqualityBetweenTypes, CelValueEqualImplTypesTest, + Combine(ValuesIn(ValueExamples1()), + ValuesIn(ValueExamples1()), Values(true)), + &CelValueEqualTestName); + +INSTANTIATE_TEST_SUITE_P(InequalityBetweenTypes, CelValueEqualImplTypesTest, + Combine(ValuesIn(ValueExamples1()), + ValuesIn(ValueExamples2()), Values(false)), + &CelValueEqualTestName); + +struct NumericInequalityTestCase { + std::string name; + CelValue a; + CelValue b; +}; + +const std::vector& NumericValuesNotEqualExample() { + static std::vector* examples = []() { + auto result = std::make_unique>(); + result->push_back({"NegativeIntAndUint", CelValue::CreateInt64(-1), + CelValue::CreateUint64(2)}); + result->push_back( + {"IntAndLargeUint", CelValue::CreateInt64(1), + CelValue::CreateUint64( + static_cast(std::numeric_limits::max()) + 1)}); + result->push_back( + {"IntAndLargeDouble", CelValue::CreateInt64(2), + CelValue::CreateDouble( + static_cast(std::numeric_limits::max()) + 1025)}); + result->push_back( + {"IntAndSmallDouble", CelValue::CreateInt64(2), + CelValue::CreateDouble( + static_cast(std::numeric_limits::lowest()) - + 1025)}); + result->push_back( + {"UintAndLargeDouble", CelValue::CreateUint64(2), + CelValue::CreateDouble( + static_cast(std::numeric_limits::max()) + + 2049)}); + result->push_back({"NegativeDoubleAndUint", CelValue::CreateDouble(-2.0), + CelValue::CreateUint64(123)}); + + // NaN tests. + result->push_back({"NanAndDouble", CelValue::CreateDouble(NAN), + CelValue::CreateDouble(1.0)}); + result->push_back({"NanAndNan", CelValue::CreateDouble(NAN), + CelValue::CreateDouble(NAN)}); + result->push_back({"DoubleAndNan", CelValue::CreateDouble(1.0), + CelValue::CreateDouble(NAN)}); + result->push_back( + {"IntAndNan", CelValue::CreateInt64(1), CelValue::CreateDouble(NAN)}); + result->push_back( + {"NanAndInt", CelValue::CreateDouble(NAN), CelValue::CreateInt64(1)}); + result->push_back( + {"UintAndNan", CelValue::CreateUint64(1), CelValue::CreateDouble(NAN)}); + result->push_back( + {"NanAndUint", CelValue::CreateDouble(NAN), CelValue::CreateUint64(1)}); + + return result.release(); + }(); + return *examples; +} + +using NumericInequalityTest = testing::TestWithParam; +TEST_P(NumericInequalityTest, NumericValues) { + NumericInequalityTestCase test_case = GetParam(); + std::optional result = CelValueEqualImpl(test_case.a, test_case.b); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, false); +} + +INSTANTIATE_TEST_SUITE_P( + InequalityBetweenNumericTypesTest, NumericInequalityTest, + ValuesIn(NumericValuesNotEqualExample()), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(CelValueEqualImplTest, LossyNumericEquality) { + std::optional result = CelValueEqualImpl( + CelValue::CreateDouble( + static_cast(std::numeric_limits::max()) - 1), + CelValue::CreateInt64(std::numeric_limits::max())); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE(*result); +} + +TEST(CelValueEqualImplTest, ListMixedTypesInequal) { + ContainerBackedListImpl lhs({CelValue::CreateInt64(1)}); + ContainerBackedListImpl rhs({CelValue::CreateStringView("abc")}); + + EXPECT_THAT( + CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), + Optional(false)); +} + +TEST(CelValueEqualImplTest, NestedList) { + ContainerBackedListImpl inner_lhs({CelValue::CreateInt64(1)}); + ContainerBackedListImpl lhs({CelValue::CreateList(&inner_lhs)}); + ContainerBackedListImpl inner_rhs({CelValue::CreateNull()}); + ContainerBackedListImpl rhs({CelValue::CreateList(&inner_rhs)}); + + EXPECT_THAT( + CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), + Optional(false)); +} + +TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { + std::vector> lhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(false)); +} + +TEST(CelValueEqualImplTest, MapMixedKeyTypesEqual) { + std::vector> lhs_data{ + {CelValue::CreateUint64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(true)); +} + +TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { + std::vector> lhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(2), CelValue::CreateInt64(2)}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(false)); +} + +TEST(CelValueEqualImplTest, NestedMaps) { + std::vector> inner_lhs_data{ + {CelValue::CreateInt64(2), CelValue::CreateStringView("abc")}}; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr inner_lhs, + CreateContainerBackedMap(absl::MakeSpan(inner_lhs_data))); + std::vector> lhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateMap(inner_lhs.get())}}; + + std::vector> inner_rhs_data{ + {CelValue::CreateInt64(2), CelValue::CreateNull()}}; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr inner_rhs, + CreateContainerBackedMap(absl::MakeSpan(inner_rhs_data))); + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateMap(inner_rhs.get())}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(false)); +} + +TEST(CelValueEqualImplTest, ProtoEqualityDifferingTypenameInequal) { + // If message wrappers report a different typename, treat as inequal without + // calling into the provided equal implementation. + google::protobuf::Arena arena; + TestMessage example; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &example)); + + CelValue lhs = CelProtoWrapper::CreateMessage(&example, &arena); + CelValue rhs = CelValue::CreateMessageWrapper( + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + + EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); +} + +TEST(CelValueEqualImplTest, ProtoEqualityNoAccessorInequal) { + // If message wrappers report no access apis, then treat as inequal. + TestMessage example; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &example)); + + CelValue lhs = CelValue::CreateMessageWrapper( + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + CelValue rhs = CelValue::CreateMessageWrapper( + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + + EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); +} + +TEST(CelValueEqualImplTest, ProtoEqualityAny) { + google::protobuf::Arena arena; + TestMessage packed_value; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &packed_value)); + + TestMessage lhs; + lhs.mutable_any_value()->PackFrom(packed_value); + + TestMessage rhs; + rhs.mutable_any_value()->PackFrom(packed_value); + + EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), + CelProtoWrapper::CreateMessage(&rhs, &arena)), + Optional(true)); + + // Equality falls back to bytewise comparison if type is missing. + lhs.mutable_any_value()->clear_type_url(); + rhs.mutable_any_value()->clear_type_url(); + EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), + CelProtoWrapper::CreateMessage(&rhs, &arena)), + Optional(true)); +} + +// Add transitive dependencies in appropriate order for the dynamic descriptor +// pool. +// Return false if the dependencies could not be added to the pool. +bool AddDepsToPool(const google::protobuf::FileDescriptor* descriptor, + google::protobuf::DescriptorPool& pool) { + for (int i = 0; i < descriptor->dependency_count(); i++) { + if (!AddDepsToPool(descriptor->dependency(i), pool)) { + return false; + } + } + google::protobuf::FileDescriptorProto descriptor_proto; + descriptor->CopyTo(&descriptor_proto); + return pool.BuildFile(descriptor_proto) != nullptr; +} + +// Equivalent descriptors managed by separate descriptor pools are not equal, so +// the underlying messages are not considered equal. +TEST(CelValueEqualImplTest, DynamicDescriptorAndGeneratedInequal) { + // Simulate a dynamically loaded descriptor that happens to match the + // compiled version. + google::protobuf::DescriptorPool pool; + google::protobuf::DynamicMessageFactory factory; + google::protobuf::Arena arena; + factory.SetDelegateToGeneratedFactory(false); + + ASSERT_TRUE(AddDepsToPool(TestMessage::descriptor()->file(), pool)); + + TestMessage example_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + int64_value: 12345 + bool_list: false + bool_list: true + message_value { float_value: 1.0 } + )pb", + &example_message)); + + // Messages from a loaded descriptor and generated versions can't be compared + // via MessageDifferencer, so return false. + std::unique_ptr example_dynamic_message( + factory + .GetPrototype(pool.FindMessageTypeByName( + TestMessage::descriptor()->full_name())) + ->New()); + + ASSERT_TRUE(example_dynamic_message->ParseFromString( + example_message.SerializeAsString())); + + EXPECT_THAT(CelValueEqualImpl( + CelProtoWrapper::CreateMessage(&example_message, &arena), + CelProtoWrapper::CreateMessage(example_dynamic_message.get(), + &arena)), + Optional(false)); +} + +TEST(CelValueEqualImplTest, DynamicMessageAndMessageEqual) { + google::protobuf::DynamicMessageFactory factory; + google::protobuf::Arena arena; + factory.SetDelegateToGeneratedFactory(false); + + TestMessage example_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + int64_value: 12345 + bool_list: false + bool_list: true + message_value { float_value: 1.0 } + )pb", + &example_message)); + + // Dynamic message and generated Message subclass with the same generated + // descriptor are comparable. + std::unique_ptr example_dynamic_message( + factory.GetPrototype(TestMessage::descriptor())->New()); + + ASSERT_TRUE(example_dynamic_message->ParseFromString( + example_message.SerializeAsString())); + + EXPECT_THAT(CelValueEqualImpl( + CelProtoWrapper::CreateMessage(&example_message, &arena), + CelProtoWrapper::CreateMessage(example_dynamic_message.get(), + &arena)), + Optional(true)); +} + +class EqualityFunctionTest + : public testing::TestWithParam> { + public: + EqualityFunctionTest() { + options_.enable_heterogeneous_equality = std::get<1>(GetParam()); + options_.enable_empty_wrapper_null_unboxing = true; + builder_ = CreateCelExpressionBuilder(options_); + } + + CelFunctionRegistry& registry() { return *builder_->GetRegistry(); } + + absl::StatusOr Evaluate(absl::string_view expr, const CelValue& lhs, + const CelValue& rhs) { + CEL_ASSIGN_OR_RETURN(ParsedExpr parsed_expr, parser::Parse(expr)); + Activation activation; + activation.InsertValue("lhs", lhs); + activation.InsertValue("rhs", rhs); + + CEL_ASSIGN_OR_RETURN(auto expression, + builder_->CreateExpression( + &parsed_expr.expr(), &parsed_expr.source_info())); + + return expression->Evaluate(activation, &arena_); + } + + protected: + std::unique_ptr builder_; + InterpreterOptions options_; + google::protobuf::Arena arena_; +}; + +constexpr std::array kEqualableTypes = { + CelValue::Type::kInt64, CelValue::Type::kUint64, + CelValue::Type::kString, CelValue::Type::kDouble, + CelValue::Type::kBytes, CelValue::Type::kDuration, + CelValue::Type::kMap, CelValue::Type::kList, + CelValue::Type::kBool, CelValue::Type::kTimestamp}; + +TEST(RegisterEqualityFunctionsTest, EqualDefined) { + InterpreterOptions options; + options.enable_fast_builtins = false; + CelFunctionRegistry registry; + ASSERT_THAT(RegisterEqualityFunctions(®istry, options), IsOk()); + for (CelValue::Type type : kEqualableTypes) { + EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kEqual, type)); + } +} + +TEST(RegisterEqualityFunctionsTest, InequalDefined) { + InterpreterOptions options; + options.enable_fast_builtins = false; + CelFunctionRegistry registry; + ASSERT_THAT(RegisterEqualityFunctions(®istry, options), IsOk()); + for (CelValue::Type type : kEqualableTypes) { + EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kInequal, type)); + } +} + +TEST_P(EqualityFunctionTest, SmokeTest) { + EqualityTestCase test_case = std::get<0>(GetParam()); + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterEqualityFunctions(®istry(), options_), IsOk()); + ASSERT_OK_AND_ASSIGN(auto result, + Evaluate(test_case.expr, test_case.lhs, test_case.rhs)); + + if (absl::holds_alternative(test_case.result)) { + EXPECT_THAT(result, test::IsCelBool(absl::get(test_case.result))); + } else { + switch (absl::get(test_case.result)) { + case EqualityTestCase::ErrorKind::kMissingOverload: + EXPECT_THAT(result, test::IsCelError( + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("No matching overloads")))) + << test_case.expr; + break; + case EqualityTestCase::ErrorKind::kMissingIdentifier: + EXPECT_THAT(result, test::IsCelError( + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("found in Activation")))); + break; + default: + EXPECT_THAT(result, test::IsCelError(_)); + break; + } + } +} + +INSTANTIATE_TEST_SUITE_P( + Equality, EqualityFunctionTest, + Combine(testing::ValuesIn( + {{"null == null", true}, + {"true == false", false}, + {"1 == 1", true}, + {"-2 == -1", false}, + {"1.1 == 1.2", false}, + {"'a' == 'a'", true}, + {"lhs == rhs", false, CelValue::CreateBytesView("a"), + CelValue::CreateBytesView("b")}, + {"lhs == rhs", false, + CelValue::CreateDuration(absl::Seconds(1)), + CelValue::CreateDuration(absl::Seconds(2))}, + {"lhs == rhs", true, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), + CelValue::CreateTimestamp(absl::FromUnixSeconds(20))}, + // This should fail before getting to the equal operator. + {"no_such_identifier == 1", + EqualityTestCase::ErrorKind::kMissingIdentifier}, + {"{1: no_such_identifier} == {1: 1}", + EqualityTestCase::ErrorKind::kMissingIdentifier}}), + // heterogeneous equality enabled + testing::Bool())); + +INSTANTIATE_TEST_SUITE_P( + Inequality, EqualityFunctionTest, + Combine(testing::ValuesIn( + {{"null != null", false}, + {"true != false", true}, + {"1 != 1", false}, + {"-2 != -1", true}, + {"1.1 != 1.2", true}, + {"'a' != 'a'", false}, + {"lhs != rhs", true, CelValue::CreateBytesView("a"), + CelValue::CreateBytesView("b")}, + {"lhs != rhs", true, + CelValue::CreateDuration(absl::Seconds(1)), + CelValue::CreateDuration(absl::Seconds(2))}, + {"lhs != rhs", true, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), + CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}, + // This should fail before getting to the equal operator. + {"no_such_identifier != 1", + EqualityTestCase::ErrorKind::kMissingIdentifier}, + {"{1: no_such_identifier} != {1: 1}", + EqualityTestCase::ErrorKind::kMissingIdentifier}}), + // heterogeneous equality enabled + testing::Bool())); + +INSTANTIATE_TEST_SUITE_P(HeterogeneousNumericContainers, EqualityFunctionTest, + Combine(testing::ValuesIn({ + {"{1: 2} == {1u: 2}", true}, + {"{1: 2} == {2u: 2}", false}, + {"{1: 2} == {true: 2}", false}, + {"{1: 2} != {1u: 2}", false}, + {"{1: 2} != {2u: 2}", true}, + {"{1: 2} != {true: 2}", true}, + {"[1u, 2u, 3.0] != [1, 2.0, 3]", false}, + {"[1u, 2u, 3.0] == [1, 2.0, 3]", true}, + {"[1u, 2u, 3.0] != [1, 2.1, 3]", true}, + {"[1u, 2u, 3.0] == [1, 2.1, 3]", false}, + }), + // heterogeneous equality enabled + testing::Values(true))); + +INSTANTIATE_TEST_SUITE_P( + HomogenousNumericContainers, EqualityFunctionTest, + Combine(testing::ValuesIn({ + {"{1: 2} == {1u: 2}", false}, + {"{1: 2} == {2u: 2}", false}, + {"{1: 2} == {true: 2}", false}, + {"{1: 2} != {1u: 2}", true}, + {"{1: 2} != {2u: 2}", true}, + {"{1: 2} != {true: 2}", true}, + {"[1u, 2u, 3.0] != [1, 2.0, 3]", + EqualityTestCase::ErrorKind::kMissingOverload}, + {"[1u, 2u, 3.0] == [1, 2.0, 3]", + EqualityTestCase::ErrorKind::kMissingOverload}, + {"[1u, 2u, 3.0] != [1, 2.1, 3]", + EqualityTestCase::ErrorKind::kMissingOverload}, + {"[1u, 2u, 3.0] == [1, 2.1, 3]", + EqualityTestCase::ErrorKind::kMissingOverload}, + }), + // heterogeneous equality enabled + testing::Values(false))); + +INSTANTIATE_TEST_SUITE_P( + NullInequalityLegacy, EqualityFunctionTest, + Combine(testing::ValuesIn( + {{"null != null", false}, + {"true != null", + EqualityTestCase::ErrorKind::kMissingOverload}, + {"1 != null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"-2 != null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"1.1 != null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"'a' != null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"lhs != null", EqualityTestCase::ErrorKind::kMissingOverload, + CelValue::CreateBytesView("a")}, + {"lhs != null", EqualityTestCase::ErrorKind::kMissingOverload, + CelValue::CreateDuration(absl::Seconds(1))}, + {"lhs != null", EqualityTestCase::ErrorKind::kMissingOverload, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20))}}), + // heterogeneous equality enabled + testing::Values(false))); + +INSTANTIATE_TEST_SUITE_P( + NullEqualityLegacy, EqualityFunctionTest, + Combine(testing::ValuesIn( + {{"null == null", true}, + {"true == null", + EqualityTestCase::ErrorKind::kMissingOverload}, + {"1 == null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"-2 == null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"1.1 == null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"'a' == null", EqualityTestCase::ErrorKind::kMissingOverload}, + {"lhs == null", EqualityTestCase::ErrorKind::kMissingOverload, + CelValue::CreateBytesView("a")}, + {"lhs == null", EqualityTestCase::ErrorKind::kMissingOverload, + CelValue::CreateDuration(absl::Seconds(1))}, + {"lhs == null", EqualityTestCase::ErrorKind::kMissingOverload, + CelValue::CreateTimestamp(absl::FromUnixSeconds(20))}}), + // heterogeneous equality enabled + testing::Values(false))); + +INSTANTIATE_TEST_SUITE_P( + NullInequality, EqualityFunctionTest, + Combine(testing::ValuesIn( + {{"null != null", false}, + {"true != null", true}, + {"null != false", true}, + {"1 != null", true}, + {"null != 1", true}, + {"-2 != null", true}, + {"null != -2", true}, + {"1.1 != null", true}, + {"null != 1.1", true}, + {"'a' != null", true}, + {"lhs != null", true, CelValue::CreateBytesView("a")}, + {"lhs != null", true, + CelValue::CreateDuration(absl::Seconds(1))}, + {"google.api.expr.runtime.TestMessage{} != null", true}, + {"google.api.expr.runtime.TestMessage{}.string_wrapper_value" + " != null", + false}, + {"google.api.expr.runtime.TestMessage{string_wrapper_value: " + "google.protobuf.StringValue{}}.string_wrapper_value != null", + true}, + {"{} != null", true}, + {"[] != null", true}}), + // heterogeneous equality enabled + testing::Values(true))); + +INSTANTIATE_TEST_SUITE_P( + NullEquality, EqualityFunctionTest, + Combine(testing::ValuesIn({ + {"null == null", true}, + {"true == null", false}, + {"null == false", false}, + {"1 == null", false}, + {"null == 1", false}, + {"-2 == null", false}, + {"null == -2", false}, + {"1.1 == null", false}, + {"null == 1.1", false}, + {"'a' == null", false}, + {"lhs == null", false, CelValue::CreateBytesView("a")}, + {"lhs == null", false, + CelValue::CreateDuration(absl::Seconds(1))}, + {"google.api.expr.runtime.TestMessage{} == null", false}, + + {"google.api.expr.runtime.TestMessage{}.string_wrapper_value" + " == null", + true}, + {"google.api.expr.runtime.TestMessage{string_wrapper_value: " + "google.protobuf.StringValue{}}.string_wrapper_value == null", + false}, + {"{} == null", false}, + {"[] == null", false}, + }), + // heterogeneous equality enabled + testing::Values(true))); + +INSTANTIATE_TEST_SUITE_P( + ProtoEquality, EqualityFunctionTest, + Combine(testing::ValuesIn({ + {"google.api.expr.runtime.TestMessage{} == null", false}, + {"google.api.expr.runtime.TestMessage{string_wrapper_value: " + "google.protobuf.StringValue{}}.string_wrapper_value == ''", + true}, + {"google.api.expr.runtime.TestMessage{" + "int64_wrapper_value: " + "google.protobuf.Int64Value{value: 1}," + "double_value: 1.1} == " + "google.api.expr.runtime.TestMessage{" + "int64_wrapper_value: " + "google.protobuf.Int64Value{value: 1}," + "double_value: 1.1}", + true}, + // ProtoDifferencer::Equals distinguishes set fields vs + // defaulted + {"google.api.expr.runtime.TestMessage{" + "string_wrapper_value: google.protobuf.StringValue{}} == " + "google.api.expr.runtime.TestMessage{}", + false}, + // Differently typed messages inequal. + {"google.api.expr.runtime.TestMessage{} == " + "google.rpc.context.AttributeContext{}", + false}, + }), + // heterogeneous equality enabled + testing::Values(true))); + +void RunBenchmark(absl::string_view expr, benchmark::State& benchmark) { + InterpreterOptions opts; + auto builder = CreateCelExpressionBuilder(opts); + ASSERT_THAT(RegisterEqualityFunctions(builder->GetRegistry(), opts), IsOk()); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(expr)); + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(auto plan, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : benchmark) { + ASSERT_OK_AND_ASSIGN(auto result, plan->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + } +} + +void RunIdentBenchmark(const CelValue& lhs, const CelValue& rhs, + benchmark::State& benchmark) { + InterpreterOptions opts; + auto builder = CreateCelExpressionBuilder(opts); + ASSERT_THAT(RegisterEqualityFunctions(builder->GetRegistry(), opts), IsOk()); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("lhs == rhs")); + google::protobuf::Arena arena; + Activation activation; + activation.InsertValue("lhs", lhs); + activation.InsertValue("rhs", rhs); + + ASSERT_OK_AND_ASSIGN(auto plan, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : benchmark) { + ASSERT_OK_AND_ASSIGN(auto result, plan->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + } +} + +void BM_EqualsInt(benchmark::State& s) { RunBenchmark("42 == 43", s); } + +BENCHMARK(BM_EqualsInt); + +void BM_EqualsString(benchmark::State& s) { + RunBenchmark("'1234' == '1235'", s); +} + +BENCHMARK(BM_EqualsString); + +void BM_EqualsCreatedList(benchmark::State& s) { + RunBenchmark("[1, 2, 3, 4, 5] == [1, 2, 3, 4, 6]", s); +} + +BENCHMARK(BM_EqualsCreatedList); + +void BM_EqualsBoundLegacyList(benchmark::State& s) { + ContainerBackedListImpl lhs( + {CelValue::CreateInt64(1), CelValue::CreateInt64(2), + CelValue::CreateInt64(3), CelValue::CreateInt64(4), + CelValue::CreateInt64(5)}); + ContainerBackedListImpl rhs( + {CelValue::CreateInt64(1), CelValue::CreateInt64(2), + CelValue::CreateInt64(3), CelValue::CreateInt64(4), + CelValue::CreateInt64(6)}); + + RunIdentBenchmark(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs), s); +} + +BENCHMARK(BM_EqualsBoundLegacyList); + +void BM_EqualsCreatedMap(benchmark::State& s) { + RunBenchmark("{1: 2, 2: 3, 3: 6} == {1: 2, 2: 3, 3: 6}", s); +} + +BENCHMARK(BM_EqualsCreatedMap); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/extension_func_registrar.cc b/eval/public/extension_func_registrar.cc index 1ac79e12c..d3411e9fc 100644 --- a/eval/public/extension_func_registrar.cc +++ b/eval/public/extension_func_registrar.cc @@ -1,14 +1,250 @@ #include "eval/public/extension_func_registrar.h" +#include +#include +#include + +#include "google/type/timeofday.pb.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/civil_time.h" +#include "absl/time/time.h" +#include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" namespace google { namespace api { namespace expr { namespace runtime { -absl::Status RegisterExtensionFunctions(CelFunctionRegistry*) { - return absl::OkStatus(); +using google::protobuf::Arena; + +CelValue BetweenTs(Arena* arena, absl::Time time_stamp, absl::Time start, + absl::Time stop) { + bool is_between = false; + // check if timestamp paremeter is between start and stop parameters + is_between = (start <= time_stamp) && (time_stamp < stop); + return CelValue::CreateBool(is_between); +} + +CelValue BetweenStr(Arena* arena, absl::Time time_stamp, + absl::string_view start, absl::string_view stop) { + // convert start and stop into timestamps + absl::Time start_ts; + absl::Time stop_ts; + // check if timestamp parameter is between start and stop -> call BetweenTs + if (!absl::ParseTime(absl::RFC3339_full, start, &start_ts, nullptr) || + !absl::ParseTime(absl::RFC3339_full, stop, &stop_ts, nullptr)) { + return CreateErrorValue(arena, "String to Timestamp conversion failed", + absl::StatusCode::kInvalidArgument); + } + + return BetweenTs(arena, time_stamp, start_ts, stop_ts); +} + +CelValue GetDateTz(Arena* arena, absl::Time time_stamp, + absl::TimeZone time_zone) { + absl::Time ret_date; + absl::CivilDay normalized_date; + // convert absl time to civil day, which normalizes to midnight time, + // convert the result to CivilSecond + // convert CivilSecond from previous step back into absl::Time + normalized_date = absl::ToCivilDay(time_stamp, time_zone); + absl::CivilSecond normalized_date_cs(normalized_date); + ret_date = absl::FromCivil(normalized_date_cs, time_zone); + return CelValue::CreateTimestamp(ret_date); +} + +CelValue GetDate(Arena* arena, absl::Time time_stamp, + absl::string_view time_zone) { + absl::TimeZone time_zone_tz; + // convert timezone from string to TimeZone + if (!absl::LoadTimeZone(time_zone, &time_zone_tz)) { + return CreateErrorValue(arena, "String to Timezone conversion failed", + absl::StatusCode::kInvalidArgument); + } + return GetDateTz(arena, time_stamp, time_zone_tz); +} + +CelValue GetDateUTC(Arena* arena, absl::Time time_stamp) { + absl::TimeZone time_zone = absl::UTCTimeZone(); + return GetDateTz(arena, time_stamp, time_zone); +} + +CelValue GetTimeOfDayTz(Arena* arena, absl::Time time_stamp, + absl::TimeZone time_zone) { + absl::CivilSecond date_civil_time = + absl::ToCivilSecond(time_stamp, time_zone); + google::type::TimeOfDay* tod_message = + Arena::Create(arena); + + tod_message->set_seconds(date_civil_time.second()); + tod_message->set_minutes(date_civil_time.minute()); + tod_message->set_hours(date_civil_time.hour()); + // transform into celvalue for return + + return CelProtoWrapper::CreateMessage(tod_message, arena); +} + +CelValue GetTimeOfDay(Arena* arena, absl::Time time_stamp, + absl::string_view time_zone) { + absl::TimeZone time_zonetz; + + if (!absl::LoadTimeZone(time_zone, &time_zonetz)) { + return CreateErrorValue(arena, "String to Timezone conversion failed", + absl::StatusCode::kInvalidArgument); + } + + return GetTimeOfDayTz(arena, time_stamp, time_zonetz); +} + +CelValue GetTimeOfDayUTC(Arena* arena, absl::Time time_stamp) { + absl::TimeZone utc = absl::UTCTimeZone(); + // call to helper function GetTimeOfDayTz + // return value from helper + return GetTimeOfDayTz(arena, time_stamp, utc); +} + +int ToSeconds(const google::type::TimeOfDay* time_of_day) { + int seconds = 0; + + seconds += time_of_day->hours() * 60 * 60; + seconds += time_of_day->minutes() * 60; + seconds += time_of_day->seconds(); + + return seconds; +} + +CelValue BetweenToD(Arena* arena, const google::protobuf::Message* time_of_day, + const google::protobuf::Message* start, const google::protobuf::Message* stop) { + bool is_between; + const google::type::TimeOfDay* time_of_day_tod = + google::protobuf::DynamicCastMessage(time_of_day); + const google::type::TimeOfDay* start_tod = + google::protobuf::DynamicCastMessage(start); + const google::type::TimeOfDay* stop_tod = + google::protobuf::DynamicCastMessage(stop); + + if ((time_of_day_tod == nullptr) || (start_tod == nullptr) || + (stop_tod == nullptr)) { + return CreateErrorValue(arena, "Message type downcast failed", + absl::StatusCode::kInvalidArgument); + } + // resolution for TimeOfDay in this function is 1 second + int start_time = ToSeconds(start_tod); + int stop_time = ToSeconds(stop_tod); + int tod_time = ToSeconds(time_of_day_tod); + + is_between = (tod_time >= start_time) && (tod_time < stop_time); + return CelValue::CreateBool(is_between); +} + +CelValue BetweenToDStr(Arena* arena, const google::protobuf::Message* time_of_day, + absl::string_view start, absl::string_view stop) { + std::string start_date_time = absl::StrCat("1970-01-01T", start, "+00:00"); + std::string stop_date_time = absl::StrCat("1970-01-01T", stop, "+00:00"); + absl::Time start_ts; + absl::Time stop_ts; + // format of time of day string: "HH:MM:SS" + // Below we prepend a generic date string and append a generic timezone string + // this generates a full timestamp string that can be parsed with ParseTime() + + if (!absl::ParseTime(absl::RFC3339_sec, start_date_time, absl::UTCTimeZone(), + &start_ts, nullptr) || + !absl::ParseTime(absl::RFC3339_sec, stop_date_time, absl::UTCTimeZone(), + &stop_ts, nullptr)) { + return CreateErrorValue(arena, "String to Timestamp conversion failed", + absl::StatusCode::kInvalidArgument); + } + + const google::protobuf::Message* start_msg = + GetTimeOfDayUTC(arena, start_ts).MessageOrDie(); + const google::protobuf::Message* stop_msg = + GetTimeOfDayUTC(arena, stop_ts).MessageOrDie(); + + return BetweenToD(arena, time_of_day, start_msg, stop_msg); +} + +absl::Status RegisterExtensionFunctions(CelFunctionRegistry* registry) { + auto status = FunctionAdapter:: + CreateAndRegister( + "between", true, + [](Arena* arena, absl::Time ts, absl::Time start, absl::Time stop) + -> CelValue { return BetweenTs(arena, ts, start, stop); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + "between", true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder start, + CelValue::StringHolder stop) -> CelValue { + return BetweenStr(arena, ts, start.value(), stop.value()); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + "date", true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDate(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + "date", true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDateUTC(arena, ts); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + "timeOfDay", true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetTimeOfDay(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + "timeOfDay", true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetTimeOfDayUTC(arena, ts); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + "between", true, + [](Arena* arena, const google::protobuf::Message* tod, + const google::protobuf::Message* start, + const google::protobuf::Message* stop) -> CelValue { + return BetweenToD(arena, tod, start, stop); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + "between", true, + [](Arena* arena, const google::protobuf::Message* tod, + CelValue::StringHolder start, + CelValue::StringHolder stop) -> CelValue { + return BetweenToDStr(arena, tod, start.value(), stop.value()); + }, + registry); + + return status; } } // namespace runtime diff --git a/eval/public/extension_func_test.cc b/eval/public/extension_func_test.cc index aad518b7e..2e2497d7d 100644 --- a/eval/public/extension_func_test.cc +++ b/eval/public/extension_func_test.cc @@ -1,11 +1,24 @@ -#include "google/protobuf/util/time_util.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include + +#include "google/type/timeofday.pb.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/civil_time.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" #include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" #include "eval/public/extension_func_registrar.h" -#include "base/status_macros.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/time_util.h" namespace google { namespace api { @@ -13,8 +26,6 @@ namespace expr { namespace runtime { namespace { -using google::protobuf::Duration; -using google::protobuf::Timestamp; using google::protobuf::Arena; static const int kNanosPerSecond = 1000000000; @@ -68,7 +79,7 @@ class ExtensionTest : public ::testing::Test { } // Helper method to test timestamp() function - void PerformTimestampConversion(Arena* arena, std::string ts_str, + void PerformTimestampConversion(Arena* arena, const std::string& ts_str, CelValue* result) { auto functions = registry_.FindOverloads("timestamp", false, {CelValue::Type::kString}); @@ -83,8 +94,153 @@ class ExtensionTest : public ::testing::Test { ASSERT_OK(status); } + void PerformBetweenTest(Arena* arena, absl::Time time_stamp, + absl::Time start_ts, absl::Time stop_ts, + CelValue* result) { + auto functions = registry_.FindOverloads( + "between", true, + {CelValue::Type::kTimestamp, CelValue::Type::kTimestamp, + CelValue::Type::kTimestamp}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = {CelValue::CreateTimestamp(time_stamp), + CelValue::CreateTimestamp(start_ts), + CelValue::CreateTimestamp(stop_ts)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + + void PerformBetweenStrTest(Arena* arena, absl::Time time_stamp, + std::string* start, std::string* stop, + CelValue* result) { + auto functions = registry_.FindOverloads( + "between", true, + {CelValue::Type::kTimestamp, CelValue::Type::kString, + CelValue::Type::kString}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = {CelValue::CreateTimestamp(time_stamp), + CelValue::CreateString(start), + CelValue::CreateString(stop)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + + void PerformGetDateTest(Arena* arena, absl::Time time_stamp, + std::string* time_zone, CelValue* result) { + auto functions = registry_.FindOverloads( + "date", true, {CelValue::Type::kTimestamp, CelValue::Type::kString}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = {CelValue::CreateTimestamp(time_stamp), + CelValue::CreateString(time_zone)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + + void PerformGetDateUTCTest(Arena* arena, absl::Time time_stamp, + CelValue* result) { + auto functions = + registry_.FindOverloads("date", true, {CelValue::Type::kTimestamp}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = {CelValue::CreateTimestamp(time_stamp)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + + void PerformGetTimeOfDayTest(Arena* arena, absl::Time time_stamp, + std::string* time_zone, CelValue* result) { + auto functions = registry_.FindOverloads( + "timeOfDay", true, + {CelValue::Type::kTimestamp, CelValue::Type::kString}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = {CelValue::CreateTimestamp(time_stamp), + CelValue::CreateString(time_zone)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + + void PerformGetTimeOfDayUTCTest(Arena* arena, absl::Time time_stamp, + CelValue* result) { + auto functions = registry_.FindOverloads("timeOfDay", true, + {CelValue::Type::kTimestamp}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = {CelValue::CreateTimestamp(time_stamp)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + + void PerformBetweenToDTest(Arena* arena, const google::protobuf::Message* time_of_day, + const google::protobuf::Message* start, + const google::protobuf::Message* stop, CelValue* result) { + auto functions = registry_.FindOverloads( + "between", true, + {CelValue::Type::kMessage, CelValue::Type::kMessage, + CelValue::Type::kMessage}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = { + CelProtoWrapper::CreateMessage(time_of_day, arena), + CelProtoWrapper::CreateMessage(start, arena), + CelProtoWrapper::CreateMessage(stop, arena)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + + void PerformBetweenToDStrTest(Arena* arena, + const google::protobuf::Message* time_of_day, + std::string* start, std::string* stop, + CelValue* result) { + auto functions = registry_.FindOverloads( + "between", true, + {CelValue::Type::kMessage, CelValue::Type::kString, + CelValue::Type::kString}); + ASSERT_EQ(functions.size(), 1); + + auto func = functions[0]; + + std::vector args = { + CelProtoWrapper::CreateMessage(time_of_day, arena), + CelValue::CreateString(start), CelValue::CreateString(stop)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + + ASSERT_OK(status); + } + // Helper method to test duration() function - void PerformDurationConversion(Arena* arena, std::string ts_str, + void PerformDurationConversion(Arena* arena, const std::string& ts_str, CelValue* result) { auto functions = registry_.FindOverloads("duration", false, {CelValue::Type::kString}); @@ -101,6 +257,7 @@ class ExtensionTest : public ::testing::Test { // Function registry object CelFunctionRegistry registry_; + Arena arena_; }; // Test string startsWith() function. @@ -211,6 +368,227 @@ TEST_F(ExtensionTest, TestDurationFromString) { ASSERT_TRUE(result.IsError()); } +TEST_F(ExtensionTest, TestBetweenTs) { + absl::Time time_1; + absl::Time time_2; + absl::Time time_3; + std::string time_stampstr = "1997-07-16T19:50:30.45+01:00"; + std::string time_start = "1997-07-16T19:20:30.45+01:00"; + std::string time_stop = "1997-07-16T20:20:30.45+01:00"; + Arena arena; + CelValue result; + + absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_2, nullptr); + absl::ParseTime(absl::RFC3339_full, time_start, &time_1, nullptr); + absl::ParseTime(absl::RFC3339_full, time_stop, &time_3, nullptr); + + PerformBetweenTest(&arena, time_2, time_1, time_3, &result); + ASSERT_EQ(result.BoolOrDie(), true); + PerformBetweenTest(&arena, time_1, time_2, time_3, &result); + ASSERT_EQ(result.BoolOrDie(), false); + PerformBetweenTest(&arena, time_1, time_1, time_3, &result); + ASSERT_EQ(result.BoolOrDie(), true); + PerformBetweenTest(&arena, time_3, time_1, time_2, &result); + ASSERT_EQ(result.BoolOrDie(), false); + PerformBetweenTest(&arena, time_3, time_1, time_3, &result); + ASSERT_EQ(result.BoolOrDie(), false); +} + +TEST_F(ExtensionTest, TestBetweenStr) { + Arena arena; + absl::Time time_stamp; + CelValue result; + std::string time_stampstr = "1997-07-16T19:50:30.45+01:00"; + std::string time_start = "1997-07-16T19:20:30.45+01:00"; + std::string time_stop = "1997-07-16T20:20:30.45+01:00"; + + absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_stamp, nullptr); + PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); + ASSERT_EQ(result.BoolOrDie(), true); + + absl::ParseTime(absl::RFC3339_full, time_start, &time_stamp, nullptr); + PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); + ASSERT_EQ(result.BoolOrDie(), true); + + absl::ParseTime(absl::RFC3339_full, time_stop, &time_stamp, nullptr); + PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); + + time_stampstr = "1997-07-16T18:20:30.45+01:00"; + absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_stamp, nullptr); + PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); + + time_stampstr = "1997-07-16T21:20:30.45+01:00"; + absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_stamp, nullptr); + PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); +} + +TEST_F(ExtensionTest, TestGetDate) { + Arena arena; + CelValue result; + absl::CivilSecond date(2015, 2, 3, 4, 5, 6); + absl::CivilSecond normal_date(2015, 2, 3); + absl::TimeZone time_zone; + std::string time_zonestr = "America/Los_Angeles"; + absl::LoadTimeZone(time_zonestr, &time_zone); + + absl::Time expected_val = absl::FromCivil(normal_date, time_zone); + absl::Time input_val = absl::FromCivil(date, time_zone); + + PerformGetDateTest(&arena, input_val, &time_zonestr, &result); + ASSERT_EQ(result.TimestampOrDie(), expected_val); +} + +TEST_F(ExtensionTest, TestGetDateUTC) { + Arena arena; + CelValue result; + absl::CivilSecond date(2015, 2, 3, 4, 5, 6); + absl::CivilSecond normal_date(2015, 2, 3); + absl::TimeZone time_zone = absl::UTCTimeZone(); + + absl::Time expected_val = absl::FromCivil(normal_date, time_zone); + absl::Time input_val = absl::FromCivil(date, time_zone); + + PerformGetDateUTCTest(&arena, input_val, &result); + ASSERT_EQ(result.TimestampOrDie(), expected_val); +} + +TEST_F(ExtensionTest, TestGetTimeOfDay) { + Arena arena; + CelValue result; + absl::CivilSecond date(2015, 2, 3, 4, 5, 6); + absl::TimeZone time_zone; + std::string time_zonestr = "America/Los_Angeles"; + google::type::TimeOfDay* tod_message = + Arena::Create(&arena); + + absl::LoadTimeZone(time_zonestr, &time_zone); + absl::Time input_val = absl::FromCivil(date, time_zone); + + tod_message->set_seconds(date.second()); + tod_message->set_minutes(date.minute()); + tod_message->set_hours(date.hour()); + + PerformGetTimeOfDayTest(&arena, input_val, &time_zonestr, &result); + const google::type::TimeOfDay* time_of_day_tod = + google::protobuf::DynamicCastMessage( + result.MessageOrDie()); + + ASSERT_EQ(time_of_day_tod->seconds(), tod_message->seconds()); + ASSERT_EQ(time_of_day_tod->minutes(), tod_message->minutes()); + ASSERT_EQ(time_of_day_tod->hours(), tod_message->hours()); +} + +TEST_F(ExtensionTest, TestGetTimeOfDayUTC) { + Arena arena; + CelValue result; + absl::TimeZone time_zone = absl::UTCTimeZone(); + absl::CivilSecond date(2015, 2, 3, 4, 5, 6); + absl::Time input_time = absl::FromCivil(date, time_zone); + google::type::TimeOfDay* tod_message = + Arena::Create(&arena); + + tod_message->set_seconds(date.second()); + tod_message->set_minutes(date.minute()); + tod_message->set_hours(date.hour()); + + PerformGetTimeOfDayUTCTest(&arena, input_time, &result); + const google::type::TimeOfDay* time_of_day_tod = + google::protobuf::DynamicCastMessage( + result.MessageOrDie()); + + ASSERT_EQ(time_of_day_tod->seconds(), tod_message->seconds()); + ASSERT_EQ(time_of_day_tod->minutes(), tod_message->minutes()); + ASSERT_EQ(time_of_day_tod->hours(), tod_message->hours()); +} + +TEST_F(ExtensionTest, TestBetweenToD) { + Arena arena; + CelValue result; + google::type::TimeOfDay* time_of_day = + Arena::Create(&arena); + google::type::TimeOfDay* start = + Arena::Create(&arena); + google::type::TimeOfDay* stop = + Arena::Create(&arena); + + start->set_hours(20); + start->set_minutes(0); + start->set_seconds(0); + stop->set_hours(21); + stop->set_minutes(0); + stop->set_seconds(0); + time_of_day->set_hours(20); + time_of_day->set_minutes(30); + time_of_day->set_seconds(0); + + PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); + ASSERT_EQ(result.BoolOrDie(), true); + + time_of_day->set_minutes(0); + PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); + ASSERT_EQ(result.BoolOrDie(), true); + + time_of_day->set_hours(19); + PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); + + time_of_day->set_hours(21); + PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); + + time_of_day->set_seconds(1); + PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); +} + +TEST_F(ExtensionTest, TestBetweenTodStr) { + Arena arena; + CelValue result; + std::string start = "18:20:30"; + std::string stop = "19:20:30"; + google::type::TimeOfDay* time_of_day = + Arena::Create(&arena); + + time_of_day->set_hours(19); + time_of_day->set_minutes(0); + time_of_day->set_seconds(0); + + PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); + ASSERT_EQ(result.BoolOrDie(), true); + + time_of_day->set_hours(18); + time_of_day->set_minutes(20); + time_of_day->set_seconds(30); + + PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); + ASSERT_EQ(result.BoolOrDie(), true); + + time_of_day->set_seconds(29); + + PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); + + time_of_day->set_hours(19); + time_of_day->set_minutes(20); + time_of_day->set_seconds(30); + + PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); + + time_of_day->set_seconds(29); + + PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); + ASSERT_EQ(result.BoolOrDie(), true); + + time_of_day->set_seconds(31); + + PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); + ASSERT_EQ(result.BoolOrDie(), false); +} + } // namespace } // namespace runtime diff --git a/eval/public/logical_function_registrar.cc b/eval/public/logical_function_registrar.cc new file mode 100644 index 000000000..f84e9cb1e --- /dev/null +++ b/eval/public/logical_function_registrar.cc @@ -0,0 +1,30 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/logical_function_registrar.h" + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/standard/logical_functions.h" + +namespace google::api::expr::runtime { + +absl::Status RegisterLogicalFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + return cel::RegisterLogicalFunctions(registry->InternalGetRegistry(), + ConvertToRuntimeOptions(options)); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/logical_function_registrar.h b/eval/public/logical_function_registrar.h new file mode 100644 index 000000000..9337e3dbb --- /dev/null +++ b/eval/public/logical_function_registrar.h @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_LOGICAL_FUNCTION_REGISTRAR_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_LOGICAL_FUNCTION_REGISTRAR_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" + +namespace google::api::expr::runtime { + +// Register logical operators ! and @not_strictly_false. +// +// &&, ||, ?: are special cased by the interpreter (not implemented via the +// function registry.) +// +// Most users should use RegisterBuiltinFunctions, which includes these +// definitions. +absl::Status RegisterLogicalFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_LOGICAL_FUNCTION_REGISTRAR_H_ diff --git a/eval/public/logical_function_registrar_test.cc b/eval/public/logical_function_registrar_test.cc new file mode 100644 index 000000000..6b7346498 --- /dev/null +++ b/eval/public/logical_function_registrar_test.cc @@ -0,0 +1,127 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/logical_function_registrar.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/portable_cel_function_adapter.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using cel::expr::Expr; +using cel::expr::SourceInfo; + +using ::absl_testing::StatusIs; +using ::testing::HasSubstr; + +struct TestCase { + std::string test_name; + std::string expr; + absl::StatusOr result = CelValue::CreateBool(true); +}; + +const CelError* ExampleError() { + static absl::NoDestructor error( + absl::InternalError("test example error")); + + return &*error; +} + +void ExpectResult(const TestCase& test_case) { + auto parsed_expr = parser::Parse(test_case.expr); + ASSERT_OK(parsed_expr); + const Expr& expr_ast = parsed_expr->expr(); + const SourceInfo& source_info = parsed_expr->source_info(); + InterpreterOptions options; + options.short_circuiting = true; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterLogicalFunctions(builder->GetRegistry(), options)); + ASSERT_OK(builder->GetRegistry()->Register( + PortableUnaryFunctionAdapter::Create( + "toBool", false, + [](google::protobuf::Arena*, CelValue::StringHolder holder) -> CelValue { + if (holder.value() == "true") { + return CelValue::CreateBool(true); + } else if (holder.value() == "false") { + return CelValue::CreateBool(false); + } + return CelValue::CreateError(ExampleError()); + }))); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr_ast, &source_info)); + + Activation activation; + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena)); + if (!test_case.result.ok()) { + EXPECT_TRUE(value.IsError()); + EXPECT_THAT(*value.ErrorOrDie(), + StatusIs(test_case.result.status().code(), + HasSubstr(test_case.result.status().message()))); + return; + } + EXPECT_THAT(value, test::EqualsCelValue(*test_case.result)); +} + +using BuiltinFuncParamsTest = testing::TestWithParam; +TEST_P(BuiltinFuncParamsTest, StandardFunctions) { ExpectResult(GetParam()); } + +INSTANTIATE_TEST_SUITE_P( + BuiltinFuncParamsTest, BuiltinFuncParamsTest, + testing::ValuesIn({ + // Legacy duration and timestamp arithmetic tests. + {"LogicalNotOfTrue", "!true", CelValue::CreateBool(false)}, + {"LogicalNotOfFalse", "!false", CelValue::CreateBool(true)}, + // Not strictly false is an internal function for implementing logical + // shortcutting in comprehensions. + {"NotStrictlyFalseTrue", "[true, true, true].all(x, x)", + CelValue::CreateBool(true)}, + // List creation is eager so use an extension function to introduce an + // error. + {"NotStrictlyFalseErrorShortcircuit", + "['true', 'false', 'error'].all(x, toBool(x))", + CelValue::CreateBool(false)}, + {"NotStrictlyFalseError", "['true', 'true', 'error'].all(x, toBool(x))", + CelValue::CreateError(ExampleError())}, + {"NotStrictlyFalseFalse", "[false, false, false].all(x, x)", + CelValue::CreateBool(false)}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/message_wrapper.h b/eval/public/message_wrapper.h new file mode 100644 index 000000000..698eff5bb --- /dev/null +++ b/eval/public/message_wrapper.h @@ -0,0 +1,139 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/numeric/bits.h" +#include "base/internal/message_wrapper.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace cel::interop_internal { +struct MessageWrapperAccess; +} // namespace cel::interop_internal + +namespace google::api::expr::runtime { + +// Forward declare to resolve cycle. +class LegacyTypeInfoApis; + +// Wrapper type for protobuf messages. This is used to limit internal usages of +// proto APIs and to support working with the proto lite runtime. +// +// Provides operations for checking if down-casting to Message is safe. +class ABSL_DEPRECATED("Use google::protobuf::Message directly") MessageWrapper { + public: + // Simple builder class. + // + // Wraps a tagged mutable message lite ptr. + class ABSL_DEPRECATED("Use google::protobuf::Message directly") Builder { + public: + explicit Builder(google::protobuf::MessageLite* message) + : message_ptr_(reinterpret_cast(message)) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); + } + explicit Builder(google::protobuf::Message* message) + : message_ptr_(reinterpret_cast(message) | kMessageTag) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); + } + + google::protobuf::MessageLite* message_ptr() const { + return reinterpret_cast(message_ptr_ & kPtrMask); + } + + bool HasFullProto() const { + return (message_ptr_ & kTagMask) == kMessageTag; + } + + MessageWrapper Build(const LegacyTypeInfoApis* type_info) { + return MessageWrapper(message_ptr_, type_info); + } + + private: + friend class MessageWrapper; + + explicit Builder(uintptr_t message_ptr) : message_ptr_(message_ptr) {} + + uintptr_t message_ptr_; + }; + + static_assert(alignof(google::protobuf::MessageLite) >= 2, + "Assume that valid MessageLite ptrs have a free low-order bit"); + MessageWrapper() : message_ptr_(0), legacy_type_info_(nullptr) {} + + MessageWrapper(const google::protobuf::MessageLite* message, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(reinterpret_cast(message)), + legacy_type_info_(legacy_type_info) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); + } + + MessageWrapper(const google::protobuf::Message* message, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(reinterpret_cast(message) | kMessageTag), + legacy_type_info_(legacy_type_info) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); + } + + // If true, the message is using the full proto runtime and downcasting to + // message should be safe. + bool HasFullProto() const { return (message_ptr_ & kTagMask) == kMessageTag; } + + // Returns the underlying message. + // + // Clients must check HasFullProto before downcasting to Message. + const google::protobuf::MessageLite* message_ptr() const { + return reinterpret_cast(message_ptr_ & + kPtrMask); + } + + // Type information associated with this message. + const LegacyTypeInfoApis* legacy_type_info() const { + return legacy_type_info_; + } + + private: + friend struct ::cel::interop_internal::MessageWrapperAccess; + + MessageWrapper(uintptr_t message_ptr, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(message_ptr), legacy_type_info_(legacy_type_info) {} + + Builder ToBuilder() { return Builder(message_ptr_); } + + static constexpr int kTagSize = ::cel::base_internal::kMessageWrapperTagSize; + static constexpr uintptr_t kTagMask = + ::cel::base_internal::kMessageWrapperTagMask; + static constexpr uintptr_t kPtrMask = + ::cel::base_internal::kMessageWrapperPtrMask; + static constexpr uintptr_t kMessageTag = + ::cel::base_internal::kMessageWrapperTagMessageValue; + uintptr_t message_ptr_; + const LegacyTypeInfoApis* legacy_type_info_; +}; + +static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), + "MessageWrapper must not increase CelValue size."); + +} // namespace google::api::expr::runtime +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ diff --git a/eval/public/message_wrapper_test.cc b/eval/public/message_wrapper_test.cc new file mode 100644 index 000000000..15e5e88da --- /dev/null +++ b/eval/public/message_wrapper_test.cc @@ -0,0 +1,85 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/message_wrapper.h" + +#include + +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace google::api::expr::runtime { +namespace { + +TEST(MessageWrapper, Size) { + static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), + "MessageWrapper must not increase CelValue size."); +} + +TEST(MessageWrapper, WrapsMessage) { + TestMessage test_message; + + test_message.set_int64_value(20); + test_message.set_double_value(12.3); + + MessageWrapper wrapped_message(&test_message, TrivialTypeInfo::GetInstance()); + + constexpr bool is_full_proto_runtime = + std::is_base_of_v; + + EXPECT_EQ(wrapped_message.message_ptr(), + static_cast(&test_message)); + ASSERT_EQ(wrapped_message.HasFullProto(), is_full_proto_runtime); +} + +TEST(MessageWrapperBuilder, Builder) { + TestMessage test_message; + + MessageWrapper::Builder builder(&test_message); + constexpr bool is_full_proto_runtime = + std::is_base_of_v; + + ASSERT_EQ(builder.HasFullProto(), is_full_proto_runtime); + + ASSERT_EQ(builder.message_ptr(), + static_cast(&test_message)); + + auto mutable_message = + google::protobuf::DownCastMessage(builder.message_ptr()); + mutable_message->set_int64_value(20); + mutable_message->set_double_value(12.3); + + MessageWrapper wrapped_message = + builder.Build(TrivialTypeInfo::GetInstance()); + + ASSERT_EQ(wrapped_message.message_ptr(), + static_cast(&test_message)); + ASSERT_EQ(wrapped_message.HasFullProto(), is_full_proto_runtime); + EXPECT_EQ(wrapped_message.message_ptr(), + static_cast(&test_message)); + EXPECT_EQ(test_message.int64_value(), 20); + EXPECT_EQ(test_message.double_value(), 12.3); +} + +TEST(MessageWrapper, DefaultNull) { + MessageWrapper wrapper; + EXPECT_EQ(wrapper.message_ptr(), nullptr); + EXPECT_EQ(wrapper.legacy_type_info(), nullptr); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/portable_cel_function_adapter.h b/eval/public/portable_cel_function_adapter.h new file mode 100644 index 000000000..86e5b1320 --- /dev/null +++ b/eval/public/portable_cel_function_adapter.h @@ -0,0 +1,72 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ + +#include "eval/public/cel_function_adapter.h" + +namespace google::api::expr::runtime { + +// Portable version of the FunctionAdapter template utility. +// +// The PortableFunctionAdapter variation provides the same interface, +// but doesn't support unwrapping google::protobuf::Message values. See documentation on +// Function adapter for example usage. +// +// Most users should prefer using the standard FunctionAdapter. +template +using PortableFunctionAdapter = FunctionAdapter; + +// PortableUnaryFunctionAdapter provides a factory for adapting 1 argument +// functions to CEL extension functions. +// +// Static Methods: +// +// Create(absl::string_view function_name, bool receiver_style, +// FunctionType func) -> std::unique_ptr +// +// Usage example: +// +// auto func = [](::google::protobuf::Arena* arena, int64_t i) -> int64_t { +// return -i; +// }; +// +// auto cel_func = +// PortableUnaryFunctionAdapter::Create("negate", true, +// func); +template +using PortableUnaryFunctionAdapter = UnaryFunctionAdapter; + +// PortableBinaryFunctionAdapter provides a factory for adapting 2 argument +// functions to CEL extension functions. +// +// Create(absl::string_view function_name, bool receiver_style, +// FunctionType func) -> std::unique_ptr +// +// Usage example: +// +// auto func = [](::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { +// return i < j; +// }; +// +// auto cel_func = +// PortableBinaryFunctionAdapter::Create("<", +// false, func); +template +using PortableBinaryFunctionAdapter = BinaryFunctionAdapter; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ diff --git a/eval/public/set_util.cc b/eval/public/set_util.cc index fd85903f1..60594e5fa 100644 --- a/eval/public/set_util.cc +++ b/eval/public/set_util.cc @@ -1,11 +1,9 @@ #include "eval/public/set_util.h" #include +#include -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { // Default implementation is operator<. @@ -21,6 +19,29 @@ int ComparisonImpl(T lhs, T rhs) { } } +template <> +int ComparisonImpl(const CelError* lhs, const CelError* rhs) { + if (*lhs == *rhs) { + return 0; + } + return lhs < rhs ? -1 : 1; +} + +// Message wrapper specialization +template <> +int ComparisonImpl(CelValue::MessageWrapper lhs_wrapper, + CelValue::MessageWrapper rhs_wrapper) { + auto* lhs = lhs_wrapper.message_ptr(); + auto* rhs = rhs_wrapper.message_ptr(); + if (lhs < rhs) { + return -1; + } else if (lhs > rhs) { + return 1; + } else { + return 0; + } +} + // List specialization -- compare size then elementwise compare. template <> int ComparisonImpl(const CelList* lhs, const CelList* rhs) { @@ -28,9 +49,10 @@ int ComparisonImpl(const CelList* lhs, const CelList* rhs) { if (size_comparison != 0) { return size_comparison; } + google::protobuf::Arena arena; for (int i = 0; i < lhs->size(); i++) { - CelValue lhs_i = lhs->operator[](i); - CelValue rhs_i = rhs->operator[](i); + CelValue lhs_i = lhs->Get(&arena, i); + CelValue rhs_i = rhs->Get(&arena, i); int value_comparison = CelValueCompare(lhs_i, rhs_i); if (value_comparison != 0) { return value_comparison; @@ -51,17 +73,19 @@ int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { return size_comparison; } + google::protobuf::Arena arena; + std::vector lhs_keys; std::vector rhs_keys; lhs_keys.reserve(lhs->size()); rhs_keys.reserve(lhs->size()); - const CelList* lhs_key_view = lhs->ListKeys(); - const CelList* rhs_key_view = rhs->ListKeys(); + const CelList* lhs_key_view = lhs->ListKeys(&arena).value(); + const CelList* rhs_key_view = rhs->ListKeys(&arena).value(); for (int i = 0; i < lhs->size(); i++) { - lhs_keys.push_back(lhs_key_view->operator[](i)); - rhs_keys.push_back(rhs_key_view->operator[](i)); + lhs_keys.push_back(lhs_key_view->Get(&arena, i)); + rhs_keys.push_back(rhs_key_view->Get(&arena, i)); } std::sort(lhs_keys.begin(), lhs_keys.end(), &CelValueLessThan); @@ -76,8 +100,8 @@ int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { } // keys equal, compare values. - auto lhs_value_i = lhs->operator[](lhs_key_i).value(); - auto rhs_value_i = rhs->operator[](rhs_key_i).value(); + auto lhs_value_i = lhs->Get(&arena, lhs_key_i).value(); + auto rhs_value_i = rhs->Get(&arena, rhs_key_i).value(); int value_comparison = CelValueCompare(lhs_value_i, rhs_value_i); if (value_comparison != 0) { return value_comparison; @@ -88,8 +112,7 @@ int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { } struct ComparisonVisitor { - CelValue rhs; - ComparisonVisitor(CelValue rhs) : rhs(rhs) {} + explicit ComparisonVisitor(CelValue rhs) : rhs(rhs) {} template int operator()(T lhs_value) { T rhs_value; @@ -99,27 +122,26 @@ struct ComparisonVisitor { } return ComparisonImpl(lhs_value, rhs_value); } + + CelValue rhs; }; } // namespace int CelValueCompare(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)); + return lhs.InternalVisit(ComparisonVisitor(rhs)); } bool CelValueLessThan(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)) < 0; + return lhs.InternalVisit(ComparisonVisitor(rhs)) < 0; } bool CelValueEqual(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)) == 0; + return lhs.InternalVisit(ComparisonVisitor(rhs)) == 0; } bool CelValueGreaterThan(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)) > 0; + return lhs.InternalVisit(ComparisonVisitor(rhs)) > 0; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/set_util_test.cc b/eval/public/set_util_test.cc index 0845ac86a..5eeabafdd 100644 --- a/eval/public/set_util_test.cc +++ b/eval/public/set_util_test.cc @@ -1,13 +1,13 @@ #include "eval/public/set_util.h" -#include +#include +#include +#include +#include +#include #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" -#include "google/protobuf/arena.h" -#include "google/protobuf/message.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/status/status.h" #include "absl/time/clock.h" #include "absl/time/time.h" @@ -16,6 +16,8 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/unknown_set.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" namespace google { namespace api { @@ -45,8 +47,8 @@ std::string* ExampleStr2() { // ordering in |CelValueLessThan|. Length 13 std::vector TypeExamples(Arena* arena) { Empty* empty = Arena::Create(arena); - Struct* proto_map = Arena::CreateMessage(arena); - ListValue* proto_list = Arena::CreateMessage(arena); + Struct* proto_map = Arena::Create(arena); + ListValue* proto_list = Arena::Create(arena); UnknownSet* unknown_set = Arena::Create(arena); return {CelValue::CreateBool(false), CelValue::CreateInt64(0), @@ -256,8 +258,8 @@ TEST(CelValueLessThan, PtrCmpUnknownSet) { TEST(CelValueLessThan, PtrCmpError) { Arena arena; - CelValue lhs = CreateErrorValue(&arena, "test", absl::StatusCode::kInternal); - CelValue rhs = CreateErrorValue(&arena, "test", absl::StatusCode::kInternal); + CelValue lhs = CreateErrorValue(&arena, "test1", absl::StatusCode::kInternal); + CelValue rhs = CreateErrorValue(&arena, "test2", absl::StatusCode::kInternal); if (lhs.ErrorOrDie() > rhs.ErrorOrDie()) { std::swap(lhs, rhs); @@ -331,19 +333,22 @@ TEST(CelValueLessThan, CelMapSameSize) { {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - auto cel_map_backing_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + auto cel_map_backing_1 = + CreateContainerBackedMap(absl::MakeSpan(values)).value(); std::vector> values2{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(4), CelValue::CreateInt64(6)}}; - auto cel_map_backing_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); + auto cel_map_backing_2 = + CreateContainerBackedMap(absl::MakeSpan(values2)).value(); std::vector> values3{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(8)}}; - auto cel_map_backing_3 = CreateContainerBackedMap(absl::MakeSpan(values3)); + auto cel_map_backing_3 = + CreateContainerBackedMap(absl::MakeSpan(values3)).value(); CelValue map1 = CelValue::CreateMap(cel_map_backing_1.get()); CelValue map2 = CelValue::CreateMap(cel_map_backing_2.get()); @@ -359,14 +364,14 @@ TEST(CelValueLessThan, CelMapDifferentSizes) { {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); std::vector> values2{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)).value(); EXPECT_TRUE(CelValueLessThan(CelValue::CreateMap(cel_map_1.get()), CelValue::CreateMap(cel_map_2.get()))); @@ -378,14 +383,14 @@ TEST(CelValueLessThan, CelMapEqual) { {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); std::vector> values2{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)).value(); EXPECT_FALSE(CelValueLessThan(CelValue::CreateMap(cel_map_1.get()), CelValue::CreateMap(cel_map_2.get()))); @@ -418,7 +423,7 @@ TEST(CelValueLessThan, CelMapSupportProtoMapCompatible) { {CelValue::CreateStringView(kFields[1]), CelValue::CreateDouble(1.0)}, {CelValue::CreateStringView(kFields[0]), CelValue::CreateBool(true)}}; - auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); + auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)).value(); CelValue cel_map = CelValue::CreateMap(backing_map.get()); @@ -451,7 +456,7 @@ TEST(CelValueLessThan, NestedMap) { std::vector> values{ {CelValue::CreateStringView("field"), cel_list}}; - auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); + auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)).value(); CelValue cel_map = CelValue::CreateMap(backing_map.get()); CelValue proto_map = CelProtoWrapper::CreateMessage(&value_struct, &arena); diff --git a/eval/public/source_position.cc b/eval/public/source_position.cc index 350d0a30e..ac902fa0e 100644 --- a/eval/public/source_position.cc +++ b/eval/public/source_position.cc @@ -14,12 +14,14 @@ #include "eval/public/source_position.h" +#include + namespace google { namespace api { namespace expr { namespace runtime { -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::SourceInfo; namespace { diff --git a/eval/public/source_position.h b/eval/public/source_position.h index 739f501b4..c4b7f0f88 100644 --- a/eval/public/source_position.h +++ b/eval/public/source_position.h @@ -17,7 +17,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" namespace google { namespace api { @@ -31,7 +31,7 @@ class SourcePosition { // Constructor for a SourcePosition value. The source_info may be nullptr, // in which case line, column, and character_offset will return 0. SourcePosition(const int64_t expr_id, - const google::api::expr::v1alpha1::SourceInfo* source_info) + const cel::expr::SourceInfo* source_info) : expr_id_(expr_id), source_info_(source_info) {} // Non-copyable @@ -54,7 +54,7 @@ class SourcePosition { // The expression identifier. const int64_t expr_id_; // The source information reference generated during expression parsing. - const google::api::expr::v1alpha1::SourceInfo* source_info_; + const cel::expr::SourceInfo* source_info_; }; } // namespace runtime diff --git a/eval/public/source_position_test.cc b/eval/public/source_position_test.cc index 4d16c9259..16140d96f 100644 --- a/eval/public/source_position_test.cc +++ b/eval/public/source_position_test.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "google/api/expr/v1alpha1/syntax.pb.h" #include "eval/public/source_position.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "cel/expr/syntax.pb.h" +#include "internal/testing.h" + namespace google { namespace api { namespace expr { @@ -24,8 +24,8 @@ namespace runtime { namespace { -using testing::Eq; -using google::api::expr::v1alpha1::SourceInfo; +using ::testing::Eq; +using cel::expr::SourceInfo; class SourcePositionTest : public testing::Test { protected: diff --git a/eval/public/string_extension_func_registrar.cc b/eval/public/string_extension_func_registrar.cc new file mode 100644 index 000000000..9bccfe6d1 --- /dev/null +++ b/eval/public/string_extension_func_registrar.cc @@ -0,0 +1,29 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/string_extension_func_registrar.h" + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "extensions/strings.h" + +namespace google::api::expr::runtime { + +absl::Status RegisterStringExtensionFunctions( + CelFunctionRegistry* registry, const InterpreterOptions& options) { + return cel::extensions::RegisterStringsFunctions(registry, options); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/string_extension_func_registrar.h b/eval/public/string_extension_func_registrar.h new file mode 100644 index 000000000..98c296745 --- /dev/null +++ b/eval/public/string_extension_func_registrar.h @@ -0,0 +1,31 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRING_EXTENSION_FUNC_REGISTRAR_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRING_EXTENSION_FUNC_REGISTRAR_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" + +namespace google::api::expr::runtime { + +// Register string related widely used extension functions. +absl::Status RegisterStringExtensionFunctions( + CelFunctionRegistry* registry, + const InterpreterOptions& options = InterpreterOptions()); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRING_EXTENSION_FUNC_REGISTRAR_H_ diff --git a/eval/public/string_extension_func_registrar_test.cc b/eval/public/string_extension_func_registrar_test.cc new file mode 100644 index 000000000..7fd6e746f --- /dev/null +++ b/eval/public/string_extension_func_registrar_test.cc @@ -0,0 +1,373 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/string_extension_func_registrar.h" + +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "absl/types/span.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { +using google::protobuf::Arena; + +class StringExtensionTest : public ::testing::Test { + protected: + StringExtensionTest() = default; + void SetUp() override { + ASSERT_OK(RegisterBuiltinFunctions(®istry_)); + ASSERT_OK(RegisterStringExtensionFunctions(®istry_)); + } + + void PerformSplitStringTest(Arena* arena, std::string* value, + std::string* delimiter, CelValue* result) { + auto function = registry_.FindOverloads( + "split", true, {CelValue::Type::kString, CelValue::Type::kString}); + ASSERT_EQ(function.size(), 1); + auto func = function[0]; + std::vector args = {CelValue::CreateString(value), + CelValue::CreateString(delimiter)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + ASSERT_OK(status); + } + + void PerformSplitStringWithLimitTest(Arena* arena, std::string* value, + std::string* delimiter, int64_t limit, + CelValue* result) { + auto function = registry_.FindOverloads( + "split", true, + {CelValue::Type::kString, CelValue::Type::kString, + CelValue::Type::kInt64}); + ASSERT_EQ(function.size(), 1); + auto func = function[0]; + std::vector args = {CelValue::CreateString(value), + CelValue::CreateString(delimiter), + CelValue::CreateInt64(limit)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + ASSERT_OK(status); + } + + void PerformJoinStringTest(Arena* arena, std::vector& values, + CelValue* result) { + auto function = + registry_.FindOverloads("join", true, {CelValue::Type::kList}); + ASSERT_EQ(function.size(), 1); + auto func = function[0]; + + std::vector cel_list; + cel_list.reserve(values.size()); + for (const std::string& value : values) { + cel_list.push_back( + CelValue::CreateString(Arena::Create(arena, value))); + } + + std::vector args = {CelValue::CreateList( + Arena::Create(arena, cel_list))}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + ASSERT_OK(status); + } + + void PerformJoinStringWithSeparatorTest(Arena* arena, + std::vector& values, + std::string* separator, + CelValue* result) { + auto function = registry_.FindOverloads( + "join", true, {CelValue::Type::kList, CelValue::Type::kString}); + ASSERT_EQ(function.size(), 1); + auto func = function[0]; + + std::vector cel_list; + cel_list.reserve(values.size()); + for (const std::string& value : values) { + cel_list.push_back( + CelValue::CreateString(Arena::Create(arena, value))); + } + std::vector args = { + CelValue::CreateList( + Arena::Create(arena, cel_list)), + CelValue::CreateString(separator)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + ASSERT_OK(status); + } + + void PerformLowerAsciiTest(Arena* arena, std::string* value, + CelValue* result) { + auto function = + registry_.FindOverloads("lowerAscii", true, {CelValue::Type::kString}); + ASSERT_EQ(function.size(), 1); + auto func = function[0]; + std::vector args = {CelValue::CreateString(value)}; + absl::Span arg_span(&args[0], args.size()); + auto status = func->Evaluate(arg_span, result, arena); + ASSERT_OK(status); + } + + // Function registry + CelFunctionRegistry registry_; + Arena arena_; +}; + +TEST_F(StringExtensionTest, TestStringSplit) { + Arena arena; + CelValue result; + std::string value = "This!!Is!!Test"; + std::string delimiter = "!!"; + std::vector expected = {"This", "Is", "Test"}; + + ASSERT_NO_FATAL_FAILURE( + PerformSplitStringTest(&arena, &value, &delimiter, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 3); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), + expected[i]); + } +} + +TEST_F(StringExtensionTest, TestStringSplitEmptyDelimiter) { + Arena arena; + CelValue result; + std::string value = "TEST"; + std::string delimiter = ""; + std::vector expected = {"T", "E", "S", "T"}; + + ASSERT_NO_FATAL_FAILURE( + PerformSplitStringTest(&arena, &value, &delimiter, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 4); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), + expected[i]); + } +} + +TEST_F(StringExtensionTest, TestStringSplitWithLimitTwo) { + Arena arena; + CelValue result; + int64_t limit = 2; + std::string value = "This!!Is!!Test"; + std::string delimiter = "!!"; + std::vector expected = {"This", "Is!!Test"}; + + ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( + &arena, &value, &delimiter, limit, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 2); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), + expected[i]); + } +} + +TEST_F(StringExtensionTest, TestStringSplitWithLimitOne) { + Arena arena; + CelValue result; + int64_t limit = 1; + std::string value = "This!!Is!!Test"; + std::string delimiter = "!!"; + ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( + &arena, &value, &delimiter, limit, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 1); + EXPECT_EQ(result.ListOrDie()->Get(&arena, 0).StringOrDie().value(), value); +} + +TEST_F(StringExtensionTest, TestStringSplitWithLimitZero) { + Arena arena; + CelValue result; + int64_t limit = 0; + std::string value = "This!!Is!!Test"; + std::string delimiter = "!!"; + ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( + &arena, &value, &delimiter, limit, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 0); +} + +TEST_F(StringExtensionTest, TestStringSplitWithLimitNegative) { + Arena arena; + CelValue result; + int64_t limit = -1; + std::string value = "This!!Is!!Test"; + std::string delimiter = "!!"; + std::vector expected = {"This", "Is", "Test"}; + ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( + &arena, &value, &delimiter, limit, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 3); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), + expected[i]); + } +} + +TEST_F(StringExtensionTest, TestStringSplitWithLimitAsMaxPossibleSplits) { + Arena arena; + CelValue result; + int64_t limit = 3; + std::string value = "This!!Is!!Test"; + std::string delimiter = "!!"; + std::vector expected = {"This", "Is", "Test"}; + + ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( + &arena, &value, &delimiter, limit, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 3); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), + expected[i]); + } +} + +TEST_F(StringExtensionTest, + TestStringSplitWithLimitGreaterThanMaxPossibleSplits) { + Arena arena; + CelValue result; + int64_t limit = 4; + std::string value = "This!!Is!!Test"; + std::string delimiter = "!!"; + std::vector expected = {"This", "Is", "Test"}; + + ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( + &arena, &value, &delimiter, limit, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kList); + EXPECT_EQ(result.ListOrDie()->size(), 3); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), + expected[i]); + } +} + +TEST_F(StringExtensionTest, TestStringJoin) { + Arena arena; + CelValue result; + std::vector value = {"This", "Is", "Test"}; + std::string expected = "ThisIsTest"; + + ASSERT_NO_FATAL_FAILURE(PerformJoinStringTest(&arena, value, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestStringJoinEmptyInput) { + Arena arena; + CelValue result; + std::vector value = {}; + std::string expected = ""; + + ASSERT_NO_FATAL_FAILURE(PerformJoinStringTest(&arena, value, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestStringJoinWithSeparator) { + Arena arena; + CelValue result; + std::vector value = {"This", "Is", "Test"}; + std::string separator = "-"; + std::string expected = "This-Is-Test"; + + ASSERT_NO_FATAL_FAILURE( + PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestStringJoinWithMultiCharSeparator) { + Arena arena; + CelValue result; + std::vector value = {"This", "Is", "Test"}; + std::string separator = "--"; + std::string expected = "This--Is--Test"; + + ASSERT_NO_FATAL_FAILURE( + PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestStringJoinWithEmptySeparator) { + Arena arena; + CelValue result; + std::vector value = {"This", "Is", "Test"}; + std::string separator = ""; + std::string expected = "ThisIsTest"; + + ASSERT_NO_FATAL_FAILURE( + PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestStringJoinWithSeparatorEmptyInput) { + Arena arena; + CelValue result; + std::vector value = {}; + std::string separator = "-"; + std::string expected = ""; + + ASSERT_NO_FATAL_FAILURE( + PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestLowerAscii) { + Arena arena; + CelValue result; + std::string value = "ThisIs@Test!-5"; + std::string expected = "thisis@test!-5"; + + ASSERT_NO_FATAL_FAILURE(PerformLowerAsciiTest(&arena, &value, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestLowerAsciiWithEmptyInput) { + Arena arena; + CelValue result; + std::string value = ""; + std::string expected = ""; + + ASSERT_NO_FATAL_FAILURE(PerformLowerAsciiTest(&arena, &value, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +TEST_F(StringExtensionTest, TestLowerAsciiWithNonAsciiCharacter) { + Arena arena; + CelValue result; + std::string value = "TacoCÆt"; + std::string expected = "tacocÆt"; + + ASSERT_NO_FATAL_FAILURE(PerformLowerAsciiTest(&arena, &value, &result)); + ASSERT_EQ(result.type(), CelValue::Type::kString); + EXPECT_EQ(result.StringOrDie().value(), expected); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 3660d1e31..d722559e3 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -1,6 +1,23 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) cc_library( name = "cel_proto_wrapper", @@ -11,13 +28,173 @@ cc_library( "cel_proto_wrapper.h", ], deps = [ + ":cel_proto_wrap_util", + ":proto_message_type_adapter", "//eval/public:cel_value", - "//internal:proto_util", - "@com_google_absl//absl/container:node_hash_map", + "//eval/public:message_wrapper", + "//internal:proto_time_encoding", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:timestamp_cc_proto", + ], +) + +cc_library( + name = "protobuf_value_factory", + hdrs = [ + "protobuf_value_factory.h", + ], + deps = [ + "//eval/public:cel_value", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "cel_proto_wrap_util", + srcs = [ + "cel_proto_wrap_util.cc", + ], + hdrs = [ + "cel_proto_wrap_util.h", + ], + deps = [ + ":protobuf_value_factory", + "//eval/public:cel_value", + "//internal:overflow", + "//internal:proto_time_encoding", + "//internal:status_macros", + "//internal:time", + "//internal:well_known_types", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_test( + name = "cel_proto_wrap_util_test", + size = "small", + srcs = [ + "cel_proto_wrap_util_test.cc", + ], + deps = [ + ":cel_proto_wrap_util", + ":protobuf_value_factory", + ":trivial_legacy_type_info", + "//eval/public:cel_value", + "//eval/public:message_wrapper", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", + "//eval/testutil:test_message_cc_proto", + "//internal:proto_time_encoding", + "//internal:status_macros", + "//internal:testing", + "//testutil:util", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:empty_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_library( + name = "field_access_impl", + srcs = [ + "field_access_impl.cc", + ], + hdrs = [ + "field_access_impl.h", + ], + deps = [ + ":cel_proto_wrap_util", + ":protobuf_value_factory", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:casts", + "//internal:overflow", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_test( + name = "field_access_impl_test", + srcs = ["field_access_impl_test.cc"], + deps = [ + ":cel_proto_wrapper", + ":field_access_impl", + "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "//internal:time", + "//testutil:util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "cel_proto_descriptor_pool_builder", + srcs = ["cel_proto_descriptor_pool_builder.cc"], + hdrs = ["cel_proto_descriptor_pool_builder.h"], + deps = [ + "//internal:proto_util", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:empty_cc_proto", + "@com_google_protobuf//:field_mask_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_test( + name = "cel_proto_descriptor_pool_builder_test", + srcs = ["cel_proto_descriptor_pool_builder_test.cc"], + deps = [ + ":cel_proto_descriptor_pool_builder", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_protobuf//:any_cc_proto", ], ) @@ -29,11 +206,259 @@ cc_test( ], deps = [ ":cel_proto_wrapper", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", - "//internal:proto_util", + "//internal:proto_time_encoding", + "//internal:status_macros", + "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:empty_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_library( + name = "legacy_type_provider", + srcs = ["legacy_type_provider.cc"], + hdrs = ["legacy_type_provider.h"], + deps = [ + ":legacy_type_adapter", + ":legacy_type_info_apis", + "//common:legacy_value", + "//common:memory", + "//common:type", + "//common:value", + "//eval/public:message_wrapper", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "legacy_type_adapter", + hdrs = ["legacy_type_adapter.h"], + deps = [ + "//base:attributes", + "//common:memory", + "//eval/public:cel_options", + "//eval/public:cel_value", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "legacy_type_adapter_test", + srcs = ["legacy_type_adapter_test.cc"], + deps = [ + ":legacy_type_adapter", + ":trivial_legacy_type_info", + "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "proto_message_type_adapter", + srcs = ["proto_message_type_adapter.cc"], + hdrs = ["proto_message_type_adapter.h"], + deps = [ + ":cel_proto_wrap_util", + ":field_access_impl", + ":legacy_type_adapter", + ":legacy_type_info_apis", + "//base:attributes", + "//common:memory", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:message_wrapper", + "//eval/public/containers:internal_field_backed_list_impl", + "//eval/public/containers:internal_field_backed_map_impl", + "//extensions/protobuf:memory_manager", + "//extensions/protobuf/internal:qualify", + "//internal:casts", + "//internal:status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_message_type_adapter_test", + srcs = ["proto_message_type_adapter_test.cc"], + deps = [ + ":legacy_type_adapter", + ":legacy_type_info_apis", + ":proto_message_type_adapter", + "//base:attributes", + "//common:value", + "//eval/public:cel_value", + "//eval/public:message_wrapper", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:memory_manager", + "//internal:proto_matchers", + "//internal:testing", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_library( + name = "protobuf_descriptor_type_provider", + srcs = ["protobuf_descriptor_type_provider.cc"], + hdrs = ["protobuf_descriptor_type_provider.h"], + deps = [ + ":legacy_type_adapter", + ":legacy_type_info_apis", + ":legacy_type_provider", + ":proto_message_type_adapter", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "protobuf_descriptor_type_provider_test", + srcs = ["protobuf_descriptor_type_provider_test.cc"], + deps = [ + ":legacy_type_info_apis", + ":protobuf_descriptor_type_provider", + "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//extensions/protobuf:memory_manager", + "//internal:testing", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_library( + name = "legacy_type_info_apis", + hdrs = ["legacy_type_info_apis.h"], + deps = [ + "//eval/public:message_wrapper", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "trivial_legacy_type_info", + testonly = True, + hdrs = ["trivial_legacy_type_info.h"], + deps = [ + ":legacy_type_info_apis", + "//eval/public:message_wrapper", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "trivial_legacy_type_info_test", + srcs = ["trivial_legacy_type_info_test.cc"], + deps = [ + ":trivial_legacy_type_info", + "//eval/public:message_wrapper", + "//internal:testing", + ], +) + +cc_test( + name = "legacy_type_provider_test", + srcs = ["legacy_type_provider_test.cc"], + deps = [ + ":legacy_type_info_apis", + ":legacy_type_provider", + "//internal:testing", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "dynamic_descriptor_pool_end_to_end_test", + srcs = ["dynamic_descriptor_pool_end_to_end_test.cc"], + deps = [ + ":cel_proto_descriptor_pool_builder", + ":cel_proto_wrapper", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public/testing:matchers", + "//internal:testing", + "//parser", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "field_access_impl_benchmark_test", + srcs = ["field_access_impl_benchmark_test.cc"], + tags = [ + "benchmark", + "manual", + ], + deps = [ + ":cel_proto_wrapper", + ":field_access_impl", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//extensions/protobuf/internal:map_reflection", + "//internal:benchmark", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/structs/cel_proto_descriptor_pool_builder.cc b/eval/public/structs/cel_proto_descriptor_pool_builder.cc new file mode 100644 index 000000000..158fcb8de --- /dev/null +++ b/eval/public/structs/cel_proto_descriptor_pool_builder.cc @@ -0,0 +1,134 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" + +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/empty.pb.h" +#include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/container/flat_hash_map.h" +#include "internal/proto_util.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { +namespace { +template +absl::Status AddOrValidateMessageType(google::protobuf::DescriptorPool& descriptor_pool) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + if (descriptor_pool.FindMessageTypeByName(descriptor->full_name()) != + nullptr) { + return internal::ValidateStandardMessageType(descriptor_pool); + } + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + if (descriptor_pool.BuildFile(file_descriptor_proto) == nullptr) { + return absl::InternalError( + absl::StrFormat("Failed to add descriptor '%s' to descriptor pool", + descriptor->full_name())); + } + return absl::OkStatus(); +} + +template +void AddStandardMessageTypeToMap( + absl::flat_hash_map& fdmap) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + + if (fdmap.contains(descriptor->file()->name())) return; + + descriptor->file()->CopyTo(&fdmap[descriptor->file()->name()]); +} + +} // namespace + +absl::Status AddStandardMessageTypesToDescriptorPool( + google::protobuf::DescriptorPool& descriptor_pool) { + // The types below do not depend on each other, hence we can add them in any + // order. Should that change with new messages add them in the proper order, + // i.e., dependencies first. + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + return absl::OkStatus(); +} + +google::protobuf::FileDescriptorSet GetStandardMessageTypesFileDescriptorSet() { + // The types below do not depend on each other, hence we can add them to + // an unordered map. Should that change with new messages being added here + // adapt this to a sorted data structure and add in the proper order. + absl::flat_hash_map files; + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + google::protobuf::FileDescriptorSet fdset; + for (const auto& [name, fdproto] : files) { + *fdset.add_file() = fdproto; + } + return fdset; +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/cel_proto_descriptor_pool_builder.h b/eval/public/structs/cel_proto_descriptor_pool_builder.h new file mode 100644 index 000000000..bb1357a6f --- /dev/null +++ b/eval/public/structs/cel_proto_descriptor_pool_builder.h @@ -0,0 +1,40 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ + +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "google/protobuf/descriptor.h" + +namespace google::api::expr::runtime { + +// Add standard message types required by CEL to given descriptor pool. +// This includes standard wrappers, timestamp, duration, any, etc. +// This does not work for descriptor pools that have a fallback database. +// Use GetStandardMessageTypesFileDescriptorSet() below instead to populate. +absl::Status AddStandardMessageTypesToDescriptorPool( + google::protobuf::DescriptorPool& descriptor_pool); + +// Get the standard message types required by CEL. +// This includes standard wrappers, timestamp, duration, any, etc. These can be +// used to, e.g., add them to a DescriptorDatabase backing a DescriptorPool. +google::protobuf::FileDescriptorSet GetStandardMessageTypesFileDescriptorSet(); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ diff --git a/eval/public/structs/cel_proto_descriptor_pool_builder_test.cc b/eval/public/structs/cel_proto_descriptor_pool_builder_test.cc new file mode 100644 index 000000000..43c76386b --- /dev/null +++ b/eval/public/structs/cel_proto_descriptor_pool_builder_test.cc @@ -0,0 +1,188 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" + +#include +#include + +#include "google/protobuf/any.pb.h" +#include "absl/container/flat_hash_map.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { + +namespace { + +using ::absl_testing::StatusIs; +using ::testing::HasSubstr; +using ::testing::UnorderedElementsAre; + +TEST(DescriptorPoolUtilsTest, PopulatesEmptyDescriptorPool) { + google::protobuf::DescriptorPool descriptor_pool; + + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Any"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.BoolValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.BytesValue"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.DoubleValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Duration"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.FloatValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Int32Value"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Int64Value"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.ListValue"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.StringValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Struct"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Timestamp"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt32Value"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt64Value"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Value"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.FieldMask"), + nullptr); + + ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); + + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Any"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.BoolValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.BytesValue"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.DoubleValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Duration"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.FloatValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Int32Value"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Int64Value"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.ListValue"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.StringValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Struct"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Timestamp"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt32Value"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt64Value"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Value"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.FieldMask"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Empty"), + nullptr); +} + +TEST(DescriptorPoolUtilsTest, AcceptsPreAddedStandardTypes) { + google::protobuf::DescriptorPool descriptor_pool; + + for (auto proto_name : std::vector{ + "google.protobuf.Any", "google.protobuf.BoolValue", + "google.protobuf.BytesValue", "google.protobuf.DoubleValue", + "google.protobuf.Duration", "google.protobuf.FloatValue", + "google.protobuf.Int32Value", "google.protobuf.Int64Value", + "google.protobuf.ListValue", "google.protobuf.StringValue", + "google.protobuf.Struct", "google.protobuf.Timestamp", + "google.protobuf.UInt32Value", "google.protobuf.UInt64Value", + "google.protobuf.Value", "google.protobuf.FieldMask", + "google.protobuf.Empty"}) { + const google::protobuf::Descriptor* descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + proto_name); + ASSERT_NE(descriptor, nullptr); + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + ASSERT_NE(descriptor_pool.BuildFile(file_descriptor_proto), nullptr); + } + + EXPECT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); +} + +TEST(DescriptorPoolUtilsTest, RejectsModifiedStandardType) { + google::protobuf::DescriptorPool descriptor_pool; + + const google::protobuf::Descriptor* descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.protobuf.Duration"); + ASSERT_NE(descriptor, nullptr); + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + // We emulate a modification by external code that replaced the nanos by a + // millis field. + google::protobuf::FieldDescriptorProto seconds_desc_proto; + google::protobuf::FieldDescriptorProto nanos_desc_proto; + descriptor->FindFieldByName("seconds")->CopyTo(&seconds_desc_proto); + descriptor->FindFieldByName("nanos")->CopyTo(&nanos_desc_proto); + nanos_desc_proto.set_name("millis"); + file_descriptor_proto.mutable_message_type(0)->clear_field(); + *file_descriptor_proto.mutable_message_type(0)->add_field() = + seconds_desc_proto; + *file_descriptor_proto.mutable_message_type(0)->add_field() = + nanos_desc_proto; + + descriptor_pool.BuildFile(file_descriptor_proto); + + EXPECT_THAT( + AddStandardMessageTypesToDescriptorPool(descriptor_pool), + StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("differs"))); +} + +TEST(DescriptorPoolUtilsTest, GetStandardMessageTypesFileDescriptorSet) { + google::protobuf::FileDescriptorSet fdset = GetStandardMessageTypesFileDescriptorSet(); + std::vector file_names; + for (int i = 0; i < fdset.file_size(); ++i) { + file_names.push_back(fdset.file(i).name()); + } + EXPECT_THAT( + file_names, + UnorderedElementsAre( + "google/protobuf/any.proto", "google/protobuf/struct.proto", + "google/protobuf/wrappers.proto", "google/protobuf/timestamp.proto", + "google/protobuf/duration.proto", "google/protobuf/field_mask.proto", + "google/protobuf/empty.proto")); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc new file mode 100644 index 000000000..7bfe81fe6 --- /dev/null +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -0,0 +1,1479 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/cel_proto_wrap_util.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "internal/overflow.h" +#include "internal/proto_time_encoding.h" +#include "internal/status_macros.h" +#include "internal/time.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace google::api::expr::runtime::internal { + +namespace { + +using cel::internal::DecodeDuration; +using cel::internal::DecodeTime; +using google::protobuf::Any; +using google::protobuf::BoolValue; +using google::protobuf::BytesValue; +using google::protobuf::DoubleValue; +using google::protobuf::Duration; +using google::protobuf::FloatValue; +using google::protobuf::Int32Value; +using google::protobuf::Int64Value; +using google::protobuf::ListValue; +using google::protobuf::StringValue; +using google::protobuf::Struct; +using google::protobuf::Timestamp; +using google::protobuf::UInt32Value; +using google::protobuf::UInt64Value; +using google::protobuf::Value; +using google::protobuf::Arena; +using google::protobuf::Descriptor; +using google::protobuf::DescriptorPool; +using google::protobuf::Message; +using google::protobuf::MessageFactory; + +// kMaxIntJSON is defined as the Number.MAX_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMaxIntJSON = (1ll << 53) - 1; + +// kMinIntJSON is defined as the Number.MIN_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMinIntJSON = -kMaxIntJSON; + +// IsJSONSafe indicates whether the int is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(int64_t i) { + return i >= kMinIntJSON && i <= kMaxIntJSON; +} + +// IsJSONSafe indicates whether the uint is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(uint64_t i) { + return i <= static_cast(kMaxIntJSON); +} + +// Map implementation wrapping google.protobuf.ListValue +class DynamicList : public CelList { + public: + DynamicList(const ListValue* values, ProtobufValueFactory factory, + Arena* arena) + : arena_(arena), factory_(std::move(factory)), values_(values) {} + + CelValue operator[](int index) const override; + + // List size + int size() const override { return values_->values_size(); } + + private: + Arena* arena_; + ProtobufValueFactory factory_; + const ListValue* values_; +}; + +// Map implementation wrapping google.protobuf.Struct. +class DynamicMap : public CelMap { + public: + DynamicMap(const Struct* values, ProtobufValueFactory factory, Arena* arena) + : arena_(arena), + factory_(std::move(factory)), + values_(values), + key_list_(values) {} + + absl::StatusOr Has(const CelValue& key) const override { + CelValue::StringHolder str_key; + if (!key.GetValue(&str_key)) { + // Not a string key. + return absl::InvalidArgumentError(absl::StrCat( + "Invalid map key type: '", CelValue::TypeName(key.type()), "'")); + } + + return values_->fields().contains(std::string(str_key.value())); + } + + absl::optional operator[](CelValue key) const override; + + int size() const override { return values_->fields_size(); } + + absl::StatusOr ListKeys() const override { + return &key_list_; + } + + private: + // List of keys in Struct.fields map. + // It utilizes lazy initialization, to avoid performance penalties. + class DynamicMapKeyList : public CelList { + public: + explicit DynamicMapKeyList(const Struct* values) + : values_(values), keys_(), initialized_(false) {} + + // Index access + CelValue operator[](int index) const override { + CheckInit(); + return keys_[index]; + } + + // List size + int size() const override { + CheckInit(); + return values_->fields_size(); + } + + private: + void CheckInit() const { + absl::MutexLock lock(mutex_); + if (!initialized_) { + for (const auto& it : values_->fields()) { + keys_.push_back(CelValue::CreateString(&it.first)); + } + initialized_ = true; + } + } + + const Struct* values_; + mutable absl::Mutex mutex_; + mutable std::vector keys_; + mutable bool initialized_; + }; + + Arena* arena_; + ProtobufValueFactory factory_; + const Struct* values_; + const DynamicMapKeyList key_list_; +}; + +// Adapter for usage with CEL_RETURN_IF_ERROR and CEL_ASSIGN_OR_RETURN. +class ReturnCelValueError { + public: + explicit ReturnCelValueError(google::protobuf::Arena* absl_nonnull arena) + : arena_(arena) {} + + CelValue operator()(const absl::Status& status) const { + ABSL_DCHECK(!status.ok()); + return CelValue::CreateError( + google::protobuf::Arena::Create(arena_, status)); + } + + private: + google::protobuf::Arena* absl_nonnull arena_; +}; + +struct IgnoreErrorAndReturnNullptr { + std::nullptr_t operator()(const absl::Status& status) const { + status.IgnoreError(); + return nullptr; + } +}; + +// ValueManager provides ValueFromMessage(....) function family. +// Functions of this family create CelValue object from specific subtypes of +// protobuf message. +class ValueManager { + public: + ValueManager(const ProtobufValueFactory& value_factory, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::Arena* arena, google::protobuf::MessageFactory* message_factory) + : value_factory_(value_factory), + descriptor_pool_(descriptor_pool), + arena_(arena), + message_factory_(message_factory) {} + + // Note: this overload should only be used in the context of accessing struct + // value members, which have already been adapted to the generated message + // types. + ValueManager(const ProtobufValueFactory& value_factory, google::protobuf::Arena* arena) + : value_factory_(value_factory), + descriptor_pool_(DescriptorPool::generated_pool()), + arena_(arena), + message_factory_(MessageFactory::generated_factory()) {} + + static CelValue ValueFromDuration(absl::Duration duration) { + return CelValue::CreateDuration(duration); + } + + CelValue ValueFromDuration(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetDurationReflection(message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromDuration(reflection.UnsafeToAbslDuration(*message)); + } + + CelValue ValueFromMessage(const Duration* duration) { + return ValueFromDuration(DecodeDuration(*duration)); + } + + CelValue ValueFromTimestamp(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetTimestampReflection(message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromTimestamp(reflection.UnsafeToAbslTime(*message)); + } + + static CelValue ValueFromTimestamp(absl::Time timestamp) { + return CelValue::CreateTimestamp(timestamp); + } + + CelValue ValueFromMessage(const Timestamp* timestamp) { + return ValueFromTimestamp(DecodeTime(*timestamp)); + } + + CelValue ValueFromMessage(const ListValue* list_values) { + return CelValue::CreateList(Arena::Create( + arena_, list_values, value_factory_, arena_)); + } + + CelValue ValueFromMessage(const Struct* struct_value) { + return CelValue::CreateMap(Arena::Create( + arena_, struct_value, value_factory_, arena_)); + } + + CelValue ValueFromAny(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetAnyReflection(message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + std::string type_url_scratch; + std::string value_scratch; + return ValueFromAny(reflection.GetTypeUrl(*message, type_url_scratch), + reflection.GetValue(*message, value_scratch), + descriptor_pool_, message_factory_); + } + + CelValue ValueFromAny(const cel::well_known_types::StringValue& type_url, + const cel::well_known_types::BytesValue& payload, + const DescriptorPool* descriptor_pool, + MessageFactory* message_factory) { + std::string type_url_string_scratch; + absl::string_view type_url_string = absl::visit( + absl::Overload([](absl::string_view string) + -> absl::string_view { return string; }, + [&type_url_string_scratch]( + const absl::Cord& cord) -> absl::string_view { + if (auto flat = cord.TryFlat(); flat) { + return *flat; + } + absl::CopyCordToString(cord, &type_url_string_scratch); + return absl::string_view(type_url_string_scratch); + }), + cel::well_known_types::AsVariant(type_url)); + auto pos = type_url_string.find_last_of('/'); + if (pos == type_url_string.npos) { + // TODO(issues/25) What error code? + // Malformed type_url + return CreateErrorValue(arena_, "Malformed type_url string"); + } + + absl::string_view full_name = type_url_string.substr(pos + 1); + const Descriptor* nested_descriptor = + descriptor_pool->FindMessageTypeByName(full_name); + + if (nested_descriptor == nullptr) { + // Descriptor not found for the type + // TODO(issues/25) What error code? + return CreateErrorValue(arena_, "Descriptor not found"); + } + + const Message* prototype = message_factory->GetPrototype(nested_descriptor); + if (prototype == nullptr) { + // Failed to obtain prototype for the descriptor + // TODO(issues/25) What error code? + return CreateErrorValue(arena_, "Prototype not found"); + } + + Message* nested_message = prototype->New(arena_); + bool ok = + absl::visit(absl::Overload( + [nested_message](absl::string_view string) -> bool { + return nested_message->ParsePartialFromString(string); + }, + [nested_message](const absl::Cord& cord) -> bool { + return nested_message->ParsePartialFromString(cord); + }), + cel::well_known_types::AsVariant(payload)); + if (!ok) { + // Failed to unpack. + // TODO(issues/25) What error code? + return CreateErrorValue(arena_, "Failed to unpack Any into message"); + } + + return UnwrapMessageToValue(nested_message, value_factory_, arena_); + } + + CelValue ValueFromMessage(const Any* any_value, + const DescriptorPool* descriptor_pool, + MessageFactory* message_factory) { + return ValueFromAny(any_value->type_url(), absl::Cord(any_value->value()), + descriptor_pool, message_factory); + } + + CelValue ValueFromMessage(const Any* any_value) { + return ValueFromMessage(any_value, descriptor_pool_, message_factory_); + } + + CelValue ValueFromBool(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetBoolValueReflection(message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromBool(reflection.GetValue(*message)); + } + + static CelValue ValueFromBool(bool value) { + return CelValue::CreateBool(value); + } + + CelValue ValueFromMessage(const BoolValue* wrapper) { + return ValueFromBool(wrapper->value()); + } + + CelValue ValueFromInt32(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetInt32ValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromInt32(reflection.GetValue(*message)); + } + + static CelValue ValueFromInt32(int32_t value) { + return CelValue::CreateInt64(value); + } + + CelValue ValueFromMessage(const Int32Value* wrapper) { + return ValueFromInt32(wrapper->value()); + } + + CelValue ValueFromUInt32(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetUInt32ValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromUInt32(reflection.GetValue(*message)); + } + + static CelValue ValueFromUInt32(uint32_t value) { + return CelValue::CreateUint64(value); + } + + CelValue ValueFromMessage(const UInt32Value* wrapper) { + return ValueFromUInt32(wrapper->value()); + } + + CelValue ValueFromInt64(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetInt64ValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromInt64(reflection.GetValue(*message)); + } + + static CelValue ValueFromInt64(int64_t value) { + return CelValue::CreateInt64(value); + } + + CelValue ValueFromMessage(const Int64Value* wrapper) { + return ValueFromInt64(wrapper->value()); + } + + CelValue ValueFromUInt64(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetUInt64ValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromUInt64(reflection.GetValue(*message)); + } + + static CelValue ValueFromUInt64(uint64_t value) { + return CelValue::CreateUint64(value); + } + + CelValue ValueFromMessage(const UInt64Value* wrapper) { + return ValueFromUInt64(wrapper->value()); + } + + CelValue ValueFromFloat(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetFloatValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromFloat(reflection.GetValue(*message)); + } + + static CelValue ValueFromFloat(float value) { + return CelValue::CreateDouble(value); + } + + CelValue ValueFromMessage(const FloatValue* wrapper) { + return ValueFromFloat(wrapper->value()); + } + + CelValue ValueFromDouble(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetDoubleValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + return ValueFromDouble(reflection.GetValue(*message)); + } + + static CelValue ValueFromDouble(double value) { + return CelValue::CreateDouble(value); + } + + CelValue ValueFromMessage(const DoubleValue* wrapper) { + return ValueFromDouble(wrapper->value()); + } + + CelValue ValueFromString(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetStringValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + std::string scratch; + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> CelValue { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + return CelValue::CreateString( + google::protobuf::Arena::Create(arena_, + std::move(scratch))); + } + return CelValue::CreateString(google::protobuf::Arena::Create( + arena_, std::string(string))); + }, + [&](absl::Cord&& cord) -> CelValue { + auto* string = google::protobuf::Arena::Create(arena_); + absl::CopyCordToString(cord, string); + return CelValue::CreateString(string); + }), + cel::well_known_types::AsVariant( + reflection.GetValue(*message, scratch))); + } + + CelValue ValueFromString(const absl::Cord& value) { + return CelValue::CreateString( + Arena::Create(arena_, static_cast(value))); + } + + static CelValue ValueFromString(const std::string* value) { + return CelValue::CreateString(value); + } + + CelValue ValueFromMessage(const StringValue* wrapper) { + return ValueFromString(&wrapper->value()); + } + + CelValue ValueFromBytes(const google::protobuf::Message* message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + cel::well_known_types::GetBytesValueReflection( + message->GetDescriptor()), + _.With(ReturnCelValueError(arena_))); + std::string scratch; + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> CelValue { + if (string.data() == scratch.data() && + string.size() == scratch.size()) { + return CelValue::CreateBytes(google::protobuf::Arena::Create( + arena_, std::move(scratch))); + } + return CelValue::CreateBytes(google::protobuf::Arena::Create( + arena_, std::string(string))); + }, + [&](absl::Cord&& cord) -> CelValue { + auto* string = google::protobuf::Arena::Create(arena_); + absl::CopyCordToString(cord, string); + return CelValue::CreateBytes(string); + }), + cel::well_known_types::AsVariant( + reflection.GetValue(*message, scratch))); + } + + CelValue ValueFromBytes(const absl::Cord& value) { + return CelValue::CreateBytes( + Arena::Create(arena_, static_cast(value))); + } + + static CelValue ValueFromBytes(google::protobuf::Arena* arena, std::string value) { + return CelValue::CreateBytes( + Arena::Create(arena, std::move(value))); + } + + CelValue ValueFromMessage(const BytesValue* wrapper) { + // BytesValue stores value as Cord + return CelValue::CreateBytes( + Arena::Create(arena_, std::string(wrapper->value()))); + } + + CelValue ValueFromMessage(const Value* value) { + switch (value->kind_case()) { + case Value::KindCase::kNullValue: + return CelValue::CreateNull(); + case Value::KindCase::kNumberValue: + return CelValue::CreateDouble(value->number_value()); + case Value::KindCase::kStringValue: + return CelValue::CreateString(&value->string_value()); + case Value::KindCase::kBoolValue: + return CelValue::CreateBool(value->bool_value()); + case Value::KindCase::kStructValue: + return ValueFromMessage(&value->struct_value()); + case Value::KindCase::kListValue: + return ValueFromMessage(&value->list_value()); + default: + return CelValue::CreateNull(); + } + } + + template + CelValue ValueFromGeneratedMessageLite(const google::protobuf::Message* message) { + const auto* downcast_message = google::protobuf::DynamicCastToGenerated(message); + if (downcast_message != nullptr) { + return ValueFromMessage(downcast_message); + } + auto* value = google::protobuf::Arena::Create(arena_); + absl::Cord serialized; + if (!message->SerializeToString(&serialized)) { + return CreateErrorValue( + arena_, absl::UnknownError( + absl::StrCat("failed to serialize dynamic message: ", + message->GetTypeName()))); + } + if (!value->ParseFromCord(serialized)) { + return CreateErrorValue(arena_, absl::UnknownError(absl::StrCat( + "failed to parse generated message: ", + value->GetTypeName()))); + } + return ValueFromMessage(value); + } + + template + CelValue ValueFromMessage(const google::protobuf::Message* message) { + if constexpr (std::is_same_v) { + return ValueFromAny(message); + } else if constexpr (std::is_same_v) { + return ValueFromBool(message); + } else if constexpr (std::is_same_v) { + return ValueFromBytes(message); + } else if constexpr (std::is_same_v) { + return ValueFromDouble(message); + } else if constexpr (std::is_same_v) { + return ValueFromDuration(message); + } else if constexpr (std::is_same_v) { + return ValueFromFloat(message); + } else if constexpr (std::is_same_v) { + return ValueFromInt32(message); + } else if constexpr (std::is_same_v) { + return ValueFromInt64(message); + } else if constexpr (std::is_same_v) { + return ValueFromGeneratedMessageLite(message); + } else if constexpr (std::is_same_v) { + return ValueFromString(message); + } else if constexpr (std::is_same_v) { + return ValueFromGeneratedMessageLite(message); + } else if constexpr (std::is_same_v) { + return ValueFromTimestamp(message); + } else if constexpr (std::is_same_v) { + return ValueFromUInt32(message); + } else if constexpr (std::is_same_v) { + return ValueFromUInt64(message); + } else if constexpr (std::is_same_v) { + return ValueFromGeneratedMessageLite(message); + } else { + ABSL_UNREACHABLE(); + } + } + + private: + const ProtobufValueFactory& value_factory_; + const google::protobuf::DescriptorPool* descriptor_pool_; + google::protobuf::Arena* arena_; + MessageFactory* message_factory_; +}; + +// Class makes CelValue from generic protobuf Message. +// It holds a registry of CelValue factories for specific subtypes of Message. +// If message does not match any of types stored in registry, generic +// message-containing CelValue is created. +class ValueFromMessageMaker { + public: + template + static CelValue CreateWellknownTypeValue(const google::protobuf::Message* msg, + const ProtobufValueFactory& factory, + Arena* arena) { + // Copy the original descriptor pool and message factory for unpacking 'Any' + // values. + google::protobuf::MessageFactory* message_factory = + msg->GetReflection()->GetMessageFactory(); + const google::protobuf::DescriptorPool* pool = msg->GetDescriptor()->file()->pool(); + return ValueManager(factory, pool, arena, message_factory) + .ValueFromMessage(msg); + } + + static absl::optional CreateValue( + const google::protobuf::Message* message, const ProtobufValueFactory& factory, + Arena* arena) { + switch (message->GetDescriptor()->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + return CreateWellknownTypeValue(message, factory, arena); + // WELLKNOWNTYPE_FIELDMASK has no special CelValue type + default: + return std::nullopt; + } + } + + // Non-copyable, non-assignable + ValueFromMessageMaker(const ValueFromMessageMaker&) = delete; + ValueFromMessageMaker& operator=(const ValueFromMessageMaker&) = delete; +}; + +CelValue DynamicList::operator[](int index) const { + return ValueManager(factory_, arena_) + .ValueFromMessage(&values_->values(index)); +} + +absl::optional DynamicMap::operator[](CelValue key) const { + CelValue::StringHolder str_key; + if (!key.GetValue(&str_key)) { + // Not a string key. + return CreateErrorValue(arena_, absl::InvalidArgumentError(absl::StrCat( + "Invalid map key type: '", + CelValue::TypeName(key.type()), "'"))); + } + + auto it = values_->fields().find(std::string(str_key.value())); + if (it == values_->fields().end()) { + return std::nullopt; + } + + return ValueManager(factory_, arena_).ValueFromMessage(&it->second); +} + +google::protobuf::Message* DurationFromValue(const google::protobuf::Message* prototype, + const CelValue& value, + google::protobuf::Arena* arena) { + absl::Duration val; + if (!value.GetValue(&val)) { + return nullptr; + } + if (!cel::internal::ValidateDuration(val).ok()) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetDurationReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.UnsafeSetFromAbslDuration(message, val); + return message; +} + +google::protobuf::Message* BoolFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + bool val; + if (!value.GetValue(&val)) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetBoolValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, val); + return message; +} + +google::protobuf::Message* BytesFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + CelValue::BytesHolder view_val; + if (!value.GetValue(&view_val)) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetBytesValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, view_val.value()); + return message; +} + +google::protobuf::Message* DoubleFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + double val; + if (!value.GetValue(&val)) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetDoubleValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, val); + return message; +} + +google::protobuf::Message* FloatFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + double val; + if (!value.GetValue(&val)) { + return nullptr; + } + float fval = val; + // Abort the conversion if the value is outside the float range. + if (val > std::numeric_limits::max()) { + fval = std::numeric_limits::infinity(); + } else if (val < std::numeric_limits::lowest()) { + fval = -std::numeric_limits::infinity(); + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetFloatValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, static_cast(fval)); + return message; +} + +google::protobuf::Message* Int32FromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + int64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + if (!cel::internal::CheckedInt64ToInt32(val).ok()) { + return nullptr; + } + int32_t ival = static_cast(val); + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetInt32ValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, ival); + return message; +} + +google::protobuf::Message* Int64FromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + int64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetInt64ValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, val); + return message; +} + +google::protobuf::Message* StringFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + CelValue::StringHolder view_val; + if (!value.GetValue(&view_val)) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetStringValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, view_val.value()); + return message; +} + +google::protobuf::Message* TimestampFromValue(const google::protobuf::Message* prototype, + const CelValue& value, + google::protobuf::Arena* arena) { + absl::Time val; + if (!value.GetValue(&val)) { + return nullptr; + } + if (!cel::internal::ValidateTimestamp(val).ok()) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetTimestampReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.UnsafeSetFromAbslTime(message, val); + return message; +} + +google::protobuf::Message* UInt32FromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + uint64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + if (!cel::internal::CheckedUint64ToUint32(val).ok()) { + return nullptr; + } + uint32_t ival = static_cast(val); + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetUInt32ValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, ival); + return message; +} + +google::protobuf::Message* UInt64FromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + uint64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetUInt64ValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetValue(message, val); + return message; +} + +google::protobuf::Message* ValueFromValue(google::protobuf::Message* message, const CelValue& value, + google::protobuf::Arena* arena); + +google::protobuf::Message* ValueFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + return ValueFromValue(prototype->New(arena), value, arena); +} + +google::protobuf::Message* ListFromValue(google::protobuf::Message* message, const CelValue& value, + google::protobuf::Arena* arena) { + if (!value.IsList()) { + return nullptr; + } + const CelList& list = *value.ListOrDie(); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetListValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + for (int i = 0; i < list.size(); i++) { + auto e = list.Get(arena, i); + auto* elem = reflection.AddValues(message); + if (ValueFromValue(elem, e, arena) == nullptr) { + return nullptr; + } + } + return message; +} + +google::protobuf::Message* ListFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + if (!value.IsList()) { + return nullptr; + } + return ListFromValue(prototype->New(arena), value, arena); +} + +google::protobuf::Message* StructFromValue(google::protobuf::Message* message, + const CelValue& value, google::protobuf::Arena* arena) { + if (!value.IsMap()) { + return nullptr; + } + const CelMap& map = *value.MapOrDie(); + absl::StatusOr keys_or = map.ListKeys(arena); + if (!keys_or.ok()) { + // If map doesn't support listing keys, it can't pack into a Struct value. + // This will surface as a CEL error when the object creation expression + // fails. + return nullptr; + } + const CelList& keys = **keys_or; + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetStructReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + for (int i = 0; i < keys.size(); i++) { + auto k = keys.Get(arena, i); + // If the key is not a string type, abort the conversion. + if (!k.IsString()) { + return nullptr; + } + absl::string_view key = k.StringOrDie().value(); + + auto v = map.Get(arena, k); + if (!v.has_value()) { + return nullptr; + } + auto* field = reflection.InsertField(message, key); + if (ValueFromValue(field, *v, arena) == nullptr) { + return nullptr; + } + } + return message; +} + +google::protobuf::Message* StructFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + if (!value.IsMap()) { + return nullptr; + } + return StructFromValue(prototype->New(arena), value, arena); +} + +google::protobuf::Message* ValueFromValue(google::protobuf::Message* message, const CelValue& value, + google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetValueReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + switch (value.type()) { + case CelValue::Type::kBool: { + bool val; + if (value.GetValue(&val)) { + reflection.SetBoolValue(message, val); + return message; + } + } break; + case CelValue::Type::kBytes: { + // Base64 encode byte strings to ensure they can safely be transported + // in a JSON string. + CelValue::BytesHolder val; + if (value.GetValue(&val)) { + reflection.SetStringValueFromBytes(message, val.value()); + return message; + } + } break; + case CelValue::Type::kDouble: { + double val; + if (value.GetValue(&val)) { + reflection.SetNumberValue(message, val); + return message; + } + } break; + case CelValue::Type::kDuration: { + // Convert duration values to a protobuf JSON format. + absl::Duration val; + if (value.GetValue(&val)) { + CEL_RETURN_IF_ERROR(cel::internal::ValidateDuration(val)) + .With(IgnoreErrorAndReturnNullptr()); + reflection.SetStringValueFromDuration(message, val); + return message; + } + } break; + case CelValue::Type::kInt64: { + int64_t val; + // Convert int64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + reflection.SetNumberValue(message, val); + return message; + } + } break; + case CelValue::Type::kString: { + CelValue::StringHolder val; + if (value.GetValue(&val)) { + reflection.SetStringValue(message, val.value()); + return message; + } + } break; + case CelValue::Type::kTimestamp: { + // Convert timestamp values to a protobuf JSON format. + absl::Time val; + if (value.GetValue(&val)) { + CEL_RETURN_IF_ERROR(cel::internal::ValidateTimestamp(val)) + .With(IgnoreErrorAndReturnNullptr()); + reflection.SetStringValueFromTimestamp(message, val); + return message; + } + } break; + case CelValue::Type::kUint64: { + uint64_t val; + // Convert uint64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + reflection.SetNumberValue(message, val); + return message; + } + } break; + case CelValue::Type::kList: { + if (ListFromValue(reflection.MutableListValue(message), value, arena) != + nullptr) { + return message; + } + } break; + case CelValue::Type::kMap: { + if (StructFromValue(reflection.MutableStructValue(message), value, + arena) != nullptr) { + return message; + } + } break; + case CelValue::Type::kNullType: + reflection.SetNullValue(message); + return message; + break; + default: + return nullptr; + } + return nullptr; +} + +bool ValueFromValue(Value* json, const CelValue& value, google::protobuf::Arena* arena); + +bool ListFromValue(ListValue* json_list, const CelValue& value, + google::protobuf::Arena* arena) { + if (!value.IsList()) { + return false; + } + const CelList& list = *value.ListOrDie(); + for (int i = 0; i < list.size(); i++) { + auto e = list.Get(arena, i); + Value* elem = json_list->add_values(); + if (!ValueFromValue(elem, e, arena)) { + return false; + } + } + return true; +} + +bool StructFromValue(Struct* json_struct, const CelValue& value, + google::protobuf::Arena* arena) { + if (!value.IsMap()) { + return false; + } + const CelMap& map = *value.MapOrDie(); + absl::StatusOr keys_or = map.ListKeys(arena); + if (!keys_or.ok()) { + // If map doesn't support listing keys, it can't pack into a Struct value. + // This will surface as a CEL error when the object creation expression + // fails. + return false; + } + const CelList& keys = **keys_or; + auto fields = json_struct->mutable_fields(); + for (int i = 0; i < keys.size(); i++) { + auto k = keys.Get(arena, i); + // If the key is not a string type, abort the conversion. + if (!k.IsString()) { + return false; + } + absl::string_view key = k.StringOrDie().value(); + + auto v = map.Get(arena, k); + if (!v.has_value()) { + return false; + } + Value field_value; + if (!ValueFromValue(&field_value, *v, arena)) { + return false; + } + (*fields)[std::string(key)] = field_value; + } + return true; +} + +bool ValueFromValue(Value* json, const CelValue& value, google::protobuf::Arena* arena) { + switch (value.type()) { + case CelValue::Type::kBool: { + bool val; + if (value.GetValue(&val)) { + json->set_bool_value(val); + return true; + } + } break; + case CelValue::Type::kBytes: { + // Base64 encode byte strings to ensure they can safely be transported + // in a JSON string. + CelValue::BytesHolder val; + if (value.GetValue(&val)) { + json->set_string_value(absl::Base64Escape(val.value())); + return true; + } + } break; + case CelValue::Type::kDouble: { + double val; + if (value.GetValue(&val)) { + json->set_number_value(val); + return true; + } + } break; + case CelValue::Type::kDuration: { + // Convert duration values to a protobuf JSON format. + absl::Duration val; + if (value.GetValue(&val)) { + auto encode = cel::internal::EncodeDurationToString(val); + if (!encode.ok()) { + return false; + } + json->set_string_value(*encode); + return true; + } + } break; + case CelValue::Type::kInt64: { + int64_t val; + // Convert int64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return true; + } + } break; + case CelValue::Type::kString: { + CelValue::StringHolder val; + if (value.GetValue(&val)) { + json->set_string_value(val.value()); + return true; + } + } break; + case CelValue::Type::kTimestamp: { + // Convert timestamp values to a protobuf JSON format. + absl::Time val; + if (value.GetValue(&val)) { + auto encode = cel::internal::EncodeTimeToString(val); + if (!encode.ok()) { + return false; + } + json->set_string_value(*encode); + return true; + } + } break; + case CelValue::Type::kUint64: { + uint64_t val; + // Convert uint64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return true; + } + } break; + case CelValue::Type::kList: + return ListFromValue(json->mutable_list_value(), value, arena); + case CelValue::Type::kMap: + return StructFromValue(json->mutable_struct_value(), value, arena); + case CelValue::Type::kNullType: + json->set_null_value(protobuf::NULL_VALUE); + return true; + default: + return false; + } + return false; +} + +google::protobuf::Message* AnyFromValue(const google::protobuf::Message* prototype, + const CelValue& value, google::protobuf::Arena* arena) { + std::string type_name; + absl::Cord payload; + + // In open source, any->PackFrom() returns void rather than boolean. + switch (value.type()) { + case CelValue::Type::kBool: { + BoolValue v; + type_name = v.GetTypeName(); + v.set_value(value.BoolOrDie()); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kBytes: { + BytesValue v; + type_name = v.GetTypeName(); + v.set_value(std::string(value.BytesOrDie().value())); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kDouble: { + DoubleValue v; + type_name = v.GetTypeName(); + v.set_value(value.DoubleOrDie()); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kDuration: { + Duration v; + if (!cel::internal::EncodeDuration(value.DurationOrDie(), &v).ok()) { + return nullptr; + } + type_name = v.GetTypeName(); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kInt64: { + Int64Value v; + type_name = v.GetTypeName(); + v.set_value(value.Int64OrDie()); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kString: { + StringValue v; + type_name = v.GetTypeName(); + v.set_value(std::string(value.StringOrDie().value())); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kTimestamp: { + Timestamp v; + if (!cel::internal::EncodeTime(value.TimestampOrDie(), &v).ok()) { + return nullptr; + } + type_name = v.GetTypeName(); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kUint64: { + UInt64Value v; + type_name = v.GetTypeName(); + v.set_value(value.Uint64OrDie()); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kList: { + ListValue v; + if (!ListFromValue(&v, value, arena)) { + return nullptr; + } + type_name = v.GetTypeName(); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kMap: { + Struct v; + if (!StructFromValue(&v, value, arena)) { + return nullptr; + } + type_name = v.GetTypeName(); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kNullType: { + Value v; + type_name = v.GetTypeName(); + v.set_null_value(google::protobuf::NULL_VALUE); + payload = v.SerializeAsCord(); + } break; + case CelValue::Type::kMessage: { + type_name = value.MessageWrapperOrDie().message_ptr()->GetTypeName(); + payload = value.MessageWrapperOrDie().message_ptr()->SerializeAsCord(); + } break; + default: + return nullptr; + } + + auto* message = prototype->New(arena); + CEL_ASSIGN_OR_RETURN( + auto reflection, + cel::well_known_types::GetAnyReflection(message->GetDescriptor()), + _.With(IgnoreErrorAndReturnNullptr())); + reflection.SetTypeUrl(message, + absl::StrCat("type.googleapis.com/", type_name)); + reflection.SetValue(message, payload); + return message; +} + +bool IsAlreadyWrapped(google::protobuf::Descriptor::WellKnownType wkt, + const CelValue& value) { + if (value.IsMessage()) { + const auto* msg = value.MessageOrDie(); + if (wkt == msg->GetDescriptor()->well_known_type()) { + return true; + } + } + return false; +} + +// MessageFromValueMaker makes a specific protobuf Message instance based on +// the desired protobuf type name and an input CelValue. +// +// It holds a registry of CelValue factories for specific subtypes of Message. +// If message does not match any of types stored in registry, an the factory +// returns an absent value. +class MessageFromValueMaker { + public: + // Non-copyable, non-assignable + MessageFromValueMaker(const MessageFromValueMaker&) = delete; + MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; + + static google::protobuf::Message* MaybeWrapMessage(const google::protobuf::Descriptor* descriptor, + google::protobuf::MessageFactory* factory, + const CelValue& value, + Arena* arena) { + switch (descriptor->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return DoubleFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return FloatFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return Int64FromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return UInt64FromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return Int32FromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return UInt32FromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return StringFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return BytesFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return BoolFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return AnyFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return DurationFromValue(factory->GetPrototype(descriptor), value, + arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return TimestampFromValue(factory->GetPrototype(descriptor), value, + arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return ValueFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return ListFromValue(factory->GetPrototype(descriptor), value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { + return nullptr; + } + return StructFromValue(factory->GetPrototype(descriptor), value, arena); + // WELLKNOWNTYPE_FIELDMASK has no special CelValue type + default: + return nullptr; + } + } +}; + +} // namespace + +CelValue UnwrapMessageToValue(const google::protobuf::Message* value, + const ProtobufValueFactory& factory, + Arena* arena) { + // Messages are Nullable types + if (value == nullptr) { + return CelValue::CreateNull(); + } + + absl::optional special_value = + ValueFromMessageMaker::CreateValue(value, factory, arena); + if (special_value.has_value()) { + return *special_value; + } + return factory(value); +} + +const google::protobuf::Message* MaybeWrapValueToMessage( + const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* factory, + const CelValue& value, Arena* arena) { + google::protobuf::Message* msg = MessageFromValueMaker::MaybeWrapMessage( + descriptor, factory, value, arena); + return msg; +} + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/cel_proto_wrap_util.h b/eval/public/structs/cel_proto_wrap_util.h new file mode 100644 index 000000000..508985209 --- /dev/null +++ b/eval/public/structs/cel_proto_wrap_util.h @@ -0,0 +1,45 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ + +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { + +// UnwrapValue creates CelValue from google::protobuf::Message. +// As some of CEL basic types are subclassing google::protobuf::Message, +// this method contains type checking and downcasts. +CelValue UnwrapMessageToValue(const google::protobuf::Message* value, + const ProtobufValueFactory& factory, + google::protobuf::Arena* arena); + +// MaybeWrapValue attempts to wrap the input value in a proto message with +// the given type_name. If the value can be wrapped, it is returned as a +// protobuf message. Otherwise, the result will be nullptr. +// +// This method is the complement to MaybeUnwrapValue which may unwrap a protobuf +// message to native CelValue representation during a protobuf field read. +// Just as CreateMessage should only be used when reading protobuf values, +// MaybeWrapValue should only be used when assigning protobuf fields. +const google::protobuf::Message* MaybeWrapValueToMessage( + const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* factory, + const CelValue& value, google::protobuf::Arena* arena); + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ diff --git a/eval/public/structs/cel_proto_wrap_util_test.cc b/eval/public/structs/cel_proto_wrap_util_test.cc new file mode 100644 index 000000000..59597fe8f --- /dev/null +++ b/eval/public/structs/cel_proto_wrap_util_test.cc @@ -0,0 +1,921 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/cel_proto_wrap_util.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/empty.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/proto_time_encoding.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "testutil/util.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { + +namespace { + +using ::testing::Eq; +using ::testing::UnorderedPointwise; + +using google::protobuf::Duration; +using google::protobuf::ListValue; +using google::protobuf::Struct; +using google::protobuf::Timestamp; +using google::protobuf::Value; + +using google::protobuf::Any; +using google::protobuf::BoolValue; +using google::protobuf::BytesValue; +using google::protobuf::DoubleValue; +using google::protobuf::FloatValue; +using google::protobuf::Int32Value; +using google::protobuf::Int64Value; +using google::protobuf::StringValue; +using google::protobuf::UInt32Value; +using google::protobuf::UInt64Value; + +using google::protobuf::Arena; + +CelValue ProtobufValueFactoryImpl(const google::protobuf::Message* m) { + return CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(m, TrivialTypeInfo::GetInstance())); +} + +class CelProtoWrapperTest : public ::testing::Test { + protected: + CelProtoWrapperTest() {} + + void ExpectWrappedMessage(const CelValue& value, + const google::protobuf::Message& message) { + // Test the input value wraps to the destination message type. + auto* result = MaybeWrapValueToMessage( + message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), + value, arena()); + EXPECT_TRUE(result != nullptr); + EXPECT_THAT(result, testutil::EqualsProto(message)); + + // Ensure that double wrapping results in the object being wrapped once. + auto* identity = MaybeWrapValueToMessage( + message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), + ProtobufValueFactoryImpl(result), arena()); + EXPECT_TRUE(identity == nullptr); + + // Check to make sure that even dynamic messages can be used as input to + // the wrapping call. + result = MaybeWrapValueToMessage( + ReflectedCopy(message)->GetDescriptor(), + ReflectedCopy(message)->GetReflection()->GetMessageFactory(), value, + arena()); + EXPECT_TRUE(result != nullptr); + EXPECT_THAT(result, testutil::EqualsProto(message)); + } + + void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { + // Test the input value does not wrap by asserting value == result. + auto result = MaybeWrapValueToMessage( + message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), + value, arena()); + EXPECT_TRUE(result == nullptr); + } + + template + void ExpectUnwrappedPrimitive(const google::protobuf::Message& message, T result) { + CelValue cel_value = + UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()); + T value; + EXPECT_TRUE(cel_value.GetValue(&value)); + EXPECT_THAT(value, Eq(result)); + + T dyn_value; + CelValue cel_dyn_value = UnwrapMessageToValue( + ReflectedCopy(message).get(), &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(cel_dyn_value.type(), Eq(cel_value.type())); + EXPECT_TRUE(cel_dyn_value.GetValue(&dyn_value)); + EXPECT_THAT(value, Eq(dyn_value)); + } + + void ExpectUnwrappedMessage(const google::protobuf::Message& message, + google::protobuf::Message* result) { + CelValue cel_value = + UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()); + if (result == nullptr) { + EXPECT_TRUE(cel_value.IsNull()); + return; + } + EXPECT_TRUE(cel_value.IsMessage()); + EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); + } + + std::unique_ptr ReflectedCopy( + const google::protobuf::Message& message) { + std::unique_ptr dynamic_value( + factory_.GetPrototype(message.GetDescriptor())->New()); + dynamic_value->CopyFrom(message); + return dynamic_value; + } + + Arena* arena() { return &arena_; } + + private: + Arena arena_; + google::protobuf::DynamicMessageFactory factory_; +}; + +TEST_F(CelProtoWrapperTest, TestType) { + Duration msg_duration; + msg_duration.set_seconds(2); + msg_duration.set_nanos(3); + + CelValue value_duration2 = + UnwrapMessageToValue(&msg_duration, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); + + Timestamp msg_timestamp; + msg_timestamp.set_seconds(2); + msg_timestamp.set_nanos(3); + + CelValue value_timestamp2 = + UnwrapMessageToValue(&msg_timestamp, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); +} + +// This test verifies CelValue support of Duration type. +TEST_F(CelProtoWrapperTest, TestDuration) { + Duration msg_duration; + msg_duration.set_seconds(2); + msg_duration.set_nanos(3); + CelValue value = + UnwrapMessageToValue(&msg_duration, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(value.type(), Eq(CelValue::Type::kDuration)); + + Duration out; + auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &out); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); +} + +// This test verifies CelValue support of Timestamp type. +TEST_F(CelProtoWrapperTest, TestTimestamp) { + Timestamp msg_timestamp; + msg_timestamp.set_seconds(2); + msg_timestamp.set_nanos(3); + + CelValue value = + UnwrapMessageToValue(&msg_timestamp, &ProtobufValueFactoryImpl, arena()); + + EXPECT_TRUE(value.IsTimestamp()); + Timestamp out; + auto status = cel::internal::EncodeTime(value.TimestampOrDie(), &out); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); +} + +// Dynamic Values test +// +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueNull) { + Value json; + json.set_null_value(google::protobuf::NullValue::NULL_VALUE); + ExpectUnwrappedMessage(json, nullptr); +} + +// Test support for unwrapping a google::protobuf::Value to a CEL value. +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueNull) { + Value value_msg; + value_msg.set_null_value(protobuf::NULL_VALUE); + + CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsNull()); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueBool) { + bool value = true; + + Value json; + json.set_bool_value(true); + ExpectUnwrappedPrimitive(json, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueNumber) { + double value = 1.0; + + Value json; + json.set_number_value(value); + ExpectUnwrappedPrimitive(json, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueString) { + const std::string test = "test"; + auto value = CelValue::StringHolder(&test); + + Value json; + json.set_string_value(test); + ExpectUnwrappedPrimitive(json, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueStruct) { + const std::vector kFields = {"field1", "field2", "field3"}; + Struct value_struct; + + auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; + value1.set_bool_value(true); + + auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; + value2.set_number_value(1.0); + + auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; + value3.set_string_value("test"); + + CelValue value = + UnwrapMessageToValue(&value_struct, &ProtobufValueFactoryImpl, arena()); + ASSERT_TRUE(value.IsMap()); + + const CelMap* cel_map = value.MapOrDie(); + + CelValue field1 = CelValue::CreateString(&kFields[0]); + auto field1_presence = cel_map->Has(field1); + ASSERT_OK(field1_presence); + EXPECT_TRUE(*field1_presence); + auto lookup1 = (*cel_map)[field1]; + ASSERT_TRUE(lookup1.has_value()); + ASSERT_TRUE(lookup1->IsBool()); + EXPECT_EQ(lookup1->BoolOrDie(), true); + + CelValue field2 = CelValue::CreateString(&kFields[1]); + auto field2_presence = cel_map->Has(field2); + ASSERT_OK(field2_presence); + EXPECT_TRUE(*field2_presence); + auto lookup2 = (*cel_map)[field2]; + ASSERT_TRUE(lookup2.has_value()); + ASSERT_TRUE(lookup2->IsDouble()); + EXPECT_DOUBLE_EQ(lookup2->DoubleOrDie(), 1.0); + + CelValue field3 = CelValue::CreateString(&kFields[2]); + auto field3_presence = cel_map->Has(field3); + ASSERT_OK(field3_presence); + EXPECT_TRUE(*field3_presence); + auto lookup3 = (*cel_map)[field3]; + ASSERT_TRUE(lookup3.has_value()); + ASSERT_TRUE(lookup3->IsString()); + EXPECT_EQ(lookup3->StringOrDie().value(), "test"); + + std::string missing = "missing_field"; + CelValue missing_field = CelValue::CreateString(&missing); + auto missing_field_presence = cel_map->Has(missing_field); + ASSERT_OK(missing_field_presence); + EXPECT_FALSE(*missing_field_presence); + + const CelList* key_list = cel_map->ListKeys().value(); + ASSERT_EQ(key_list->size(), kFields.size()); + + std::vector result_keys; + for (int i = 0; i < key_list->size(); i++) { + CelValue key = (*key_list)[i]; + ASSERT_TRUE(key.IsString()); + result_keys.push_back(std::string(key.StringOrDie().value())); + } + + EXPECT_THAT(result_keys, UnorderedPointwise(Eq(), kFields)); +} + +// Test support for google::protobuf::Struct when it is created as dynamic +// message +TEST_F(CelProtoWrapperTest, UnwrapDynamicStruct) { + Struct struct_msg; + const std::string kFieldInt = "field_int"; + const std::string kFieldBool = "field_bool"; + (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); + (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); + CelValue value = UnwrapMessageToValue(ReflectedCopy(struct_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsMap()); + const CelMap* cel_map = value.MapOrDie(); + ASSERT_TRUE(cel_map != nullptr); + + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsDouble()); + EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); + } + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsBool()); + EXPECT_EQ(v.BoolOrDie(), true); + } + { + auto presence = cel_map->Has(CelValue::CreateBool(true)); + ASSERT_FALSE(presence.ok()); + EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); + auto lookup = (*cel_map)[CelValue::CreateBool(true)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsError()); + } +} + +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueStruct) { + const std::string kField1 = "field1"; + const std::string kField2 = "field2"; + Value value_msg; + (*value_msg.mutable_struct_value()->mutable_fields())[kField1] + .set_number_value(1); + (*value_msg.mutable_struct_value()->mutable_fields())[kField2] + .set_number_value(2); + + CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsMap()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueList) { + const std::vector kFields = {"field1", "field2", "field3"}; + + ListValue list_value; + + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + + CelValue value = + UnwrapMessageToValue(&list_value, &ProtobufValueFactoryImpl, arena()); + ASSERT_TRUE(value.IsList()); + + const CelList* cel_list = value.ListOrDie(); + + ASSERT_EQ(cel_list->size(), 3); + + CelValue value1 = (*cel_list)[0]; + ASSERT_TRUE(value1.IsBool()); + EXPECT_EQ(value1.BoolOrDie(), true); + + auto value2 = (*cel_list)[1]; + ASSERT_TRUE(value2.IsDouble()); + EXPECT_DOUBLE_EQ(value2.DoubleOrDie(), 1.0); + + auto value3 = (*cel_list)[2]; + ASSERT_TRUE(value3.IsString()); + EXPECT_EQ(value3.StringOrDie().value(), "test"); +} + +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueListValue) { + Value value_msg; + value_msg.mutable_list_value()->add_values()->set_number_value(1.); + value_msg.mutable_list_value()->add_values()->set_number_value(2.); + + CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsList()); + EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); + EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); +} + +// Test support of google.protobuf.Any in CelValue. +TEST_F(CelProtoWrapperTest, UnwrapAnyValue) { + TestMessage test_message; + test_message.set_string_value("test"); + + Any any; + any.PackFrom(test_message); + ExpectUnwrappedMessage(any, &test_message); +} + +TEST_F(CelProtoWrapperTest, UnwrapInvalidAny) { + Any any; + CelValue value = + UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()); + ASSERT_TRUE(value.IsError()); + + any.set_type_url("/"); + ASSERT_TRUE( + UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()).IsError()); + + any.set_type_url("/invalid.proto.name"); + ASSERT_TRUE( + UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()).IsError()); +} + +// Test support of google.protobuf.Value wrappers in CelValue. +TEST_F(CelProtoWrapperTest, UnwrapBoolWrapper) { + bool value = true; + + BoolValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapInt32Wrapper) { + int64_t value = 12; + + Int32Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapUInt32Wrapper) { + uint64_t value = 12; + + UInt32Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapInt64Wrapper) { + int64_t value = 12; + + Int64Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapUInt64Wrapper) { + uint64_t value = 12; + + UInt64Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapFloatWrapper) { + double value = 42.5; + + FloatValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapDoubleWrapper) { + double value = 42.5; + + DoubleValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapStringWrapper) { + std::string text = "42"; + auto value = CelValue::StringHolder(&text); + + StringValue wrapper; + wrapper.set_value(text); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapBytesWrapper) { + std::string text = "42"; + auto value = CelValue::BytesHolder(&text); + + BytesValue wrapper; + wrapper.set_value("42"); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, WrapNull) { + auto cel_value = CelValue::CreateNull(); + + Value json; + json.set_null_value(protobuf::NULL_VALUE); + ExpectWrappedMessage(cel_value, json); + + Any any; + any.PackFrom(json); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBool) { + auto cel_value = CelValue::CreateBool(true); + + Value json; + json.set_bool_value(true); + ExpectWrappedMessage(cel_value, json); + + BoolValue wrapper; + wrapper.set_value(true); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBytes) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); + + BytesValue wrapper; + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBytesToValue) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); + + Value json; + json.set_string_value("aGVsbG8gd29ybGQ="); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapDuration) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); + + Duration d; + d.set_seconds(300); + ExpectWrappedMessage(cel_value, d); + + Any any; + any.PackFrom(d); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapDurationToValue) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); + + Value json; + json.set_string_value("300s"); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapDouble) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); + + Value json; + json.set_number_value(num); + ExpectWrappedMessage(cel_value, json); + + DoubleValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapDoubleToFloatValue) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); + + FloatValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + // Imprecise double -> float representation results in truncation. + double small_num = -9.9e-100; + wrapper.set_value(small_num); + cel_value = CelValue::CreateDouble(small_num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapDoubleOverflow) { + double lowest_double = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateDouble(lowest_double); + + // Double exceeds float precision, overflow to -infinity. + FloatValue wrapper; + wrapper.set_value(-std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); + + double max_double = std::numeric_limits::max(); + cel_value = CelValue::CreateDouble(max_double); + + wrapper.set_value(std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapInt64) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); + + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); + + Int64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapInt64ToInt32Value) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); + + Int32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapFailureInt64ToInt32Value) { + int64_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); + + Int32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapInt64ToValue) { + int64_t max = std::numeric_limits::max(); + auto cel_value = CelValue::CreateInt64(max); + + Value json; + json.set_string_value(absl::StrCat(max)); + ExpectWrappedMessage(cel_value, json); + + int64_t min = std::numeric_limits::min(); + cel_value = CelValue::CreateInt64(min); + + json.set_string_value(absl::StrCat(min)); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapUint64) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); + + UInt64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapUint64ToUint32Value) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + UInt32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapUint64ToValue) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + Value json; + json.set_string_value(absl::StrCat(num)); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureUint64ToUint32Value) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + UInt32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapString) { + std::string str = "test"; + auto cel_value = CelValue::CreateString(CelValue::StringHolder(&str)); + + Value json; + json.set_string_value(str); + ExpectWrappedMessage(cel_value, json); + + StringValue wrapper; + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapTimestamp) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); + + Timestamp t; + t.set_seconds(1615852799); + ExpectWrappedMessage(cel_value, t); + + Any any; + any.PackFrom(t); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapTimestampToValue) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); + + Value json; + json.set_string_value("2021-03-15T23:59:59Z"); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapList) { + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelValue::CreateInt64(-2L), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); + + Value json; + json.mutable_list_value()->add_values()->set_number_value(1.5); + json.mutable_list_value()->add_values()->set_number_value(-2.); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.list_value()); + + Any any; + any.PackFrom(json.list_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureListValueBadJSON) { + TestMessage message; + std::vector list_elems = { + CelValue::CreateDouble(1.5), + UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapStruct) { + const std::string kField1 = "field1"; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelValue::CreateBool(true)}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + (*json.mutable_struct_value()->mutable_fields())[kField1].set_bool_value( + true); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.struct_value()); + + Any any; + any.PackFrom(json.struct_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadKeyType) { + std::vector> args = { + {CelValue::CreateInt64(1L), CelValue::CreateBool(true)}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadValueType) { + const std::string kField1 = "field1"; + TestMessage bad_value; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + UnwrapMessageToValue(&bad_value, &ProtobufValueFactoryImpl, arena())}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + Value json; + ExpectNotWrapped(cel_value, json); +} + +class TestMap : public CelMapBuilder { + public: + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("test"); + } +}; + +TEST_F(CelProtoWrapperTest, WrapFailureStructListKeysUnimplemented) { + const std::string kField1 = "field1"; + TestMap map; + ASSERT_OK(map.Add(CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelValue::CreateString(CelValue::StringHolder(&kField1)))); + + auto cel_value = CelValue::CreateMap(&map); + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureWrongType) { + auto cel_value = CelValue::CreateNull(); + std::vector wrong_types = { + &BoolValue::default_instance(), &BytesValue::default_instance(), + &DoubleValue::default_instance(), &Duration::default_instance(), + &FloatValue::default_instance(), &Int32Value::default_instance(), + &Int64Value::default_instance(), &ListValue::default_instance(), + &StringValue::default_instance(), &Struct::default_instance(), + &Timestamp::default_instance(), &UInt32Value::default_instance(), + &UInt64Value::default_instance(), + }; + for (const auto* wrong_type : wrong_types) { + ExpectNotWrapped(cel_value, *wrong_type); + } +} + +TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { + auto cel_value = CreateNoSuchFieldError(arena(), "error_field"); + ExpectNotWrapped(cel_value, Any::default_instance()); +} + +TEST_F(CelProtoWrapperTest, DebugString) { + google::protobuf::Empty e; + // Note: the value factory is trivial so the debug string for a message-typed + // value is uninteresting. + EXPECT_EQ(UnwrapMessageToValue(&e, &ProtobufValueFactoryImpl, arena()) + .DebugString(), + "Message: opaque"); + + ListValue list_value; + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + CelValue value = + UnwrapMessageToValue(&list_value, &ProtobufValueFactoryImpl, arena()); + EXPECT_EQ(value.DebugString(), + "CelList: [bool: 1, double: 1.000000, string: test]"); + + Struct value_struct; + auto& value1 = (*value_struct.mutable_fields())["a"]; + value1.set_bool_value(true); + auto& value2 = (*value_struct.mutable_fields())["b"]; + value2.set_number_value(1.0); + auto& value3 = (*value_struct.mutable_fields())["c"]; + value3.set_string_value("test"); + + value = + UnwrapMessageToValue(&value_struct, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT( + value.DebugString(), + testing::AllOf(testing::StartsWith("CelMap: {"), + testing::HasSubstr(": "), + testing::HasSubstr(": : "))); +} + +} // namespace + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 78c29f463..6fad6aee3 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -1,350 +1,60 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "eval/public/structs/cel_proto_wrapper.h" -#include "google/protobuf/any.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/wrappers.pb.h" -#include "absl/container/node_hash_map.h" -#include "absl/strings/substitute.h" -#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/cel_proto_wrap_util.h" +#include "eval/public/structs/proto_message_type_adapter.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { -using google::protobuf::Arena; -using google::protobuf::Descriptor; -using google::protobuf::DescriptorPool; -using google::protobuf::Message; - -using google::protobuf::Any; -using google::protobuf::BoolValue; -using google::protobuf::BytesValue; -using google::protobuf::DoubleValue; -using google::protobuf::Duration; -using google::protobuf::FloatValue; -using google::protobuf::Int32Value; -using google::protobuf::Int64Value; -using google::protobuf::ListValue; -using google::protobuf::StringValue; -using google::protobuf::Struct; -using google::protobuf::Timestamp; -using google::protobuf::UInt32Value; -using google::protobuf::UInt64Value; -using google::protobuf::Value; - -// Forward declaration for google.protobuf.Value -CelValue ValueFromMessage(const Value* value, Arena* arena); - -// Map implementation wrapping google.protobuf.ListValue -class DynamicList : public CelList { - public: - DynamicList(const ListValue* values, Arena* arena) - : arena_(arena), values_(values) {} - - CelValue operator[](int index) const override { - return ValueFromMessage(&values_->values(index), arena_); - } - - // List size - int size() const override { return values_->values_size(); } - - private: - Arena* arena_; - const ListValue* values_; -}; - -// Map implementation wrapping google.protobuf.Struct. -class DynamicMap : public CelMap { - public: - DynamicMap(const Struct* values, Arena* arena) - : arena_(arena), values_(values), key_list_(values) {} - - absl::optional operator[](CelValue key) const override { - CelValue::StringHolder str_key; - if (!key.GetValue(&str_key)) { - return {}; // Not a string key - } - - auto it = values_->fields().find(std::string(str_key.value())); - if (it == values_->fields().end()) { - return {}; - } - - return ValueFromMessage(&it->second, arena_); - } - - int size() const override { return values_->fields_size(); } - - const CelList* ListKeys() const override { return &key_list_; } - - private: - // List of keys in Struct.fields map. - // It utilizes lazy initialization, to avoid performance penalties. - class DynamicMapKeyList : public CelList { - public: - explicit DynamicMapKeyList(const Struct* values) - : values_(values), keys_(), initialized_(false) {} - - // Index access - CelValue operator[](int index) const override { - CheckInit(); - return keys_[index]; - } - - // List size - int size() const override { - CheckInit(); - return values_->fields_size(); - } - - private: - void CheckInit() const { - absl::MutexLock lock(&mutex_); - if (!initialized_) { - for (const auto& it : values_->fields()) { - keys_.push_back(CelValue::CreateString(&it.first)); - } - initialized_ = true; - } - } - - const Struct* values_; - mutable absl::Mutex mutex_; - mutable std::vector keys_; - mutable bool initialized_; - }; - - Arena* arena_; - const Struct* values_; - const DynamicMapKeyList key_list_; -}; - -// ValueFromMessage(....) function family. -// Functions of this family create CelValue object from specific subtypes of -// protobuf message. -CelValue ValueFromMessage(const Duration* duration, Arena*) { - return CelProtoWrapper::CreateDuration(duration); -} - -CelValue ValueFromMessage(const Timestamp* timestamp, Arena*) { - return CelProtoWrapper::CreateTimestamp(timestamp); -} - -CelValue ValueFromMessage(const ListValue* list_values, Arena* arena) { - return CelValue::CreateList( - Arena::Create(arena, list_values, arena)); -} - -CelValue ValueFromMessage(const Struct* struct_value, Arena* arena) { - return CelValue::CreateMap( - Arena::Create(arena, struct_value, arena)); -} - -CelValue ValueFromMessage(const Any* any_value, Arena* arena) { - auto type_url = any_value->type_url(); - - auto pos = type_url.find_last_of("/"); - if (pos == absl::string_view::npos) { - // TODO(issues/25) What error code? - // Malformed type_url - return CreateErrorValue(arena, "Malformed type_url string"); - } - - std::string full_name = std::string(type_url.substr(pos + 1)); - const Descriptor* nested_descriptor = - DescriptorPool::generated_pool()->FindMessageTypeByName(full_name); - - if (nested_descriptor == nullptr) { - // Descriptor not found for the type - // TODO(issues/25) What error code? - return CreateErrorValue(arena, "Descriptor not found"); - } - - const Message* prototype = - google::protobuf::MessageFactory::generated_factory()->GetPrototype( - nested_descriptor); - if (prototype == nullptr) { - // Failed to obtain prototype for the descriptor - // TODO(issues/25) What error code? - return CreateErrorValue(arena, "Prototype not found"); - } - - Message* nested_message = prototype->New(arena); - if (!any_value->UnpackTo(nested_message)) { - // Failed to unpack. - // TODO(issues/25) What error code? - return CreateErrorValue(arena, "Failed to unpack Any into message"); - } - - return CelProtoWrapper::CreateMessage(nested_message, arena); -} - -CelValue ValueFromMessage(const BoolValue* wrapper, Arena*) { - return CelValue::CreateBool(wrapper->value()); -} - -CelValue ValueFromMessage(const Int32Value* wrapper, Arena*) { - return CelValue::CreateInt64(wrapper->value()); -} - -CelValue ValueFromMessage(const UInt32Value* wrapper, Arena*) { - return CelValue::CreateUint64(wrapper->value()); -} - -CelValue ValueFromMessage(const Int64Value* wrapper, Arena*) { - return CelValue::CreateInt64(wrapper->value()); -} - -CelValue ValueFromMessage(const UInt64Value* wrapper, Arena*) { - return CelValue::CreateUint64(wrapper->value()); -} - -CelValue ValueFromMessage(const FloatValue* wrapper, Arena*) { - return CelValue::CreateDouble(wrapper->value()); -} - -CelValue ValueFromMessage(const DoubleValue* wrapper, Arena*) { - return CelValue::CreateDouble(wrapper->value()); -} +using ::google::protobuf::Arena; +using ::google::protobuf::Descriptor; +using ::google::protobuf::Message; -CelValue ValueFromMessage(const StringValue* wrapper, Arena*) { - return CelValue::CreateString(&wrapper->value()); -} - -CelValue ValueFromMessage(const BytesValue* wrapper, Arena* arena) { - // BytesValue stores value as Cord - return CelValue::CreateBytes( - Arena::Create(arena, std::string(wrapper->value()))); -} +} // namespace -CelValue ValueFromMessage(const Value* value, Arena* arena) { - switch (value->kind_case()) { - case Value::KindCase::kNullValue: - return CelValue::CreateNull(); - case Value::KindCase::kNumberValue: - return CelValue::CreateDouble(value->number_value()); - case Value::KindCase::kStringValue: - return CelValue::CreateString(&value->string_value()); - case Value::KindCase::kBoolValue: - return CelValue::CreateBool(value->bool_value()); - case Value::KindCase::kStructValue: - return CelProtoWrapper::CreateMessage(&value->struct_value(), arena); - case Value::KindCase::kListValue: - return CelProtoWrapper::CreateMessage(&value->list_value(), arena); - default: - return CreateErrorValue(arena, "No known fields set in Value message"); - } +CelValue CelProtoWrapper::InternalWrapMessage(const Message* message) { + return CelValue::CreateMessageWrapper( + MessageWrapper(message, &GetGenericProtoTypeInfoInstance())); } -// Factory class, responsible for creating CelValue object from Message of some -// fixed subtype. -class ValueFromMessageFactory { - public: - virtual ~ValueFromMessageFactory() {} - virtual const google::protobuf::Descriptor* GetDescriptor() const = 0; - virtual absl::optional CreateValue(const google::protobuf::Message* value, - Arena* arena) const = 0; -}; - -// This template class has a good performance, but performes downcast -// operations on google::protobuf::Message pointers. -template -class CastingValueFromMessageFactory : public ValueFromMessageFactory { - public: - const google::protobuf::Descriptor* GetDescriptor() const override { - return MessageType::descriptor(); - } - - absl::optional CreateValue(const google::protobuf::Message* msg, - Arena* arena) const override { - if (MessageType::descriptor() == msg->GetDescriptor()) { - const MessageType* message = - google::protobuf::DynamicCastToGenerated(msg); - if (message == nullptr) { - auto message_copy = Arena::CreateMessage(arena); - message_copy->CopyFrom(*msg); - message = message_copy; - } - return ValueFromMessage(message, arena); - } - return {}; - } -}; - -// Class makes CelValue from generic protobuf Message. -// It holds a registry of CelValue factories for specific subtypes of Message. -// If message does not match any of types stored in registry, generic -// message-containing CelValue is created. -class ValueFromMessageMaker { - public: - explicit ValueFromMessageMaker() { - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - - Add(absl::make_unique>()); - - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - } - - absl::optional CreateValue(const google::protobuf::Message* value, - Arena* arena) const { - auto it = factories_.find(value->GetDescriptor()); - if (it == factories_.end()) { - // Not found for value->GetDescriptor()->name() - return {}; - } - return (it->second)->CreateValue(value, arena); - } - - // Non-copyable, non-assignable - ValueFromMessageMaker(const ValueFromMessageMaker&) = delete; - ValueFromMessageMaker& operator=(const ValueFromMessageMaker&) = delete; - - private: - void Add(std::unique_ptr factory) { - const Descriptor* desc = factory->GetDescriptor(); - factories_.emplace(desc, std::move(factory)); - } - - absl::node_hash_map> - factories_; -}; - -} // namespace - // CreateMessage creates CelValue from google::protobuf::Message. // As some of CEL basic types are subclassing google::protobuf::Message, // this method contains type checking and downcasts. -CelValue CelProtoWrapper::CreateMessage(const google::protobuf::Message* value, - Arena* arena) { - static const ValueFromMessageMaker* maker = new ValueFromMessageMaker(); +CelValue CelProtoWrapper::CreateMessage(const Message* value, Arena* arena) { + return internal::UnwrapMessageToValue(value, &InternalWrapMessage, arena); +} - // Messages are Nullable types - if (value == nullptr) { - return CelValue(value); +absl::optional CelProtoWrapper::MaybeWrapValue( + const Descriptor* descriptor, google::protobuf::MessageFactory* factory, + const CelValue& value, Arena* arena) { + const Message* msg = + internal::MaybeWrapValueToMessage(descriptor, factory, value, arena); + if (msg != nullptr) { + return InternalWrapMessage(msg); + } else { + return std::nullopt; } - - auto special_value = maker->CreateValue(value, arena); - - return special_value.has_value() ? special_value.value() : CelValue(value); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/cel_proto_wrapper.h b/eval/public/structs/cel_proto_wrapper.h index 830bfd67f..73942c253 100644 --- a/eval/public/structs/cel_proto_wrapper.h +++ b/eval/public/structs/cel_proto_wrapper.h @@ -3,35 +3,51 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" +#include "absl/types/optional.h" #include "eval/public/cel_value.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { class CelProtoWrapper { public: // CreateMessage creates CelValue from google::protobuf::Message. // As some of CEL basic types are subclassing google::protobuf::Message, // this method contains type checking and downcasts. - static CelValue CreateMessage(const google::protobuf::Message *value, - google::protobuf::Arena *arena); + static CelValue CreateMessage(const google::protobuf::Message* value, + google::protobuf::Arena* arena); + + // Internal utility for creating a CelValue wrapping a user defined type. + // Assumes that the message has been properly unpacked. + static CelValue InternalWrapMessage(const google::protobuf::Message* message); // CreateDuration creates CelValue from a non-null protobuf duration value. - static CelValue CreateDuration(const google::protobuf::Duration *value) { - return CelValue(expr::internal::DecodeDuration(*value)); + static CelValue CreateDuration(const google::protobuf::Duration* value) { + return CelValue(cel::internal::DecodeDuration(*value)); } // CreateTimestamp creates CelValue from a non-null protobuf timestamp value. - static CelValue CreateTimestamp(const google::protobuf::Timestamp *value) { - return CelValue(expr::internal::DecodeTime(*value)); + static CelValue CreateTimestamp(const google::protobuf::Timestamp* value) { + return CelValue(cel::internal::DecodeTime(*value)); } + + // MaybeWrapValue attempts to wrap the input value in a proto message with + // the given type_name. If the value can be wrapped, it is returned as a + // CelValue pointing to the protobuf message. Otherwise, the result will be + // empty. + // + // This method is the complement to CreateMessage which may unwrap a protobuf + // message to native CelValue representation during a protobuf field read. + // Just as CreateMessage should only be used when reading protobuf values, + // MaybeWrapValue should only be used when assigning protobuf fields. + static absl::optional MaybeWrapValue( + const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* factory, + const CelValue& value, google::protobuf::Arena* arena); }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAPPER_H_ diff --git a/eval/public/structs/cel_proto_wrapper_test.cc b/eval/public/structs/cel_proto_wrapper_test.cc index 4e797589a..b9fcd6b51 100644 --- a/eval/public/structs/cel_proto_wrapper_test.cc +++ b/eval/public/structs/cel_proto_wrapper_test.cc @@ -1,25 +1,38 @@ #include "eval/public/structs/cel_proto_wrapper.h" +#include +#include +#include +#include +#include +#include + #include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" -#include "google/protobuf/dynamic_message.h" -#include "google/protobuf/message.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" #include "eval/testutil/test_message.pb.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" +#include "internal/status_macros.h" +#include "internal/testing.h" #include "testutil/util.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace { -using testing::Eq; -using testing::UnorderedPointwise; +using ::testing::Eq; +using ::testing::UnorderedPointwise; using google::protobuf::Duration; using google::protobuf::ListValue; @@ -38,9 +51,89 @@ using google::protobuf::StringValue; using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; -TEST(CelProtoWrapperTest, TestType) { - ::google::protobuf::Arena arena; +using google::protobuf::Arena; + +class CelProtoWrapperTest : public ::testing::Test { + protected: + CelProtoWrapperTest() {} + + void ExpectWrappedMessage(const CelValue& value, + const google::protobuf::Message& message) { + // Test the input value wraps to the destination message type. + auto result = CelProtoWrapper::MaybeWrapValue( + message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), + value, arena()); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE((*result).IsMessage()); + EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); + + // Ensure that double wrapping results in the object being wrapped once. + auto identity = CelProtoWrapper::MaybeWrapValue( + message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), + *result, arena()); + EXPECT_FALSE(identity.has_value()); + + // Check to make sure that even dynamic messages can be used as input to + // the wrapping call. + result = CelProtoWrapper::MaybeWrapValue( + ReflectedCopy(message)->GetDescriptor(), + ReflectedCopy(message)->GetReflection()->GetMessageFactory(), value, + arena()); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE((*result).IsMessage()); + EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); + } + + void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { + // Test the input value does not wrap by asserting value == result. + auto result = CelProtoWrapper::MaybeWrapValue( + message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), + value, arena()); + EXPECT_FALSE(result.has_value()); + } + + template + void ExpectUnwrappedPrimitive(const google::protobuf::Message& message, T result) { + CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); + T value; + EXPECT_TRUE(cel_value.GetValue(&value)); + EXPECT_THAT(value, Eq(result)); + + T dyn_value; + CelValue cel_dyn_value = + CelProtoWrapper::CreateMessage(ReflectedCopy(message).get(), arena()); + EXPECT_THAT(cel_dyn_value.type(), Eq(cel_value.type())); + EXPECT_TRUE(cel_dyn_value.GetValue(&dyn_value)); + EXPECT_THAT(value, Eq(dyn_value)); + } + void ExpectUnwrappedMessage(const google::protobuf::Message& message, + google::protobuf::Message* result) { + CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); + if (result == nullptr) { + EXPECT_TRUE(cel_value.IsNull()); + return; + } + EXPECT_TRUE(cel_value.IsMessage()); + EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); + } + + std::unique_ptr ReflectedCopy( + const google::protobuf::Message& message) { + std::unique_ptr dynamic_value( + factory_.GetPrototype(message.GetDescriptor())->New()); + dynamic_value->CopyFrom(message); + return dynamic_value; + } + + Arena* arena() { return &arena_; } + + private: + Arena arena_; + google::protobuf::DynamicMessageFactory factory_; +}; + +TEST_F(CelProtoWrapperTest, TestType) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); @@ -48,7 +141,7 @@ TEST(CelProtoWrapperTest, TestType) { EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = - CelProtoWrapper::CreateMessage(&msg_duration, &arena); + CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); Timestamp msg_timestamp; @@ -58,14 +151,12 @@ TEST(CelProtoWrapperTest, TestType) { EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = - CelProtoWrapper::CreateMessage(&msg_timestamp, &arena); + CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); } // This test verifies CelValue support of Duration type. -TEST(CelProtoWrapperTest, TestDuration) { - google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, TestDuration) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); @@ -73,21 +164,19 @@ TEST(CelProtoWrapperTest, TestDuration) { EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = - CelProtoWrapper::CreateMessage(&msg_duration, &arena); + CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); CelValue value = CelProtoWrapper::CreateDuration(&msg_duration); - // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsDuration()); Duration out; - expr::internal::EncodeDuration(value.DurationOrDie(), &out); + auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &out); + EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); } // This test verifies CelValue support of Timestamp type. -TEST(CelProtoWrapperTest, TestTimestamp) { - google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, TestTimestamp) { Timestamp msg_timestamp; msg_timestamp.set_seconds(2); msg_timestamp.set_nanos(3); @@ -95,70 +184,63 @@ TEST(CelProtoWrapperTest, TestTimestamp) { EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = - CelProtoWrapper::CreateMessage(&msg_timestamp, &arena); + CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); CelValue value = CelProtoWrapper::CreateTimestamp(&msg_timestamp); // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsTimestamp()); Timestamp out; - expr::internal::EncodeTime(value.TimestampOrDie(), &out); + auto status = cel::internal::EncodeTime(value.TimestampOrDie(), &out); + EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); } // Dynamic Values test // - -TEST(CelProtoWrapperTest, TestValueFieldNull) { - ::google::protobuf::Arena arena; - - Value value1; - value1.set_null_value(google::protobuf::NullValue::NULL_VALUE); - - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsNull()); +TEST_F(CelProtoWrapperTest, UnwrapValueNull) { + Value json; + json.set_null_value(google::protobuf::NullValue::NULL_VALUE); + ExpectUnwrappedMessage(json, nullptr); } -TEST(CelProtoWrapperTest, TestValueFieldBool) { - ::google::protobuf::Arena arena; - - Value value1; - value1.set_bool_value(true); +// Test support for unwrapping a google::protobuf::Value to a CEL value. +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueNull) { + Value value_msg; + value_msg.set_null_value(protobuf::NULL_VALUE); - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsBool()); - EXPECT_EQ(value.BoolOrDie(), true); + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsNull()); } -TEST(CelProtoWrapperTest, TestValueFieldNumeric) { - ::google::protobuf::Arena arena; - - Value value1; - value1.set_number_value(1.0); +TEST_F(CelProtoWrapperTest, UnwrapValueBool) { + bool value = true; - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsDouble()); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), 1.0); + Value json; + json.set_bool_value(true); + ExpectUnwrappedPrimitive(json, value); } -TEST(CelProtoWrapperTest, TestValueFieldString) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapValueNumber) { + double value = 1.0; - const std::string kTest = "test"; + Value json; + json.set_number_value(value); + ExpectUnwrappedPrimitive(json, value); +} - Value value1; - value1.set_string_value(kTest); +TEST_F(CelProtoWrapperTest, UnwrapValueString) { + const std::string test = "test"; + auto value = CelValue::StringHolder(&test); - CelValue value = CelProtoWrapper::CreateMessage(&value1, &arena); - ASSERT_TRUE(value.IsString()); - EXPECT_EQ(value.StringOrDie().value(), kTest); + Value json; + json.set_string_value(test); + ExpectUnwrappedPrimitive(json, value); } -TEST(CelProtoWrapperTest, TestValueFieldStruct) { - ::google::protobuf::Arena arena; - +TEST_F(CelProtoWrapperTest, UnwrapValueStruct) { const std::vector kFields = {"field1", "field2", "field3"}; - Struct value_struct; auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; @@ -170,27 +252,45 @@ TEST(CelProtoWrapperTest, TestValueFieldStruct) { auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; value3.set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&value_struct, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&value_struct, arena()); ASSERT_TRUE(value.IsMap()); const CelMap* cel_map = value.MapOrDie(); - auto lookup1 = (*cel_map)[CelValue::CreateString(&kFields[0])]; + CelValue field1 = CelValue::CreateString(&kFields[0]); + auto field1_presence = cel_map->Has(field1); + ASSERT_OK(field1_presence); + EXPECT_TRUE(*field1_presence); + auto lookup1 = (*cel_map)[field1]; ASSERT_TRUE(lookup1.has_value()); - ASSERT_TRUE(lookup1.value().IsBool()); - EXPECT_EQ(lookup1.value().BoolOrDie(), true); - - auto lookup2 = (*cel_map)[CelValue::CreateString(&kFields[1])]; + ASSERT_TRUE(lookup1->IsBool()); + EXPECT_EQ(lookup1->BoolOrDie(), true); + + CelValue field2 = CelValue::CreateString(&kFields[1]); + auto field2_presence = cel_map->Has(field2); + ASSERT_OK(field2_presence); + EXPECT_TRUE(*field2_presence); + auto lookup2 = (*cel_map)[field2]; ASSERT_TRUE(lookup2.has_value()); - ASSERT_TRUE(lookup2.value().IsDouble()); - EXPECT_DOUBLE_EQ(lookup2.value().DoubleOrDie(), 1.0); - - auto lookup3 = (*cel_map)[CelValue::CreateString(&kFields[2])]; + ASSERT_TRUE(lookup2->IsDouble()); + EXPECT_DOUBLE_EQ(lookup2->DoubleOrDie(), 1.0); + + CelValue field3 = CelValue::CreateString(&kFields[2]); + auto field3_presence = cel_map->Has(field3); + ASSERT_OK(field3_presence); + EXPECT_TRUE(*field3_presence); + auto lookup3 = (*cel_map)[field3]; ASSERT_TRUE(lookup3.has_value()); - ASSERT_TRUE(lookup3.value().IsString()); - EXPECT_EQ(lookup3.value().StringOrDie().value(), "test"); + ASSERT_TRUE(lookup3->IsString()); + EXPECT_EQ(lookup3->StringOrDie().value(), "test"); + + std::string missing = "missing_field"; + CelValue missing_field = CelValue::CreateString(&missing); + auto missing_field_presence = cel_map->Has(missing_field); + ASSERT_OK(missing_field_presence); + EXPECT_FALSE(*missing_field_presence); - const CelList* key_list = cel_map->ListKeys(); + const CelList* key_list = cel_map->ListKeys().value(); ASSERT_EQ(key_list->size(), kFields.size()); std::vector result_keys; @@ -203,9 +303,64 @@ TEST(CelProtoWrapperTest, TestValueFieldStruct) { EXPECT_THAT(result_keys, UnorderedPointwise(Eq(), kFields)); } -TEST(CelProtoWrapperTest, TestListFieldStruct) { - ::google::protobuf::Arena arena; +// Test support for google::protobuf::Struct when it is created as dynamic +// message +TEST_F(CelProtoWrapperTest, UnwrapDynamicStruct) { + Struct struct_msg; + const std::string kFieldInt = "field_int"; + const std::string kFieldBool = "field_bool"; + (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); + (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(struct_msg).get(), arena()); + EXPECT_TRUE(value.IsMap()); + const CelMap* cel_map = value.MapOrDie(); + ASSERT_TRUE(cel_map != nullptr); + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsDouble()); + EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); + } + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsBool()); + EXPECT_EQ(v.BoolOrDie(), true); + } + { + auto presence = cel_map->Has(CelValue::CreateBool(true)); + ASSERT_FALSE(presence.ok()); + EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); + auto lookup = (*cel_map)[CelValue::CreateBool(true)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsError()); + } +} + +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueStruct) { + const std::string kField1 = "field1"; + const std::string kField2 = "field2"; + Value value_msg; + (*value_msg.mutable_struct_value()->mutable_fields())[kField1] + .set_number_value(1); + (*value_msg.mutable_struct_value()->mutable_fields())[kField2] + .set_number_value(2); + + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsMap()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); +} + +TEST_F(CelProtoWrapperTest, UnwrapValueList) { const std::vector kFields = {"field1", "field2", "field3"}; ListValue list_value; @@ -214,7 +369,7 @@ TEST(CelProtoWrapperTest, TestListFieldStruct) { list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&list_value, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); ASSERT_TRUE(value.IsList()); const CelList* cel_list = value.ListOrDie(); @@ -234,438 +389,487 @@ TEST(CelProtoWrapperTest, TestListFieldStruct) { EXPECT_EQ(value3.StringOrDie().value(), "test"); } -// Test support of google.protobuf.Any in CelValue. -TEST(CelProtoWrapperTest, TestAnyValue) { - ::google::protobuf::Arena arena; - Any any; +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueListValue) { + Value value_msg; + value_msg.mutable_list_value()->add_values()->set_number_value(1.); + value_msg.mutable_list_value()->add_values()->set_number_value(2.); + + CelValue value = + CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); + EXPECT_TRUE(value.IsList()); + EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); + EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); +} +// Test support of google.protobuf.Any in CelValue. +TEST_F(CelProtoWrapperTest, UnwrapAnyValue) { TestMessage test_message; test_message.set_string_value("test"); + Any any; any.PackFrom(test_message); - - CelValue value = CelProtoWrapper::CreateMessage(&any, &arena); - ASSERT_TRUE(value.IsMessage()); - - const google::protobuf::Message* unpacked_message = value.MessageOrDie(); - EXPECT_THAT(test_message, testutil::EqualsProto(*unpacked_message)); + ExpectUnwrappedMessage(any, &test_message); } -TEST(CelProtoWrapperTest, TestHandlingInvalidAnyValue) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInvalidAny) { Any any; - - CelValue value = CelProtoWrapper::CreateMessage(&any, &arena); + CelValue value = CelProtoWrapper::CreateMessage(&any, arena()); ASSERT_TRUE(value.IsError()); any.set_type_url("/"); - ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, &arena).IsError()); + ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); any.set_type_url("/invalid.proto.name"); - ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, &arena).IsError()); + ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); } // Test support of google.protobuf.Value wrappers in CelValue. -TEST(CelProtoWrapperTest, TestBoolWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapBoolWrapper) { + bool value = true; BoolValue wrapper; - wrapper.set_value(true); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsBool()); - - EXPECT_EQ(value.BoolOrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInt32Wrapper) { + int64_t value = 12; Int32Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsInt64()); - - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestUInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapUInt32Wrapper) { + uint64_t value = 12; UInt32Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsUint64()); - - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapInt64Wrapper) { + int64_t value = 12; Int64Value wrapper; - wrapper.set_value(12); - - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsInt64()); - - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestUInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapUInt64Wrapper) { + uint64_t value = 12; UInt64Value wrapper; - wrapper.set_value(12); + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsUint64()); +TEST_F(CelProtoWrapperTest, UnwrapFloatWrapper) { + double value = 42.5; - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + FloatValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestFloatWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapDoubleWrapper) { + double value = 42.5; - FloatValue wrapper; - wrapper.set_value(42); + DoubleValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsDouble()); +TEST_F(CelProtoWrapperTest, UnwrapStringWrapper) { + std::string text = "42"; + auto value = CelValue::StringHolder(&text); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + StringValue wrapper; + wrapper.set_value(text); + ExpectUnwrappedPrimitive(wrapper, value); } -TEST(CelProtoWrapperTest, TestDoubleWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, UnwrapBytesWrapper) { + std::string text = "42"; + auto value = CelValue::BytesHolder(&text); - DoubleValue wrapper; - wrapper.set_value(42); + BytesValue wrapper; + wrapper.set_value("42"); + ExpectUnwrappedPrimitive(wrapper, value); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsDouble()); +TEST_F(CelProtoWrapperTest, WrapNull) { + auto cel_value = CelValue::CreateNull(); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + Value json; + json.set_null_value(protobuf::NULL_VALUE); + ExpectWrappedMessage(cel_value, json); + + Any any; + any.PackFrom(json); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, TestStringWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapBool) { + auto cel_value = CelValue::CreateBool(true); - StringValue wrapper; - wrapper.set_value("42"); + Value json; + json.set_bool_value(true); + ExpectWrappedMessage(cel_value, json); - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsString()); + BoolValue wrapper; + wrapper.set_value(true); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_EQ(value.StringOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, TestBytesWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapBytes) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); BytesValue wrapper; - wrapper.set_value("42"); + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} - CelValue value = CelProtoWrapper::CreateMessage(&wrapper, &arena); - ASSERT_TRUE(value.IsBytes()); +TEST_F(CelProtoWrapperTest, WrapBytesToValue) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); - EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); + Value json; + json.set_string_value("aGVsbG8gd29ybGQ="); + ExpectWrappedMessage(cel_value, json); } -// Test support for google::protobuf::Struct when it is created as dynamic -// message -TEST(CelProtoWrapperTest, DynamicStructSupport) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDuration) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); - google::protobuf::DynamicMessageFactory factory; - { - Struct struct_msg; - - const std::string kFieldInt = "field_int"; - const std::string kFieldBool = "field_bool"; - - (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); - (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); - std::unique_ptr dynamic_struct( - factory.GetPrototype(Struct::descriptor())->New()); - dynamic_struct->CopyFrom(struct_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_struct.get(), &arena); - EXPECT_TRUE(value.IsMap()); - const CelMap* cel_map = value.MapOrDie(); - ASSERT_TRUE(cel_map != nullptr); - - { - auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; - ASSERT_TRUE(lookup.has_value()); - auto v = lookup.value(); - ASSERT_TRUE(v.IsDouble()); - EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); - } - { - auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; - ASSERT_TRUE(lookup.has_value()); - auto v = lookup.value(); - ASSERT_TRUE(v.IsBool()); - EXPECT_EQ(v.BoolOrDie(), true); - } - } + Duration d; + d.set_seconds(300); + ExpectWrappedMessage(cel_value, d); + + Any any; + any.PackFrom(d); + ExpectWrappedMessage(cel_value, any); } -// Test support for google::protobuf::Value when it is created as dynamic -// message -TEST(CelProtoWrapperTest, DynamicValueSupport) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDurationToValue) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); - google::protobuf::DynamicMessageFactory factory; - // Null - { - Value value_msg; - value_msg.set_null_value(protobuf::NULL_VALUE); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsNull()); - } - // Boolean - { - Value value_msg; - value_msg.set_bool_value(true); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsBool()); - EXPECT_TRUE(value.BoolOrDie()); - } - // Numeric - { - Value value_msg; - value_msg.set_number_value(1.0); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsDouble()); - EXPECT_THAT(value.DoubleOrDie(), testing::DoubleEq(1.)); - } - // String - { - Value value_msg; - value_msg.set_string_value("test"); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsString()); - EXPECT_THAT(value.StringOrDie().value(), Eq("test")); - } - // List - { - Value value_msg; - value_msg.mutable_list_value()->add_values()->set_number_value(1.); - value_msg.mutable_list_value()->add_values()->set_number_value(2.); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsList()); - EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); - EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); - } - // Struct - { - const std::string kField1 = "field1"; - const std::string kField2 = "field2"; - - Value value_msg; - (*value_msg.mutable_struct_value()->mutable_fields())[kField1] - .set_number_value(1); - (*value_msg.mutable_struct_value()->mutable_fields())[kField2] - .set_number_value(2); - std::unique_ptr dynamic_value( - factory.GetPrototype(Value::descriptor())->New()); - dynamic_value->CopyFrom(value_msg); - CelValue value = - CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - EXPECT_TRUE(value.IsMap()); - EXPECT_TRUE( - (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); - EXPECT_TRUE( - (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); - } + Value json; + json.set_string_value("300s"); + ExpectWrappedMessage(cel_value, json); } -// Test support of google.protobuf.Value wrappers in CelValue. -TEST(CelProtoWrapperTest, DynamicBoolWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDouble) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); - BoolValue wrapper; - wrapper.set_value(true); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(BoolValue::descriptor())->New()); - dynamic_value->CopyFrom(wrapper); + Value json; + json.set_number_value(num); + ExpectWrappedMessage(cel_value, json); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - ASSERT_TRUE(value.IsBool()); + DoubleValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_EQ(value.BoolOrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapDoubleToFloatValue) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); - Int32Value wrapper; - wrapper.set_value(12); + FloatValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); + // Imprecise double -> float representation results in truncation. + double small_num = -9.9e-100; + wrapper.set_value(small_num); + cel_value = CelValue::CreateDouble(small_num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapDoubleOverflow) { + double lowest_double = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateDouble(lowest_double); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + // Double exceeds float precision, overflow to -infinity. + FloatValue wrapper; + wrapper.set_value(-std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); - ASSERT_TRUE(value.IsInt64()); + double max_double = std::numeric_limits::max(); + cel_value = CelValue::CreateDouble(max_double); - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + wrapper.set_value(std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicUInt32Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - UInt32Value wrapper; - wrapper.set_value(12); + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Int64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - ASSERT_TRUE(value.IsUint64()); - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamocInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64ToInt32Value) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - Int64Value wrapper; - wrapper.set_value(12); + Int32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapFailureInt64ToInt32Value) { + int64_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); - EXPECT_EQ(value.Int64OrDie(), wrapper.value()); + Int32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicUInt64Wrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapInt64ToValue) { + int64_t max = std::numeric_limits::max(); + auto cel_value = CelValue::CreateInt64(max); - UInt64Value wrapper; - wrapper.set_value(12); + Value json; + json.set_string_value(absl::StrCat(max)); + ExpectWrappedMessage(cel_value, json); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); - ASSERT_TRUE(value.IsUint64()); + int64_t min = std::numeric_limits::min(); + cel_value = CelValue::CreateInt64(min); - EXPECT_EQ(value.Uint64OrDie(), wrapper.value()); + json.set_string_value(absl::StrCat(min)); + ExpectWrappedMessage(cel_value, json); } -TEST(CelProtoWrapperTest, DynamicFloatWrapper) { - ::google::protobuf::Arena arena; - - FloatValue wrapper; - wrapper.set_value(42); +TEST_F(CelProtoWrapperTest, WrapUint64) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); - ASSERT_TRUE(value.IsDouble()); + UInt64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicDoubleWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapUint64ToUint32Value) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - DoubleValue wrapper; - wrapper.set_value(42); + UInt32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapUint64ToValue) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + Value json; + json.set_string_value(absl::StrCat(num)); + ExpectWrappedMessage(cel_value, json); +} - ASSERT_TRUE(value.IsDouble()); +TEST_F(CelProtoWrapperTest, WrapFailureUint64ToUint32Value) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); - EXPECT_DOUBLE_EQ(value.DoubleOrDie(), wrapper.value()); + UInt32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); } -TEST(CelProtoWrapperTest, DynamicStringWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapString) { + std::string str = "test"; + auto cel_value = CelValue::CreateString(CelValue::StringHolder(&str)); + + Value json; + json.set_string_value(str); + ExpectWrappedMessage(cel_value, json); StringValue wrapper; - wrapper.set_value("42"); + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapTimestamp) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); - ASSERT_TRUE(value.IsString()); + Timestamp t; + t.set_seconds(1615852799); + ExpectWrappedMessage(cel_value, t); - EXPECT_EQ(value.StringOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(t); + ExpectWrappedMessage(cel_value, any); } -TEST(CelProtoWrapperTest, DynamicBytesWrapper) { - ::google::protobuf::Arena arena; +TEST_F(CelProtoWrapperTest, WrapTimestampToValue) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); - BytesValue wrapper; - wrapper.set_value("42"); + Value json; + json.set_string_value("2021-03-15T23:59:59Z"); + ExpectWrappedMessage(cel_value, json); +} - google::protobuf::DynamicMessageFactory factory; - std::unique_ptr dynamic_value( - factory.GetPrototype(wrapper.descriptor())->New()); - dynamic_value->CopyFrom(wrapper); - CelValue value = CelProtoWrapper::CreateMessage(dynamic_value.get(), &arena); +TEST_F(CelProtoWrapperTest, WrapList) { + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelValue::CreateInt64(-2L), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); - ASSERT_TRUE(value.IsBytes()); + Value json; + json.mutable_list_value()->add_values()->set_number_value(1.5); + json.mutable_list_value()->add_values()->set_number_value(-2.); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.list_value()); - EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); + Any any; + any.PackFrom(json.list_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureListValueBadJSON) { + TestMessage message; + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelProtoWrapper::CreateMessage(&message, arena()), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapStruct) { + const std::string kField1 = "field1"; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelValue::CreateBool(true)}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + (*json.mutable_struct_value()->mutable_fields())[kField1].set_bool_value( + true); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.struct_value()); + + Any any; + any.PackFrom(json.struct_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadKeyType) { + std::vector> args = { + {CelValue::CreateInt64(1L), CelValue::CreateBool(true)}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadValueType) { + const std::string kField1 = "field1"; + TestMessage bad_value; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelProtoWrapper::CreateMessage(&bad_value, arena())}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureWrongType) { + auto cel_value = CelValue::CreateNull(); + std::vector wrong_types = { + &BoolValue::default_instance(), &BytesValue::default_instance(), + &DoubleValue::default_instance(), &Duration::default_instance(), + &FloatValue::default_instance(), &Int32Value::default_instance(), + &Int64Value::default_instance(), &ListValue::default_instance(), + &StringValue::default_instance(), &Struct::default_instance(), + &Timestamp::default_instance(), &UInt32Value::default_instance(), + &UInt64Value::default_instance(), + }; + for (const auto* wrong_type : wrong_types) { + ExpectNotWrapped(cel_value, *wrong_type); + } +} + +TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { + auto cel_value = CreateNoSuchFieldError(arena(), "error_field"); + ExpectNotWrapped(cel_value, Any::default_instance()); } -TEST(CelProtoWrapperTest, DebugString) { +// A CelMap implementation that returns an error for the ListKeys() method. +class InvalidListKeysCelMapBuilder : public CelMapBuilder { + public: + absl::StatusOr ListKeys() const override { + return absl::InternalError("Error while invoking ListKeys()"); + } +}; + +TEST_F(CelProtoWrapperTest, DebugString) { google::protobuf::Empty e; - ::google::protobuf::Arena arena; - EXPECT_EQ(CelProtoWrapper::CreateMessage(&e, &arena).DebugString(), - "Message: "); + EXPECT_THAT(CelProtoWrapper::CreateMessage(&e, arena()).DebugString(), + testing::StartsWith("Message: ")); ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); - CelValue value = CelProtoWrapper::CreateMessage(&list_value, &arena); - EXPECT_EQ(value.DebugString(), "List, size: 3"); + CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); + EXPECT_EQ(value.DebugString(), + "CelList: [bool: 1, double: 1.000000, string: test]"); Struct value_struct; auto& value1 = (*value_struct.mutable_fields())["a"]; @@ -675,11 +879,20 @@ TEST(CelProtoWrapperTest, DebugString) { auto& value3 = (*value_struct.mutable_fields())["c"]; value3.set_string_value("test"); - value = CelProtoWrapper::CreateMessage(&value_struct, &arena); - EXPECT_EQ(value.DebugString(), "Map, size: 3"); + value = CelProtoWrapper::CreateMessage(&value_struct, arena()); + EXPECT_THAT( + value.DebugString(), + testing::AllOf(testing::StartsWith("CelMap: {"), + testing::HasSubstr(": "), + testing::HasSubstr(": : "))); + + // DebugString of a CelMap with an invalid internal list. + InvalidListKeysCelMapBuilder invalid_cel_map; + auto cel_map_value = CelValue::CreateMap(&invalid_cel_map); + EXPECT_EQ(cel_map_value.DebugString(), "CelMap: invalid list keys"); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc b/eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc new file mode 100644 index 000000000..ae04cead5 --- /dev/null +++ b/eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc @@ -0,0 +1,351 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/message_differencer.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::protobuf::DescriptorPool; + +constexpr int32_t kStartingFieldNumber = 600; +constexpr int32_t kIntFieldNumber = kStartingFieldNumber; +constexpr int32_t kStringFieldNumber = kStartingFieldNumber + 1; +constexpr int32_t kMessageFieldNumber = kStartingFieldNumber + 2; + +MATCHER_P(CelEqualsProto, msg, + absl::StrCat("CEL Equals ", msg->ShortDebugString())) { + const google::protobuf::Message* got = arg; + const google::protobuf::Message* want = msg; + + return google::protobuf::util::MessageDifferencer::Equals(*got, *want); +} + +// Simulate a dynamic descriptor pool with an alternate definition for a linked +// type. +absl::Status AddTestTypes(DescriptorPool& pool) { + google::protobuf::FileDescriptorProto file_descriptor; + + TestAllTypes::descriptor()->file()->CopyTo(&file_descriptor); + auto* message_type_entry = file_descriptor.mutable_message_type(0); + + auto* dynamic_int_field = message_type_entry->add_field(); + dynamic_int_field->set_number(kIntFieldNumber); + dynamic_int_field->set_name("dynamic_int_field"); + dynamic_int_field->set_type(google::protobuf::FieldDescriptorProto::TYPE_INT64); + auto* dynamic_string_field = message_type_entry->add_field(); + dynamic_string_field->set_number(kStringFieldNumber); + dynamic_string_field->set_name("dynamic_string_field"); + dynamic_string_field->set_type(google::protobuf::FieldDescriptorProto::TYPE_STRING); + auto* dynamic_message_field = message_type_entry->add_field(); + dynamic_message_field->set_number(kMessageFieldNumber); + dynamic_message_field->set_name("dynamic_message_field"); + dynamic_message_field->set_type(google::protobuf::FieldDescriptorProto::TYPE_MESSAGE); + dynamic_message_field->set_type_name( + ".cel.expr.conformance.proto3.TestAllTypes"); + + CEL_RETURN_IF_ERROR(AddStandardMessageTypesToDescriptorPool(pool)); + if (!pool.BuildFile(file_descriptor)) { + return absl::InternalError( + "failed initializing custom descriptor pool for test."); + } + + return absl::OkStatus(); +} + +class DynamicDescriptorPoolTest : public ::testing::Test { + public: + DynamicDescriptorPoolTest() : factory_(&descriptor_pool_) {} + + void SetUp() override { ASSERT_OK(AddTestTypes(descriptor_pool_)); } + + protected: + absl::StatusOr> CreateMessageFromText( + absl::string_view text_format) { + const google::protobuf::Descriptor* dynamic_desc = + descriptor_pool_.FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"); + auto message = absl::WrapUnique(factory_.GetPrototype(dynamic_desc)->New()); + + if (!google::protobuf::TextFormat::ParseFromString(text_format, message.get())) { + return absl::InvalidArgumentError( + "invalid text format for dynamic message"); + } + + return message; + } + + DescriptorPool descriptor_pool_; + google::protobuf::DynamicMessageFactory factory_; + google::protobuf::Arena arena_; +}; + +TEST_F(DynamicDescriptorPoolTest, FieldAccess) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr message, + CreateMessageFromText("dynamic_int_field: 42")); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("msg.dynamic_int_field < 50")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); + act.InsertValue("msg", val); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST_F(DynamicDescriptorPoolTest, Create) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + builder->set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse( + R"cel( + TestAllTypes{ + dynamic_int_field: 42, + dynamic_string_field: "string", + dynamic_message_field: TestAllTypes{dynamic_int_field: 50 } + } + )cel")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + ASSERT_OK_AND_ASSIGN(auto expected, CreateMessageFromText(R"pb( + dynamic_int_field: 42 + dynamic_string_field: "string" + dynamic_message_field { dynamic_int_field: 50 } + )pb")); + + EXPECT_THAT(result, test::IsCelMessage(CelEqualsProto(expected.get()))); +} + +TEST_F(DynamicDescriptorPoolTest, AnyUnpack) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN( + auto message, CreateMessageFromText(R"pb( + single_any { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { + dynamic_int_field: 45 + } + } + )pb")); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("msg.single_any.dynamic_int_field < 50")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); + act.InsertValue("msg", val); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST_F(DynamicDescriptorPoolTest, AnyWrapperUnpack) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN( + auto message, CreateMessageFromText(R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int64Value] { value: 45 } + } + )pb")); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("msg.single_any < 50")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); + act.InsertValue("msg", val); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST_F(DynamicDescriptorPoolTest, AnyUnpackRepeated) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN( + auto message, CreateMessageFromText(R"pb( + repeated_any { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { + dynamic_int_field: 0 + } + } + repeated_any { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { + dynamic_int_field: 1 + } + } + )pb")); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("msg.repeated_any.exists(x, x.dynamic_int_field > 2)")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); + act.InsertValue("msg", val); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST_F(DynamicDescriptorPoolTest, AnyPack) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + builder->set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( + TestAllTypes{ + single_any: TestAllTypes{dynamic_int_field: 42} + })cel")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + ASSERT_OK_AND_ASSIGN( + auto expected_message, CreateMessageFromText(R"pb( + single_any { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { + dynamic_int_field: 42 + } + } + )pb")); + EXPECT_THAT(result, + test::IsCelMessage(CelEqualsProto(expected_message.get()))); +} + +TEST_F(DynamicDescriptorPoolTest, AnyWrapperPack) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + builder->set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( + TestAllTypes{ + single_any: 42 + })cel")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + ASSERT_OK_AND_ASSIGN( + auto expected_message, CreateMessageFromText(R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int64Value] { value: 42 } + } + )pb")); + EXPECT_THAT(result, + test::IsCelMessage(CelEqualsProto(expected_message.get()))); +} + +TEST_F(DynamicDescriptorPoolTest, AnyPackRepeated) { + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + builder->set_container("cel.expr.conformance.proto3"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( + TestAllTypes{ + repeated_any: [ + TestAllTypes{dynamic_int_field: 0}, + TestAllTypes{dynamic_int_field: 1}, + ] + })cel")); + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation act; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); + + ASSERT_OK_AND_ASSIGN( + auto expected_message, CreateMessageFromText(R"pb( + repeated_any { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { + dynamic_int_field: 0 + } + } + repeated_any { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { + dynamic_int_field: 1 + } + } + )pb")); + EXPECT_THAT(result, + test::IsCelMessage(CelEqualsProto(expected_message.get()))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc new file mode 100644 index 000000000..2bd9fff9d --- /dev/null +++ b/eval/public/structs/field_access_impl.cc @@ -0,0 +1,746 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/field_access_impl.h" + +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "eval/public/structs/cel_proto_wrap_util.h" +#include "internal/casts.h" +#include "internal/overflow.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" + +#undef GetMessage + +namespace google::api::expr::runtime::internal { + +namespace { + +using ::google::protobuf::Arena; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::MapValueConstRef; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +// Singular message fields and repeated message fields have similar access model +// To provide common approach, we implement accessor classes, based on CRTP. +// FieldAccessor is CRTP base class, specifying Get.. method family. +template +class FieldAccessor { + public: + bool GetBool() const { return static_cast(this)->GetBool(); } + + int64_t GetInt32() const { + return static_cast(this)->GetInt32(); + } + + uint64_t GetUInt32() const { + return static_cast(this)->GetUInt32(); + } + + int64_t GetInt64() const { + return static_cast(this)->GetInt64(); + } + + uint64_t GetUInt64() const { + return static_cast(this)->GetUInt64(); + } + + double GetFloat() const { + return static_cast(this)->GetFloat(); + } + + double GetDouble() const { + return static_cast(this)->GetDouble(); + } + + absl::string_view GetString(std::string* buffer) const { + return static_cast(this)->GetString(buffer); + } + + const Message* GetMessage() const { + return static_cast(this)->GetMessage(); + } + + int64_t GetEnumValue() const { + return static_cast(this)->GetEnumValue(); + } + + // This method provides message field content, wrapped in CelValue. + // If value provided successfully, return a CelValue, otherwise returns a + // status with non-ok status code. + // + // arena Arena to use for allocations if needed. + absl::StatusOr CreateValueFromFieldAccessor(Arena* arena) { + switch (field_desc_->cpp_type()) { + case FieldDescriptor::CPPTYPE_BOOL: { + bool value = GetBool(); + return CelValue::CreateBool(value); + } + case FieldDescriptor::CPPTYPE_INT32: { + int64_t value = GetInt32(); + return CelValue::CreateInt64(value); + } + case FieldDescriptor::CPPTYPE_INT64: { + int64_t value = GetInt64(); + return CelValue::CreateInt64(value); + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint64_t value = GetUInt32(); + return CelValue::CreateUint64(value); + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64_t value = GetUInt64(); + return CelValue::CreateUint64(value); + } + case FieldDescriptor::CPPTYPE_FLOAT: { + double value = GetFloat(); + return CelValue::CreateDouble(value); + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + double value = GetDouble(); + return CelValue::CreateDouble(value); + } + case FieldDescriptor::CPPTYPE_STRING: { + std::string buffer; + absl::string_view value = GetString(&buffer); + if (value.data() == buffer.data() && value.size() == buffer.size()) { + value = absl::string_view( + *google::protobuf::Arena::Create(arena, std::move(buffer))); + } + switch (field_desc_->type()) { + case FieldDescriptor::TYPE_STRING: + return CelValue::CreateStringView(value); + case FieldDescriptor::TYPE_BYTES: + return CelValue::CreateBytesView(value); + default: + break; + } + break; + } + case FieldDescriptor::CPPTYPE_MESSAGE: { + const google::protobuf::Message* msg_value = GetMessage(); + return UnwrapMessageToValue(msg_value, protobuf_value_factory_, arena); + } + case FieldDescriptor::CPPTYPE_ENUM: { + int enum_value = GetEnumValue(); + return CelValue::CreateInt64(enum_value); + } + default: + break; + } + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unhandled C++ type conversion"); + } + + protected: + FieldAccessor(const Message* msg, const FieldDescriptor* field_desc, + const ProtobufValueFactory& protobuf_value_factory) + : msg_(msg), + field_desc_(field_desc), + protobuf_value_factory_(protobuf_value_factory) {} + + const Message* msg_; + const FieldDescriptor* field_desc_; + const ProtobufValueFactory& protobuf_value_factory_; +}; + +const absl::flat_hash_set& WellKnownWrapperTypes() { + static auto* wrapper_types = new absl::flat_hash_set{ + "google.protobuf.BoolValue", "google.protobuf.DoubleValue", + "google.protobuf.FloatValue", "google.protobuf.Int64Value", + "google.protobuf.Int32Value", "google.protobuf.UInt64Value", + "google.protobuf.UInt32Value", "google.protobuf.StringValue", + "google.protobuf.BytesValue", + }; + return *wrapper_types; +} + +bool IsWrapperType(const FieldDescriptor* field_descriptor) { + return WellKnownWrapperTypes().find( + field_descriptor->message_type()->full_name()) != + WellKnownWrapperTypes().end(); +} + +// Accessor class, to work with singular fields +class ScalarFieldAccessor : public FieldAccessor { + public: + ScalarFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, + bool unset_wrapper_as_null, + const ProtobufValueFactory& factory) + : FieldAccessor(msg, field_desc, factory), + unset_wrapper_as_null_(unset_wrapper_as_null) {} + + bool GetBool() const { return GetReflection()->GetBool(*msg_, field_desc_); } + + int64_t GetInt32() const { + return GetReflection()->GetInt32(*msg_, field_desc_); + } + + uint64_t GetUInt32() const { + return GetReflection()->GetUInt32(*msg_, field_desc_); + } + + int64_t GetInt64() const { + return GetReflection()->GetInt64(*msg_, field_desc_); + } + + uint64_t GetUInt64() const { + return GetReflection()->GetUInt64(*msg_, field_desc_); + } + + double GetFloat() const { + return GetReflection()->GetFloat(*msg_, field_desc_); + } + + double GetDouble() const { + return GetReflection()->GetDouble(*msg_, field_desc_); + } + + absl::string_view GetString(std::string* buffer) const { + return GetReflection()->GetStringReference(*msg_, field_desc_, buffer); + } + + const Message* GetMessage() const { + // Unset wrapper types have special semantics. + // If set, return the unwrapped value, else return 'null'. + if (unset_wrapper_as_null_ && + !GetReflection()->HasField(*msg_, field_desc_) && + IsWrapperType(field_desc_)) { + return nullptr; + } + return &GetReflection()->GetMessage(*msg_, field_desc_); + } + + int64_t GetEnumValue() const { + return GetReflection()->GetEnumValue(*msg_, field_desc_); + } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } + + private: + bool unset_wrapper_as_null_; +}; + +// Accessor class, to work with repeated fields. +class RepeatedFieldAccessor : public FieldAccessor { + public: + RepeatedFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, + int index, const ProtobufValueFactory& factory) + : FieldAccessor(msg, field_desc, factory), index_(index) {} + + bool GetBool() const { + return GetReflection()->GetRepeatedBool(*msg_, field_desc_, index_); + } + + int64_t GetInt32() const { + return GetReflection()->GetRepeatedInt32(*msg_, field_desc_, index_); + } + + uint64_t GetUInt32() const { + return GetReflection()->GetRepeatedUInt32(*msg_, field_desc_, index_); + } + + int64_t GetInt64() const { + return GetReflection()->GetRepeatedInt64(*msg_, field_desc_, index_); + } + + uint64_t GetUInt64() const { + return GetReflection()->GetRepeatedUInt64(*msg_, field_desc_, index_); + } + + double GetFloat() const { + return GetReflection()->GetRepeatedFloat(*msg_, field_desc_, index_); + } + + double GetDouble() const { + return GetReflection()->GetRepeatedDouble(*msg_, field_desc_, index_); + } + + absl::string_view GetString(std::string* buffer) const { + return GetReflection()->GetRepeatedStringReference(*msg_, field_desc_, + index_, buffer); + } + + const Message* GetMessage() const { + return &GetReflection()->GetRepeatedMessage(*msg_, field_desc_, index_); + } + + int64_t GetEnumValue() const { + return GetReflection()->GetRepeatedEnumValue(*msg_, field_desc_, index_); + } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } + + private: + int index_; +}; + +// Accessor class, to work with map values +class MapValueAccessor : public FieldAccessor { + public: + MapValueAccessor(const Message* msg, const FieldDescriptor* field_desc, + const MapValueConstRef* value_ref, + const ProtobufValueFactory& factory) + : FieldAccessor(msg, field_desc, factory), value_ref_(value_ref) {} + + bool GetBool() const { return value_ref_->GetBoolValue(); } + + int64_t GetInt32() const { return value_ref_->GetInt32Value(); } + + uint64_t GetUInt32() const { return value_ref_->GetUInt32Value(); } + + int64_t GetInt64() const { return value_ref_->GetInt64Value(); } + + uint64_t GetUInt64() const { return value_ref_->GetUInt64Value(); } + + double GetFloat() const { return value_ref_->GetFloatValue(); } + + double GetDouble() const { return value_ref_->GetDoubleValue(); } + + absl::string_view GetString(std::string* /*buffer*/) const { + return value_ref_->GetStringValue(); + } + + const Message* GetMessage() const { return &value_ref_->GetMessageValue(); } + + int64_t GetEnumValue() const { return value_ref_->GetEnumValue(); } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } + + private: + const MapValueConstRef* value_ref_; +}; + +// Singular message fields and repeated message fields have similar access model +// To provide common approach, we implement field setter classes, based on CRTP. +// FieldAccessor is CRTP base class, specifying Get.. method family. +template +class FieldSetter { + public: + bool AssignBool(const CelValue& cel_value) const { + bool value; + + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetBool(value); + return true; + } + + bool AssignInt32(const CelValue& cel_value) const { + int64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + absl::StatusOr checked_cast = + cel::internal::CheckedInt64ToInt32(value); + if (!checked_cast.ok()) { + return false; + } + static_cast(this)->SetInt32(*checked_cast); + return true; + } + + bool AssignUInt32(const CelValue& cel_value) const { + uint64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + if (!cel::internal::CheckedUint64ToUint32(value).ok()) { + return false; + } + static_cast(this)->SetUInt32(value); + return true; + } + + bool AssignInt64(const CelValue& cel_value) const { + int64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetInt64(value); + return true; + } + + bool AssignUInt64(const CelValue& cel_value) const { + uint64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetUInt64(value); + return true; + } + + bool AssignFloat(const CelValue& cel_value) const { + double value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetFloat(value); + return true; + } + + bool AssignDouble(const CelValue& cel_value) const { + double value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetDouble(value); + return true; + } + + bool AssignString(const CelValue& cel_value) const { + CelValue::StringHolder value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetString(value); + return true; + } + + bool AssignBytes(const CelValue& cel_value) const { + CelValue::BytesHolder value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetBytes(value); + return true; + } + + bool AssignEnum(const CelValue& cel_value) const { + int64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + if (!cel::internal::CheckedInt64ToInt32(value).ok()) { + return false; + } + static_cast(this)->SetEnum(value); + return true; + } + + bool AssignMessage(const google::protobuf::Message* message) const { + return static_cast(this)->SetMessage(message); + } + + // This method provides message field content, wrapped in CelValue. + // If value provided successfully, returns Ok. + // arena Arena to use for allocations if needed. + // result pointer to object to store value in. + bool SetFieldFromCelValue(const CelValue& value) { + switch (field_desc_->cpp_type()) { + case FieldDescriptor::CPPTYPE_BOOL: { + return AssignBool(value); + } + case FieldDescriptor::CPPTYPE_INT32: { + return AssignInt32(value); + } + case FieldDescriptor::CPPTYPE_INT64: { + return AssignInt64(value); + } + case FieldDescriptor::CPPTYPE_UINT32: { + return AssignUInt32(value); + } + case FieldDescriptor::CPPTYPE_UINT64: { + return AssignUInt64(value); + } + case FieldDescriptor::CPPTYPE_FLOAT: { + return AssignFloat(value); + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + return AssignDouble(value); + } + case FieldDescriptor::CPPTYPE_STRING: { + switch (field_desc_->type()) { + case FieldDescriptor::TYPE_STRING: + + return AssignString(value); + case FieldDescriptor::TYPE_BYTES: + return AssignBytes(value); + default: + return false; + } + break; + } + case FieldDescriptor::CPPTYPE_MESSAGE: { + // When the field is a message, it might be a well-known type with a + // non-proto representation that requires special handling before it + // can be set on the field. + const google::protobuf::Message* wrapped_value = MaybeWrapValueToMessage( + field_desc_->message_type(), + msg_->GetReflection()->GetMessageFactory(), value, arena_); + if (wrapped_value == nullptr) { + // It we aren't unboxing to a protobuf null representation, setting a + // field to null is a no-op. + if (value.IsNull()) { + return true; + } + if (CelValue::MessageWrapper wrapper; + value.GetValue(&wrapper) && wrapper.HasFullProto()) { + wrapped_value = + static_cast(wrapper.message_ptr()); + } else { + return false; + } + } + + return AssignMessage(wrapped_value); + } + case FieldDescriptor::CPPTYPE_ENUM: { + return AssignEnum(value); + } + default: + return false; + } + + return true; + } + + protected: + FieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) + : msg_(msg), field_desc_(field_desc), arena_(arena) {} + + Message* msg_; + const FieldDescriptor* field_desc_; + Arena* arena_; +}; + +bool MergeFromWithSerializeFallback(const google::protobuf::Message& value, + google::protobuf::Message& field) { + if (field.GetDescriptor() == value.GetDescriptor()) { + field.MergeFrom(value); + return true; + } + // TODO(uncreated-issue/26): this indicates means we're mixing dynamic messages with + // generated messages. This is expected for WKTs where CEL explicitly requires + // wire format compatibility, but this may not be the expected behavior for + // other types. + return field.MergeFromString(value.SerializeAsString()); +} + +// Accessor class, to work with singular fields +class ScalarFieldSetter : public FieldSetter { + public: + ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} + + bool SetBool(bool value) const { + GetReflection()->SetBool(msg_, field_desc_, value); + return true; + } + + bool SetInt32(int32_t value) const { + GetReflection()->SetInt32(msg_, field_desc_, value); + return true; + } + + bool SetUInt32(uint32_t value) const { + GetReflection()->SetUInt32(msg_, field_desc_, value); + return true; + } + + bool SetInt64(int64_t value) const { + GetReflection()->SetInt64(msg_, field_desc_, value); + return true; + } + + bool SetUInt64(uint64_t value) const { + GetReflection()->SetUInt64(msg_, field_desc_, value); + return true; + } + + bool SetFloat(float value) const { + GetReflection()->SetFloat(msg_, field_desc_, value); + return true; + } + + bool SetDouble(double value) const { + GetReflection()->SetDouble(msg_, field_desc_, value); + return true; + } + + bool SetString(CelValue::StringHolder value) const { + GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetBytes(CelValue::BytesHolder value) const { + GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetMessage(const Message* value) const { + if (!value) { + ABSL_LOG(ERROR) << "Message is NULL"; + return true; + } + if (value->GetDescriptor()->full_name() == + field_desc_->message_type()->full_name()) { + auto* assignable_field_msg = + GetReflection()->MutableMessage(msg_, field_desc_); + return MergeFromWithSerializeFallback(*value, *assignable_field_msg); + } + + return false; + } + + bool SetEnum(const int64_t value) const { + GetReflection()->SetEnumValue(msg_, field_desc_, value); + return true; + } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } +}; + +// Appender class, to work with repeated fields +class RepeatedFieldSetter : public FieldSetter { + public: + RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} + + bool SetBool(bool value) const { + GetReflection()->AddBool(msg_, field_desc_, value); + return true; + } + + bool SetInt32(int32_t value) const { + GetReflection()->AddInt32(msg_, field_desc_, value); + return true; + } + + bool SetUInt32(uint32_t value) const { + GetReflection()->AddUInt32(msg_, field_desc_, value); + return true; + } + + bool SetInt64(int64_t value) const { + GetReflection()->AddInt64(msg_, field_desc_, value); + return true; + } + + bool SetUInt64(uint64_t value) const { + GetReflection()->AddUInt64(msg_, field_desc_, value); + return true; + } + + bool SetFloat(float value) const { + GetReflection()->AddFloat(msg_, field_desc_, value); + return true; + } + + bool SetDouble(double value) const { + GetReflection()->AddDouble(msg_, field_desc_, value); + return true; + } + + bool SetString(CelValue::StringHolder value) const { + GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetBytes(CelValue::BytesHolder value) const { + GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetMessage(const Message* value) const { + if (!value) return true; + if (value->GetDescriptor()->full_name() != + field_desc_->message_type()->full_name()) { + return false; + } + + auto* assignable_message = GetReflection()->AddMessage(msg_, field_desc_); + return MergeFromWithSerializeFallback(*value, *assignable_message); + } + + bool SetEnum(const int64_t value) const { + GetReflection()->AddEnumValue(msg_, field_desc_, value); + return true; + } + + private: + const Reflection* GetReflection() const { return msg_->GetReflection(); } +}; + +} // namespace + +absl::StatusOr CreateValueFromSingleField( + const google::protobuf::Message* msg, const FieldDescriptor* desc, + ProtoWrapperTypeOptions options, const ProtobufValueFactory& factory, + google::protobuf::Arena* arena) { + ScalarFieldAccessor accessor( + msg, desc, (options == ProtoWrapperTypeOptions::kUnsetNull), factory); + return accessor.CreateValueFromFieldAccessor(arena); +} + +absl::StatusOr CreateValueFromRepeatedField( + const google::protobuf::Message* msg, const FieldDescriptor* desc, int index, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena) { + RepeatedFieldAccessor accessor(msg, desc, index, factory); + return accessor.CreateValueFromFieldAccessor(arena); +} + +absl::StatusOr CreateValueFromMapValue( + const google::protobuf::Message* msg, const FieldDescriptor* desc, + const MapValueConstRef* value_ref, const ProtobufValueFactory& factory, + google::protobuf::Arena* arena) { + MapValueAccessor accessor(msg, desc, value_ref, factory); + return accessor.CreateValueFromFieldAccessor(arena); +} + +absl::Status SetValueToSingleField(const CelValue& value, + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + ScalarFieldSetter setter(msg, desc, arena); + return (setter.SetFieldFromCelValue(value)) + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( + "Could not assign supplied argument to message \"$0\" field " + "\"$1\" of type $2: value type \"$3\"", + msg->GetDescriptor()->name(), desc->name(), + desc->type_name(), CelValue::TypeName(value.type()))); +} + +absl::Status AddValueToRepeatedField(const CelValue& value, + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + RepeatedFieldSetter setter(msg, desc, arena); + return (setter.SetFieldFromCelValue(value)) + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( + "Could not add supplied argument to message \"$0\" field " + "\"$1\" of type $2: value type \"$3\"", + msg->GetDescriptor()->name(), desc->name(), + desc->type_name(), CelValue::TypeName(value.type()))); +} + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/field_access_impl.h b/eval/public/structs/field_access_impl.h new file mode 100644 index 000000000..78e22e5ba --- /dev/null +++ b/eval/public/structs/field_access_impl.h @@ -0,0 +1,80 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ + +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" + +namespace google::api::expr::runtime::internal { + +// Creates CelValue from singular message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// options Option to enable treating unset wrapper type fields as null. +// arena Arena object to allocate result on, if needed. +// result pointer to CelValue to store the result in. +absl::StatusOr CreateValueFromSingleField( + const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, + ProtoWrapperTypeOptions options, const ProtobufValueFactory& factory, + google::protobuf::Arena* arena); + +// Creates CelValue from repeated message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// arena Arena object to allocate result on, if needed. +// index position in the repeated field. +absl::StatusOr CreateValueFromRepeatedField( + const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, int index, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena); + +// Creates CelValue from map message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// value_ref pointer to map value. +// arena Arena object to allocate result on, if needed. +// TODO(uncreated-issue/7): This should be inlined into the FieldBackedMap +// implementation. +absl::StatusOr CreateValueFromMapValue( + const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, + const google::protobuf::MapValueConstRef* value_ref, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena); + +// Assigns content of CelValue to singular message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when setting the field. +absl::Status SetValueToSingleField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg, google::protobuf::Arena* arena); + +// Adds content of CelValue to repeated message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when adding the value. +absl::Status AddValueToRepeatedField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg, + google::protobuf::Arena* arena); + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ diff --git a/eval/public/structs/field_access_impl_benchmark_test.cc b/eval/public/structs/field_access_impl_benchmark_test.cc new file mode 100644 index 000000000..888e424b1 --- /dev/null +++ b/eval/public/structs/field_access_impl_benchmark_test.cc @@ -0,0 +1,239 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/field_access_impl.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/benchmark.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using ::cel::expr::conformance::proto3::TestAllTypes; + +void BM_CreateValueFromSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Int64); + +void BM_CreateValueFromSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_String); + +void BM_CreateValueFromSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.mutable_standalone_message()->set_bb(123); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Message); + +void BM_CreateValueFromRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_Int64); + +void BM_CreateValueFromRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_String); + +void BM_CreateValueFromMapValue_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + (*msg.mutable_map_int64_int64())[42] = 100; + const google::protobuf::FieldDescriptor* map_desc = + TestAllTypes::descriptor()->FindFieldByName("map_int64_int64"); + const google::protobuf::FieldDescriptor* value_desc = + map_desc->message_type()->FindFieldByName("value"); + + google::protobuf::ConstMapIterator iter = + cel::extensions::protobuf_internal::ConstMapBegin(*msg.GetReflection(), + msg, *map_desc); + google::protobuf::MapValueConstRef value_ref = iter.GetValueRef(); + + for (auto _ : state) { + auto value = + CreateValueFromMapValue(&msg, value_desc, &value_ref, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromMapValue_Int64); + +void BM_SetValueToSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Int64); + +void BM_SetValueToSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_String); + +void BM_SetValueToSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + TestAllTypes::NestedMessage nested_msg; + nested_msg.set_bb(123); + CelValue val = CelProtoWrapper::CreateMessage(&nested_msg, &arena); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Message); + +void BM_AddValueToRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + msg.clear_repeated_int64(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_Int64); + +void BM_AddValueToRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_String); + +void BM_CreateValueFromRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string_piece("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_StringPiece); + +void BM_AddValueToRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string_piece(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_StringPiece); + +} // namespace +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/field_access_impl_test.cc b/eval/public/structs/field_access_impl_test.cc new file mode 100644 index 000000000..d7e6827c6 --- /dev/null +++ b/eval/public/structs/field_access_impl_test.cc @@ -0,0 +1,648 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/field_access_impl.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "internal/time.h" +#include "testutil/util.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace google::api::expr::runtime::internal { + +namespace { + +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::internal::MaxDuration; +using ::cel::internal::MaxTimestamp; +using ::google::protobuf::Arena; +using ::google::protobuf::FieldDescriptor; +using ::testing::HasSubstr; +using testutil::EqualsProto; + +TEST(FieldAccessTest, SetDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField(CelValue::CreateDuration(MaxDuration()), + field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetDurationBadDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField( + CelValue::CreateDuration(MaxDuration() + absl::Seconds(1)), field, &msg, + &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetDurationBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestamp) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField(CelValue::CreateTimestamp(MaxTimestamp()), + field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetTimestampBadTime) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField( + CelValue::CreateTimestamp(MaxTimestamp() + absl::Seconds(1)), field, &msg, + &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestampBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetInt32Overflow) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_int32"); + EXPECT_THAT( + SetValueToSingleField( + CelValue::CreateInt64(std::numeric_limits::max() + 1L), + field, &msg, &arena), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Could not assign"))); +} + +TEST(FieldAccessTest, SetUint32Overflow) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_uint32"); + EXPECT_THAT( + SetValueToSingleField( + CelValue::CreateUint64(std::numeric_limits::max() + 1L), + field, &msg, &arena), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Could not assign"))); +} + +TEST(FieldAccessTest, SetMessage) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + TestAllTypes::NestedMessage* nested_msg = + google::protobuf::Arena::Create(&arena); + nested_msg->set_bb(1); + auto status = SetValueToSingleField( + CelProtoWrapper::CreateMessage(nested_msg, &arena), field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetMessageWithNull) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + auto status = + SetValueToSingleField(CelValue::CreateNull(), field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +struct AccessFieldTestParam { + absl::string_view field_name; + absl::string_view message_textproto; + CelValue cel_value; +}; + +std::string GetTestName( + const testing::TestParamInfo& info) { + return std::string(info.param.field_name); +} + +class SingleFieldTest : public testing::TestWithParam { + public: + absl::string_view field_name() const { return GetParam().field_name; } + absl::string_view message_textproto() const { + return GetParam().message_textproto; + } + CelValue cel_value() const { return GetParam().cel_value; } +}; + +TEST_P(SingleFieldTest, Getter) { + TestAllTypes test_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(message_textproto(), &test_message)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + CelValue accessed_value, + CreateValueFromSingleField( + &test_message, + test_message.GetDescriptor()->FindFieldByName(field_name()), + ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena)); + + EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); +} + +TEST_P(SingleFieldTest, Setter) { + TestAllTypes test_message; + CelValue to_set = cel_value(); + google::protobuf::Arena arena; + + ASSERT_OK(SetValueToSingleField( + to_set, test_message.GetDescriptor()->FindFieldByName(field_name()), + &test_message, &arena)); + + EXPECT_THAT(test_message, EqualsProto(message_textproto())); +} + +INSTANTIATE_TEST_SUITE_P( + AllTypes, SingleFieldTest, + testing::ValuesIn({ + {"single_int32", "single_int32: 1", CelValue::CreateInt64(1)}, + {"single_int64", "single_int64: 1", CelValue::CreateInt64(1)}, + {"single_uint32", "single_uint32: 1", CelValue::CreateUint64(1)}, + {"single_uint64", "single_uint64: 1", CelValue::CreateUint64(1)}, + {"single_sint32", "single_sint32: 1", CelValue::CreateInt64(1)}, + {"single_sint64", "single_sint64: 1", CelValue::CreateInt64(1)}, + {"single_fixed32", "single_fixed32: 1", CelValue::CreateUint64(1)}, + {"single_fixed64", "single_fixed64: 1", CelValue::CreateUint64(1)}, + {"single_sfixed32", "single_sfixed32: 1", CelValue::CreateInt64(1)}, + {"single_sfixed64", "single_sfixed64: 1", CelValue::CreateInt64(1)}, + {"single_float", "single_float: 1.0", CelValue::CreateDouble(1.0)}, + {"single_double", "single_double: 1.0", CelValue::CreateDouble(1.0)}, + {"single_bool", "single_bool: true", CelValue::CreateBool(true)}, + {"single_string", "single_string: 'abcd'", + CelValue::CreateStringView("abcd")}, + {"single_bytes", "single_bytes: 'asdf'", + CelValue::CreateBytesView("asdf")}, + {"standalone_enum", "standalone_enum: BAZ", CelValue::CreateInt64(2)}, + // Basic coverage for unwrapping -- specifics are managed by the + // wrapping library. + {"single_int64_wrapper", "single_int64_wrapper { value: 20 }", + CelValue::CreateInt64(20)}, + {"single_value", "single_value { null_value: NULL_VALUE }", + CelValue::CreateNull()}, + }), + &GetTestName); + +TEST(CreateValueFromSingleFieldTest, GetMessage) { + TestAllTypes test_message; + google::protobuf::Arena arena; + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + "standalone_message { bb: 10 }", &test_message)); + + ASSERT_OK_AND_ASSIGN( + CelValue accessed_value, + CreateValueFromSingleField( + &test_message, + test_message.GetDescriptor()->FindFieldByName("standalone_message"), + ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena)); + + EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 10"))); +} + +TEST(SetValueToSingleFieldTest, WrongType) { + TestAllTypes test_message; + google::protobuf::Arena arena; + + EXPECT_THAT(SetValueToSingleField( + CelValue::CreateDouble(1.0), + test_message.GetDescriptor()->FindFieldByName("single_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(SetValueToSingleFieldTest, IntOutOfRange) { + CelValue out_of_range = CelValue::CreateInt64(1LL << 31); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(SetValueToSingleField(out_of_range, + descriptor->FindFieldByName("single_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // proto enums are are represented as int32, but CEL converts to/from int64. + EXPECT_THAT(SetValueToSingleField( + out_of_range, descriptor->FindFieldByName("standalone_enum"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(SetValueToSingleFieldTest, UintOutOfRange) { + CelValue out_of_range = CelValue::CreateUint64(1LL << 32); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(SetValueToSingleField( + out_of_range, descriptor->FindFieldByName("single_uint32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(SetValueToSingleFieldTest, SetMessage) { + TestAllTypes::NestedMessage nested_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + bb: 42 + )", + &nested_message)); + google::protobuf::Arena arena; + CelValue nested_value = + CelProtoWrapper::CreateMessage(&nested_message, &arena); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(SetValueToSingleField( + nested_value, descriptor->FindFieldByName("standalone_message"), + &test_message, &arena)); + EXPECT_THAT(test_message, EqualsProto("standalone_message { bb: 42 }")); +} + +TEST(SetValueToSingleFieldTest, SetAnyMessage) { + TestAllTypes::NestedMessage nested_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + bb: 42 + )", + &nested_message)); + google::protobuf::Arena arena; + CelValue nested_value = + CelProtoWrapper::CreateMessage(&nested_message, &arena); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(SetValueToSingleField(nested_value, + descriptor->FindFieldByName("single_any"), + &test_message, &arena)); + + TestAllTypes::NestedMessage unpacked; + test_message.single_any().UnpackTo(&unpacked); + EXPECT_THAT(unpacked, EqualsProto("bb: 42")); +} + +TEST(SetValueToSingleFieldTest, SetMessageToNullNoop) { + google::protobuf::Arena arena; + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(SetValueToSingleField( + CelValue::CreateNull(), descriptor->FindFieldByName("standalone_message"), + &test_message, &arena)); + EXPECT_THAT(test_message, EqualsProto(test_message.default_instance())); +} + +class RepeatedFieldTest : public testing::TestWithParam { + public: + absl::string_view field_name() const { return GetParam().field_name; } + absl::string_view message_textproto() const { + return GetParam().message_textproto; + } + CelValue cel_value() const { return GetParam().cel_value; } +}; + +TEST_P(RepeatedFieldTest, GetFirstElem) { + TestAllTypes test_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(message_textproto(), &test_message)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + CelValue accessed_value, + CreateValueFromRepeatedField( + &test_message, + test_message.GetDescriptor()->FindFieldByName(field_name()), 0, + &CelProtoWrapper::InternalWrapMessage, &arena)); + + EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); +} + +TEST_P(RepeatedFieldTest, AppendElem) { + TestAllTypes test_message; + CelValue to_add = cel_value(); + google::protobuf::Arena arena; + + ASSERT_OK(AddValueToRepeatedField( + to_add, test_message.GetDescriptor()->FindFieldByName(field_name()), + &test_message, &arena)); + + EXPECT_THAT(test_message, EqualsProto(message_textproto())); +} + +INSTANTIATE_TEST_SUITE_P( + AllTypes, RepeatedFieldTest, + testing::ValuesIn( + {{"repeated_int32", "repeated_int32: 1", CelValue::CreateInt64(1)}, + {"repeated_int64", "repeated_int64: 1", CelValue::CreateInt64(1)}, + {"repeated_uint32", "repeated_uint32: 1", CelValue::CreateUint64(1)}, + {"repeated_uint64", "repeated_uint64: 1", CelValue::CreateUint64(1)}, + {"repeated_sint32", "repeated_sint32: 1", CelValue::CreateInt64(1)}, + {"repeated_sint64", "repeated_sint64: 1", CelValue::CreateInt64(1)}, + {"repeated_fixed32", "repeated_fixed32: 1", CelValue::CreateUint64(1)}, + {"repeated_fixed64", "repeated_fixed64: 1", CelValue::CreateUint64(1)}, + {"repeated_sfixed32", "repeated_sfixed32: 1", + CelValue::CreateInt64(1)}, + {"repeated_sfixed64", "repeated_sfixed64: 1", + CelValue::CreateInt64(1)}, + {"repeated_float", "repeated_float: 1.0", CelValue::CreateDouble(1.0)}, + {"repeated_double", "repeated_double: 1.0", + CelValue::CreateDouble(1.0)}, + {"repeated_bool", "repeated_bool: true", CelValue::CreateBool(true)}, + {"repeated_string", "repeated_string: 'abcd'", + CelValue::CreateStringView("abcd")}, + {"repeated_bytes", "repeated_bytes: 'asdf'", + CelValue::CreateBytesView("asdf")}, + {"repeated_nested_enum", "repeated_nested_enum: BAZ", + CelValue::CreateInt64(2)}}), + &GetTestName); + +TEST(RepeatedFieldTest, GetMessage) { + TestAllTypes test_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + "repeated_nested_message { bb: 30 }", &test_message)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue accessed_value, + CreateValueFromRepeatedField( + &test_message, + test_message.GetDescriptor()->FindFieldByName( + "repeated_nested_message"), + 0, &CelProtoWrapper::InternalWrapMessage, &arena)); + + EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 30"))); +} + +TEST(AddValueToRepeatedFieldTest, WrongType) { + TestAllTypes test_message; + google::protobuf::Arena arena; + + EXPECT_THAT( + AddValueToRepeatedField( + CelValue::CreateDouble(1.0), + test_message.GetDescriptor()->FindFieldByName("repeated_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(AddValueToRepeatedFieldTest, IntOutOfRange) { + CelValue out_of_range = CelValue::CreateInt64(1LL << 31); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(AddValueToRepeatedField( + out_of_range, descriptor->FindFieldByName("repeated_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // proto enums are are represented as int32, but CEL converts to/from int64. + EXPECT_THAT( + AddValueToRepeatedField( + out_of_range, descriptor->FindFieldByName("repeated_nested_enum"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(AddValueToRepeatedFieldTest, UintOutOfRange) { + CelValue out_of_range = CelValue::CreateUint64(1LL << 32); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(AddValueToRepeatedField( + out_of_range, descriptor->FindFieldByName("repeated_uint32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(AddValueToRepeatedFieldTest, AddMessage) { + TestAllTypes::NestedMessage nested_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + bb: 42 + )", + &nested_message)); + google::protobuf::Arena arena; + CelValue nested_value = + CelProtoWrapper::CreateMessage(&nested_message, &arena); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(AddValueToRepeatedField( + nested_value, descriptor->FindFieldByName("repeated_nested_message"), + &test_message, &arena)); + EXPECT_THAT(test_message, EqualsProto("repeated_nested_message { bb: 42 }")); +} + +constexpr std::array kWrapperFieldNames = { + "single_bool_wrapper", "single_int64_wrapper", "single_int32_wrapper", + "single_uint64_wrapper", "single_uint32_wrapper", "single_double_wrapper", + "single_float_wrapper", "single_string_wrapper", "single_bytes_wrapper"}; + +// Unset wrapper type fields are treated as null if accessed after option +// enabled. +TEST(CreateValueFromFieldTest, UnsetWrapperTypesNullIfEnabled) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + for (const auto& field : kWrapperFieldNames) { + ASSERT_OK_AND_ASSIGN( + result, CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName(field), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); + ASSERT_TRUE(result.IsNull()) << field << ": " << result.DebugString(); + } +} + +// Unset wrapper type fields are treated as proto default under old +// behavior. +TEST(CreateValueFromFieldTest, UnsetWrapperTypesDefaultValueIfDisabled) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + for (const auto& field : kWrapperFieldNames) { + ASSERT_OK_AND_ASSIGN( + result, CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName(field), + ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena)); + ASSERT_FALSE(result.IsNull()) << field << ": " << result.DebugString(); + } +} + +// If a wrapper type is set to default value, the corresponding CelValue is the +// proto default value. +TEST(CreateValueFromFieldTest, SetWrapperTypesDefaultValue) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + single_bool_wrapper {} + single_int64_wrapper {} + single_int32_wrapper {} + single_uint64_wrapper {} + single_uint32_wrapper {} + single_double_wrapper {} + single_float_wrapper {} + single_string_wrapper {} + single_bytes_wrapper {} + )pb", + &test_message)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_bool_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelBool(false)); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_int64_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelInt64(0)); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_int32_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelInt64(0)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_uint64_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, + + &arena)); + EXPECT_THAT(result, test::IsCelUint64(0)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_uint32_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, + + &arena)); + EXPECT_THAT(result, test::IsCelUint64(0)); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_double_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelDouble(0.0f)); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_float_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelDouble(0.0f)); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_string_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelString("")); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_bytes_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelBytes("")); +} + +} // namespace + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h new file mode 100644 index 000000000..dc7a3ab1b --- /dev/null +++ b/eval/public/structs/legacy_type_adapter.h @@ -0,0 +1,177 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Definitions for legacy type APIs to emulate the behavior of the new type +// system. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/memory.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime { + +// Interface for mutation apis. +// Note: in the new type system, a type provider represents this by returning +// a cel::Type and cel::ValueManager for the type. +class LegacyTypeMutationApis { + public: + virtual ~LegacyTypeMutationApis() = default; + + // Return whether the type defines the given field. + // TODO(uncreated-issue/3): This is only used to eagerly fail during the planning + // phase. Check if it's safe to remove this behavior and fail at runtime. + virtual bool DefinesField(absl::string_view field_name) const = 0; + + // Create a new empty instance of the type. + // May return a status if the type is not possible to create. + virtual absl::StatusOr NewInstance( + cel::MemoryManagerRef memory_manager) const = 0; + + // Normalize special types to a native CEL value after building. + // The interpreter guarantees that instance is uniquely owned by the + // interpreter, and can be safely mutated. + virtual absl::StatusOr AdaptFromWellKnownType( + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder instance) const = 0; + + // Set field on instance to value. + // The interpreter guarantees that instance is uniquely owned by the + // interpreter, and can be safely mutated. + virtual absl::Status SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder& instance) const = 0; + + virtual absl::Status SetFieldByNumber( + int64_t field_number [[maybe_unused]], + const CelValue& value [[maybe_unused]], + cel::MemoryManagerRef memory_manager [[maybe_unused]], + CelValue::MessageWrapper::Builder& instance [[maybe_unused]]) const { + return absl::UnimplementedError("SetFieldByNumber is not yet implemented"); + } +}; + +// Interface for access apis. +// Note: in new type system this is integrated into the StructValue (via +// dynamic dispatch to concrete implementations). +class LegacyTypeAccessApis { + public: + struct LegacyQualifyResult { + // The possibly intermediate result of the select operation. + CelValue value; + // Number of qualifiers applied. + int qualifier_count; + }; + + virtual ~LegacyTypeAccessApis() = default; + + // Return whether an instance of the type has field set to a non-default + // value. + virtual absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const = 0; + + // Access field on instance. + virtual absl::StatusOr GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) const = 0; + + // Apply a series of select operations on the given instance. + // + // Each select qualifier may represent either a singular field access ( + // FieldSpecifier) or an index into a container (AttributeQualifier). + // + // The Qualify implementation should return an appropriate CelError when + // intermediate fields or indexes are not found, or the given qualifier + // doesn't apply to operand. + // + // A Status with a non-ok error code may be returned for other errors. + // absl::StatusCode::kUnimplemented signals that Qualify is unsupported and + // the evaluator should emulate the default select behavior. + // + // - presence_test controls whether to treat the call as a 'has' call, + // returning + // whether the leaf field is set to a non-default value. + virtual absl::StatusOr Qualify( + absl::Span, + const CelValue::MessageWrapper& instance [[maybe_unused]], + bool presence_test [[maybe_unused]], + cel::MemoryManagerRef memory_manager [[maybe_unused]]) const { + return absl::UnimplementedError("Qualify unsupported."); + } + + // Interface for equality operator. + // The interpreter will check that both instances report to be the same type, + // but implementations should confirm that both instances are actually of the + // same type. + // If the two instances are of different type, return false. Otherwise, + // return whether they are equal. + // To conform to the CEL spec, message equality should follow the behavior of + // MessageDifferencer::Equals. + virtual bool IsEqualTo(const CelValue::MessageWrapper&, + const CelValue::MessageWrapper&) const { + return false; + } + + virtual std::vector ListFields( + const CelValue::MessageWrapper& instance) const = 0; +}; + +// Type information about a legacy Struct type. +// Provides methods to the interpreter for interacting with a custom type. +// +// mutation_apis() provide equivalent behavior to a cel::Type and +// cel::ValueManager (resolved from a type name). +// +// access_apis() provide equivalent behavior to cel::StructValue accessors +// (virtual dispatch to a concrete implementation for accessing underlying +// values). +// +// This class is a simple wrapper around (nullable) pointers to the interface +// implementations. The underlying pointers are expected to be valid as long as +// the type provider that returned this object. +class LegacyTypeAdapter { + public: + LegacyTypeAdapter(const LegacyTypeAccessApis* access, + const LegacyTypeMutationApis* mutation) + : access_apis_(access), mutation_apis_(mutation) {} + + // Apis for access for the represented type. + // If null, access is not supported (this is an opaque type). + const LegacyTypeAccessApis* access_apis() { return access_apis_; } + + // Apis for mutation for the represented type. + // If null, mutation is not supported (this type cannot be created). + const LegacyTypeMutationApis* mutation_apis() { return mutation_apis_; } + + private: + const LegacyTypeAccessApis* access_apis_; + const LegacyTypeMutationApis* mutation_apis_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc new file mode 100644 index 000000000..4c16a59ad --- /dev/null +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -0,0 +1,63 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/legacy_type_adapter.h" + +#include + +#include "eval/public/cel_value.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +class TestAccessApiImpl : public LegacyTypeAccessApis { + public: + TestAccessApiImpl() {} + absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const override { + return absl::UnimplementedError("Not implemented"); + } + + absl::StatusOr GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) const override { + return absl::UnimplementedError("Not implemented"); + } + + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override { + return std::vector(); + } +}; + +TEST(LegacyTypeAdapterAccessApis, DefaultAlwaysInequal) { + TestMessage message; + MessageWrapper wrapper(&message, nullptr); + MessageWrapper wrapper2(&message, nullptr); + + TestAccessApiImpl impl; + + EXPECT_FALSE(impl.IsEqualTo(wrapper, wrapper2)); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/legacy_type_info_apis.h b/eval/public/structs/legacy_type_info_apis.h new file mode 100644 index 000000000..4f07470a1 --- /dev/null +++ b/eval/public/structs/legacy_type_info_apis.h @@ -0,0 +1,103 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "eval/public/message_wrapper.h" +#include "google/protobuf/descriptor.h" + +namespace google::api::expr::runtime { + +// Forward declared to resolve cyclic dependency. +class LegacyTypeAccessApis; +class LegacyTypeMutationApis; + +// Interface for providing type info from a user defined type (represented as a +// message). +// +// Provides ability to obtain field access apis, type info, and debug +// representation of a message. +// +// The message parameter may wrap a nullptr to request generic accessors / +// mutators for the TypeInfo instance if it is available. +// +// This is implemented as a separate class from LegacyTypeAccessApis to resolve +// cyclic dependency between CelValue (which needs to access these apis to +// provide DebugString and ObtainCelTypename) and LegacyTypeAccessApis (which +// needs to return CelValue type for field access). +class LegacyTypeInfoApis { + public: + struct FieldDescription { + int number; + absl::string_view name; + }; + + virtual ~LegacyTypeInfoApis() = default; + + // Return a debug representation of the wrapped message. + virtual std::string DebugString( + const MessageWrapper& wrapped_message) const = 0; + + // Return a reference to the typename for the wrapped message's type. + // The CEL interpreter assumes that the typename is owned externally and will + // outlive any CelValues created by the interpreter. + virtual absl::string_view GetTypename( + const MessageWrapper& wrapped_message) const = 0; + + virtual const google::protobuf::Descriptor* absl_nullable GetDescriptor( + const MessageWrapper& wrapped_message [[maybe_unused]]) const { + return nullptr; + } + + // Return a pointer to the wrapped message's access api implementation. + // + // The CEL interpreter assumes that the returned pointer is owned externally + // and will outlive any CelValues created by the interpreter. + // + // Nullptr signals that the value does not provide access apis. For field + // access, the interpreter will treat this the same as accessing a field that + // is not defined for the type. + virtual const LegacyTypeAccessApis* GetAccessApis( + const MessageWrapper& wrapped_message) const = 0; + + // Return a pointer to the wrapped message's mutation api implementation. + // + // The CEL interpreter assumes that the returned pointer is owned externally + // and will outlive any CelValues created by the interpreter. + // + // Nullptr signals that the value does not provide mutation apis. + virtual const LegacyTypeMutationApis* GetMutationApis( + const MessageWrapper& wrapped_message [[maybe_unused]]) const { + return nullptr; + } + + // Return a description of the underlying field if defined. + // + // The underlying string is expected to remain valid as long as the + // LegacyTypeInfoApis instance. + virtual absl::optional FindFieldByName( + absl::string_view name [[maybe_unused]]) const { + return absl::nullopt; + } +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc new file mode 100644 index 000000000..f8db92298 --- /dev/null +++ b/eval/public/structs/legacy_type_provider.cc @@ -0,0 +1,218 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/legacy_type_provider.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/legacy_value.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "common/value.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +namespace { + +using google::api::expr::runtime::LegacyTypeAdapter; +using google::api::expr::runtime::MessageWrapper; + +class LegacyStructValueBuilder final : public cel::StructValueBuilder { + public: + LegacyStructValueBuilder(cel::MemoryManagerRef memory_manager, + LegacyTypeAdapter adapter, + MessageWrapper::Builder builder) + : memory_manager_(memory_manager), + adapter_(adapter), + builder_(std::move(builder)) {} + + absl::StatusOr> SetFieldByName( + absl::string_view name, cel::Value value) override { + CEL_ASSIGN_OR_RETURN( + auto legacy_value, + LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), + value), + _.With(cel::ErrorValueReturn())); + CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( + name, legacy_value, memory_manager_, builder_)) + .With(cel::ErrorValueReturn()); + return std::nullopt; + } + + absl::StatusOr> SetFieldByNumber( + int64_t number, cel::Value value) override { + CEL_ASSIGN_OR_RETURN( + auto legacy_value, + LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), + value), + _.With(cel::ErrorValueReturn())); + CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( + number, legacy_value, memory_manager_, builder_)) + .With(cel::ErrorValueReturn()); + return std::nullopt; + } + + absl::StatusOr Build() && override { + CEL_ASSIGN_OR_RETURN(auto message, + adapter_.mutation_apis()->AdaptFromWellKnownType( + memory_manager_, std::move(builder_))); + if (!message.IsMessage()) { + return absl::FailedPreconditionError("expected MessageWrapper"); + } + auto message_wrapper = message.MessageWrapperOrDie(); + return cel::common_internal::LegacyStructValue( + google::protobuf::DownCastMessage(message_wrapper.message_ptr()), + message_wrapper.legacy_type_info()); + } + + private: + cel::MemoryManagerRef memory_manager_; + LegacyTypeAdapter adapter_; + MessageWrapper::Builder builder_; +}; + +class LegacyValueBuilder final : public cel::ValueBuilder { + public: + LegacyValueBuilder(cel::MemoryManagerRef memory_manager, + LegacyTypeAdapter adapter, MessageWrapper::Builder builder) + : memory_manager_(memory_manager), + adapter_(adapter), + builder_(std::move(builder)) {} + + absl::StatusOr> SetFieldByName( + absl::string_view name, cel::Value value) override { + CEL_ASSIGN_OR_RETURN( + auto legacy_value, + LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), + value), + _.With(cel::ErrorValueReturn())); + CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( + name, legacy_value, memory_manager_, builder_)) + .With(cel::ErrorValueReturn()); + return std::nullopt; + } + + absl::StatusOr> SetFieldByNumber( + int64_t number, cel::Value value) override { + CEL_ASSIGN_OR_RETURN( + auto legacy_value, + LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), + value), + _.With(cel::ErrorValueReturn())); + CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( + number, legacy_value, memory_manager_, builder_)) + .With(cel::ErrorValueReturn()); + return std::nullopt; + } + + absl::StatusOr Build() && override { + CEL_ASSIGN_OR_RETURN(auto value, + adapter_.mutation_apis()->AdaptFromWellKnownType( + memory_manager_, std::move(builder_)), + _.With(cel::ErrorValueReturn())); + CEL_ASSIGN_OR_RETURN( + auto result, + cel::ModernValue( + cel::extensions::ProtoMemoryManagerArena(memory_manager_), value), + _.With(cel::ErrorValueReturn())); + return result; + } + + private: + cel::MemoryManagerRef memory_manager_; + LegacyTypeAdapter adapter_; + MessageWrapper::Builder builder_; +}; + +} // namespace + +absl::StatusOr +LegacyTypeProvider::NewValueBuilder( + absl::string_view name, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + if (auto type_adapter = ProvideLegacyType(name); type_adapter.has_value()) { + const auto* mutation_apis = type_adapter->mutation_apis(); + if (mutation_apis == nullptr) { + return absl::FailedPreconditionError( + absl::StrCat("LegacyTypeMutationApis missing for type: ", name)); + } + CEL_ASSIGN_OR_RETURN( + auto builder, + mutation_apis->NewInstance(cel::MemoryManagerRef::Pooling(arena))); + return std::make_unique( + cel::MemoryManagerRef::Pooling(arena), *type_adapter, + std::move(builder)); + } + return nullptr; +} + +absl::StatusOr> LegacyTypeProvider::FindTypeImpl( + absl::string_view name) const { + if (auto type = cel::FindWellKnownType(name); type.has_value()) { + return type; + } + if (auto type_info = ProvideLegacyTypeInfo(name); type_info.has_value()) { + const auto* descriptor = (*type_info)->GetDescriptor(MessageWrapper()); + if (descriptor != nullptr) { + return cel::MessageType(descriptor); + } + return cel::common_internal::MakeBasicStructType( + (*type_info)->GetTypename(MessageWrapper())); + } + return std::nullopt; +} + +absl::StatusOr> +LegacyTypeProvider::FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const { + if (auto result = cel::FindWellKnownTypeFieldByName(type, name); + result.has_value()) { + return result; + } + if (auto type_info = ProvideLegacyTypeInfo(type); type_info.has_value()) { + if (auto field_desc = (*type_info)->FindFieldByName(name); + field_desc.has_value()) { + return cel::common_internal::BasicStructTypeField( + field_desc->name, field_desc->number, cel::DynType{}); + } else { + const auto* mutation_apis = + (*type_info)->GetMutationApis(MessageWrapper()); + if (mutation_apis == nullptr || !mutation_apis->DefinesField(name)) { + return std::nullopt; + } + return cel::common_internal::BasicStructTypeField(name, 0, + cel::DynType{}); + } + } + return std::nullopt; +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/legacy_type_provider.h b/eval/public/structs/legacy_type_provider.h new file mode 100644 index 000000000..e2e67411c --- /dev/null +++ b/eval/public/structs/legacy_type_provider.h @@ -0,0 +1,79 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_reflector.h" +#include "common/value.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +// An internal extension of cel::TypeProvider that also deals with legacy types. +// +// Note: This API is not finalized. Consult the CEL team before introducing new +// implementations. +class LegacyTypeProvider : public cel::TypeReflector { + public: + virtual ~LegacyTypeProvider() = default; + + // Return LegacyTypeAdapter for the fully qualified type name if available. + // + // nullopt values are interpreted as not present. + // + // Returned non-null pointers from the adapter implemententation must remain + // valid as long as the type provider. + // TODO(uncreated-issue/3): add alternative for new type system. + virtual absl::optional ProvideLegacyType( + absl::string_view name) const = 0; + + // Return LegacyTypeInfoApis for the fully qualified type name if available. + // + // nullopt values are interpreted as not present. + // + // Since custom type providers should create values compatible with evaluator + // created ones, the TypeInfoApis returned from this method should be the same + // as the ones used in value creation. + virtual absl::optional ProvideLegacyTypeInfo( + ABSL_ATTRIBUTE_UNUSED absl::string_view name) const { + return absl::nullopt; + } + + absl::StatusOr NewValueBuilder( + absl::string_view name, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const final; + + protected: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final; + + absl::StatusOr> + FindStructTypeFieldByNameImpl(absl::string_view type, + absl::string_view name) const final; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ diff --git a/eval/public/structs/legacy_type_provider_test.cc b/eval/public/structs/legacy_type_provider_test.cc new file mode 100644 index 000000000..8de2aba01 --- /dev/null +++ b/eval/public/structs/legacy_type_provider_test.cc @@ -0,0 +1,93 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/legacy_type_provider.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +class LegacyTypeProviderTestEmpty : public LegacyTypeProvider { + public: + absl::optional ProvideLegacyType( + absl::string_view name) const override { + return std::nullopt; + } +}; + +class LegacyTypeInfoApisEmpty : public LegacyTypeInfoApis { + public: + std::string DebugString( + const MessageWrapper& wrapped_message) const override { + return ""; + } + absl::string_view GetTypename( + const MessageWrapper& wrapped_message) const override { + return test_string_; + } + const LegacyTypeAccessApis* GetAccessApis( + const MessageWrapper& wrapped_message) const override { + return nullptr; + } + + private: + const std::string test_string_ = "test"; +}; + +class LegacyTypeProviderTestImpl : public LegacyTypeProvider { + public: + explicit LegacyTypeProviderTestImpl(const LegacyTypeInfoApis* test_type_info) + : test_type_info_(test_type_info) {} + absl::optional ProvideLegacyType( + absl::string_view name) const override { + if (name == "test") { + return LegacyTypeAdapter(nullptr, nullptr); + } + return std::nullopt; + } + absl::optional ProvideLegacyTypeInfo( + absl::string_view name) const override { + if (name == "test") { + return test_type_info_; + } + return std::nullopt; + } + + private: + const LegacyTypeInfoApis* test_type_info_ = nullptr; +}; + +TEST(LegacyTypeProviderTest, EmptyTypeProviderHasProvideTypeInfo) { + LegacyTypeProviderTestEmpty provider; + EXPECT_EQ(provider.ProvideLegacyType("test"), std::nullopt); + EXPECT_EQ(provider.ProvideLegacyTypeInfo("test"), std::nullopt); +} + +TEST(LegacyTypeProviderTest, NonEmptyTypeProviderProvidesSomeTypes) { + LegacyTypeInfoApisEmpty test_type_info; + LegacyTypeProviderTestImpl provider(&test_type_info); + EXPECT_TRUE(provider.ProvideLegacyType("test").has_value()); + EXPECT_TRUE(provider.ProvideLegacyTypeInfo("test").has_value()); + EXPECT_EQ(provider.ProvideLegacyType("other"), std::nullopt); + EXPECT_EQ(provider.ProvideLegacyTypeInfo("other"), std::nullopt); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc new file mode 100644 index 000000000..8c140c0c7 --- /dev/null +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -0,0 +1,708 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/proto_message_type_adapter.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/memory.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/internal_field_backed_list_impl.h" +#include "eval/public/containers/internal_field_backed_map_impl.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/cel_proto_wrap_util.h" +#include "eval/public/structs/field_access_impl.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "extensions/protobuf/internal/qualify.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/message_differencer.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::extensions::ProtoMemoryManagerArena; +using ::cel::extensions::ProtoMemoryManagerRef; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +using LegacyQualifyResult = LegacyTypeAccessApis::LegacyQualifyResult; + +const std::string& UnsupportedTypeName() { + static absl::NoDestructor kUnsupportedTypeName( + ""); + return *kUnsupportedTypeName; +} + +CelValue MessageCelValueFactory(const google::protobuf::Message* message); + +inline absl::StatusOr UnwrapMessage( + const MessageWrapper& value, absl::string_view op) { + if (!value.HasFullProto() || value.message_ptr() == nullptr) { + return absl::InternalError( + absl::StrCat(op, " called on non-message type.")); + } + return static_cast(value.message_ptr()); +} + +inline absl::StatusOr UnwrapMessage( + const MessageWrapper::Builder& value, absl::string_view op) { + if (!value.HasFullProto() || value.message_ptr() == nullptr) { + return absl::InternalError( + absl::StrCat(op, " called on non-message type.")); + } + return static_cast(value.message_ptr()); +} + +bool ProtoEquals(const google::protobuf::Message& m1, const google::protobuf::Message& m2) { + // Equality behavior is undefined for message differencer if input messages + // have different descriptors. For CEL just return false. + if (m1.GetDescriptor() != m2.GetDescriptor()) { + return false; + } + return google::protobuf::util::MessageDifferencer::Equals(m1, m2); +} + +// Implements CEL's notion of field presence for protobuf. +// Assumes all arguments non-null. +bool CelFieldIsPresent(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_desc, + const google::protobuf::Reflection* reflection) { + if (field_desc->is_map()) { + // When the map field appears in a has(msg.map_field) expression, the map + // is considered 'present' when it is non-empty. Since maps are repeated + // fields they don't participate with standard proto presence testing since + // the repeated field is always at least empty. + return reflection->FieldSize(*message, field_desc) != 0; + } + + if (field_desc->is_repeated()) { + // When the list field appears in a has(msg.list_field) expression, the list + // is considered 'present' when it is non-empty. + return reflection->FieldSize(*message, field_desc) != 0; + } + + // Standard proto presence test for non-repeated fields. + return reflection->HasField(*message, field_desc); +} + +// Shared implementation for HasField. +// Handles list or map specific behavior before calling reflection helpers. +absl::StatusOr HasFieldImpl(const google::protobuf::Message* message, + const google::protobuf::Descriptor* descriptor, + absl::string_view field_name) { + ABSL_ASSERT(descriptor == message->GetDescriptor()); + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name); + if (field_desc == nullptr && reflection != nullptr) { + // Search to see whether the field name is referring to an extension. + field_desc = reflection->FindKnownExtensionByName(field_name); + } + if (field_desc == nullptr) { + return absl::NotFoundError(absl::StrCat("no_such_field : ", field_name)); + } + + if (reflection == nullptr) { + return absl::FailedPreconditionError( + "google::protobuf::Reflection unavailble in CEL field access."); + } + return CelFieldIsPresent(message, field_desc, reflection); +} + +absl::StatusOr CreateCelValueFromField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field_desc, + ProtoWrapperTypeOptions unboxing_option, google::protobuf::Arena* arena) { + if (field_desc->is_map()) { + auto* map = google::protobuf::Arena::Create( + arena, message, field_desc, &MessageCelValueFactory, arena); + + return CelValue::CreateMap(map); + } + if (field_desc->is_repeated()) { + auto* list = google::protobuf::Arena::Create( + arena, message, field_desc, &MessageCelValueFactory, arena); + return CelValue::CreateList(list); + } + + CEL_ASSIGN_OR_RETURN( + CelValue result, + internal::CreateValueFromSingleField(message, field_desc, unboxing_option, + &MessageCelValueFactory, arena)); + return result; +} + +// Shared implementation for GetField. +// Handles list or map specific behavior before calling reflection helpers. +absl::StatusOr GetFieldImpl(const google::protobuf::Message* message, + const google::protobuf::Descriptor* descriptor, + absl::string_view field_name, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) { + ABSL_ASSERT(descriptor == message->GetDescriptor()); + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name); + if (field_desc == nullptr && reflection != nullptr) { + std::string ext_name(field_name); + field_desc = reflection->FindKnownExtensionByName(ext_name); + } + if (field_desc == nullptr) { + return CreateNoSuchFieldError(memory_manager, field_name); + } + + google::protobuf::Arena* arena = ProtoMemoryManagerArena(memory_manager); + + return CreateCelValueFromField(message, field_desc, unboxing_option, arena); +} + +// State machine for incrementally applying qualifiers. +// +// Reusing the state machine to represent intermediate states (as opposed to +// returning the intermediates) is more efficient for longer select chains while +// still allowing decomposition of the qualify routine. +class LegacyQualifyState final + : public cel::extensions::protobuf_internal::ProtoQualifyState { + public: + using ProtoQualifyState::ProtoQualifyState; + + LegacyQualifyState(const LegacyQualifyState&) = delete; + LegacyQualifyState& operator=(const LegacyQualifyState&) = delete; + + absl::optional& result() { return result_; } + + private: + void SetResultFromError(absl::Status status, + cel::MemoryManagerRef memory_manager) override { + result_ = CreateErrorValue(memory_manager, status); + } + + void SetResultFromBool(bool value) override { + result_ = CelValue::CreateBool(value); + } + + absl::Status SetResultFromField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) override { + CEL_ASSIGN_OR_RETURN(result_, CreateCelValueFromField( + message, field, unboxing_option, + ProtoMemoryManagerArena(memory_manager))); + return absl::OkStatus(); + } + + absl::Status SetResultFromRepeatedField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + int index, cel::MemoryManagerRef memory_manager) override { + CEL_ASSIGN_OR_RETURN(result_, + internal::CreateValueFromRepeatedField( + message, field, index, &MessageCelValueFactory, + ProtoMemoryManagerArena(memory_manager))); + return absl::OkStatus(); + } + + absl::Status SetResultFromMapField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + const google::protobuf::MapValueConstRef& value, + cel::MemoryManagerRef memory_manager) override { + CEL_ASSIGN_OR_RETURN(result_, + internal::CreateValueFromMapValue( + message, field, &value, &MessageCelValueFactory, + ProtoMemoryManagerArena(memory_manager))); + return absl::OkStatus(); + } + + absl::optional result_; +}; + +absl::StatusOr QualifyImpl( + const google::protobuf::Message* message, const google::protobuf::Descriptor* descriptor, + absl::Span path, bool presence_test, + cel::MemoryManagerRef memory_manager) { + google::protobuf::Arena* arena = ProtoMemoryManagerArena(memory_manager); + ABSL_DCHECK(descriptor == message->GetDescriptor()); + LegacyQualifyState qualify_state(message, descriptor, + message->GetReflection()); + + for (int i = 0; i < path.size() - 1; i++) { + const auto& qualifier = path.at(i); + CEL_RETURN_IF_ERROR(qualify_state.ApplySelectQualifier( + qualifier, ProtoMemoryManagerRef(arena))); + if (qualify_state.result().has_value()) { + LegacyQualifyResult result; + result.value = std::move(qualify_state.result()).value(); + result.qualifier_count = result.value.IsError() ? -1 : i + 1; + return result; + } + } + + const auto& last_qualifier = path.back(); + LegacyQualifyResult result; + result.qualifier_count = -1; + + if (presence_test) { + CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierHas( + last_qualifier, ProtoMemoryManagerRef(arena))); + } else { + CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierGet( + last_qualifier, ProtoMemoryManagerRef(arena))); + } + result.value = *qualify_state.result(); + return result; +} + +std::vector ListFieldsImpl( + const CelValue::MessageWrapper& instance) { + if (instance.message_ptr() == nullptr) { + return std::vector(); + } + ABSL_ASSERT(instance.HasFullProto()); + const auto* message = + static_cast(instance.message_ptr()); + const auto* reflect = message->GetReflection(); + std::vector fields; + reflect->ListFields(*message, &fields); + std::vector field_names; + field_names.reserve(fields.size()); + for (const auto* field : fields) { + field_names.emplace_back(field->name()); + } + return field_names; +} + +class DucktypedMessageAdapter : public LegacyTypeAccessApis, + public LegacyTypeMutationApis, + public LegacyTypeInfoApis { + public: + // Implement field access APIs. + absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const override { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(value, "HasField")); + return HasFieldImpl(message, message->GetDescriptor(), field_name); + } + + absl::StatusOr GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) const override { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "GetField")); + return GetFieldImpl(message, message->GetDescriptor(), field_name, + unboxing_option, memory_manager); + } + + absl::StatusOr Qualify( + absl::Span qualifiers, + const CelValue::MessageWrapper& instance, bool presence_test, + cel::MemoryManagerRef memory_manager) const override { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "Qualify")); + + return QualifyImpl(message, message->GetDescriptor(), qualifiers, + presence_test, memory_manager); + } + + bool IsEqualTo( + const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const override { + absl::StatusOr lhs = + UnwrapMessage(instance, "IsEqualTo"); + absl::StatusOr rhs = + UnwrapMessage(other_instance, "IsEqualTo"); + if (!lhs.ok() || !rhs.ok()) { + // Treat this as though the underlying types are different, just return + // false. + return false; + } + return ProtoEquals(**lhs, **rhs); + } + + // Implement TypeInfo Apis + absl::string_view GetTypename( + const MessageWrapper& wrapped_message) const override { + if (!wrapped_message.HasFullProto() || + wrapped_message.message_ptr() == nullptr) { + return UnsupportedTypeName(); + } + auto* message = + static_cast(wrapped_message.message_ptr()); + return message->GetDescriptor()->full_name(); + } + + std::string DebugString( + const MessageWrapper& wrapped_message) const override { + if (!wrapped_message.HasFullProto() || + wrapped_message.message_ptr() == nullptr) { + return UnsupportedTypeName(); + } + auto* message = + static_cast(wrapped_message.message_ptr()); + return message->ShortDebugString(); + } + + bool DefinesField(absl::string_view field_name) const override { + // Pretend all our fields exist. Real errors will be returned from field + // getters and setters. + return true; + } + + absl::StatusOr NewInstance( + cel::MemoryManagerRef memory_manager) const override { + return absl::UnimplementedError("NewInstance is not implemented"); + } + + absl::StatusOr AdaptFromWellKnownType( + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder instance) const override { + if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { + return absl::UnimplementedError( + "MessageLite is not supported, descriptor is required"); + } + return ProtoMessageTypeAdapter( + static_cast(instance.message_ptr()) + ->GetDescriptor(), + nullptr) + .AdaptFromWellKnownType(memory_manager, instance); + } + + absl::Status SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder& instance) const override { + if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { + return absl::UnimplementedError( + "MessageLite is not supported, descriptor is required"); + } + return ProtoMessageTypeAdapter( + static_cast(instance.message_ptr()) + ->GetDescriptor(), + nullptr) + .SetField(field_name, value, memory_manager, instance); + } + + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override { + return ListFieldsImpl(instance); + } + + const LegacyTypeAccessApis* GetAccessApis( + const MessageWrapper& wrapped_message) const override { + return this; + } + + const LegacyTypeMutationApis* GetMutationApis( + const MessageWrapper& wrapped_message) const override { + return this; + } + + static const DucktypedMessageAdapter& GetSingleton() { + static absl::NoDestructor instance; + return *instance; + } +}; + +CelValue MessageCelValueFactory(const google::protobuf::Message* message) { + return CelValue::CreateMessageWrapper( + MessageWrapper(message, &DucktypedMessageAdapter::GetSingleton())); +} + +} // namespace + +std::string ProtoMessageTypeAdapter::DebugString( + const MessageWrapper& wrapped_message) const { + if (!wrapped_message.HasFullProto() || + wrapped_message.message_ptr() == nullptr) { + return UnsupportedTypeName(); + } + auto* message = + static_cast(wrapped_message.message_ptr()); + return message->ShortDebugString(); +} + +absl::string_view ProtoMessageTypeAdapter::GetTypename( + const MessageWrapper& wrapped_message) const { + return descriptor_->full_name(); +} + +const LegacyTypeMutationApis* ProtoMessageTypeAdapter::GetMutationApis( + const MessageWrapper& wrapped_message) const { + // Defer checks for misuse on wrong message kind in the accessor calls. + return this; +} + +const LegacyTypeAccessApis* ProtoMessageTypeAdapter::GetAccessApis( + const MessageWrapper& wrapped_message) const { + // Defer checks for misuse on wrong message kind in the builder calls. + return this; +} + +absl::optional +ProtoMessageTypeAdapter::FindFieldByName(absl::string_view field_name) const { + if (descriptor_ == nullptr) { + return std::nullopt; + } + + const google::protobuf::FieldDescriptor* field_descriptor = + descriptor_->FindFieldByName(field_name); + + if (field_descriptor == nullptr) { + return std::nullopt; + } + + return LegacyTypeInfoApis::FieldDescription{field_descriptor->number(), + field_descriptor->name()}; +} + +absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( + bool assertion, absl::string_view field, absl::string_view detail) const { + if (!assertion) { + return absl::InvalidArgumentError( + absl::Substitute("SetField failed on message $0, field '$1': $2", + descriptor_->full_name(), field, detail)); + } + return absl::OkStatus(); +} + +absl::StatusOr +ProtoMessageTypeAdapter::NewInstance( + cel::MemoryManagerRef memory_manager) const { + if (message_factory_ == nullptr) { + return absl::UnimplementedError( + absl::StrCat("Cannot create message ", descriptor_->name())); + } + + // This implementation requires arena-backed memory manager. + google::protobuf::Arena* arena = ProtoMemoryManagerArena(memory_manager); + const Message* prototype = message_factory_->GetPrototype(descriptor_); + + Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; + + if (msg == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to create message ", descriptor_->name())); + } + return MessageWrapper::Builder(msg); +} + +bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { + return descriptor_->FindFieldByName(field_name) != nullptr; +} + +absl::StatusOr ProtoMessageTypeAdapter::HasField( + absl::string_view field_name, const CelValue::MessageWrapper& value) const { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(value, "HasField")); + return HasFieldImpl(message, descriptor_, field_name); +} + +absl::StatusOr ProtoMessageTypeAdapter::GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) const { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "GetField")); + + return GetFieldImpl(message, descriptor_, field_name, unboxing_option, + memory_manager); +} + +absl::StatusOr +ProtoMessageTypeAdapter::Qualify( + absl::Span qualifiers, + const CelValue::MessageWrapper& instance, bool presence_test, + cel::MemoryManagerRef memory_manager) const { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "Qualify")); + + return QualifyImpl(message, descriptor_, qualifiers, presence_test, + memory_manager); +} + +absl::Status ProtoMessageTypeAdapter::SetField( + const google::protobuf::FieldDescriptor* field, const CelValue& value, + google::protobuf::Arena* arena, google::protobuf::Message* message) const { + if (field->is_map()) { + constexpr int kKeyField = 1; + constexpr int kValueField = 2; + + const CelMap* cel_map; + CEL_RETURN_IF_ERROR(ValidateSetFieldOp( + value.GetValue(&cel_map) && cel_map != nullptr, + field->name(), + absl::StrCat("value is not CelMap - value is ", + CelValue::TypeName(value.type())))); + + auto entry_descriptor = field->message_type(); + + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(entry_descriptor != nullptr, field->name(), + "failed to find map entry descriptor")); + auto key_field_descriptor = entry_descriptor->FindFieldByNumber(kKeyField); + auto value_field_descriptor = + entry_descriptor->FindFieldByNumber(kValueField); + + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(key_field_descriptor != nullptr, field->name(), + "failed to find key field descriptor")); + + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(value_field_descriptor != nullptr, field->name(), + "failed to find value field descriptor")); + + bool prune_when_null = false; + if (value_field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + auto well_known_type = + value_field_descriptor->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + prune_when_null = true; + } + } + + CEL_ASSIGN_OR_RETURN(const CelList* key_list, cel_map->ListKeys(arena)); + for (int i = 0; i < key_list->size(); i++) { + CelValue key = (*key_list).Get(arena, i); + + auto value = (*cel_map).Get(arena, key); + CEL_RETURN_IF_ERROR(ValidateSetFieldOp(value.has_value(), field->name(), + "error serializing CelMap")); + if (prune_when_null && value->IsNull()) { + continue; + } + Message* entry_msg = message->GetReflection()->AddMessage(message, field); + CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( + key, key_field_descriptor, entry_msg, arena)); + CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( + value.value(), value_field_descriptor, entry_msg, arena)); + } + + } else if (field->is_repeated()) { + const CelList* cel_list; + CEL_RETURN_IF_ERROR(ValidateSetFieldOp( + value.GetValue(&cel_list) && cel_list != nullptr, + field->name(), + absl::StrCat("expected CelList value - value is", + CelValue::TypeName(value.type())))); + + for (int i = 0; i < cel_list->size(); i++) { + CEL_RETURN_IF_ERROR(internal::AddValueToRepeatedField( + (*cel_list).Get(arena, i), field, message, arena)); + } + } else { + CEL_RETURN_IF_ERROR( + internal::SetValueToSingleField(value, field, message, arena)); + } + return absl::OkStatus(); +} + +absl::Status ProtoMessageTypeAdapter::SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder& instance) const { + // Assume proto arena implementation if this provider is used. + google::protobuf::Arena* arena = + cel::extensions::ProtoMemoryManagerArena(memory_manager); + + CEL_ASSIGN_OR_RETURN(google::protobuf::Message * mutable_message, + UnwrapMessage(instance, "SetField")); + + const google::protobuf::FieldDescriptor* field_descriptor = + descriptor_->FindFieldByName(field_name); + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(field_descriptor != nullptr, field_name, "not found")); + + return SetField(field_descriptor, value, arena, mutable_message); +} + +absl::Status ProtoMessageTypeAdapter::SetFieldByNumber( + int64_t field_number, const CelValue& value, + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder& instance) const { + // Assume proto arena implementation if this provider is used. + google::protobuf::Arena* arena = + cel::extensions::ProtoMemoryManagerArena(memory_manager); + + CEL_ASSIGN_OR_RETURN(google::protobuf::Message * mutable_message, + UnwrapMessage(instance, "SetField")); + + const google::protobuf::FieldDescriptor* field_descriptor = + descriptor_->FindFieldByNumber(field_number); + CEL_RETURN_IF_ERROR(ValidateSetFieldOp( + field_descriptor != nullptr, absl::StrCat(field_number), "not found")); + + return SetField(field_descriptor, value, arena, mutable_message); +} + +absl::StatusOr ProtoMessageTypeAdapter::AdaptFromWellKnownType( + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder instance) const { + // Assume proto arena implementation if this provider is used. + google::protobuf::Arena* arena = + cel::extensions::ProtoMemoryManagerArena(memory_manager); + CEL_ASSIGN_OR_RETURN(google::protobuf::Message * message, + UnwrapMessage(instance, "AdaptFromWellKnownType")); + return internal::UnwrapMessageToValue(message, &MessageCelValueFactory, + arena); +} + +bool ProtoMessageTypeAdapter::IsEqualTo( + const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const { + absl::StatusOr lhs = + UnwrapMessage(instance, "IsEqualTo"); + absl::StatusOr rhs = + UnwrapMessage(other_instance, "IsEqualTo"); + if (!lhs.ok() || !rhs.ok()) { + // Treat this as though the underlying types are different, just return + // false. + return false; + } + return ProtoEquals(**lhs, **rhs); +} + +std::vector ProtoMessageTypeAdapter::ListFields( + const CelValue::MessageWrapper& instance) const { + return ListFieldsImpl(instance); +} + +const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance() { + return DucktypedMessageAdapter::GetSingleton(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h new file mode 100644 index 000000000..f2fc43a8a --- /dev/null +++ b/eval/public/structs/proto_message_type_adapter.h @@ -0,0 +1,129 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/memory.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "google/protobuf/descriptor.h" + +namespace google::api::expr::runtime { + +// Implementation for legacy struct (message) type apis using reflection. +// +// Note: The type info API implementation attached to message values is +// generally the duck-typed instance to support the default behavior of +// deferring to the protobuf reflection apis on the message instance. +class ProtoMessageTypeAdapter : public LegacyTypeInfoApis, + public LegacyTypeAccessApis, + public LegacyTypeMutationApis { + public: + ProtoMessageTypeAdapter(const google::protobuf::Descriptor* descriptor, + google::protobuf::MessageFactory* message_factory) + : message_factory_(message_factory), descriptor_(descriptor) {} + + ~ProtoMessageTypeAdapter() override = default; + + // Implement LegacyTypeInfoApis + std::string DebugString(const MessageWrapper& wrapped_message) const override; + + absl::string_view GetTypename( + const MessageWrapper& wrapped_message) const override; + + const google::protobuf::Descriptor* absl_nullable GetDescriptor( + const MessageWrapper& wrapped_message [[maybe_unused]]) const override { + return descriptor_; + } + + const LegacyTypeAccessApis* GetAccessApis( + const MessageWrapper& wrapped_message) const override; + + const LegacyTypeMutationApis* GetMutationApis( + const MessageWrapper& wrapped_message) const override; + + absl::optional FindFieldByName( + absl::string_view field_name) const override; + + // Implement LegacyTypeMutation APIs. + absl::StatusOr NewInstance( + cel::MemoryManagerRef memory_manager) const override; + + bool DefinesField(absl::string_view field_name) const override; + + absl::Status SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder& instance) const override; + + absl::Status SetFieldByNumber( + int64_t field_number, const CelValue& value, + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder& instance) const override; + + absl::StatusOr AdaptFromWellKnownType( + cel::MemoryManagerRef memory_manager, + CelValue::MessageWrapper::Builder instance) const override; + + // Implement LegacyTypeAccessAPIs. + absl::StatusOr GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager) const override; + + absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const override; + + absl::StatusOr Qualify( + absl::Span qualifiers, + const CelValue::MessageWrapper& instance, bool presence_test, + cel::MemoryManagerRef memory_manager) const override; + + bool IsEqualTo(const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const override; + + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override; + + private: + // Helper for standardizing error messages for SetField operation. + absl::Status ValidateSetFieldOp(bool assertion, absl::string_view field, + absl::string_view detail) const; + + absl::Status SetField(const google::protobuf::FieldDescriptor* field, + const CelValue& value, google::protobuf::Arena* arena, + google::protobuf::Message* message) const; + + google::protobuf::MessageFactory* message_factory_; + const google::protobuf::Descriptor* descriptor_; +}; + +// Returns a TypeInfo provider representing an arbitrary message. +// This allows for the legacy duck-typed behavior of messages on field access +// instead of expecting a particular message type given a TypeInfo. +const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance(); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc new file mode 100644 index 000000000..e28d76102 --- /dev/null +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -0,0 +1,1411 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/proto_message_type_adapter.h" + +#include + +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "base/attribute.h" +#include "common/value.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::ProtoWrapperTypeOptions; +using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::internal::test::EqualsProto; +using ::google::protobuf::Int64Value; +using ::testing::_; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::Truly; + +using LegacyQualifyResult = LegacyTypeAccessApis::LegacyQualifyResult; + +class ProtoMessageTypeAccessorTest : public testing::TestWithParam { + public: + ProtoMessageTypeAccessorTest() + : type_specific_instance_( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()) {} + + const LegacyTypeAccessApis& GetAccessApis() { + bool use_generic_instance = GetParam(); + if (use_generic_instance) { + // implementation detail: in general, type info implementations may + // return a different accessor object based on the message instance, but + // this implementation returns the same one no matter the message. + return *GetGenericProtoTypeInfoInstance().GetAccessApis(dummy_); + + } else { + return type_specific_instance_; + } + } + + private: + ProtoMessageTypeAdapter type_specific_instance_; + CelValue::MessageWrapper dummy_; +}; + +TEST_P(ProtoMessageTypeAccessorTest, HasFieldSingular) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + TestMessage example; + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(false)); + example.set_int64_value(10); + EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(true)); +} + +TEST_P(ProtoMessageTypeAccessorTest, HasFieldRepeated) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + TestMessage example; + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(false)); + example.add_int64_list(10); + EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(true)); +} + +TEST_P(ProtoMessageTypeAccessorTest, HasFieldMap) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + TestMessage example; + example.set_int64_value(10); + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(false)); + (*example.mutable_int64_int32_map())[2] = 3; + EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(true)); +} + +TEST_P(ProtoMessageTypeAccessorTest, HasFieldUnknownField) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + TestMessage example; + example.set_int64_value(10); + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.HasField("unknown_field", value), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_P(ProtoMessageTypeAccessorTest, HasFieldNonMessageType) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + MessageWrapper value(static_cast(nullptr), + nullptr); + + EXPECT_THAT(accessor.HasField("unknown_field", value), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST_P(ProtoMessageTypeAccessorTest, GetFieldSingular) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage example; + example.set_int64_value(10); + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.GetField("int64_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), + IsOkAndHolds(test::IsCelInt64(10))); +} + +TEST_P(ProtoMessageTypeAccessorTest, GetFieldNoSuchField) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage example; + example.set_int64_value(10); + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.GetField("unknown_field", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), + IsOkAndHolds(test::IsCelError(StatusIs( + absl::StatusCode::kNotFound, HasSubstr("unknown_field"))))); +} + +TEST_P(ProtoMessageTypeAccessorTest, GetFieldNotAMessage) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + MessageWrapper value(static_cast(nullptr), + nullptr); + + EXPECT_THAT(accessor.GetField("int64_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST_P(ProtoMessageTypeAccessorTest, GetFieldRepeated) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage example; + example.add_int64_list(10); + example.add_int64_list(20); + + MessageWrapper value(&example, nullptr); + + ASSERT_OK_AND_ASSIGN( + CelValue result, + accessor.GetField("int64_list", value, + ProtoWrapperTypeOptions::kUnsetNull, manager)); + + const CelList* held_value; + ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); + + EXPECT_EQ(held_value->size(), 2); + EXPECT_THAT((*held_value)[0], test::IsCelInt64(10)); + EXPECT_THAT((*held_value)[1], test::IsCelInt64(20)); +} + +TEST_P(ProtoMessageTypeAccessorTest, GetFieldMap) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage example; + (*example.mutable_int64_int32_map())[10] = 20; + + MessageWrapper value(&example, nullptr); + + ASSERT_OK_AND_ASSIGN( + CelValue result, + accessor.GetField("int64_int32_map", value, + ProtoWrapperTypeOptions::kUnsetNull, manager)); + + const CelMap* held_value; + ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); + + EXPECT_EQ(held_value->size(), 1); + EXPECT_THAT((*held_value)[CelValue::CreateInt64(10)], + Optional(test::IsCelInt64(20))); +} + +TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperType) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), + IsOkAndHolds(test::IsCelInt64(10))); +} + +TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperTypeUnsetNullUnbox) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage example; + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), + IsOkAndHolds(test::IsCelNull())); + + // Wrapper field present, but default value. + example.mutable_int64_wrapper_value()->clear_value(); + EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), + IsOkAndHolds(test::IsCelInt64(_))); +} + +TEST_P(ProtoMessageTypeAccessorTest, + GetFieldWrapperTypeUnsetDefaultValueUnbox) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage example; + + MessageWrapper value(&example, nullptr); + + EXPECT_THAT( + accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), + IsOkAndHolds(test::IsCelInt64(_))); + + // Wrapper field present with unset value is used to signal Null, but legacy + // behavior just returns the proto default value. + example.mutable_int64_wrapper_value()->clear_value(); + // Same behavior for this option. + EXPECT_THAT( + accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), + IsOkAndHolds(test::IsCelInt64(_))); +} + +TEST_P(ProtoMessageTypeAccessorTest, IsEqualTo) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + TestMessage example2; + example2.mutable_int64_wrapper_value()->set_value(10); + + MessageWrapper value(&example, nullptr); + MessageWrapper value2(&example2, nullptr); + + EXPECT_TRUE(accessor.IsEqualTo(value, value2)); + EXPECT_TRUE(accessor.IsEqualTo(value2, value)); +} + +TEST_P(ProtoMessageTypeAccessorTest, IsEqualToSameTypeInequal) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + TestMessage example2; + example2.mutable_int64_wrapper_value()->set_value(12); + + MessageWrapper value(&example, nullptr); + MessageWrapper value2(&example2, nullptr); + + EXPECT_FALSE(accessor.IsEqualTo(value, value2)); + EXPECT_FALSE(accessor.IsEqualTo(value2, value)); +} + +TEST_P(ProtoMessageTypeAccessorTest, IsEqualToDifferentTypeInequal) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + Int64Value example2; + example2.set_value(10); + + MessageWrapper value(&example, nullptr); + MessageWrapper value2(&example2, nullptr); + + EXPECT_FALSE(accessor.IsEqualTo(value, value2)); + EXPECT_FALSE(accessor.IsEqualTo(value2, value)); +} + +TEST_P(ProtoMessageTypeAccessorTest, IsEqualToNonMessageInequal) { + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + TestMessage example2; + example2.mutable_int64_wrapper_value()->set_value(10); + + MessageWrapper value(&example, nullptr); + // Upcast to message lite to prevent unwrapping to message. + MessageWrapper value2(static_cast(&example2), + nullptr); + + EXPECT_FALSE(accessor.IsEqualTo(value, value2)); + EXPECT_FALSE(accessor.IsEqualTo(value2, value)); +} + +INSTANTIATE_TEST_SUITE_P(GenericAndSpecific, ProtoMessageTypeAccessorTest, + testing::Bool()); + +TEST(GetGenericProtoTypeInfoInstance, GetTypeName) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + CelValue::MessageWrapper wrapped_message(&test_message, nullptr); + + EXPECT_EQ(info_api.GetTypename(wrapped_message), test_message.GetTypeName()); +} + +TEST(GetGenericProtoTypeInfoInstance, DebugString) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + test_message.set_string_value("abcd"); + CelValue::MessageWrapper wrapped_message(&test_message, nullptr); + + EXPECT_EQ(info_api.DebugString(wrapped_message), + test_message.ShortDebugString()); +} + +TEST(GetGenericProtoTypeInfoInstance, GetAccessApis) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + test_message.set_string_value("abcd"); + CelValue::MessageWrapper wrapped_message(&test_message, nullptr); + + auto* accessor = info_api.GetAccessApis(wrapped_message); + google::protobuf::Arena arena; + auto manager = ProtoMemoryManagerRef(&arena); + + ASSERT_OK_AND_ASSIGN( + CelValue result, + accessor->GetField("string_value", wrapped_message, + ProtoWrapperTypeOptions::kUnsetNull, manager)); + EXPECT_THAT(result, test::IsCelString("abcd")); +} + +TEST(GetGenericProtoTypeInfoInstance, FallbackForNonMessage) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + test_message.set_string_value("abcd"); + // Upcast to signal no google::protobuf::Message / reflection support. + CelValue::MessageWrapper wrapped_message( + static_cast(&test_message), nullptr); + + EXPECT_EQ(info_api.GetTypename(wrapped_message), ""); + EXPECT_EQ(info_api.DebugString(wrapped_message), ""); + + // Check for not-null. + CelValue::MessageWrapper null_message( + static_cast(nullptr), nullptr); + + EXPECT_EQ(info_api.GetTypename(null_message), ""); + EXPECT_EQ(info_api.DebugString(null_message), ""); +} + +TEST(ProtoMessageTypeAdapter, NewInstance) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder result, + adapter.NewInstance(manager)); + EXPECT_EQ(result.message_ptr()->SerializeAsString(), ""); +} + +TEST(ProtoMessageTypeAdapter, NewInstanceUnsupportedDescriptor) { + google::protobuf::Arena arena; + + google::protobuf::DescriptorPool pool; + google::protobuf::FileDescriptorProto faked_file; + faked_file.set_name("faked.proto"); + faked_file.set_syntax("proto3"); + faked_file.set_package("google.api.expr.runtime"); + auto msg_descriptor = faked_file.add_message_type(); + msg_descriptor->set_name("FakeMessage"); + pool.BuildFile(faked_file); + + ProtoMessageTypeAdapter adapter( + pool.FindMessageTypeByName("google.api.expr.runtime.FakeMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + // Message factory doesn't know how to create our custom message, even though + // we provided a descriptor for it. + EXPECT_THAT( + adapter.NewInstance(manager), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("FakeMessage"))); +} + +TEST(ProtoMessageTypeAdapter, DefinesField) { + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + + EXPECT_TRUE(adapter.DefinesField("int64_value")); + EXPECT_FALSE(adapter.DefinesField("not_a_field")); +} + +TEST(ProtoMessageTypeAdapter, SetFieldSingular) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder value, + adapter.NewInstance(manager)); + + ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(10), manager, + value)); + + TestMessage message; + message.set_int64_value(10); + EXPECT_EQ(value.message_ptr()->SerializeAsString(), + message.SerializeAsString()); + + ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), + manager, value), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'not_a_field': not found"))); +} + +TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + ContainerBackedListImpl list( + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + CelValue value_to_set = CelValue::CreateList(&list); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, + adapter.NewInstance(manager)); + + ASSERT_OK(adapter.SetField("int64_list", value_to_set, manager, instance)); + + TestMessage message; + message.add_int64_list(1); + message.add_int64_list(2); + + EXPECT_EQ(instance.message_ptr()->SerializeAsString(), + message.SerializeAsString()); +} + +TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, + adapter.NewInstance(manager)); + + ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), + manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'not_a_field': not found"))); +} + +TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + ContainerBackedListImpl list( + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + CelValue list_value = CelValue::CreateList(&list); + + CelMapBuilder builder; + ASSERT_OK(builder.Add(CelValue::CreateInt64(1), CelValue::CreateInt64(2))); + ASSERT_OK(builder.Add(CelValue::CreateInt64(2), CelValue::CreateInt64(4))); + + CelValue map_value = CelValue::CreateMap(&builder); + + CelValue int_value = CelValue::CreateInt64(42); + + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, + adapter.NewInstance(manager)); + + EXPECT_THAT(adapter.SetField("int64_value", map_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(adapter.SetField("int64_value", list_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + + EXPECT_THAT( + adapter.SetField("int64_int32_map", list_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(adapter.SetField("int64_int32_map", int_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + + EXPECT_THAT(adapter.SetField("int64_list", int_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(adapter.SetField("int64_list", map_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + CelValue int_value = CelValue::CreateInt64(42); + CelValue::MessageWrapper::Builder instance( + static_cast(nullptr)); + + EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST(ProtoMesssageTypeAdapter, SetFieldNullMessage) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + CelValue int_value = CelValue::CreateInt64(42); + CelValue::MessageWrapper::Builder instance( + static_cast(nullptr)); + + EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.protobuf.Int64Value"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, + adapter.NewInstance(manager)); + ASSERT_OK( + adapter.SetField("value", CelValue::CreateInt64(42), manager, instance)); + + ASSERT_OK_AND_ASSIGN(CelValue value, + adapter.AdaptFromWellKnownType(manager, instance)); + + EXPECT_THAT(value, test::IsCelInt64(42)); +} + +TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, + adapter.NewInstance(manager)); + + ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(42), manager, + instance)); + ASSERT_OK_AND_ASSIGN(CelValue value, + adapter.AdaptFromWellKnownType(manager, instance)); + + // TestMessage should not be converted to a CEL primitive type. + EXPECT_THAT(value, test::IsCelMessage(EqualsProto("int64_value: 42"))); +} + +TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + CelValue::MessageWrapper::Builder instance( + static_cast(nullptr)); + + // Interpreter guaranteed to call this with a message type, otherwise, + // something has broken. + EXPECT_THAT(adapter.AdaptFromWellKnownType(manager, instance), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST(ProtoMesssageTypeAdapter, TypeInfoDebug) { + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + + TestMessage message; + message.set_int64_value(42); + EXPECT_THAT(adapter.DebugString(MessageWrapper(&message, &adapter)), + HasSubstr(message.ShortDebugString())); + + EXPECT_THAT(adapter.DebugString(MessageWrapper()), + HasSubstr("")); +} + +TEST(ProtoMesssageTypeAdapter, TypeInfoName) { + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + + EXPECT_EQ(adapter.GetTypename(MessageWrapper()), + "google.api.expr.runtime.TestMessage"); +} + +TEST(ProtoMesssageTypeAdapter, FindFieldFound) { + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + + EXPECT_THAT( + adapter.FindFieldByName("int64_value"), + Optional(Truly([](const LegacyTypeInfoApis::FieldDescription& desc) { + return desc.name == "int64_value" && desc.number == 2; + }))) + << "expected field int64_value: 2"; +} + +TEST(ProtoMesssageTypeAdapter, FindFieldNotFound) { + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + + EXPECT_EQ(adapter.FindFieldByName("foo_not_a_field"), std::nullopt); +} + +TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + const LegacyTypeMutationApis* api = adapter.GetMutationApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + ASSERT_OK_AND_ASSIGN(MessageWrapper::Builder builder, + api->NewInstance(manager)); + EXPECT_NE(google::protobuf::DynamicCastMessage(builder.message_ptr()), + nullptr); +} + +TEST(ProtoMesssageTypeAdapter, TypeInfoAccesor) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + EXPECT_THAT(api->GetField("int64_value", wrapped, + ProtoWrapperTypeOptions::kUnsetNull, manager), + IsOkAndHolds(test::IsCelInt64(42))); +} + +TEST(ProtoMesssageTypeAdapter, Qualify) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.mutable_message_value()->set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{2, "int64_value"}}; + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyDynamicFieldAccessUnsupported) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.mutable_message_value()->set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::AttributeQualifier::OfString("int64_value")}; + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST(ProtoMesssageTypeAdapter, QualifyNoSuchField) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.mutable_message_value()->set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{99, "not_a_field"}, + cel::FieldSpecifier{2, "int64_value"}}; + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field( + &LegacyQualifyResult::value, + test::IsCelError(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("no_such_field")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyHasNoSuchField) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.mutable_message_value()->set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{99, "not_a_field"}}; + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/true, manager), + IsOkAndHolds(Field( + &LegacyQualifyResult::value, + test::IsCelError(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("no_such_field")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyNoSuchFieldLeaf) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.mutable_message_value()->set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{99, "not_a_field"}}; + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field( + &LegacyQualifyResult::value, + test::IsCelError(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("no_such_field")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalSupport) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_string_message_map())["@key"].set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{210, "string_message_map"}, + cel::AttributeQualifier::OfString("@key"), + cel::FieldSpecifier{2, "int64_value"}}; + + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); +} + +TEST(ProtoMesssageTypeAdapter, TypedFieldAccessOnMapUnsupported) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_string_message_map())["@key"].set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{210, "string_message_map"}, + // This is probably a bug, but defer to evaluator for consistent handling. + cel::FieldSpecifier{2, "value"}, cel::FieldSpecifier{2, "int64_value"}}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalWrongKeyType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_string_message_map())["@key"].set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{210, "string_message_map"}, + cel::AttributeQualifier::OfInt(0), cel::FieldSpecifier{2, "int64_value"}}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, + test::IsCelError(StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid map key type")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalHasWrongKeyType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_string_message_map())["@key"].set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{210, "string_message_map"}, + cel::AttributeQualifier::OfInt(0)}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/true, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, + test::IsCelError(StatusIs( + absl::StatusCode::kUnknown, + HasSubstr("No matching overloads")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalSupportNoSuchKey) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_string_message_map())["@key"].set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{210, "string_message_map"}, + cel::AttributeQualifier::OfString("bad_key"), + cel::FieldSpecifier{2, "int64_value"}}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field( + &LegacyQualifyResult::value, + test::IsCelError(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalInt32Key) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_int32_int32_map())[0] = 42; + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{205, "int32_int32_map"}, + cel::AttributeQualifier::OfInt(0)}; + + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalIntOutOfRange) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_int32_int32_map())[0] = 42; + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{205, "int32_int32_map"}, + cel::AttributeQualifier::OfInt(1LL << 32)}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field( + &LegacyQualifyResult::value, + test::IsCelError(StatusIs(absl::StatusCode::kOutOfRange, + HasSubstr("integer overflow")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalUint32Key) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_uint32_uint32_map())[0] = 42; + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{206, "uint32_uint32_map"}, + cel::AttributeQualifier::OfUint(0)}; + + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelUint64(42)))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalUintOutOfRange) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_uint32_uint32_map())[0] = 42; + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{206, "uint32_uint32_map"}, + cel::AttributeQualifier::OfUint(1LL << 32)}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field( + &LegacyQualifyResult::value, + test::IsCelError(StatusIs(absl::StatusCode::kOutOfRange, + HasSubstr("integer overflow")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalUnexpectedFieldAccess) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_string_message_map())["@key"].set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{210, "string_message_map"}, + // For coverage check that qualify gives up if there's a strong field + // access requested for a map. + cel::FieldSpecifier{0, "field_like_key"}}; + + auto result = api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager); + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + StatusIs(absl::StatusCode::kUnimplemented, _)); +} + +TEST(ProtoMesssageTypeAdapter, UntypedQualifiersNotYetSupported) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + (*message.mutable_string_message_map())["@key"].set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::AttributeQualifier::OfString("string_message_map"), + cel::AttributeQualifier::OfString("@key"), + cel::AttributeQualifier::OfString("int64_value")}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + StatusIs(absl::StatusCode::kUnimplemented, _)); +} + +TEST(ProtoMesssageTypeAdapter, QualifyRepeatedIndexWrongType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.add_message_list()->add_int64_list(1); + message.add_message_list()->add_int64_list(2); + + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{112, "message_list"}, + cel::AttributeQualifier::OfBool(false), + cel::FieldSpecifier{102, "int64_list"}, + cel::AttributeQualifier::OfInt(0)}; + + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, + test::IsCelError(StatusIs( + absl::StatusCode::kUnknown, + HasSubstr("No matching overloads found")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyRepeatedTypeCheckError) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + message.add_int64_list(1); + message.add_int64_list(2); + + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{102, "int64_list"}, cel::AttributeQualifier::OfInt(0), + // index on an int. + cel::AttributeQualifier::OfInt(1)}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Unexpected qualify intermediate type"))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyRepeatedLeaf) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + auto* nested = message.mutable_message_value(); + nested->add_int64_list(1); + nested->add_int64_list(2); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{102, "int64_list"}, + }; + + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, + test::IsCelList(ElementsAre(test::IsCelInt64(1), + test::IsCelInt64(2)))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyRepeatedIndexLeaf) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + auto* nested = message.mutable_message_value(); + nested->add_int64_list(1); + nested->add_int64_list(2); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{102, "int64_list"}, + cel::AttributeQualifier::OfInt(1)}; + + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(2)))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyRepeatedIndexLeafOutOfBounds) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + auto* nested = message.mutable_message_value(); + nested->add_int64_list(1); + nested->add_int64_list(2); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{102, "int64_list"}, + cel::AttributeQualifier::OfInt(2)}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, + test::IsCelError(StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("index out of bounds")))))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapLeaf) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + auto* nested_map = + message.mutable_message_value()->mutable_string_int32_map(); + (*nested_map)["@key"] = 42; + (*nested_map)["@key2"] = -42; + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{203, "string_int32_map"}, + }; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field( + &LegacyQualifyResult::value, Truly([](const CelValue& v) { + return v.IsMap() && v.MapOrDie()->size() == 2; + })))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapIndexLeaf) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + auto* nested_map = + message.mutable_message_value()->mutable_string_int32_map(); + (*nested_map)["@key"] = 42; + (*nested_map)["@key2"] = -42; + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{203, "string_int32_map"}, + cel::AttributeQualifier::OfString("@key")}; + + EXPECT_THAT( + api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); +} + +TEST(ProtoMesssageTypeAdapter, QualifyMapIndexLeafWrongType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + auto manager = ProtoMemoryManagerRef(&arena); + + TestMessage message; + auto* nested_map = + message.mutable_message_value()->mutable_string_int32_map(); + (*nested_map)["@key"] = 42; + (*nested_map)["@key2"] = -42; + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + std::vector qualfiers{ + cel::FieldSpecifier{12, "message_value"}, + cel::FieldSpecifier{203, "string_int32_map"}, + cel::AttributeQualifier::OfInt(0)}; + + EXPECT_THAT(api->Qualify(qualfiers, wrapped, + /*presence_test=*/false, manager), + IsOkAndHolds(Field(&LegacyQualifyResult::value, + test::IsCelError(StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid map key type")))))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc new file mode 100644 index 000000000..b5746523e --- /dev/null +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -0,0 +1,70 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/protobuf_descriptor_type_provider.h" + +#include +#include + +#include "absl/synchronization/mutex.h" +#include "eval/public/structs/proto_message_type_adapter.h" +#include "google/protobuf/descriptor.h" + +namespace google::api::expr::runtime { + +absl::optional ProtobufDescriptorProvider::ProvideLegacyType( + absl::string_view name) const { + const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); + if (result == nullptr) { + return std::nullopt; + } + // ProtoMessageTypeAdapter provides apis for both access and mutation. + return LegacyTypeAdapter(result, result); +} + +absl::optional +ProtobufDescriptorProvider::ProvideLegacyTypeInfo( + absl::string_view name) const { + const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); + if (result == nullptr) { + return std::nullopt; + } + return result; +} + +std::unique_ptr +ProtobufDescriptorProvider::CreateTypeAdapter(absl::string_view name) const { + const google::protobuf::Descriptor* descriptor = + descriptor_pool_->FindMessageTypeByName(name); + if (descriptor == nullptr) { + return nullptr; + } + + return std::make_unique(descriptor, + message_factory_); +} + +const ProtoMessageTypeAdapter* ProtobufDescriptorProvider::GetTypeAdapter( + absl::string_view name) const { + absl::MutexLock lock(mu_); + auto it = type_cache_.find(name); + if (it != type_cache_.end()) { + return it->second.get(); + } + auto type_provider = CreateTypeAdapter(name); + const ProtoMessageTypeAdapter* result = type_provider.get(); + type_cache_[name] = std::move(type_provider); + return result; +} +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.h b/eval/public/structs/protobuf_descriptor_type_provider.h new file mode 100644 index 000000000..232e848b4 --- /dev/null +++ b/eval/public/structs/protobuf_descriptor_type_provider.h @@ -0,0 +1,67 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "eval/public/structs/legacy_type_provider.h" +#include "eval/public/structs/proto_message_type_adapter.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { + +// Implementation of a type provider that generates types from protocol buffer +// descriptors. +class ProtobufDescriptorProvider : public LegacyTypeProvider { + public: + ProtobufDescriptorProvider(const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* factory) + : descriptor_pool_(pool), message_factory_(factory) {} + + absl::optional ProvideLegacyType( + absl::string_view name) const final; + + absl::optional ProvideLegacyTypeInfo( + absl::string_view name) const final; + + private: + // Create a new type instance if found in the registered descriptor pool. + // Otherwise, returns nullptr. + std::unique_ptr CreateTypeAdapter( + absl::string_view name) const; + + const ProtoMessageTypeAdapter* GetTypeAdapter(absl::string_view name) const; + + const google::protobuf::DescriptorPool* descriptor_pool_; + google::protobuf::MessageFactory* message_factory_; + mutable absl::flat_hash_map> + type_cache_ ABSL_GUARDED_BY(mu_); + mutable absl::Mutex mu_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ diff --git a/eval/public/structs/protobuf_descriptor_type_provider_test.cc b/eval/public/structs/protobuf_descriptor_type_provider_test.cc new file mode 100644 index 000000000..3a8fae26b --- /dev/null +++ b/eval/public/structs/protobuf_descriptor_type_provider_test.cc @@ -0,0 +1,95 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/protobuf_descriptor_type_provider.h" + +#include + +#include "google/protobuf/wrappers.pb.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "eval/public/testing/matchers.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::extensions::ProtoMemoryManager; + +TEST(ProtobufDescriptorProvider, Basic) { + ProtobufDescriptorProvider provider( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + google::protobuf::Arena arena; + auto manager = ProtoMemoryManager(&arena); + auto type_adapter = provider.ProvideLegacyType("google.protobuf.Int64Value"); + absl::optional type_info = + provider.ProvideLegacyTypeInfo("google.protobuf.Int64Value"); + + ASSERT_TRUE(type_adapter.has_value()); + ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); + ASSERT_TRUE(type_info.has_value()); + ASSERT_TRUE(type_info != nullptr); + + google::protobuf::Int64Value int64_value; + CelValue::MessageWrapper int64_cel_value(&int64_value, *type_info); + EXPECT_EQ((*type_info)->GetTypename(int64_cel_value), + "google.protobuf.Int64Value"); + + ASSERT_TRUE(type_adapter->mutation_apis()->DefinesField("value")); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder value, + type_adapter->mutation_apis()->NewInstance(manager)); + + ASSERT_OK(type_adapter->mutation_apis()->SetField( + "value", CelValue::CreateInt64(10), manager, value)); + + ASSERT_OK_AND_ASSIGN( + CelValue adapted, + type_adapter->mutation_apis()->AdaptFromWellKnownType(manager, value)); + + EXPECT_THAT(adapted, test::IsCelInt64(10)); +} + +// This is an implementation detail, but testing for coverage. +TEST(ProtobufDescriptorProvider, MemoizesAdapters) { + ProtobufDescriptorProvider provider( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + auto type_adapter = provider.ProvideLegacyType("google.protobuf.Int64Value"); + + ASSERT_TRUE(type_adapter.has_value()); + ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); + + auto type_adapter2 = provider.ProvideLegacyType("google.protobuf.Int64Value"); + ASSERT_TRUE(type_adapter2.has_value()); + + EXPECT_EQ(type_adapter->mutation_apis(), type_adapter2->mutation_apis()); + EXPECT_EQ(type_adapter->access_apis(), type_adapter2->access_apis()); +} + +TEST(ProtobufDescriptorProvider, NotFound) { + ProtobufDescriptorProvider provider( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + auto type_adapter = provider.ProvideLegacyType("UnknownType"); + auto type_info = provider.ProvideLegacyTypeInfo("UnknownType"); + + ASSERT_FALSE(type_adapter.has_value()); + ASSERT_FALSE(type_info.has_value()); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_value_factory.h b/eval/public/structs/protobuf_value_factory.h new file mode 100644 index 000000000..8f4e3add9 --- /dev/null +++ b/eval/public/structs/protobuf_value_factory.h @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ + +#include + +#include "eval/public/cel_value.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { + +// Definiton for factory producing a properly initialized message-typed +// CelValue. +// +// google::protobuf::Message is assumed adapted as possible, so this function just +// associates it with appropriate type information. +// +// Used to break cyclic dependency between field access and message wrapping -- +// not intended for general use. +using ProtobufValueFactory = CelValue (*)(const google::protobuf::Message*); +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ diff --git a/eval/public/structs/trivial_legacy_type_info.h b/eval/public/structs/trivial_legacy_type_info.h new file mode 100644 index 000000000..2189bd478 --- /dev/null +++ b/eval/public/structs/trivial_legacy_type_info.h @@ -0,0 +1,54 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ + +#include + +#include "absl/base/no_destructor.h" +#include "absl/strings/string_view.h" +#include "eval/public/message_wrapper.h" +#include "eval/public/structs/legacy_type_info_apis.h" + +namespace google::api::expr::runtime { + +// Implementation of type info APIs suitable for testing where no message +// operations need to be supported. +class TrivialTypeInfo : public LegacyTypeInfoApis { + public: + absl::string_view GetTypename(const MessageWrapper& wrapper) const override { + return "opaque"; + } + + std::string DebugString(const MessageWrapper& wrapper) const override { + return "opaque"; + } + + const LegacyTypeAccessApis* GetAccessApis( + const MessageWrapper& wrapper) const override { + // Accessors unsupported -- caller should treat this as an opaque type (no + // fields defined, field access always results in a CEL error). + return nullptr; + } + + static const TrivialTypeInfo* GetInstance() { + static absl::NoDestructor kInstance; + return &*kInstance; + } +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ diff --git a/eval/public/structs/trivial_legacy_type_info_test.cc b/eval/public/structs/trivial_legacy_type_info_test.cc new file mode 100644 index 000000000..9cc6e4916 --- /dev/null +++ b/eval/public/structs/trivial_legacy_type_info_test.cc @@ -0,0 +1,65 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/trivial_legacy_type_info.h" + +#include "eval/public/message_wrapper.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +TEST(TrivialTypeInfo, GetTypename) { + TrivialTypeInfo info; + MessageWrapper wrapper; + + EXPECT_EQ(info.GetTypename(wrapper), "opaque"); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetTypename(wrapper), "opaque"); +} + +TEST(TrivialTypeInfo, DebugString) { + TrivialTypeInfo info; + MessageWrapper wrapper; + + EXPECT_EQ(info.DebugString(wrapper), "opaque"); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->DebugString(wrapper), "opaque"); +} + +TEST(TrivialTypeInfo, GetAccessApis) { + TrivialTypeInfo info; + MessageWrapper wrapper; + + EXPECT_EQ(info.GetAccessApis(wrapper), nullptr); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetAccessApis(wrapper), nullptr); +} + +TEST(TrivialTypeInfo, GetMutationApis) { + TrivialTypeInfo info; + MessageWrapper wrapper; + + EXPECT_EQ(info.GetMutationApis(wrapper), nullptr); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetMutationApis(wrapper), nullptr); +} + +TEST(TrivialTypeInfo, FindFieldByName) { + TrivialTypeInfo info; + MessageWrapper wrapper; + + EXPECT_EQ(info.FindFieldByName("foo"), std::nullopt); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->FindFieldByName("foo"), + std::nullopt); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/testing/BUILD b/eval/public/testing/BUILD index 61f75a421..f4529e931 100644 --- a/eval/public/testing/BUILD +++ b/eval/public/testing/BUILD @@ -1,58 +1,25 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + package( default_testonly = True, default_visibility = ["//visibility:public"], ) -licenses(["notice"]) # Apache 2.0 - -cc_library( - name = "debug_string", - srcs = ["debug_string.cc"], - hdrs = ["debug_string.h"], - deps = [ - "//eval/public:cel_attribute", - "//eval/public:cel_value", - "//eval/public:unknown_function_result_set", - "//eval/public:unknown_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "debug_string_test", - srcs = ["debug_string_test.cc"], - deps = [ - ":debug_string", - "//eval/public:cel_attribute", - "//eval/public:cel_function", - "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", - "//eval/public:unknown_function_result_set", - "//eval/public:unknown_set", - "//eval/public/structs:cel_proto_wrapper", - "//eval/testutil:test_message_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", - ], -) +licenses(["notice"]) cc_library( name = "matchers", srcs = ["matchers.cc"], hdrs = ["matchers.h"], deps = [ - ":debug_string", "//eval/public:cel_value", "//eval/public:set_util", - "//eval/public:unknown_set", + "//internal:casts", + "//internal:testing", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", "@com_google_protobuf//:protobuf", ], ) @@ -62,12 +29,12 @@ cc_test( srcs = ["matchers_test.cc"], deps = [ ":matchers", + "//eval/public/containers:container_backed_list_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", + "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", ], ) diff --git a/eval/public/testing/debug_string.cc b/eval/public/testing/debug_string.cc deleted file mode 100644 index bf550eccb..000000000 --- a/eval/public/testing/debug_string.cc +++ /dev/null @@ -1,151 +0,0 @@ -#include "eval/public/testing/debug_string.h" - -#include "google/protobuf/message.h" -#include "absl/strings/escaping.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/substitute.h" -#include "absl/time/time.h" -#include "eval/public/cel_attribute.h" -#include "eval/public/unknown_function_result_set.h" -#include "eval/public/unknown_set.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace test { - -namespace { - -// Forward declare -- depends on value string visitor. -std::string AttributeString(const CelAttribute* attr); - -std::string FunctionCallString(const UnknownFunctionResult* fn); - -struct ValueStringVisitor { - std::string operator()(int64_t arg) { return absl::StrFormat("%d", arg); } - - std::string operator()(uint64_t arg) { return absl::StrFormat("%d", arg); } - - std::string operator()(bool arg) { return (arg) ? "true" : "false"; } - - std::string operator()(double arg) { return absl::StrFormat("%f", arg); } - - std::string operator()(const google::protobuf::Message* arg) { - if (arg == nullptr) { - return "NULL"; - } - return arg->DebugString(); - } - std::string operator()(CelValue::StringHolder arg) { - return absl::StrFormat("'%s'", arg.value()); - } - - std::string operator()(CelValue::BytesHolder arg) { - return absl::StrFormat("0x%s", absl::BytesToHexString(arg.value())); - } - - std::string operator()(absl::Time arg) { return absl::FormatTime(arg); } - - std::string operator()(absl::Duration arg) { - return absl::FormatDuration(arg); - } - - std::string operator()(const CelList* arg) { - std::vector elements; - elements.reserve(arg->size()); - for (int i = 0; i < arg->size(); i++) { - elements.push_back(DebugString(arg->operator[](i))); - } - return absl::StrCat("[", absl::StrJoin(elements, ", "), "]"); - } - - std::string operator()(const CelMap* arg) { - const CelList* keys = arg->ListKeys(); - std::vector elements; - elements.reserve(keys->size()); - for (int i = 0; i < keys->size(); i++) { - elements.push_back( - absl::Substitute("$0:$1", DebugString((*keys)[i]), - DebugString(arg->operator[]((*keys)[i]).value()))); - } - return absl::Substitute("{$0}", absl::StrJoin(elements, ", ")); - } - - std::string operator()(CelValue::CelTypeHolder arg) { - return absl::StrFormat("'%s'", arg.value()); - } - - std::string operator()(const CelError* arg) { return arg->ToString(); } - - std::string operator()(const UnknownSet* arg) { - std::vector attrs; - attrs.reserve(arg->unknown_attributes().attributes().size()); - for (const auto* attr : arg->unknown_attributes().attributes()) { - attrs.push_back(AttributeString(attr)); - } - std::vector fns; - fns.reserve( - arg->unknown_function_results().unknown_function_results().size()); - for (const auto* fn : - arg->unknown_function_results().unknown_function_results()) { - fns.push_back(FunctionCallString(fn)); - } - return absl::Substitute("{attributes:[$0], functions:[$1]}", - absl::StrJoin(attrs, ", "), - absl::StrJoin(fns, ", ")); - } -}; - -std::string AttributeString(const CelAttribute* attr) { - // qualification = - std::string output(attr->variable().ident_expr().name()); - for (const auto& q : attr->qualifier_path()) { - absl::StrAppend(&output, ".", q.Visit(ValueStringVisitor())); - } - return output; -} - -std::string FunctionCallString(const UnknownFunctionResult* fn) { - std::vector args; - std::string call; - args.reserve(fn->arguments().size()); - if (fn->descriptor().receiver_style()) { - if (fn->arguments().empty()) { - absl::StrAppend(&call, "."); - } else { - absl::StrAppend(&call, DebugString(fn->arguments()[0]), "."); - } - absl::StrAppend(&call, fn->descriptor().name()); - for (int i = 1; i < fn->arguments().size(); i++) { - args.push_back(DebugString(fn->arguments()[i])); - } - } else { - absl::StrAppend(&call, fn->descriptor().name()); - for (int i = 0; i < fn->arguments().size(); i++) { - args.push_back(DebugString(fn->arguments()[i])); - } - } - return absl::Substitute("$0($1)", call, absl::StrJoin(args, ", ")); -} - -} // namespace - -// String rerpesentation of the underlying value. -std::string DebugValueString(const CelValue& value) { - return value.Visit(ValueStringVisitor()); -} - -// String representation of the cel value. -std::string DebugString(const CelValue& value) { - return absl::Substitute("<$0,$1>", CelValue::TypeName(value.type()), - DebugValueString(value)); -} - -} // namespace test -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/testing/debug_string.h b/eval/public/testing/debug_string.h deleted file mode 100644 index 2060fc777..000000000 --- a/eval/public/testing/debug_string.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_DEBUG_STRING_H_ -#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_DEBUG_STRING_H_ - -#include "eval/public/cel_value.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace test { - -// String rerpesentation of the underlying value. -std::string DebugValueString(const CelValue& value); - -// String representation of the cel value. This should only be used for -// informational purposes and the exact format may change. In particular, -// ordering is not guaranteed for some container types. -std::string DebugString(const CelValue& value); - -} // namespace test -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_DEBUG_STRING_H_ diff --git a/eval/public/testing/debug_string_test.cc b/eval/public/testing/debug_string_test.cc deleted file mode 100644 index 5ff9dd497..000000000 --- a/eval/public/testing/debug_string_test.cc +++ /dev/null @@ -1,162 +0,0 @@ -#include "eval/public/testing/debug_string.h" - -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/arena.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/status/status.h" -#include "absl/time/time.h" -#include "eval/public/cel_attribute.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_value.h" -#include "eval/public/structs/cel_proto_wrapper.h" -#include "eval/public/unknown_attribute_set.h" -#include "eval/public/unknown_function_result_set.h" -#include "eval/public/unknown_set.h" -#include "eval/testutil/test_message.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace test { -namespace { -using testing::AnyOf; -using testing::HasSubstr; - -TEST(DebugString, Primitives) { - constexpr char data[] = { - '\x01', '\x01', '\x00', '\xc0', '\xff', - }; - std::string bytestring(data, 5); - CelError error = absl::InternalError("error"); - - EXPECT_EQ(DebugString(CelValue::CreateStringView("hello world")), - ""); - EXPECT_EQ(DebugString(CelValue::CreateBytes(&bytestring)), - ""); - EXPECT_EQ(DebugString(CelValue::CreateBool(false)), ""); - EXPECT_EQ(DebugString(CelValue::CreateDouble(1.5)), ""); - EXPECT_EQ(DebugString(CelValue::CreateDuration(absl::Seconds(2))), - ""); - EXPECT_THAT(DebugString(CelValue::CreateTimestamp(absl::FromUnixSeconds(1))), - AnyOf(HasSubstr(""); - EXPECT_EQ(DebugString(CelValue::CreateInt64(-1)), - ""); // no transform - - EXPECT_EQ(DebugString(CelValue::CreateUint64(1)), - ""); // no transform -} - -TEST(DebugString, Messages) { - google::protobuf::Arena arena; - TestMessage message; - message.add_int64_list(1); - message.add_int64_list(2); - - EXPECT_EQ(DebugString(CelValue::CreateNull()), ""); - EXPECT_EQ(DebugString(CelProtoWrapper::CreateMessage(&message, &arena)), - ""); -} - -TEST(DebugString, Lists) { - google::protobuf::Arena arena; - - protobuf::ListValue list_msg; - list_msg.add_values()->set_bool_value(true); - list_msg.add_values()->set_bool_value(false); - - // converted to a list - EXPECT_EQ(DebugString(CelProtoWrapper::CreateMessage(&list_msg, &arena)), - ", ]>"); -} - -TEST(DebugString, Maps) { - google::protobuf::Arena arena; - - // Converted to a map on CelValue::Create. - protobuf::Struct struct_msg; - (*struct_msg.mutable_fields())["field1"].set_bool_value(true); - (*struct_msg.mutable_fields())["field2"].set_bool_value(false); - - // Ordering isn't guaranteed for the backing map for a converted struct. - EXPECT_THAT(DebugString(CelProtoWrapper::CreateMessage(&struct_msg, &arena)), - AnyOf(":, " - ":" - "}>", - ":, " - ":" - "}>")); -} - -TEST(DebugString, UnknownSet) { - google::protobuf::Arena arena; - google::api::expr::v1alpha1::Expr ident; - ident.mutable_ident_expr()->set_name("var"); - CelAttribute attr( - ident, - {CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - CelAttributeQualifier::Create(CelValue::CreateStringView("field"))}); - UnknownFunctionResult function_result( - CelFunctionDescriptor("IntFn", false, {CelValue::Type::kInt64}), 1, - {CelValue::CreateInt64(1)}); - UnknownSet unknown_set(UnknownAttributeSet({&attr}), - UnknownFunctionResultSet(&function_result)); - EXPECT_EQ(DebugString(CelValue::CreateUnknownSet(&unknown_set)), - ")]}>"); // no transform -} - -TEST(DebugString, Type) { - constexpr char data[] = { - '\x01', '\x01', '\x00', '\xc0', '\xff', - }; - std::string bytestring(data, 5); - EXPECT_EQ( - DebugString(CelValue::CreateStringView("hello world").ObtainCelType()), - ""); - EXPECT_EQ(DebugString(CelValue::CreateBytes(&bytestring).ObtainCelType()), - ""); - EXPECT_EQ(DebugString(CelValue::CreateBool(false).ObtainCelType()), - ""); - EXPECT_EQ(DebugString(CelValue::CreateDouble(1.5).ObtainCelType()), - ""); - EXPECT_EQ( - DebugString(CelValue::CreateDuration(absl::Seconds(2)).ObtainCelType()), - ""); - EXPECT_EQ( - DebugString( - CelValue::CreateTimestamp(absl::FromUnixSeconds(1)).ObtainCelType()), - ""); - EXPECT_EQ(DebugString(CelValue::CreateInt64(-1).ObtainCelType()), - ""); - - EXPECT_EQ(DebugString(CelValue::CreateUint64(1).ObtainCelType()), - ""); - - google::protobuf::Arena arena; - - protobuf::ListValue list_msg; - EXPECT_EQ( - DebugString( - CelProtoWrapper::CreateMessage(&list_msg, &arena).ObtainCelType()), - ""); - // Converted to a map on CelValue::Create. - protobuf::Struct struct_msg; - EXPECT_EQ( - DebugString( - CelProtoWrapper::CreateMessage(&struct_msg, &arena).ObtainCelType()), - ""); -} - -} // namespace -} // namespace test -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index 2315a0127..4f728c730 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -1,25 +1,27 @@ #include "eval/public/testing/matchers.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include + #include "absl/strings/string_view.h" +#include "eval/public/cel_value.h" #include "eval/public/set_util.h" -#include "eval/public/testing/debug_string.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { void PrintTo(const CelValue& value, std::ostream* os) { - *os << test::DebugString(value); + *os << value.DebugString(); } namespace test { namespace { -using testing::MatcherInterface; -using testing::MatchResultListener; +using ::testing::_; +using ::testing::MatcherInterface; +using ::testing::MatchResultListener; class CelValueEqualImpl : public MatcherInterface { public: @@ -31,7 +33,7 @@ class CelValueEqualImpl : public MatcherInterface { } void DescribeTo(std::ostream* os) const override { - *os << DebugString(value_); + *os << value_.DebugString(); } private: @@ -42,8 +44,8 @@ class CelValueEqualImpl : public MatcherInterface { template class CelValueMatcherImpl : public testing::MatcherInterface { public: - CelValueMatcherImpl(testing::Matcher m) - : underlying_type_matcher_(m) {} + explicit CelValueMatcherImpl(testing::Matcher m) + : underlying_type_matcher_(std::move(m)) {} bool MatchAndExplain(const CelValue& v, testing::MatchResultListener* listener) const override { UnderlyingType arg; @@ -61,26 +63,55 @@ class CelValueMatcherImpl : public testing::MatcherInterface { const testing::Matcher underlying_type_matcher_; }; +// Template specialization for google::protobuf::Message. +template <> +class CelValueMatcherImpl + : public testing::MatcherInterface { + public: + explicit CelValueMatcherImpl(testing::Matcher m) + : underlying_type_matcher_(std::move(m)) {} + bool MatchAndExplain(const CelValue& v, + testing::MatchResultListener* listener) const override { + CelValue::MessageWrapper arg; + return v.GetValue(&arg) && arg.HasFullProto() && + underlying_type_matcher_.Matches( + google::protobuf::DownCastMessage(arg.message_ptr())); + } + + void DescribeTo(std::ostream* os) const override { + *os << absl::StrCat("type is ", + CelValue::TypeName(CelValue::Type::kMessage), " and "); + underlying_type_matcher_.DescribeTo(os); + } + + private: + const testing::Matcher underlying_type_matcher_; +}; + } // namespace CelValueMatcher EqualsCelValue(const CelValue& v) { return CelValueMatcher(new CelValueEqualImpl(v)); } +CelValueMatcher IsCelNull() { + return CelValueMatcher(new CelValueMatcherImpl(_)); +} + CelValueMatcher IsCelBool(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelInt64(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelUint64(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelDouble(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelString(testing::Matcher m) { @@ -94,15 +125,16 @@ CelValueMatcher IsCelBytes(testing::Matcher m) { } CelValueMatcher IsCelMessage(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher( + new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelDuration(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelTimestamp(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelError(testing::Matcher m) { @@ -112,7 +144,4 @@ CelValueMatcher IsCelError(testing::Matcher m) { } } // namespace test -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/testing/matchers.h b/eval/public/testing/matchers.h index d2ef85677..5bd73dd1d 100644 --- a/eval/public/testing/matchers.h +++ b/eval/public/testing/matchers.h @@ -1,17 +1,17 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_MATCHERS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_MATCHERS_H_ +#include #include +#include -#include "google/protobuf/message.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" -#include "eval/public/set_util.h" -#include "eval/public/testing/debug_string.h" -#include "eval/public/unknown_set.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" namespace google { namespace api { @@ -29,10 +29,13 @@ using CelValueMatcher = testing::Matcher; // Tests equality to CelValue v using the set_util implementation. CelValueMatcher EqualsCelValue(const CelValue& v); +// Matches CelValues of type null. +CelValueMatcher IsCelNull(); + // Matches CelValues of type bool whose held value matches |m|. CelValueMatcher IsCelBool(testing::Matcher m); -// Matches CelValues of type int64_t whose held value matches |m|. +// Matches CelValues of type int64 whose held value matches |m|. CelValueMatcher IsCelInt64(testing::Matcher m); // Matches CelValues of type uint64_t whose held value matches |m|. @@ -60,8 +63,50 @@ CelValueMatcher IsCelTimestamp(testing::Matcher m); // The matcher |m| is wrapped to allow using the testing::status::... matchers. CelValueMatcher IsCelError(testing::Matcher m); -// TODO(issues/73): add helpers for working with maps, unknown sets, and -// lists. +// A matcher that wraps a Container matcher so that container matchers can be +// used for matching CelList. +// +// This matcher can be avoided if CelList supported the iterators needed by the +// standard container matchers but given that it is an interface it is a much +// larger project. +// +// TODO(issues/73): Re-use CelValueMatcherImpl. There are template details +// that need to be worked out specifically on how CelValueMatcherImpl can accept +// a generic matcher for CelList instead of testing::Matcher. +template +class CelListMatcher : public testing::MatcherInterface { + public: + explicit CelListMatcher(ContainerMatcher m) : container_matcher_(m) {} + + bool MatchAndExplain(const CelValue& v, + testing::MatchResultListener* listener) const override { + const CelList* cel_list; + if (!v.GetValue(&cel_list) || cel_list == nullptr) return false; + + std::vector cel_vector; + cel_vector.reserve(cel_list->size()); + for (int i = 0; i < cel_list->size(); ++i) { + cel_vector.push_back((*cel_list)[i]); + } + return container_matcher_.Matches(cel_vector); + } + + void DescribeTo(std::ostream* os) const override { + CelValue::Type type = + static_cast(CelValue::IndexOf::value); + *os << absl::StrCat("type is ", CelValue::TypeName(type), " and "); + container_matcher_.DescribeTo(os); + } + + private: + const testing::Matcher> container_matcher_; +}; + +template +CelValueMatcher IsCelList(ContainerMatcher m) { + return CelValueMatcher(new CelListMatcher(m)); +} +// TODO(issues/73): add helpers for working with maps and unknown sets. } // namespace test } // namespace runtime diff --git a/eval/public/testing/matchers_test.cc b/eval/public/testing/matchers_test.cc index ac21531c2..774f91578 100644 --- a/eval/public/testing/matchers_test.cc +++ b/eval/public/testing/matchers_test.cc @@ -1,25 +1,24 @@ #include "eval/public/testing/matchers.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/status/status.h" #include "absl/time/time.h" +#include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" #include "testutil/util.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace test { +namespace google::api::expr::runtime::test { namespace { -using testing::DoubleEq; -using testing::DoubleNear; -using testing::Gt; -using testing::Lt; -using testing::Not; +using ::testing::Contains; +using ::testing::DoubleEq; +using ::testing::DoubleNear; +using ::testing::ElementsAre; +using ::testing::Gt; +using ::testing::Lt; +using ::testing::Not; +using ::testing::UnorderedElementsAre; using testutil::EqualsProto; TEST(IsCelValue, EqualitySmoketest) { @@ -61,6 +60,9 @@ TEST(IsCelValue, EqualitySmoketest) { } TEST(PrimitiveMatchers, Smoketest) { + EXPECT_THAT(CelValue::CreateNull(), IsCelNull()); + EXPECT_THAT(CelValue::CreateBool(false), Not(IsCelNull())); + EXPECT_THAT(CelValue::CreateBool(true), IsCelBool(true)); EXPECT_THAT(CelValue::CreateBool(false), IsCelBool(Not(true))); @@ -118,9 +120,36 @@ TEST(SpecialMatchers, SmokeTest) { EXPECT_THAT(message, IsCelMessage(EqualsProto(proto_message))); } +TEST(ListMatchers, NotList) { + EXPECT_THAT(CelValue::CreateInt64(1), + Not(IsCelList(Contains(IsCelInt64(1))))); +} + +TEST(ListMatchers, All) { + ContainerBackedListImpl list({ + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3), + CelValue::CreateInt64(4), + }); + CelValue cel_list = CelValue::CreateList(&list); + EXPECT_THAT(cel_list, IsCelList(Contains(IsCelInt64(3)))); + EXPECT_THAT(cel_list, IsCelList(Not(Contains(IsCelInt64(0))))); + + EXPECT_THAT(cel_list, IsCelList(ElementsAre(IsCelInt64(1), IsCelInt64(2), + IsCelInt64(3), IsCelInt64(4)))); + EXPECT_THAT(cel_list, + IsCelList(Not(ElementsAre(IsCelInt64(2), IsCelInt64(1), + IsCelInt64(3), IsCelInt64(4))))); + + EXPECT_THAT(cel_list, + IsCelList(UnorderedElementsAre(IsCelInt64(2), IsCelInt64(1), + IsCelInt64(4), IsCelInt64(3)))); + EXPECT_THAT( + cel_list, + IsCelList(Not(UnorderedElementsAre(IsCelInt64(2), IsCelInt64(1), + IsCelInt64(4), IsCelInt64(0))))); +} + } // namespace -} // namespace test -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime::test diff --git a/eval/public/transform_utility.cc b/eval/public/transform_utility.cc index 0ae4b44e8..6cb859c19 100644 --- a/eval/public/transform_utility.cc +++ b/eval/public/transform_utility.cc @@ -1,25 +1,30 @@ #include "eval/public/transform_utility.h" -#include "google/api/expr/v1alpha1/value.pb.h" +#include +#include +#include +#include + +#include "cel/expr/value.pb.h" #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" -#include "google/protobuf/arena.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/proto_util.h" -#include "base/status_macros.h" - +#include "internal/proto_time_encoding.h" +#include "internal/status_macros.h" namespace google { namespace api { namespace expr { namespace runtime { -absl::Status CelValueToValue(const CelValue& value, Value* result) { +absl::Status CelValueToValue(const CelValue& value, Value* result, + google::protobuf::Arena* arena) { switch (value.type()) { case CelValue::Type::kBool: result->set_bool_value(value.BoolOrDie()); @@ -43,18 +48,29 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { break; case CelValue::Type::kDuration: { google::protobuf::Duration duration; - expr::internal::EncodeDuration(value.DurationOrDie(), &duration); + auto status = + cel::internal::EncodeDuration(value.DurationOrDie(), &duration); + if (!status.ok()) { + return status; + } result->mutable_object_value()->PackFrom(duration); break; } case CelValue::Type::kTimestamp: { google::protobuf::Timestamp timestamp; - expr::internal::EncodeTime(value.TimestampOrDie(), ×tamp); + auto status = + cel::internal::EncodeTime(value.TimestampOrDie(), ×tamp); + if (!status.ok()) { + return status; + } result->mutable_object_value()->PackFrom(timestamp); break; } + case CelValue::Type::kNullType: + result->set_null_value(google::protobuf::NullValue::NULL_VALUE); + break; case CelValue::Type::kMessage: - if (value.MessageOrDie() == nullptr) { + if (value.IsNull()) { result->set_null_value(google::protobuf::NullValue::NULL_VALUE); } else { result->mutable_object_value()->PackFrom(*value.MessageOrDie()); @@ -64,25 +80,26 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { auto& list = *value.ListOrDie(); auto* list_value = result->mutable_list_value(); for (int i = 0; i < list.size(); ++i) { - RETURN_IF_ERROR(CelValueToValue(list[i], list_value->add_values())); + CEL_RETURN_IF_ERROR(CelValueToValue(list.Get(arena, i), + list_value->add_values(), arena)); } break; } case CelValue::Type::kMap: { auto* map_value = result->mutable_map_value(); auto& cel_map = *value.MapOrDie(); - const auto& keys = *cel_map.ListKeys(); - for (int i = 0; i < keys.size(); ++i) { - CelValue key = keys[i]; + CEL_ASSIGN_OR_RETURN(const auto* keys, cel_map.ListKeys(arena)); + for (int i = 0; i < keys->size(); ++i) { + CelValue key = (*keys).Get(arena, i); auto* entry = map_value->add_entries(); - RETURN_IF_ERROR(CelValueToValue(key, entry->mutable_key())); - auto optional_value = cel_map[key]; + CEL_RETURN_IF_ERROR(CelValueToValue(key, entry->mutable_key(), arena)); + auto optional_value = cel_map.Get(arena, key); if (!optional_value) { return absl::Status(absl::StatusCode::kInternal, "key not found in map"); } - RETURN_IF_ERROR( - CelValueToValue(optional_value.value(), entry->mutable_value())); + CEL_RETURN_IF_ERROR( + CelValueToValue(*optional_value, entry->mutable_value(), arena)); } break; } @@ -113,7 +130,8 @@ absl::StatusOr ValueToCelValue(const Value& value, case Value::kBoolValue: return CelValue::CreateBool(value.bool_value()); case Value::kBytesValue: - return CelValue::CreateBytes(CelValue::BytesHolder(&value.bytes_value())); + return CelValue::CreateBytes(CelValue::BytesHolder( + arena->Create(arena, value.bytes_value()))); case Value::kDoubleValue: return CelValue::CreateDouble(value.double_value()); case Value::kEnumValue: @@ -123,7 +141,7 @@ absl::StatusOr ValueToCelValue(const Value& value, case Value::kListValue: { std::vector list; for (const auto& subvalue : value.list_value().values()) { - ASSIGN_OR_RETURN(auto list_value, ValueToCelValue(subvalue, arena)); + CEL_ASSIGN_OR_RETURN(auto list_value, ValueToCelValue(subvalue, arena)); list.push_back(list_value); } return CelValue::CreateList( @@ -132,27 +150,34 @@ absl::StatusOr ValueToCelValue(const Value& value, case Value::kMapValue: { std::vector> key_values; for (const auto& entry : value.map_value().entries()) { - ASSIGN_OR_RETURN(auto map_key, ValueToCelValue(entry.key(), arena)); - ASSIGN_OR_RETURN(auto map_value, ValueToCelValue(entry.value(), arena)); + CEL_ASSIGN_OR_RETURN(auto map_key, ValueToCelValue(entry.key(), arena)); + CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(map_key)); + CEL_ASSIGN_OR_RETURN(auto map_value, + ValueToCelValue(entry.value(), arena)); key_values.push_back(std::pair(map_key, map_value)); } - auto cel_map = + CEL_ASSIGN_OR_RETURN( + auto cel_map, CreateContainerBackedMap(absl::Span>( - key_values.data(), key_values.size())) - .release(); - arena->Own(cel_map); - return CelValue::CreateMap(cel_map); + key_values.data(), key_values.size()))); + auto* cel_map_ptr = cel_map.release(); + arena->Own(cel_map_ptr); + return CelValue::CreateMap(cel_map_ptr); } case Value::kNullValue: return CelValue::CreateNull(); - case Value::kObjectValue: - return CelProtoWrapper::CreateMessage(&value.object_value(), arena); + case Value::kObjectValue: { + auto cel_value = + CelProtoWrapper::CreateMessage(&value.object_value(), arena); + if (cel_value.IsError()) return *cel_value.ErrorOrDie(); + return cel_value; + } case Value::kStringValue: - return CelValue::CreateString( - CelValue::StringHolder(&value.string_value())); + return CelValue::CreateString(CelValue::StringHolder( + arena->Create(arena, value.string_value()))); case Value::kTypeValue: - return CelValue::CreateCelType( - CelValue::CelTypeHolder(&value.type_value())); + return CelValue::CreateCelType(CelValue::CelTypeHolder( + arena->Create(arena, value.type_value()))); case Value::kUint64Value: return CelValue::CreateUint64(value.uint64_value()); case Value::KIND_NOT_SET: @@ -161,7 +186,6 @@ absl::StatusOr ValueToCelValue(const Value& value, } } - } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/transform_utility.h b/eval/public/transform_utility.h index ca601a8d7..ad664cd5f 100644 --- a/eval/public/transform_utility.h +++ b/eval/public/transform_utility.h @@ -1,28 +1,35 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TRANSFORM_UTILITY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TRANSFORM_UTILITY_H_ -#include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "eval/public/cel_value.h" +#include "google/protobuf/arena.h" namespace google { namespace api { namespace expr { namespace runtime { -using google::api::expr::v1alpha1::Value; +using cel::expr::Value; -// Translates a CelValue into a google::api::expr::v1alpha1::Value. Returns an error if +// Translates a CelValue into a cel::expr::Value. Returns an error if // translation is not supported. -absl::Status CelValueToValue(const CelValue& value, Value* result); +absl::Status CelValueToValue(const CelValue& value, Value* result, + google::protobuf::Arena* arena); -// Translates a google::api::expr::v1alpha1::Value into a CelValue. Allocates any required +inline absl::Status CelValueToValue(const CelValue& value, Value* result) { + google::protobuf::Arena arena; + return CelValueToValue(value, result, &arena); +} + +// Translates a cel::expr::Value into a CelValue. Allocates any required // external data on the provided arena. Returns an error if translation is not // supported. absl::StatusOr ValueToCelValue(const Value& value, google::protobuf::Arena* arena); - } // namespace runtime } // namespace expr diff --git a/eval/public/unknown_attribute_set.h b/eval/public/unknown_attribute_set.h index b3abdeeb2..0992b94e2 100644 --- a/eval/public/unknown_attribute_set.h +++ b/eval/public/unknown_attribute_set.h @@ -1,10 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_ATTRIBUTE_SET_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_ATTRIBUTE_SET_H_ -#include - -#include "absl/container/flat_hash_set.h" -#include "eval/public/cel_attribute.h" +#include "base/attribute_set.h" namespace google { namespace api { @@ -13,51 +10,7 @@ namespace runtime { // UnknownAttributeSet is a container for CEL attributes that are identified as // unknown during expression evaluation. -class UnknownAttributeSet { - public: - UnknownAttributeSet(const UnknownAttributeSet& other) = default; - UnknownAttributeSet& operator=(const UnknownAttributeSet& other) = default; - - UnknownAttributeSet() {} - UnknownAttributeSet(const std::vector& attributes) { - attributes_.reserve(attributes.size()); - for (const auto& attr : attributes) { - Add(attr); - } - } - - UnknownAttributeSet(const UnknownAttributeSet& set1, - const UnknownAttributeSet& set2) - : attributes_(set1.attributes()) { - attributes_.reserve(set1.attributes().size() + set2.attributes().size()); - for (const auto& attr : set2.attributes()) { - Add(attr); - } - } - - std::vector attributes() const { return attributes_; } - - static UnknownAttributeSet Merge(const UnknownAttributeSet& set1, - const UnknownAttributeSet& set2) { - return UnknownAttributeSet(set1, set2); - } - - private: - void Add(const CelAttribute* attribute) { - if (!attribute) { - return; - } - for (auto attr : attributes_) { - if (*attr == *attribute) { - return; - } - } - attributes_.push_back(attribute); - } - - // Attribute container. - std::vector attributes_; -}; +using UnknownAttributeSet = ::cel::AttributeSet; } // namespace runtime } // namespace expr diff --git a/eval/public/unknown_attribute_set_test.cc b/eval/public/unknown_attribute_set_test.cc index 775628f4a..efd27537f 100644 --- a/eval/public/unknown_attribute_set_test.cc +++ b/eval/public/unknown_attribute_set_test.cc @@ -1,11 +1,12 @@ #include "eval/public/unknown_attribute_set.h" #include +#include +#include -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" +#include "internal/testing.h" namespace google { namespace api { @@ -14,79 +15,73 @@ namespace runtime { namespace { -using testing::Eq; +using ::testing::Eq; -using google::api::expr::v1alpha1::Expr; +using cel::expr::Expr; TEST(UnknownAttributeSetTest, TestCreate) { - Expr expr; - expr.mutable_ident_expr()->set_name("root"); - const std::string kAttr1 = "a1"; const std::string kAttr2 = "a2"; const std::string kAttr3 = "a3"; std::shared_ptr cel_attr = std::make_shared( - expr, std::vector( - {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), - CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - CelAttributeQualifier::Create(CelValue::CreateUint64(2)), - CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - - UnknownAttributeSet unknown_set({cel_attr.get()}); - EXPECT_THAT(unknown_set.attributes().size(), Eq(1)); - EXPECT_THAT(*(unknown_set.attributes()[0]), Eq(*cel_attr)); + "root", std::vector( + {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), + CreateCelAttributeQualifier(CelValue::CreateInt64(1)), + CreateCelAttributeQualifier(CelValue::CreateUint64(2)), + CreateCelAttributeQualifier(CelValue::CreateBool(true))})); + + UnknownAttributeSet unknown_set({*cel_attr}); + EXPECT_THAT(unknown_set.size(), Eq(1)); + EXPECT_THAT(*(unknown_set.begin()), Eq(*cel_attr)); } TEST(UnknownAttributeSetTest, TestMergeSets) { - Expr expr; - expr.mutable_ident_expr()->set_name("root"); - const std::string kAttr1 = "a1"; const std::string kAttr2 = "a2"; const std::string kAttr3 = "a3"; - std::shared_ptr cel_attr1 = std::make_shared( - expr, std::vector( - {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), - CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - CelAttributeQualifier::Create(CelValue::CreateUint64(2)), - CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - - std::shared_ptr cel_attr1_copy = std::make_shared( - expr, std::vector( - {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), - CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - CelAttributeQualifier::Create(CelValue::CreateUint64(2)), - CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - - std::shared_ptr cel_attr2 = std::make_shared( - expr, std::vector( - {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), - CelAttributeQualifier::Create(CelValue::CreateInt64(2)), - CelAttributeQualifier::Create(CelValue::CreateUint64(2)), - CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - - std::shared_ptr cel_attr3 = std::make_shared( - expr, std::vector( - {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), - CelAttributeQualifier::Create(CelValue::CreateInt64(2)), - CelAttributeQualifier::Create(CelValue::CreateUint64(2)), - CelAttributeQualifier::Create(CelValue::CreateBool(false))})); - - UnknownAttributeSet unknown_set1({cel_attr1.get(), cel_attr2.get()}); - UnknownAttributeSet unknown_set2({cel_attr1_copy.get(), cel_attr3.get()}); + CelAttribute cel_attr1( + "root", std::vector( + {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), + CreateCelAttributeQualifier(CelValue::CreateInt64(1)), + CreateCelAttributeQualifier(CelValue::CreateUint64(2)), + CreateCelAttributeQualifier(CelValue::CreateBool(true))})); + + CelAttribute cel_attr1_copy( + "root", std::vector( + {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), + CreateCelAttributeQualifier(CelValue::CreateInt64(1)), + CreateCelAttributeQualifier(CelValue::CreateUint64(2)), + CreateCelAttributeQualifier(CelValue::CreateBool(true))})); + + CelAttribute cel_attr2( + "root", std::vector( + {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), + CreateCelAttributeQualifier(CelValue::CreateInt64(2)), + CreateCelAttributeQualifier(CelValue::CreateUint64(2)), + CreateCelAttributeQualifier(CelValue::CreateBool(true))})); + + CelAttribute cel_attr3( + "root", std::vector( + {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), + CreateCelAttributeQualifier(CelValue::CreateInt64(2)), + CreateCelAttributeQualifier(CelValue::CreateUint64(2)), + CreateCelAttributeQualifier(CelValue::CreateBool(false))})); + + UnknownAttributeSet unknown_set1({cel_attr1, cel_attr2}); + UnknownAttributeSet unknown_set2({cel_attr1_copy, cel_attr3}); UnknownAttributeSet unknown_set3 = UnknownAttributeSet::Merge(unknown_set1, unknown_set2); - EXPECT_THAT(unknown_set3.attributes().size(), Eq(3)); + EXPECT_THAT(unknown_set3.size(), Eq(3)); std::vector attrs1; - for (auto attr_ptr : unknown_set3.attributes()) { - attrs1.push_back(*attr_ptr); + for (const auto& attr_ptr : unknown_set3) { + attrs1.push_back(attr_ptr); } - std::vector attrs2 = {*cel_attr1, *cel_attr2, *cel_attr3}; + std::vector attrs2 = {cel_attr1, cel_attr2, cel_attr3}; EXPECT_THAT(attrs1, testing::UnorderedPointwise(Eq(), attrs2)); } diff --git a/eval/public/unknown_function_result_set.cc b/eval/public/unknown_function_result_set.cc index b2ef5b84d..60cd20ea3 100644 --- a/eval/public/unknown_function_result_set.cc +++ b/eval/public/unknown_function_result_set.cc @@ -1,99 +1 @@ #include "eval/public/unknown_function_result_set.h" - -#include - -#include "absl/container/btree_set.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_options.h" -#include "eval/public/cel_value.h" -#include "eval/public/set_util.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace { - -// Tests that lhs descriptor is less than (name, receiver call style, -// arg types). -// Argument type Any is not treated specially. For example: -// {"f", false, {kAny}} > {"f", false, {kInt64}} -bool DescriptorLessThan(const CelFunctionDescriptor& lhs, - const CelFunctionDescriptor& rhs) { - if (lhs.name() < rhs.name()) { - return true; - } - if (lhs.name() > rhs.name()) { - return false; - } - - if (lhs.receiver_style() < rhs.receiver_style()) { - return true; - } - if (lhs.receiver_style() > rhs.receiver_style()) { - return false; - } - - if (lhs.types() >= rhs.types()) { - return false; - } - - return true; -} - -bool UnknownFunctionResultLessThan(const UnknownFunctionResult& lhs, - const UnknownFunctionResult& rhs) { - if (DescriptorLessThan(lhs.descriptor(), rhs.descriptor())) { - return true; - } - if (DescriptorLessThan(rhs.descriptor(), lhs.descriptor())) { - return false; - } - - if (lhs.arguments().size() < rhs.arguments().size()) { - return true; - } - - if (lhs.arguments().size() > rhs.arguments().size()) { - return false; - } - - for (size_t i = 0; i < lhs.arguments().size(); i++) { - if (CelValueLessThan(lhs.arguments()[i], rhs.arguments()[i])) { - return true; - } - if (CelValueLessThan(rhs.arguments()[i], lhs.arguments()[i])) { - return false; - } - } - - // equal - return false; -} - -} // namespace - -bool UnknownFunctionComparator::operator()( - const UnknownFunctionResult* lhs, const UnknownFunctionResult* rhs) const { - return UnknownFunctionResultLessThan(*lhs, *rhs); -} - -bool UnknownFunctionResult::IsEqualTo( - const UnknownFunctionResult& other) const { - return !(UnknownFunctionResultLessThan(*this, other) || - UnknownFunctionResultLessThan(other, *this)); -} - -// Implementation for merge constructor. -UnknownFunctionResultSet::UnknownFunctionResultSet( - const UnknownFunctionResultSet& lhs, const UnknownFunctionResultSet& rhs) - : unknown_function_results_(lhs.unknown_function_results()) { - for (const UnknownFunctionResult* call : rhs.unknown_function_results()) { - unknown_function_results_.insert(call); - } -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/unknown_function_result_set.h b/eval/public/unknown_function_result_set.h index 891b3713f..b0d4d1cc6 100644 --- a/eval/public/unknown_function_result_set.h +++ b/eval/public/unknown_function_result_set.h @@ -1,12 +1,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ -#include - -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/container/btree_map.h" -#include "absl/container/btree_set.h" -#include "eval/public/cel_function.h" +#include "base/function_result.h" +#include "base/function_result_set.h" namespace google { namespace api { @@ -15,67 +11,13 @@ namespace runtime { // Represents a function result that is unknown at the time of execution. This // allows for lazy evaluation of expensive functions. -class UnknownFunctionResult { - public: - UnknownFunctionResult(const CelFunctionDescriptor& descriptor, int64_t expr_id, - const std::vector& arguments) - : descriptor_(descriptor), expr_id_(expr_id), arguments_(arguments) {} - - // The descriptor of the called function that return Unknown. - const CelFunctionDescriptor& descriptor() const { return descriptor_; } - - // The id of the |Expr| that triggered the function call step. Provided - // informationally -- if two different |Expr|s generate the same unknown call, - // they will be treated as the same unknown function result. - int64_t call_expr_id() const { return expr_id_; } - - // The arguments of the function call that generated the unknown. - const std::vector& arguments() const { return arguments_; } - - // Equality operator provided for testing. Compatible with set less-than - // comparator. - // Compares descriptor then arguments elementwise. - bool IsEqualTo(const UnknownFunctionResult& other) const; - - private: - CelFunctionDescriptor descriptor_; - int64_t expr_id_; - std::vector arguments_; -}; - -// Comparator for set semantics. -struct UnknownFunctionComparator { - bool operator()(const UnknownFunctionResult*, - const UnknownFunctionResult*) const; -}; +using UnknownFunctionResult = ::cel::FunctionResult; // Represents a collection of unknown function results at a particular point in // execution. Execution should advance further if this set of unknowns are // provided. It may not advance if only a subset are provided. // Set semantics use |IsEqualTo()| defined on |UnknownFunctionResult|. -class UnknownFunctionResultSet { - public: - // Empty set - UnknownFunctionResultSet() {} - - // Merge constructor -- effectively union(lhs, rhs). - UnknownFunctionResultSet(const UnknownFunctionResultSet& lhs, - const UnknownFunctionResultSet& rhs); - - // Initialize with a single UnknownFunctionResult. - UnknownFunctionResultSet(const UnknownFunctionResult* initial) - : unknown_function_results_{initial} {} - - using Container = - absl::btree_set; - - const Container& unknown_function_results() const { - return unknown_function_results_; - } - - private: - Container unknown_function_results_; -}; +using UnknownFunctionResultSet = ::cel::FunctionResultSet; } // namespace runtime } // namespace expr diff --git a/eval/public/unknown_function_result_set_test.cc b/eval/public/unknown_function_result_set_test.cc index 022ec19d0..745b5b9ff 100644 --- a/eval/public/unknown_function_result_set_test.cc +++ b/eval/public/unknown_function_result_set_test.cc @@ -3,14 +3,12 @@ #include #include +#include #include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/arena.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" @@ -19,6 +17,9 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + namespace google { namespace api { namespace expr { @@ -28,478 +29,54 @@ namespace { using ::google::protobuf::ListValue; using ::google::protobuf::Struct; using ::google::protobuf::Arena; -using testing::Eq; -using testing::SizeIs; +using ::testing::Eq; +using ::testing::SizeIs; CelFunctionDescriptor kTwoInt("TwoInt", false, {CelValue::Type::kInt64, CelValue::Type::kInt64}); CelFunctionDescriptor kOneInt("OneInt", false, {CelValue::Type::kInt64}); -// Helper to confirm the set comparator works. -bool IsLessThan(const UnknownFunctionResult& lhs, - const UnknownFunctionResult& rhs) { - return UnknownFunctionComparator()(&lhs, &rhs); -} - -TEST(UnknownFunctionResult, ArgumentCapture) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - EXPECT_THAT(call1.arguments(), SizeIs(2)); - EXPECT_THAT(call1.arguments().at(0).Int64OrDie(), Eq(1)); -} - TEST(UnknownFunctionResult, Equals) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + UnknownFunctionResult call1(kTwoInt, /*expr_id=*/0); - UnknownFunctionResult call2( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + UnknownFunctionResult call2(kTwoInt, /*expr_id=*/0); EXPECT_TRUE(call1.IsEqualTo(call2)); - EXPECT_FALSE(IsLessThan(call1, call2)); - EXPECT_FALSE(IsLessThan(call2, call1)); - UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0); - UnknownFunctionResult call4(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call4(kOneInt, /*expr_id=*/0); EXPECT_TRUE(call3.IsEqualTo(call4)); + + UnknownFunctionResultSet call_set({call1, call3}); + EXPECT_EQ(call_set.size(), 2); + EXPECT_EQ(*call_set.begin(), call3); + EXPECT_EQ(*(++call_set.begin()), call1); } TEST(UnknownFunctionResult, InequalDescriptor) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + UnknownFunctionResult call1(kTwoInt, /*expr_id=*/0); - UnknownFunctionResult call2(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call2(kOneInt, /*expr_id=*/0); EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call2, call1)); CelFunctionDescriptor one_uint("OneInt", false, {CelValue::Type::kUint64}); - UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); - - UnknownFunctionResult call4(one_uint, /*expr_id=*/0, - {CelValue::CreateUint64(1)}); - - EXPECT_FALSE(call3.IsEqualTo(call4)); - EXPECT_TRUE(IsLessThan(call3, call4)); -} - -TEST(UnknownFunctionResult, InequalArgs) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - UnknownFunctionResult call2( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(3)}); - - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); - - UnknownFunctionResult call3( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0); - UnknownFunctionResult call4(kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call4(one_uint, /*expr_id=*/0); EXPECT_FALSE(call3.IsEqualTo(call4)); - EXPECT_TRUE(IsLessThan(call4, call3)); -} - -TEST(UnknownFunctionResult, ListsEqual) { - ContainerBackedListImpl cel_list_1(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - ContainerBackedListImpl cel_list_2(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_1)}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_2)}); - - // [1, 2] == [1, 2] - EXPECT_TRUE(call1.IsEqualTo(call2)); -} - -TEST(UnknownFunctionResult, ListsDifferentSizes) { - ContainerBackedListImpl cel_list_1(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - ContainerBackedListImpl cel_list_2(std::vector{ - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), - CelValue::CreateInt64(3), - }); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_1)}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_2)}); - - // [1, 2] == [1, 2, 3] - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); -} - -TEST(UnknownFunctionResult, ListsDifferentMembers) { - ContainerBackedListImpl cel_list_1(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - ContainerBackedListImpl cel_list_2(std::vector{ - CelValue::CreateInt64(2), CelValue::CreateInt64(2)}); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_1)}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_2)}); - - // [1, 2] == [2, 2] - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); -} - -TEST(UnknownFunctionResult, MapsEqual) { - std::vector> values{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; - - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values)); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_1.get())}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_2.get())}); - - // {1: 2, 2: 4} == {1: 2, 2: 4} - EXPECT_TRUE(call1.IsEqualTo(call2)); -} - -TEST(UnknownFunctionResult, MapsDifferentSizes) { - std::vector> values{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; - - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); - - std::vector> values2{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_1.get())}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_2.get())}); - - // {1: 2, 2: 4} == {1: 2, 2: 4, 3: 6} - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); -} - -TEST(UnknownFunctionResult, MapsDifferentElements) { - std::vector> values{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); - - std::vector> values2{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(4), CelValue::CreateInt64(8)}}; - - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); - - std::vector> values3{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(3), CelValue::CreateInt64(5)}}; - - auto cel_map_3 = CreateContainerBackedMap(absl::MakeSpan(values3)); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_1.get())}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_2.get())}); - UnknownFunctionResult call3(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_3.get())}); - - // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 4: 8} - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); - // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 3: 5} - EXPECT_FALSE(call1.IsEqualTo(call3)); - EXPECT_TRUE(IsLessThan(call3, call1)); -} - -TEST(UnknownFunctionResult, Messages) { - protobuf::Empty message1; - protobuf::Empty message2; - google::protobuf::Arena arena; - - CelFunctionDescriptor desc("OneMessage", false, {CelValue::Type::kMessage}); - - UnknownFunctionResult call1( - desc, /*expr_id=*/0, {CelProtoWrapper::CreateMessage(&message1, &arena)}); - UnknownFunctionResult call2( - desc, /*expr_id=*/0, {CelProtoWrapper::CreateMessage(&message2, &arena)}); - UnknownFunctionResult call3( - desc, /*expr_id=*/0, {CelProtoWrapper::CreateMessage(&message1, &arena)}); - - // &message1 == &message2 - EXPECT_FALSE(call1.IsEqualTo(call2)); - - // &message1 == &message1 - EXPECT_TRUE(call1.IsEqualTo(call3)); -} - -TEST(UnknownFunctionResult, AnyDescriptor) { - CelFunctionDescriptor anyDesc("OneAny", false, {CelValue::Type::kAny}); - - UnknownFunctionResult callAnyInt1(anyDesc, /*expr_id=*/0, - {CelValue::CreateInt64(2)}); - UnknownFunctionResult callInt(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(2)}); - - UnknownFunctionResult callAnyInt2(anyDesc, /*expr_id=*/0, - {CelValue::CreateInt64(2)}); - UnknownFunctionResult callAnyUint(anyDesc, /*expr_id=*/0, - {CelValue::CreateUint64(2)}); - - EXPECT_FALSE(callAnyInt1.IsEqualTo(callInt)); - EXPECT_TRUE(IsLessThan(callAnyInt1, callInt)); - EXPECT_FALSE(callAnyInt1.IsEqualTo(callAnyUint)); - EXPECT_TRUE(IsLessThan(callAnyInt1, callAnyUint)); - EXPECT_TRUE(callAnyInt1.IsEqualTo(callAnyInt2)); -} - -TEST(UnknownFunctionResult, Strings) { - CelFunctionDescriptor desc("OneString", false, {CelValue::Type::kString}); - - UnknownFunctionResult callStringSmile(desc, /*expr_id=*/0, - {CelValue::CreateStringView("😁")}); - UnknownFunctionResult callStringFrown(desc, /*expr_id=*/0, - {CelValue::CreateStringView("🙁")}); - UnknownFunctionResult callStringSmile2(desc, /*expr_id=*/0, - {CelValue::CreateStringView("😁")}); - - EXPECT_TRUE(callStringSmile.IsEqualTo(callStringSmile2)); - EXPECT_FALSE(callStringSmile.IsEqualTo(callStringFrown)); -} - -TEST(UnknownFunctionResult, DurationHandling) { - google::protobuf::Arena arena; - absl::Duration duration1 = absl::Seconds(5); - protobuf::Duration duration2; - duration2.set_seconds(5); - - CelFunctionDescriptor durationDesc("OneDuration", false, - {CelValue::Type::kDuration}); - - UnknownFunctionResult callDuration1(durationDesc, /*expr_id=*/0, - {CelValue::CreateDuration(duration1)}); - UnknownFunctionResult callDuration2( - durationDesc, /*expr_id=*/0, - {CelProtoWrapper::CreateMessage(&duration2, &arena)}); - UnknownFunctionResult callDuration3( - durationDesc, /*expr_id=*/0, - {CelProtoWrapper::CreateDuration(&duration2)}); - - EXPECT_TRUE(callDuration1.IsEqualTo(callDuration2)); - EXPECT_TRUE(callDuration1.IsEqualTo(callDuration3)); -} - -TEST(UnknownFunctionResult, TimestampHandling) { - google::protobuf::Arena arena; - absl::Time ts1 = absl::FromUnixMillis(1000); - protobuf::Timestamp ts2; - ts2.set_seconds(1); - - CelFunctionDescriptor timestampDesc("OneTimestamp", false, - {CelValue::Type::kTimestamp}); - - UnknownFunctionResult callTimestamp1(timestampDesc, /*expr_id=*/0, - {CelValue::CreateTimestamp(ts1)}); - UnknownFunctionResult callTimestamp2( - timestampDesc, /*expr_id=*/0, - {CelProtoWrapper::CreateMessage(&ts2, &arena)}); - UnknownFunctionResult callTimestamp3( - timestampDesc, /*expr_id=*/0, {CelProtoWrapper::CreateTimestamp(&ts2)}); - - EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp2)); - EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp3)); -} - -// This tests that the conversion and different map backing implementations are -// compatible with the equality tests. -TEST(UnknownFunctionResult, ProtoStructTreatedAsMap) { - Arena arena; - - const std::vector kFields = {"field1", "field2", "field3"}; - - Struct value_struct; - - auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; - value1.set_bool_value(true); - - auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; - value2.set_number_value(1.0); - - auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; - value3.set_string_value("test"); - - CelValue proto_struct = CelProtoWrapper::CreateMessage(&value_struct, &arena); - ASSERT_TRUE(proto_struct.IsMap()); - - std::vector> values{ - {CelValue::CreateStringView(kFields[2]), - CelValue::CreateStringView("test")}, - {CelValue::CreateStringView(kFields[1]), CelValue::CreateDouble(1.0)}, - {CelValue::CreateStringView(kFields[0]), CelValue::CreateBool(true)}}; - - auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); - - CelValue cel_map = CelValue::CreateMap(backing_map.get()); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, - {proto_struct}); - UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); - - EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); -} - -// This tests that the conversion and different map backing implementations are -// compatible with the equality tests. -TEST(UnknownFunctionResult, ProtoListTreatedAsList) { - Arena arena; - - ListValue list_value; - - list_value.add_values()->set_bool_value(true); - list_value.add_values()->set_number_value(1.0); - list_value.add_values()->set_string_value("test"); - - CelValue proto_list = CelProtoWrapper::CreateMessage(&list_value, &arena); - ASSERT_TRUE(proto_list.IsList()); - - std::vector list_values{CelValue::CreateBool(true), - CelValue::CreateDouble(1.0), - CelValue::CreateStringView("test")}; - - ContainerBackedListImpl list_backing(list_values); - - CelValue cel_list = CelValue::CreateList(&list_backing); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult proto_list_result(desc, /*expr_id=*/0, {proto_list}); - UnknownFunctionResult cel_list_result(desc, /*expr_id=*/0, {cel_list}); - - EXPECT_TRUE(cel_list_result.IsEqualTo(proto_list_result)); -} - -TEST(UnknownFunctionResult, NestedProtoTypes) { - Arena arena; - - ListValue list_value; - - list_value.add_values()->set_bool_value(true); - list_value.add_values()->set_number_value(1.0); - list_value.add_values()->set_string_value("test"); - - std::vector list_values{CelValue::CreateBool(true), - CelValue::CreateDouble(1.0), - CelValue::CreateStringView("test")}; - - ContainerBackedListImpl list_backing(list_values); - - CelValue cel_list = CelValue::CreateList(&list_backing); - - Struct value_struct; - - *(value_struct.mutable_fields()->operator[]("field").mutable_list_value()) = - list_value; - - std::vector> values{ - {CelValue::CreateStringView("field"), cel_list}}; - - auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); - - CelValue cel_map = CelValue::CreateMap(backing_map.get()); - CelValue proto_map = CelProtoWrapper::CreateMessage(&value_struct, &arena); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); - UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, {proto_map}); - - EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); -} - -UnknownFunctionResult MakeUnknown(int64_t i) { - return UnknownFunctionResult(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(i)}); -} - -testing::Matcher UnknownMatches( - const UnknownFunctionResult& obj) { - return testing::Truly([&](const UnknownFunctionResult* to_match) { - return obj.IsEqualTo(*to_match); - }); -} - -TEST(UnknownFunctionResultSet, Merge) { - UnknownFunctionResult a = MakeUnknown(1); - UnknownFunctionResult b = MakeUnknown(2); - UnknownFunctionResult c = MakeUnknown(3); - UnknownFunctionResult d = MakeUnknown(1); - - UnknownFunctionResultSet a1(&a); - UnknownFunctionResultSet b1(&b); - UnknownFunctionResultSet c1(&c); - UnknownFunctionResultSet d1(&d); - - UnknownFunctionResultSet ab(a1, b1); - UnknownFunctionResultSet cd(c1, d1); - - UnknownFunctionResultSet merged(ab, cd); - EXPECT_THAT(merged.unknown_function_results(), SizeIs(3)); - EXPECT_THAT(merged.unknown_function_results(), - testing::UnorderedElementsAre( - UnknownMatches(a), UnknownMatches(b), UnknownMatches(c))); + UnknownFunctionResultSet call_set({call1, call3, call4}); + EXPECT_EQ(call_set.size(), 3); + auto it = call_set.begin(); + EXPECT_EQ(*it++, call3); + EXPECT_EQ(*it++, call4); + EXPECT_EQ(*it++, call1); } } // namespace diff --git a/eval/public/unknown_set.h b/eval/public/unknown_set.h index 3b7168afe..244497c34 100644 --- a/eval/public/unknown_set.h +++ b/eval/public/unknown_set.h @@ -1,8 +1,9 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ -#include "eval/public/unknown_attribute_set.h" -#include "eval/public/unknown_function_result_set.h" +#include "base/internal/unknown_set.h" +#include "eval/public/unknown_attribute_set.h" // IWYU pragma: keep +#include "eval/public/unknown_function_result_set.h" // IWYU pragma: keep namespace google { namespace api { @@ -11,38 +12,7 @@ namespace runtime { // Class representing a collection of unknowns from a single evaluation pass of // a CEL expression. -class UnknownSet { - public: - // Initilization specifying subcontainers - explicit UnknownSet( - const google::api::expr::runtime::UnknownAttributeSet& attrs) - : unknown_attributes_(attrs) {} - explicit UnknownSet(const UnknownFunctionResultSet& function_results) - : unknown_function_results_(function_results) {} - UnknownSet(const UnknownAttributeSet& attrs, - const UnknownFunctionResultSet& function_results) - : unknown_attributes_(attrs), - unknown_function_results_(function_results) {} - // Initialization for empty set - UnknownSet() {} - // Merge constructor - UnknownSet(const UnknownSet& set1, const UnknownSet& set2) - : unknown_attributes_(set1.unknown_attributes(), - set2.unknown_attributes()), - unknown_function_results_(set1.unknown_function_results(), - set2.unknown_function_results()) {} - - const UnknownAttributeSet& unknown_attributes() const { - return unknown_attributes_; - } - const UnknownFunctionResultSet& unknown_function_results() const { - return unknown_function_results_; - } - - private: - UnknownAttributeSet unknown_attributes_; - UnknownFunctionResultSet unknown_function_results_; -}; +using UnknownSet = ::cel::base_internal::UnknownSet; } // namespace runtime } // namespace expr diff --git a/eval/public/unknown_set_test.cc b/eval/public/unknown_set_test.cc index 3e4c06cda..3a0d151a5 100644 --- a/eval/public/unknown_set_test.cc +++ b/eval/public/unknown_set_test.cc @@ -1,12 +1,14 @@ #include "eval/public/unknown_set.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/arena.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include + +#include "cel/expr/syntax.pb.h" #include "eval/public/cel_attribute.h" +#include "eval/public/cel_function.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_function_result_set.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" namespace google { namespace api { @@ -15,51 +17,33 @@ namespace runtime { namespace { using ::google::protobuf::Arena; -using testing::IsEmpty; -using testing::UnorderedElementsAre; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; UnknownFunctionResultSet MakeFunctionResult(Arena* arena, int64_t id) { CelFunctionDescriptor desc("OneInt", false, {CelValue::Type::kInt64}); - std::vector call_args{CelValue::CreateInt64(id)}; - const auto* function_result = Arena::Create( - arena, desc, /*expr_id=*/0, call_args); - return UnknownFunctionResultSet(function_result); + return UnknownFunctionResultSet(UnknownFunctionResult(desc, /*expr_id=*/0)); } UnknownAttributeSet MakeAttribute(Arena* arena, int64_t id) { - google::api::expr::v1alpha1::Expr expr; - expr.mutable_ident_expr()->set_name("x"); - std::vector attr_trail{ - CelAttributeQualifier::Create(CelValue::CreateInt64(id))}; + CreateCelAttributeQualifier(CelValue::CreateInt64(id))}; - const auto* attr = Arena::Create(arena, expr, attr_trail); - return UnknownAttributeSet({attr}); + return UnknownAttributeSet({CelAttribute("x", std::move(attr_trail))}); } MATCHER_P(UnknownAttributeIs, id, "") { - const CelAttribute* attr = arg; - if (attr->qualifier_path().size() != 1) { + const CelAttribute& attr = arg; + if (attr.qualifier_path().size() != 1) { return false; } - auto maybe_qualifier = attr->qualifier_path()[0].GetInt64Key(); + auto maybe_qualifier = attr.qualifier_path()[0].GetInt64Key(); if (!maybe_qualifier.has_value()) { return false; } return maybe_qualifier.value() == id; } -MATCHER_P(UnknownFunctionResultIs, id, "") { - const UnknownFunctionResult* result = arg; - if (result->arguments().size() != 1) { - return false; - } - if (!result->arguments()[0].IsInt64()) { - return false; - } - return result->arguments()[0].Int64OrDie() == id; -} - TEST(UnknownSet, AttributesMerge) { Arena arena; UnknownSet a(MakeAttribute(&arena, 1)); @@ -69,35 +53,17 @@ TEST(UnknownSet, AttributesMerge) { UnknownSet e(c, d); EXPECT_THAT( - d.unknown_attributes().attributes(), + d.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); EXPECT_THAT( - e.unknown_attributes().attributes(), + e.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); } -TEST(UnknownSet, FunctionsMerge) { - Arena arena; - - UnknownSet a(MakeFunctionResult(&arena, 1)); - UnknownSet b(MakeFunctionResult(&arena, 2)); - UnknownSet c(MakeFunctionResult(&arena, 2)); - UnknownSet d(a, b); - UnknownSet e(c, d); - - EXPECT_THAT(d.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); - EXPECT_THAT(e.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); -} - TEST(UnknownSet, DefaultEmpty) { UnknownSet empty_set; - EXPECT_THAT(empty_set.unknown_attributes().attributes(), IsEmpty()); - EXPECT_THAT(empty_set.unknown_function_results().unknown_function_results(), - IsEmpty()); + EXPECT_THAT(empty_set.unknown_attributes(), IsEmpty()); + EXPECT_THAT(empty_set.unknown_function_results(), IsEmpty()); } TEST(UnknownSet, MixedMerges) { @@ -109,17 +75,11 @@ TEST(UnknownSet, MixedMerges) { UnknownSet d(a, b); UnknownSet e(c, d); - EXPECT_THAT(d.unknown_attributes().attributes(), + EXPECT_THAT(d.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1))); - EXPECT_THAT(d.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); EXPECT_THAT( - e.unknown_attributes().attributes(), + e.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); - EXPECT_THAT(e.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); } } // namespace diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index 89ef53022..bca8a8d65 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -1,23 +1,18 @@ #include "eval/public/value_export_util.h" -#include "google/protobuf/util/json_util.h" -#include "google/protobuf/util/time_util.h" +#include + #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" +#include "google/protobuf/util/json_util.h" +#include "google/protobuf/util/time_util.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { using google::protobuf::Duration; -using google::protobuf::ListValue; -using google::protobuf::Struct; using google::protobuf::Timestamp; using google::protobuf::Value; -using google::protobuf::FieldDescriptor; -using google::protobuf::Message; using google::protobuf::util::TimeUtil; absl::Status KeyAsString(const CelValue& value, std::string* key) { @@ -43,7 +38,8 @@ absl::Status KeyAsString(const CelValue& value, std::string* key) { } // Export content of CelValue as google.protobuf.Value. -absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { +absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value, + google::protobuf::Arena* arena) { if (in_value.IsNull()) { out_value->set_null_value(google::protobuf::NULL_VALUE); return absl::OkStatus(); @@ -71,19 +67,27 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { break; } case CelValue::Type::kBytes: { - absl::Base64Escape(in_value.BytesOrDie().value(), - out_value->mutable_string_value()); + *out_value->mutable_string_value() = + absl::Base64Escape(in_value.BytesOrDie().value()); break; } case CelValue::Type::kDuration: { Duration duration; - expr::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + auto status = + cel::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + if (!status.ok()) { + return status; + } out_value->set_string_value(TimeUtil::ToString(duration)); break; } case CelValue::Type::kTimestamp: { Timestamp timestamp; - expr::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + auto status = + cel::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + if (!status.ok()) { + return status; + } out_value->set_string_value(TimeUtil::ToString(timestamp)); break; } @@ -108,8 +112,8 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { const CelList* cel_list = in_value.ListOrDie(); auto out_values = out_value->mutable_list_value(); for (int i = 0; i < cel_list->size(); i++) { - auto status = - ExportAsProtoValue((*cel_list)[i], out_values->add_values()); + auto status = ExportAsProtoValue((*cel_list).Get(arena, i), + out_values->add_values(), arena); if (!status.ok()) { return status; } @@ -118,19 +122,19 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { } case CelValue::Type::kMap: { const CelMap* cel_map = in_value.MapOrDie(); - auto keys_list = cel_map->ListKeys(); + CEL_ASSIGN_OR_RETURN(auto keys_list, cel_map->ListKeys(arena)); auto out_values = out_value->mutable_struct_value()->mutable_fields(); for (int i = 0; i < keys_list->size(); i++) { std::string key; - CelValue map_key = (*keys_list)[i]; + CelValue map_key = (*keys_list).Get(arena, i); auto status = KeyAsString(map_key, &key); if (!status.ok()) { return status; } - auto map_value_ref = (*cel_map)[map_key]; + auto map_value_ref = (*cel_map).Get(arena, map_key); CelValue map_value = (map_value_ref) ? map_value_ref.value() : CelValue(); - status = ExportAsProtoValue(map_value, &((*out_values)[key])); + status = ExportAsProtoValue(map_value, &((*out_values)[key]), arena); if (!status.ok()) { return status; } @@ -144,7 +148,4 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { return absl::OkStatus(); } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/value_export_util.h b/eval/public/value_export_util.h index 6fbf9f8c4..26217452a 100644 --- a/eval/public/value_export_util.h +++ b/eval/public/value_export_util.h @@ -4,23 +4,25 @@ #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "eval/public/cel_value.h" +#include "google/protobuf/arena.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // Exports content of CelValue as google.protobuf.Value. // Current limitations: // - exports integer values as doubles (Value.number_value); // - exports integer keys in maps as strings; // - handles Duration and Timestamp as generic messages. -absl::Status ExportAsProtoValue(const CelValue &in_value, - google::protobuf::Value *out_value); +absl::Status ExportAsProtoValue(const CelValue& in_value, + google::protobuf::Value* out_value, + google::protobuf::Arena* arena); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +inline absl::Status ExportAsProtoValue(const CelValue& in_value, + google::protobuf::Value* out_value) { + google::protobuf::Arena arena; + return ExportAsProtoValue(in_value, out_value, &arena); +} + +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ diff --git a/eval/public/value_export_util_test.cc b/eval/public/value_export_util_test.cc index 86ebfa297..5f82958f1 100644 --- a/eval/public/value_export_util_test.cc +++ b/eval/public/value_export_util_test.cc @@ -1,21 +1,19 @@ #include "eval/public/value_export_util.h" +#include #include +#include -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" +#include "internal/status_macros.h" +#include "internal/testing.h" #include "testutil/util.h" -#include "base/status_macros.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { @@ -137,7 +135,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedBoolValue) { Arena arena; Value value; - TestMessage *msg = Arena::CreateMessage(&arena); + TestMessage* msg = Arena::Create(&arena); msg->add_bool_list(true); msg->add_bool_list(false); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); @@ -156,7 +154,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedInt32Value) { Arena arena; Value value; - TestMessage *msg = Arena::CreateMessage(&arena); + TestMessage* msg = Arena::Create(&arena); msg->add_int32_list(2); msg->add_int32_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); @@ -175,7 +173,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedInt64Value) { Arena arena; Value value; - TestMessage *msg = Arena::CreateMessage(&arena); + TestMessage* msg = Arena::Create(&arena); msg->add_int64_list(2); msg->add_int64_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); @@ -194,7 +192,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedUint64Value) { Arena arena; Value value; - TestMessage *msg = Arena::CreateMessage(&arena); + TestMessage* msg = Arena::Create(&arena); msg->add_uint64_list(2); msg->add_uint64_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); @@ -213,7 +211,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedDoubleValue) { Arena arena; Value value; - TestMessage *msg = Arena::CreateMessage(&arena); + TestMessage* msg = Arena::Create(&arena); msg->add_double_list(2); msg->add_double_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); @@ -232,7 +230,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedStringValue) { Arena arena; Value value; - TestMessage *msg = Arena::CreateMessage(&arena); + TestMessage* msg = Arena::Create(&arena); msg->add_string_list("test1"); msg->add_string_list("test2"); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); @@ -251,7 +249,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedBytesValue) { Arena arena; Value value; - TestMessage *msg = Arena::CreateMessage(&arena); + TestMessage* msg = Arena::Create(&arena); msg->add_bytes_list("test1"); msg->add_bytes_list("test2"); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); @@ -298,13 +296,14 @@ TEST(ValueExportUtilTest, ConvertCelMapWithStringKey) { {CelValue::CreateString(&key2), CelValue::CreateString(&value2)}); auto cel_map = CreateContainerBackedMap( - absl::Span>(map_entries)); + absl::Span>(map_entries)) + .value(); CelValue cel_value = CelValue::CreateMap(cel_map.get()); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); - const auto &fields = value.struct_value().fields(); + const auto& fields = value.struct_value().fields(); EXPECT_EQ(fields.at(key1).string_value(), value1); EXPECT_EQ(fields.at(key2).string_value(), value2); @@ -325,13 +324,14 @@ TEST(ValueExportUtilTest, ConvertCelMapWithInt64Key) { {CelValue::CreateInt64(key2), CelValue::CreateString(&value2)}); auto cel_map = CreateContainerBackedMap( - absl::Span>(map_entries)); + absl::Span>(map_entries)) + .value(); CelValue cel_value = CelValue::CreateMap(cel_map.get()); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); - const auto &fields = value.struct_value().fields(); + const auto& fields = value.struct_value().fields(); EXPECT_EQ(fields.at(absl::StrCat(key1)).string_value(), value1); EXPECT_EQ(fields.at(absl::StrCat(key2)).string_value(), value2); @@ -339,7 +339,4 @@ TEST(ValueExportUtilTest, ConvertCelMapWithInt64Key) { } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 2753cecb8..9163548d1 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -2,35 +2,185 @@ # # +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) exports_files(["LICENSE"]) cc_test( name = "benchmark_test", - size = "small", srcs = [ "benchmark_test.cc", ], - tags = ["manual"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", + "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", - "@com_github_google_googlebench//:benchmark", - "@com_github_google_googlebench//:benchmark_main", + "//internal:benchmark", + "//internal:status_macros", + "//internal:testing", + "//parser", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_test( + name = "modern_benchmark_test", + srcs = [ + "modern_benchmark_test.cc", + ], + tags = [ + "benchmark", + "manual", + ], + deps = [ + ":request_context_cc_proto", + "//common:allocator", + "//common:casting", + "//common:legacy_value", + "//common:memory", + "//common:native_type", + "//common:value", + "//extensions:comprehensions_v2_functions", + "//extensions:comprehensions_v2_macros", + "//extensions/protobuf:runtime_adapter", + "//extensions/protobuf:value", + "//internal:benchmark", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//parser", + "//parser:macro", + "//parser:macro_registry", + "//runtime", + "//runtime:activation", + "//runtime:constant_folding", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + ], +) + +cc_test( + name = "allocation_benchmark_test", + size = "small", + srcs = [ + "allocation_benchmark_test.cc", + ], + tags = [ + "benchmark", + "manual", + ], + deps = [ + ":request_context_cc_proto", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_value", + "//internal:benchmark", + "//internal:testing", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "memory_safety_test", + srcs = [ + "memory_safety_test.cc", + ], + deps = [ + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function_adapter", + "//eval/public:cel_options", + "//eval/public/testing:matchers", + "//internal:testing", + "//parser", + "//testutil:util", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "expression_builder_benchmark_test", + size = "small", + srcs = [ + "expression_builder_benchmark_test.cc", + ], + tags = [ + "benchmark", + "manual", + ], + deps = [ + ":request_context_cc_proto", + "//common:minimal_descriptor_pool", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_type_registry", + "//internal:benchmark", + "//internal:status_macros", + "//internal:testing", + "//parser", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -42,7 +192,6 @@ cc_test( "end_to_end_test.cc", ], deps = [ - "//base:status_macros", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", @@ -50,9 +199,13 @@ cc_test( "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_googletest//:gtest_main", + "//internal:status_macros", + "//internal:testing", + "//testutil:util", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", ], ) @@ -63,8 +216,8 @@ cc_test( "unknowns_end_to_end_test.cc", ], deps = [ - "//base:status_macros", - "//eval/eval:evaluator_core", + "//base:attributes", + "//base:function_result", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", @@ -74,14 +227,22 @@ cc_test( "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:unknown_set", - "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", - "@com_google_absl//absl/container:btree", + "//internal:status_macros", + "//internal:testing", + "//parser", + "//runtime/internal:activation_attribute_matcher_access", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", ], ) @@ -102,8 +263,9 @@ cc_library( testonly = 1, hdrs = ["mock_cel_expression.h"], deps = [ - "//eval/public:activation", + "//eval/public:base_activation", "//eval/public:cel_expression", - "@com_google_googletest//:gtest", + "//internal:testing_no_main", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/eval/tests/README.md b/eval/tests/README.md index 1eddf51af..d2227641d 100644 --- a/eval/tests/README.md +++ b/eval/tests/README.md @@ -2,11 +2,11 @@ ## Benchmarks To run the benchmark tests: -`blaze run -c opt --dynamic_mode=off //eval/tests:benchmark_test --benchmarks=all` +`blaze run -c opt --dynamic_mode=off //eval/tests:benchmark_test --benchmark_filter=all` or -`blaze run -c opt --dynamic_mode=off //eval/tests:unknowns_benchmark_test --benchmarks=all` +`blaze run -c opt --dynamic_mode=off //eval/tests:unknowns_benchmark_test --benchmark_filter=all` see go/benchmark diff --git a/eval/tests/allocation_benchmark_test.cc b/eval/tests/allocation_benchmark_test.cc new file mode 100644 index 000000000..425355e3a --- /dev/null +++ b/eval/tests/allocation_benchmark_test.cc @@ -0,0 +1,254 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "cel/expr/syntax.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/strings/substitute.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/tests/request_context.pb.h" +#include "internal/benchmark.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::HasSubstr; + +// Evaluates cel expression: +// '"1" + "1" + ...' +static void BM_StrCatLocalArena(benchmark::State& state) { + std::string expr("'1'"); + int len = state.range(0); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + for (int i = 0; i < len; i++) { + expr = absl::Substitute("($0 + $0)", expr); + } + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + CelValue::StringHolder holder; + ASSERT_TRUE(result.GetValue(&holder)); + ASSERT_EQ(holder.value().length(), 1 << len); + } +} +BENCHMARK(BM_StrCatLocalArena)->DenseRange(0, 8, 2); + +// Evaluates cel expression: +// '("1" + "1") + ...' +static void BM_StrCatSharedArena(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("'1'"); + int len = state.range(0); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + for (int i = 0; i < len; i++) { + expr = absl::Substitute("($0 + $0)", expr); + } + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + CelValue::StringHolder holder; + ASSERT_TRUE(result.GetValue(&holder)); + ASSERT_EQ(holder.value().length(), 1 << len); + } +} + +// Expression grows exponentially. +BENCHMARK(BM_StrCatSharedArena)->DenseRange(0, 8, 2); + +// Series of simple expressions that are expected to require an allocation. +static void BM_AllocateString(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("'1' + '1'"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + CelValue::StringHolder holder; + ASSERT_TRUE(result.GetValue(&holder)); + ASSERT_EQ(holder.value(), "11"); + } +} +BENCHMARK(BM_AllocateString); + +static void BM_AllocateError(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("1 / 0"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + const CelError* value; + ASSERT_TRUE(result.GetValue(&value)); + ASSERT_THAT(*value, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("divide by zero"))); + } +} +BENCHMARK(BM_AllocateError); + +static void BM_AllocateMap(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("{1: 2, 3: 4}"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMap()); + } +} + +BENCHMARK(BM_AllocateMap); + +static void BM_AllocateMessage(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr( + "google.api.expr.runtime.RequestContext{" + "ip: '192.168.0.1'," + "path: '/root'}"); + // Make sure RequestContext is loaded in the generated descriptor pool. + RequestContext context; + static_cast(context); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMessage()); + } +} + +BENCHMARK(BM_AllocateMessage); + +static void BM_AllocateLargeMessage(benchmark::State& state) { + // Make sure attribute context is loaded in the generated descriptor pool. + rpc::context::AttributeContext context; + static_cast(context); + + google::protobuf::Arena arena; + std::string expr(R"( + google.rpc.context.AttributeContext{ + source: google.rpc.context.AttributeContext.Peer{ + ip: '192.168.0.1', + port: 1025, + labels: {"abc": "123", "def": "456"} + }, + request: google.rpc.context.AttributeContext.Request{ + method: 'GET', + path: 'root', + host: 'www.example.com' + }, + resource: google.rpc.context.AttributeContext.Resource{ + labels: {"abc": "123", "def": "456"}, + } + })"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMessage()); + } +} + +BENCHMARK(BM_AllocateLargeMessage); + +static void BM_AllocateList(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("[1, 2, 3, 4]"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + } +} +BENCHMARK(BM_AllocateList); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index a7a35f5fc..f188dc0b7 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -1,20 +1,36 @@ -#include "benchmark/benchmark.h" +#include "internal/benchmark.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/rpc/context/attribute_context.pb.h" #include "absl/base/attributes.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" +#include "absl/flags/flag.h" #include "absl/strings/match.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/tests/request_context.pb.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +ABSL_FLAG(bool, enable_optimizations, false, "enable const folding opt"); +ABSL_FLAG(bool, enable_recursive_planning, false, "enable recursive planning"); namespace google { namespace api { @@ -23,16 +39,35 @@ namespace runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::google::rpc::context::AttributeContext; + +InterpreterOptions GetOptions(google::protobuf::Arena& arena) { + InterpreterOptions options; + + if (absl::GetFlag(FLAGS_enable_optimizations)) { + options.constant_arena = &arena; + options.constant_folding = true; + } + + if (absl::GetFlag(FLAGS_enable_recursive_planning)) { + options.max_recursion_depth = -1; + } + + return options; +} // Benchmark test // Evaluates cel expression: // '1 + 1 + 1 .... +1' static void BM_Eval(benchmark::State& state) { - auto builder = CreateCelExpressionBuilder(); - auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); - CHECK_OK(reg_status); + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); @@ -49,32 +84,78 @@ static void BM_Eval(benchmark::State& state) { cur_expr->mutable_const_expr()->set_int64_value(1); SourceInfo source_info; - auto cel_expr_status = builder->CreateExpression(&root_expr, &source_info); - CHECK_OK(cel_expr_status.status()); - - std::unique_ptr cel_expr = std::move(cel_expr_status.value()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&root_expr, &source_info)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; - auto eval_result = cel_expr->Evaluate(activation, &arena); - CHECK_OK(eval_result.status()); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsInt64()); + ASSERT_TRUE(result.Int64OrDie() == len + 1); + } +} + +BENCHMARK(BM_Eval)->Range(1, 10000); + +absl::Status EmptyCallback(int64_t expr_id, const CelValue& value, + google::protobuf::Arena* arena) { + return absl::OkStatus(); +} + +// Benchmark test +// Traces cel expression with an empty callback: +// '1 + 1 + 1 .... +1' +static void BM_Eval_Trace(benchmark::State& state) { + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + options.enable_recursive_tracing = true; + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); - CelValue result = eval_result.value(); - GOOGLE_CHECK(result.IsInt64()); - GOOGLE_CHECK(result.Int64OrDie() == len + 1); + int len = state.range(0); + + Expr root_expr; + Expr* cur_expr = &root_expr; + + for (int i = 0; i < len; i++) { + Expr::Call* call = cur_expr->mutable_call_expr(); + call->set_function("_+_"); + call->add_args()->mutable_const_expr()->set_int64_value(1); + cur_expr = call->add_args(); + } + + cur_expr->mutable_const_expr()->set_int64_value(1); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&root_expr, &source_info)); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Trace(activation, &arena, EmptyCallback)); + ASSERT_TRUE(result.IsInt64()); + ASSERT_TRUE(result.Int64OrDie() == len + 1); } } -BENCHMARK(BM_Eval)->Range(1, 32768); +// A number higher than 10k leads to a stack overflow due to the recursive +// nature of the proto to native type conversion. +BENCHMARK(BM_Eval_Trace)->Range(1, 10000); // Benchmark test // Evaluates cel expression: // '"a" + "a" + "a" .... + "a"' static void BM_EvalString(benchmark::State& state) { - auto builder = CreateCelExpressionBuilder(); - auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); - CHECK_OK(reg_status); + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); @@ -91,504 +172,74 @@ static void BM_EvalString(benchmark::State& state) { cur_expr->mutable_const_expr()->set_string_value("a"); SourceInfo source_info; - auto cel_expr_status = builder->CreateExpression(&root_expr, &source_info); - CHECK_OK(cel_expr_status.status()); - - std::unique_ptr cel_expr = std::move(cel_expr_status.value()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&root_expr, &source_info)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; - auto eval_result = cel_expr->Evaluate(activation, &arena); - CHECK_OK(eval_result.status()); - - CelValue result = eval_result.value(); - GOOGLE_CHECK(result.IsString()); - GOOGLE_CHECK(result.StringOrDie().value().size() == len + 1); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsString()); + ASSERT_TRUE(result.StringOrDie().value().size() == len + 1); } } -BENCHMARK(BM_EvalString)->Range(1, 32768); +// A number higher than 10k leads to a stack overflow due to the recursive +// nature of the proto to native type conversion. +BENCHMARK(BM_EvalString)->Range(1, 10000); -std::string CELAstFlattenedMap() { - return R"( -call_expr: < - function: "_&&_" - args: < - call_expr: < - function: "!_" - args: < - call_expr: < - function: "@in" - args: < - ident_expr: < - name: "ip" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "10.0.1.4" - > - > - elements: < - const_expr: < - string_value: "10.0.1.5" - > - > - elements: < - const_expr: < - string_value: "10.0.1.6" - > - > - > - > - > - > - > - > - args: < - call_expr: < - function: "_||_" - args: < - call_expr: < - function: "_||_" - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - target: < - ident_expr: < - name: "path" - > - > - function: "startsWith" - args: < - const_expr: < - string_value: "v1" - > - > - > - > - args: < - call_expr: < - function: "@in" - args: < - ident_expr: < - name: "token" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "v1" - > - > - elements: < - const_expr: < - string_value: "v2" - > - > - elements: < - const_expr: < - string_value: "admin" - > - > - > - > - > - > - > - > - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - target: < - ident_expr: < - name: "path" - > - > - function: "startsWith" - args: < - const_expr: < - string_value: "v2" - > - > - > - > - args: < - call_expr: < - function: "@in" - args: < - ident_expr: < - name: "token" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "v2" - > - > - elements: < - const_expr: < - string_value: "admin" - > - > - > - > - > - > - > - > - > - > - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - target: < - ident_expr: < - name: "path" - > - > - function: "startsWith" - args: < - const_expr: < - string_value: "/admin" - > - > - > - > - args: < - call_expr: < - function: "_==_" - args: < - ident_expr: < - name: "token" - > - > - args: < - const_expr: < - string_value: "admin" - > - > - > - > - > - > - args: < - call_expr: < - function: "@in" - args: < - ident_expr: < - name: "ip" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "10.0.1.1" - > - > - elements: < - const_expr: < - string_value: "10.0.1.2" - > - > - elements: < - const_expr: < - string_value: "10.0.1.3" - > - > - > - > - > - > - > - > - > - > -> -)"; -} +// Benchmark test +// Traces cel expression with an empty callback: +// '"a" + "a" + "a" .... + "a"' +static void BM_EvalString_Trace(benchmark::State& state) { + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + options.enable_recursive_tracing = true; -// This proto is obtained from CELAstFlattenedMap by replacing "ip", "token", -// and "path" idents with selector expressions for "request.ip", -// "request.token", and "request.path". -std::string CELAst() { - return R"( -call_expr: < - function: "_&&_" - args: < - call_expr: < - function: "!_" - args: < - call_expr: < - function: "@in" - args: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "ip" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "10.0.1.4" - > - > - elements: < - const_expr: < - string_value: "10.0.1.5" - > - > - elements: < - const_expr: < - string_value: "10.0.1.6" - > - > - > - > - > - > - > - > - args: < - call_expr: < - function: "_||_" - args: < - call_expr: < - function: "_||_" - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - target: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "path" - > - > - function: "startsWith" - args: < - const_expr: < - string_value: "v1" - > - > - > - > - args: < - call_expr: < - function: "@in" - args: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "token" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "v1" - > - > - elements: < - const_expr: < - string_value: "v2" - > - > - elements: < - const_expr: < - string_value: "admin" - > - > - > - > - > - > - > - > - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - target: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "path" - > - > - function: "startsWith" - args: < - const_expr: < - string_value: "v2" - > - > - > - > - args: < - call_expr: < - function: "@in" - args: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "token" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "v2" - > - > - elements: < - const_expr: < - string_value: "admin" - > - > - > - > - > - > - > - > - > - > - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - function: "_&&_" - args: < - call_expr: < - target: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "path" - > - > - function: "startsWith" - args: < - const_expr: < - string_value: "/admin" - > - > - > - > - args: < - call_expr: < - function: "_==_" - args: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "token" - > - > - args: < - const_expr: < - string_value: "admin" - > - > - > - > - > - > - args: < - call_expr: < - function: "@in" - args: < - select_expr: < - operand: < - ident_expr: < - name: "request" - > - > - field: "ip" - > - > - args: < - list_expr: < - elements: < - const_expr: < - string_value: "10.0.1.1" - > - > - elements: < - const_expr: < - string_value: "10.0.1.2" - > - > - elements: < - const_expr: < - string_value: "10.0.1.3" - > - > - > - > - > - > - > - > - > - > -> -)"; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + int len = state.range(0); + + Expr root_expr; + Expr* cur_expr = &root_expr; + + for (int i = 0; i < len; i++) { + Expr::Call* call = cur_expr->mutable_call_expr(); + call->set_function("_+_"); + call->add_args()->mutable_const_expr()->set_string_value("a"); + cur_expr = call->add_args(); + } + + cur_expr->mutable_const_expr()->set_string_value("a"); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&root_expr, &source_info)); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Trace(activation, &arena, EmptyCallback)); + ASSERT_TRUE(result.IsString()); + ASSERT_TRUE(result.StringOrDie().value().size() == len + 1); + } } +// A number higher than 10k leads to a stack overflow due to the recursive +// nature of the proto to native type conversion. +BENCHMARK(BM_EvalString_Trace)->Range(1, 10000); + const char kIP[] = "10.0.1.2"; const char kPath[] = "/admin/edit"; const char kToken[] = "admin"; ABSL_ATTRIBUTE_NOINLINE -bool NativeCheck(std::map& attributes, - const std::unordered_set& denylists, - const absl::node_hash_set& allowlists) { +bool NativeCheck(absl::btree_map& attributes, + const absl::flat_hash_set& denylists, + const absl::flat_hash_set& allowlists) { auto& ip = attributes["ip"]; auto& path = attributes["path"]; auto& token = attributes["token"]; @@ -615,56 +266,50 @@ bool NativeCheck(std::map& attributes, void BM_PolicyNative(benchmark::State& state) { const auto denylists = - std::unordered_set{"10.0.1.4", "10.0.1.5", "10.0.1.6"}; + absl::flat_hash_set{"10.0.1.4", "10.0.1.5", "10.0.1.6"}; const auto allowlists = - absl::node_hash_set{"10.0.1.1", "10.0.1.2", "10.0.1.3"}; - auto attributes = std::map{ + absl::flat_hash_set{"10.0.1.1", "10.0.1.2", "10.0.1.3"}; + auto attributes = absl::btree_map{ {"ip", kIP}, {"token", kToken}, {"path", kPath}}; for (auto _ : state) { auto result = NativeCheck(attributes, denylists, allowlists); - GOOGLE_CHECK(result); + ASSERT_TRUE(result); } } BENCHMARK(BM_PolicyNative); -/* - Evaluates an expression: - - !(ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && - ( - (path.startsWith("v1") && token in ["v1", "v2", "admin"]) || - (path.startsWith("v2") && token in ["v2", "admin"]) || - (path.startsWith("/admin") && token == "admin" && ip in ["10.0.1.1", - "10.0.1.2", "10.0.1.3"]) - ) -*/ void BM_PolicySymbolic(benchmark::State& state) { google::protobuf::Arena arena; - Expr expr; - google::protobuf::TextFormat::ParseFromString(CELAstFlattenedMap(), &expr); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + !(ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && + ((path.startsWith("v1") && token in ["v1", "v2", "admin"]) || + (path.startsWith("v2") && token in ["v2", "admin"]) || + (path.startsWith("/admin") && token == "admin" && ip in [ + "10.0.1.1", "10.0.1.2", "10.0.1.3" + ]) + ))cel")); - InterpreterOptions options; + InterpreterOptions options = GetOptions(arena); options.constant_folding = true; options.constant_arena = &arena; auto builder = CreateCelExpressionBuilder(options); - CHECK_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; - auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - CHECK_OK(cel_expression_status.status()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( + &parsed_expr.expr(), &source_info)); - auto cel_expression = std::move(cel_expression_status.value()); Activation activation; activation.InsertValue("ip", CelValue::CreateStringView(kIP)); activation.InsertValue("path", CelValue::CreateStringView(kPath)); activation.InsertValue("token", CelValue::CreateStringView(kToken)); for (auto _ : state) { - auto eval_result = cel_expression->Evaluate(activation, &arena); - CHECK_OK(eval_result.status()); - GOOGLE_CHECK(eval_result.value().BoolOrDie()); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.BoolOrDie()); } } @@ -672,7 +317,7 @@ BENCHMARK(BM_PolicySymbolic); class RequestMap : public CelMap { public: - absl::optional operator[](CelValue key) const override { + std::optional operator[](CelValue key) const override { if (!key.IsString()) { return {}; } @@ -687,31 +332,39 @@ class RequestMap : public CelMap { return {}; } int size() const override { return 3; } - const CelList* ListKeys() const override { return nullptr; } + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } }; // Uses a lazily constructed map container for "ip", "path", and "token". void BM_PolicySymbolicMap(benchmark::State& state) { google::protobuf::Arena arena; - Expr expr; - google::protobuf::TextFormat::ParseFromString(CELAst(), &expr); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && + ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || + (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || + (request.path.startsWith("/admin") && request.token == "admin" && + request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) + ))cel")); - auto builder = CreateCelExpressionBuilder(); - CHECK_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; - auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - CHECK_OK(cel_expression_status.status()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( + &parsed_expr.expr(), &source_info)); - auto cel_expression = std::move(cel_expression_status.value()); Activation activation; RequestMap request; activation.InsertValue("request", CelValue::CreateMap(&request)); for (auto _ : state) { - auto eval_result = cel_expression->Evaluate(activation, &arena); - CHECK_OK(eval_result.status()); - GOOGLE_CHECK(eval_result.value().BoolOrDie()); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.BoolOrDie()); } } @@ -720,17 +373,23 @@ BENCHMARK(BM_PolicySymbolicMap); // Uses a protobuf container for "ip", "path", and "token". void BM_PolicySymbolicProto(benchmark::State& state) { google::protobuf::Arena arena; - Expr expr; - google::protobuf::TextFormat::ParseFromString(CELAst(), &expr); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && + ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || + (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || + (request.path.startsWith("/admin") && request.token == "admin" && + request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) + ))cel")); - auto builder = CreateCelExpressionBuilder(); - CHECK_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; - auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - CHECK_OK(cel_expression_status.status()); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( + &parsed_expr.expr(), &source_info)); - auto cel_expression = std::move(cel_expression_status.value()); Activation activation; RequestContext request; request.set_ip(kIP); @@ -738,16 +397,16 @@ void BM_PolicySymbolicProto(benchmark::State& state) { request.set_token(kToken); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); - for (auto _ : state) { - auto eval_result = cel_expression->Evaluate(activation, &arena); - CHECK_OK(eval_result.status()); - GOOGLE_CHECK(eval_result.value().BoolOrDie()); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_PolicySymbolicProto); +// This expression has no equivalent CEL constexpr char kListSum[] = R"( id: 1 comprehension_expr: < @@ -756,7 +415,7 @@ comprehension_expr: < iter_range: < id: 2 ident_expr: < - name: "list" + name: "list_var" > > accu_init: < @@ -801,7 +460,7 @@ void BM_Comprehension(benchmark::State& state) { google::protobuf::Arena arena; Expr expr; Activation activation; - GOOGLE_CHECK(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); int len = state.range(0); std::vector list; @@ -811,21 +470,292 @@ void BM_Comprehension(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); - auto builder = CreateCelExpressionBuilder(); - CHECK_OK(RegisterBuiltinFunctions(builder->GetRegistry())); - auto expr_plan = builder->CreateExpression(&expr, nullptr); - CHECK_OK(expr_plan.status()); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); + + InterpreterOptions options = GetOptions(arena); + options.comprehension_max_iterations = 10000000; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, nullptr)); for (auto _ : state) { - auto result = expr_plan.value()->Evaluate(activation, &arena); - CHECK_OK(result.status()); - GOOGLE_CHECK(result->IsInt64()); - CHECK_EQ(result->Int64OrDie(), len); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsInt64()); + ASSERT_EQ(result.Int64OrDie(), len); } } BENCHMARK(BM_Comprehension)->Range(1, 1 << 20); +void BM_Comprehension_Trace(benchmark::State& state) { + google::protobuf::Arena arena; + Expr expr; + Activation activation; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); + + int len = state.range(0); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; i++) { + list.push_back(CelValue::CreateInt64(1)); + } + + ContainerBackedListImpl cel_list(std::move(list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); + InterpreterOptions options = GetOptions(arena); + options.enable_recursive_tracing = true; + + options.comprehension_max_iterations = 10000000; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, nullptr)); + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Trace(activation, &arena, EmptyCallback)); + ASSERT_TRUE(result.IsInt64()); + ASSERT_EQ(result.Int64OrDie(), len); + } +} + +BENCHMARK(BM_Comprehension_Trace)->Range(1, 1 << 20); + +void BM_HasMap(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("has(request.path) && !has(request.ip)")); + + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + std::vector> map_pairs{ + {CelValue::CreateStringView("path"), CelValue::CreateStringView("path")}}; + auto cel_map = + CreateContainerBackedMap(absl::Span>( + map_pairs.data(), map_pairs.size())); + activation.InsertValue("request", CelValue::CreateMap((*cel_map).get())); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_HasMap); + +void BM_HasProto(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("has(request.path) && !has(request.ip)")); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + RequestContext request; + request.set_path(kPath); + request.set_token(kToken); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_HasProto); + +void BM_HasProtoMap(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("has(request.headers.create_time) && " + "!has(request.headers.update_time)")); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + RequestContext request; + request.mutable_headers()->insert({"create_time", "2021-01-01"}); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_HasProtoMap); + +void BM_ReadProtoMap(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + request.headers.create_time == "2021-01-01" + )cel")); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + RequestContext request; + request.mutable_headers()->insert({"create_time", "2021-01-01"}); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_ReadProtoMap); + +void BM_NestedProtoFieldRead(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + !request.a.b.c.d.e + )cel")); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + RequestContext request; + request.mutable_a()->mutable_b()->mutable_c()->mutable_d()->set_e(false); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_NestedProtoFieldRead); + +void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + !request.a.b.c.d.e + )cel")); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + RequestContext request; + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_NestedProtoFieldReadDefaults); + +void BM_ProtoStructAccess(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + has(request.auth.claims.iss) && request.auth.claims.iss == 'accounts.google.com' + )cel")); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + AttributeContext::Request request; + auto* auth = request.mutable_auth(); + (*auth->mutable_claims()->mutable_fields())["iss"].set_string_value( + "accounts.google.com"); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_ProtoStructAccess); + +void BM_ProtoListAccess(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + "//.../accessLevels/MY_LEVEL_4" in request.auth.access_levels + )cel")); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + AttributeContext::Request request; + auto* auth = request.mutable_auth(); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_0"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_1"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_2"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_3"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_4"); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_ProtoListAccess); + +// This expression has no equivalent CEL expression. // Sum a square with a nested comprehension constexpr char kNestedListSum[] = R"( id: 1 @@ -835,7 +765,7 @@ comprehension_expr: < iter_range: < id: 2 ident_expr: < - name: "list" + name: "list_var" > > accu_init: < @@ -862,7 +792,7 @@ comprehension_expr: < iter_range: < id: 9 ident_expr: < - name: "list" + name: "list_var" > > accu_init: < @@ -923,7 +853,7 @@ void BM_NestedComprehension(benchmark::State& state) { google::protobuf::Arena arena; Expr expr; Activation activation; - GOOGLE_CHECK(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); int len = state.range(0); std::vector list; @@ -933,25 +863,165 @@ void BM_NestedComprehension(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); - auto builder = CreateCelExpressionBuilder(); - CHECK_OK(RegisterBuiltinFunctions(builder->GetRegistry())); - auto expr_plan = builder->CreateExpression(&expr, nullptr); - CHECK_OK(expr_plan.status()); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); + InterpreterOptions options = GetOptions(arena); + options.comprehension_max_iterations = 10000000; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, nullptr)); + for (auto _ : state) { - auto result = expr_plan.value()->Evaluate(activation, &arena); - CHECK_OK(result.status()); - GOOGLE_CHECK(result->IsInt64()); - CHECK_EQ(result->Int64OrDie(), len * len); + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsInt64()); + ASSERT_EQ(result.Int64OrDie(), len * len); } } BENCHMARK(BM_NestedComprehension)->Range(1, 1 << 10); -void BM_ComprehensionCpp(benchmark::State& state) { +void BM_NestedComprehension_Trace(benchmark::State& state) { google::protobuf::Arena arena; Expr expr; Activation activation; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); + + int len = state.range(0); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; i++) { + list.push_back(CelValue::CreateInt64(1)); + } + + ContainerBackedListImpl cel_list(std::move(list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); + InterpreterOptions options = GetOptions(arena); + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + options.enable_recursive_tracing = true; + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, nullptr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Trace(activation, &arena, EmptyCallback)); + ASSERT_TRUE(result.IsInt64()); + ASSERT_EQ(result.Int64OrDie(), len * len); + } +} + +BENCHMARK(BM_NestedComprehension_Trace)->Range(1, 1 << 10); + +void BM_ListComprehension(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("list_var.map(x, x * 2)")); + + int len = state.range(0); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; i++) { + list.push_back(CelValue::CreateInt64(1)); + } + + ContainerBackedListImpl cel_list(std::move(list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); + InterpreterOptions options = GetOptions(arena); + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + ASSERT_EQ(result.ListOrDie()->size(), len); + } +} + +BENCHMARK(BM_ListComprehension)->Range(1, 1 << 16); + +void BM_ListComprehension_Trace(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("list_var.map(x, x * 2)")); + + int len = state.range(0); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; i++) { + list.push_back(CelValue::CreateInt64(1)); + } + + ContainerBackedListImpl cel_list(std::move(list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); + InterpreterOptions options = GetOptions(arena); + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + options.enable_recursive_tracing = true; + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Trace(activation, &arena, EmptyCallback)); + ASSERT_TRUE(result.IsList()); + ASSERT_EQ(result.ListOrDie()->size(), len); + } +} + +BENCHMARK(BM_ListComprehension_Trace)->Range(1, 1 << 16); + +void BM_ListComprehension_Opt(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("list_var.map(x, x * 2)")); + + int len = state.range(0); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; i++) { + list.push_back(CelValue::CreateInt64(1)); + } + + ContainerBackedListImpl cel_list(std::move(list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); + InterpreterOptions options; + options.constant_arena = &arena; + options.constant_folding = true; + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + ASSERT_EQ(result.ListOrDie()->size(), len); + } +} + +BENCHMARK(BM_ListComprehension_Opt)->Range(1, 1 << 16); + +void BM_ComprehensionCpp(benchmark::State& state) { + Activation activation; int len = state.range(0); std::vector list; @@ -969,7 +1039,7 @@ void BM_ComprehensionCpp(benchmark::State& state) { }; for (auto _ : state) { int result = op(); - CHECK_EQ(result, len); + ASSERT_EQ(result, len); } } diff --git a/eval/tests/end_to_end_test.cc b/eval/tests/end_to_end_test.cc index f28907da9..dca0b36ee 100644 --- a/eval/tests/end_to_end_test.cc +++ b/eval/tests/end_to_end_test.cc @@ -1,7 +1,10 @@ -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" @@ -9,7 +12,10 @@ #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" -#include "base/status_macros.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "testutil/util.h" +#include "google/protobuf/text_format.h" namespace google { namespace api { @@ -18,11 +24,34 @@ namespace runtime { namespace { +using ::absl_testing::StatusIs; +using ::cel::expr::Expr; +using ::cel::expr::SourceInfo; using ::google::protobuf::Arena; -using google::protobuf::TextFormat; - -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using ::google::protobuf::TextFormat; + +// Simple one parameter function that records the message argument it receives. +class RecordArgFunction : public CelFunction { + public: + explicit RecordArgFunction(const std::string& name, + std::vector* output) + : CelFunction( + CelFunctionDescriptor{name, false, {CelValue::Type::kMessage}}), + output_(*output) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + if (args.size() != 1) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Bad arguments number"); + } + output_.push_back(args.at(0)); + *result = CelValue::CreateBool(true); + return absl::OkStatus(); + } + + std::vector& output_; +}; // Simple end-to-end test, which also serves as usage example. TEST(EndToEndTest, SimpleOnePlusOne) { @@ -54,12 +83,8 @@ TEST(EndToEndTest, SimpleOnePlusOne) { ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). - auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - - ASSERT_OK(cel_expression_status); - - auto cel_expression = std::move(cel_expression_status.value()); - + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, &source_info)); Activation activation; // Bind value to "var" parameter. @@ -68,19 +93,14 @@ TEST(EndToEndTest, SimpleOnePlusOne) { Arena arena; // Run evaluation. - auto eval_status = cel_expression->Evaluate(activation, &arena); - - ASSERT_OK(eval_status); - - CelValue result = eval_status.value(); - + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 2); } // Simple end-to-end test, which also serves as usage example. TEST(EndToEndTest, EmptyStringCompare) { - // AST CEL equivalent of "var.string_value == """ + // AST CEL equivalent of "var.string_value == '' && var.int64_value == 0" constexpr char kExpr0[] = R"( call_expr: < function: "_&&_" @@ -138,12 +158,8 @@ TEST(EndToEndTest, EmptyStringCompare) { ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). - auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - - ASSERT_OK(cel_expression_status); - - auto cel_expression = std::move(cel_expression_status.value()); - + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, &source_info)); Activation activation; // Bind value to "var" parameter. @@ -157,14 +173,131 @@ TEST(EndToEndTest, EmptyStringCompare) { activation.InsertValue("var", CelProtoWrapper::CreateMessage(&data, &arena)); // Run evaluation. - auto eval_status = cel_expression->Evaluate(activation, &arena); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + EXPECT_TRUE(result.BoolOrDie()); +} - ASSERT_OK(eval_status); +TEST(EndToEndTest, NullLiteral) { + // AST CEL equivalent of "Value{null_value: NullValue.NULL_VALUE}" + constexpr char kExpr0[] = R"( + struct_expr: < + message_name: "Value" + entries: < + field_key: "null_value" + value: < + select_expr: < + operand: < + ident_expr: < + name: "NullValue" + > + > + field: "NULL_VALUE" + > + > + > + > + )"; - CelValue result = eval_status.value(); + Expr expr; + SourceInfo source_info; + TextFormat::ParseFromString(kExpr0, &expr); - ASSERT_TRUE(result.IsBool()); - EXPECT_TRUE(result.BoolOrDie()); + // Obtain CEL Expression builder. + std::unique_ptr builder = CreateCelExpressionBuilder(); + builder->set_container("google.protobuf"); + + // Builtin registration. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, &source_info)); + Activation activation; + Arena arena; + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsNull()); +} + +// Equivalent to 'RecordArg(test_message)' +constexpr char kNullMessageHandlingExpr[] = R"pb( + id: 1 + call_expr: < + function: "RecordArg" + args: < + ident_expr: < name: "test_message" > + id: 2 + > + > +)pb"; + +TEST(EndToEndTest, StrictNullHandling) { + InterpreterOptions options; + + Expr expr; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(kNullMessageHandlingExpr, &expr)); + SourceInfo info; + + auto builder = CreateCelExpressionBuilder(options); + std::vector extension_calls; + ASSERT_OK(builder->GetRegistry()->Register( + std::make_unique("RecordArg", &extension_calls))); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder->CreateExpression(&expr, &info)); + + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("test_message", CelValue::CreateNull()); + + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + const CelError* result_value; + ASSERT_TRUE(result.GetValue(&result_value)) << result.DebugString(); + EXPECT_THAT(*result_value, + StatusIs(absl::StatusCode::kUnknown, + testing::HasSubstr("No matching overloads"))); +} + +TEST(EndToEndTest, OutOfRangeDurationConstant) { + InterpreterOptions options; + options.enable_timestamp_duration_overflow_errors = true; + + Expr expr; + // Duration representable in absl::Duration, but out of range for CelValue + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"( + call_expr { + function: "type" + args { + const_expr { + duration_value { + seconds: 28552639587287040 + } + } + } + })", + &expr)); + SourceInfo info; + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + ASSERT_OK_AND_ASSIGN(auto expression, + builder->CreateExpression(&expr, &info)); + + Activation activation; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + const CelError* result_value; + ASSERT_TRUE(result.GetValue(&result_value)) << result.DebugString(); + EXPECT_THAT(*result_value, + StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr("Duration is out of range"))); } } // namespace diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc new file mode 100644 index 000000000..410df8902 --- /dev/null +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -0,0 +1,466 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/minimal_descriptor_pool.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_type_registry.h" +#include "eval/tests/request_context.pb.h" +#include "internal/benchmark.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { + +namespace { + +using cel::expr::CheckedExpr; +using cel::expr::ParsedExpr; +using google::api::expr::parser::Parse; + +enum BenchmarkParam : int { + kDefault = 0, + kFoldConstants = 1, + kRecursivePlanning = 2, + kRecursivePlanningWithConstantFolding = 3, +}; + +std::string LabelForParam(BenchmarkParam param) { + switch (param) { + case BenchmarkParam::kDefault: + return "default"; + case BenchmarkParam::kFoldConstants: + return "fold_constants"; + case BenchmarkParam::kRecursivePlanning: + return "recursive_planning"; + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + return "recursive_planning_with_constant_folding"; + } + return "unknown"; +} + +void BM_RegisterBuiltins(benchmark::State& state) { + for (auto _ : state) { + auto builder = CreateCelExpressionBuilder(); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ASSERT_OK(reg_status); + } +} + +BENCHMARK(BM_RegisterBuiltins); + +InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { + InterpreterOptions options; + switch (param) { + case BenchmarkParam::kFoldConstants: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + options.constant_arena = &arena; + options.constant_folding = true; + break; + case BenchmarkParam::kDefault: + case BenchmarkParam::kRecursivePlanning: + options.constant_folding = false; + break; + } + switch (param) { + case BenchmarkParam::kRecursivePlanning: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + options.max_recursion_depth = 48; + break; + case BenchmarkParam::kDefault: + case BenchmarkParam::kFoldConstants: + options.max_recursion_depth = 0; + break; + } + return options; +} + +void BM_SymbolicPolicy(benchmark::State& state) { + auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( + !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && + ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || + (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || + (request.path.startsWith("/admin") && request.token == "admin" && + request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) + ))cel")); + + google::protobuf::Arena arena; + InterpreterOptions options = OptionsForParam(param, arena); + + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ASSERT_OK(reg_status); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); + } +} + +BENCHMARK(BM_SymbolicPolicy) + ->Arg(BenchmarkParam::kDefault) + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); + +absl::StatusOr> MakeBuilderForEnums( + absl::string_view container, absl::string_view enum_type, + int num_enum_values) { + auto builder = + CreateCelExpressionBuilder(cel::GetMinimalDescriptorPool(), nullptr, {}); + builder->set_container(std::string(container)); + CelTypeRegistry* type_registry = builder->GetTypeRegistry(); + std::vector enumerators; + enumerators.reserve(num_enum_values); + for (int i = 0; i < num_enum_values; ++i) { + enumerators.push_back( + CelTypeRegistry::Enumerator{absl::StrCat("ENUM_VALUE_", i), i}); + } + type_registry->RegisterEnum(enum_type, std::move(enumerators)); + + CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder->GetRegistry())); + return builder; +} + +void BM_EnumResolutionSimple(benchmark::State& state) { + static const CelExpressionBuilder* builder = []() { + auto builder = MakeBuilderForEnums("", "com.example.TestEnum", 4); + ABSL_CHECK_OK(builder.status()); + return builder->release(); + }(); + + ASSERT_OK_AND_ASSIGN(auto expr, Parse("com.example.TestEnum.ENUM_VALUE_0")); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + benchmark::DoNotOptimize(expression); + } +} + +BENCHMARK(BM_EnumResolutionSimple)->ThreadRange(1, 32); + +void BM_EnumResolutionContainer(benchmark::State& state) { + static const CelExpressionBuilder* builder = []() { + auto builder = + MakeBuilderForEnums("com.example", "com.example.TestEnum", 4); + ABSL_CHECK_OK(builder.status()); + return builder->release(); + }(); + + ASSERT_OK_AND_ASSIGN(auto expr, Parse("TestEnum.ENUM_VALUE_0")); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + benchmark::DoNotOptimize(expression); + } +} + +BENCHMARK(BM_EnumResolutionContainer)->ThreadRange(1, 32); + +void BM_EnumResolution32Candidate(benchmark::State& state) { + static const CelExpressionBuilder* builder = []() { + auto builder = + MakeBuilderForEnums("com.example.foo", "com.example.foo.TestEnum", 8); + ABSL_CHECK_OK(builder.status()); + return builder->release(); + }(); + + ASSERT_OK_AND_ASSIGN(auto expr, + Parse("com.example.foo.TestEnum.ENUM_VALUE_0")); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + benchmark::DoNotOptimize(expression); + } +} + +BENCHMARK(BM_EnumResolution32Candidate)->ThreadRange(1, 32); + +void BM_EnumResolution256Candidate(benchmark::State& state) { + static const CelExpressionBuilder* builder = []() { + auto builder = + MakeBuilderForEnums("com.example.foo", "com.example.foo.TestEnum", 64); + ABSL_CHECK_OK(builder.status()); + return builder->release(); + }(); + + ASSERT_OK_AND_ASSIGN(auto expr, + Parse("com.example.foo.TestEnum.ENUM_VALUE_0")); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + benchmark::DoNotOptimize(expression); + } +} + +BENCHMARK(BM_EnumResolution256Candidate)->ThreadRange(1, 32); + +void BM_NestedComprehension(benchmark::State& state) { + auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( + [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) + )")); + + google::protobuf::Arena arena; + InterpreterOptions options = OptionsForParam(param, arena); + + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ASSERT_OK(reg_status); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); + } +} + +BENCHMARK(BM_NestedComprehension) + ->Arg(BenchmarkParam::kDefault) + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); + +void BM_Comparisons(benchmark::State& state) { + auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( + v11 < v12 && v12 < v13 + && v21 > v22 && v22 > v23 + && v31 == v32 && v32 == v33 + && v11 != v12 && v12 != v13 + )")); + + google::protobuf::Arena arena; + InterpreterOptions options = OptionsForParam(param, arena); + + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ASSERT_OK(reg_status); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); + } +} + +BENCHMARK(BM_Comparisons) + ->Arg(BenchmarkParam::kDefault) + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); + +void BM_ComparisonsConcurrent(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( + v11 < v12 && v12 < v13 + && v21 > v22 && v22 > v23 + && v31 == v32 && v32 == v33 + && v11 != v12 && v12 != v13 + )")); + + static const CelExpressionBuilder* builder = [] { + InterpreterOptions options; + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ABSL_CHECK_OK(reg_status); + return builder.release(); + }(); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + benchmark::DoNotOptimize(expression); + } +} + +BENCHMARK(BM_ComparisonsConcurrent)->ThreadRange(1, 32); + +void RegexPrecompilationBench(bool enabled, benchmark::State& state) { + auto param = static_cast(state.range(0)); + state.SetLabel(absl::StrCat(LabelForParam(param), "_", + enabled ? "enabled" : "disabled")); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( + input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || + input_str.matches(r'10(\.[0-9]{1,3}){3}') + )cel")); + + // Fake a checked expression with enough reference information for the expr + // builder to identify the regex as optimize-able. + CheckedExpr checked_expr; + checked_expr.mutable_expr()->Swap(expr.mutable_expr()); + checked_expr.mutable_source_info()->Swap(expr.mutable_source_info()); + (*checked_expr.mutable_reference_map())[2].add_overload_id("matches_string"); + (*checked_expr.mutable_reference_map())[11].add_overload_id("matches_string"); + + google::protobuf::Arena arena; + InterpreterOptions options = OptionsForParam(param, arena); + options.enable_regex_precompilation = enabled; + + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ASSERT_OK(reg_status); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(auto expression, + builder->CreateExpression(&checked_expr)); + arena.Reset(); + } +} + +void BM_RegexPrecompilationDisabled(benchmark::State& state) { + RegexPrecompilationBench(false, state); +} + +BENCHMARK(BM_RegexPrecompilationDisabled) + ->Arg(BenchmarkParam::kDefault) + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); + +void BM_RegexPrecompilationEnabled(benchmark::State& state) { + RegexPrecompilationBench(true, state); +} + +BENCHMARK(BM_RegexPrecompilationEnabled) + ->Arg(BenchmarkParam::kDefault) + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); + +void BM_StringConcat(benchmark::State& state) { + auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); + auto size = state.range(1); + + std::string source = "'1234567890' + '1234567890'"; + auto height = static_cast(std::log2(size)); + for (int i = 1; i < height; i++) { + // Force the parse to be a binary tree, otherwise we can hit + // recursion limits. + source = absl::StrCat("(", source, " + ", source, ")"); + } + + // add a non const branch to the expression. + absl::StrAppend(&source, " + identifier"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(source)); + + google::protobuf::Arena arena; + InterpreterOptions options = OptionsForParam(param, arena); + + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ASSERT_OK(reg_status); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); + } +} + +BENCHMARK(BM_StringConcat) + ->Args({BenchmarkParam::kDefault, 2}) + ->Args({BenchmarkParam::kDefault, 4}) + ->Args({BenchmarkParam::kDefault, 8}) + ->Args({BenchmarkParam::kDefault, 16}) + ->Args({BenchmarkParam::kDefault, 32}) + ->Args({BenchmarkParam::kFoldConstants, 2}) + ->Args({BenchmarkParam::kFoldConstants, 4}) + ->Args({BenchmarkParam::kFoldConstants, 8}) + ->Args({BenchmarkParam::kFoldConstants, 16}) + ->Args({BenchmarkParam::kFoldConstants, 32}) + ->Args({BenchmarkParam::kRecursivePlanning, 2}) + ->Args({BenchmarkParam::kRecursivePlanning, 4}) + ->Args({BenchmarkParam::kRecursivePlanning, 8}) + ->Args({BenchmarkParam::kRecursivePlanning, 16}) + ->Args({BenchmarkParam::kRecursivePlanning, 32}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 2}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 4}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 8}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 16}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 32}); + +void BM_StringConcat32Concurrent(benchmark::State& state) { + std::string source = "'1234567890' + '1234567890'"; + auto height = static_cast(std::log2(32)); + for (int i = 1; i < height; i++) { + // Force the parse to be a binary tree, otherwise we can hit + // recursion limits. + source = absl::StrCat("(", source, " + ", source, ")"); + } + + // add a non const branch to the expression. + absl::StrAppend(&source, " + identifier"); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(source)); + + static const CelExpressionBuilder* builder = [] { + InterpreterOptions options; + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ABSL_CHECK_OK(reg_status); + return builder.release(); + }(); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN( + auto expression, + builder->CreateExpression(&expr.expr(), &expr.source_info())); + benchmark::DoNotOptimize(expression); + } +} + +BENCHMARK(BM_StringConcat32Concurrent)->ThreadRange(1, 32); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/tests/memory_safety_test.cc b/eval/tests/memory_safety_test.cc new file mode 100644 index 000000000..a88844fed --- /dev/null +++ b/eval/tests/memory_safety_test.cc @@ -0,0 +1,319 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Tests for memory safety using the CEL Evaluator. +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_adapter.h" +#include "eval/public/cel_options.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "testutil/util.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::cel::expr::ParsedExpr; +using ::google::rpc::context::AttributeContext; +using testutil::EqualsProto; + +struct TestCase { + std::string name; + std::string expression; + absl::flat_hash_map activation; + test::CelValueMatcher expected_matcher; + bool reference_resolver_enabled = false; +}; + +enum Options { + kDefault, + kExhaustive, + kFoldConstants, + kFoldConstantsManagedArena +}; + +using ParamType = std::tuple; + +std::string TestCaseName(const testing::TestParamInfo& param_info) { + const ParamType& param = param_info.param; + absl::string_view opt; + switch (std::get<1>(param)) { + case Options::kDefault: + opt = "default"; + break; + case Options::kExhaustive: + opt = "exhaustive"; + break; + case Options::kFoldConstants: + opt = "opt"; + break; + case Options::kFoldConstantsManagedArena: + opt = "opt_managed_arena"; + break; + } + + return absl::StrCat(std::get<0>(param).name, "_", opt); +} + +class EvaluatorMemorySafetyTest : public testing::TestWithParam { + public: + EvaluatorMemorySafetyTest() { + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + } + + protected: + const TestCase& GetTestCase() { return std::get<0>(GetParam()); } + + InterpreterOptions GetOptions() { + InterpreterOptions options; + options.constant_arena = &arena_; + + switch (std::get<1>(GetParam())) { + case Options::kDefault: + options.enable_regex_precompilation = false; + options.constant_folding = false; + options.enable_comprehension_list_append = false; + options.enable_comprehension_vulnerability_check = true; + options.short_circuiting = true; + break; + case Options::kExhaustive: + options.enable_regex_precompilation = false; + options.constant_folding = false; + options.enable_comprehension_list_append = false; + options.enable_comprehension_vulnerability_check = true; + options.short_circuiting = false; + break; + case Options::kFoldConstants: + options.enable_regex_precompilation = true; + options.constant_folding = true; + options.enable_comprehension_list_append = true; + options.enable_comprehension_vulnerability_check = false; + options.short_circuiting = true; + break; + case Options::kFoldConstantsManagedArena: + options.enable_regex_precompilation = true; + options.constant_folding = true; + options.enable_comprehension_list_append = true; + options.enable_comprehension_vulnerability_check = false; + options.short_circuiting = true; + options.constant_arena = nullptr; + break; + } + + options.enable_qualified_identifier_rewrites = + GetTestCase().reference_resolver_enabled; + + return options; + } + + google::protobuf::Arena arena_; +}; + +bool IsPrivateIpv4Impl(google::protobuf::Arena* arena, CelValue::StringHolder addr) { + // Implementation for demonstration, this is simple but incomplete and + // brittle. + return absl::StartsWith(addr.value(), "192.168.") || + absl::StartsWith(addr.value(), "10."); +} + +TEST_P(EvaluatorMemorySafetyTest, Basic) { + const auto& test_case = GetTestCase(); + InterpreterOptions options = GetOptions(); + + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + builder->set_container("google.rpc.context"); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + absl::string_view function_name = "IsPrivate"; + if (test_case.reference_resolver_enabled) { + function_name = "net.IsPrivate"; + } + ASSERT_OK((FunctionAdapter::CreateAndRegister( + function_name, false, &IsPrivateIpv4Impl, builder->GetRegistry()))); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(test_case.expression)); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation activation; + for (const auto& [key, value] : test_case.activation) { + activation.InsertValue(key, value); + } + + absl::StatusOr got = plan->Evaluate(activation, &arena_); + + EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); +} + +// Check no use after free errors if evaluated after AST is freed. +TEST_P(EvaluatorMemorySafetyTest, NoAstDependency) { + const auto& test_case = GetTestCase(); + InterpreterOptions options = GetOptions(); + + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + builder->set_container("google.rpc.context"); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + + absl::string_view function_name = "IsPrivate"; + if (test_case.reference_resolver_enabled) { + function_name = "net.IsPrivate"; + } + ASSERT_OK((FunctionAdapter::CreateAndRegister( + function_name, false, &IsPrivateIpv4Impl, builder->GetRegistry()))); + + auto parsed_expr = parser::Parse(test_case.expression); + ASSERT_OK(parsed_expr.status()); + auto expr = std::make_unique(std::move(parsed_expr).value()); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr plan, + builder->CreateExpression(&expr->expr(), &expr->source_info())); + + expr.reset(); // ParsedExpr expr freed + + Activation activation; + for (const auto& [key, value] : test_case.activation) { + activation.InsertValue(key, value); + } + + absl::StatusOr got = plan->Evaluate(activation, &arena_); + + EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); +} + +// TODO(uncreated-issue/25): make expression plan memory safe after builder is freed. +// TEST_P(EvaluatorMemorySafetyTest, NoBuilderDependency) + +INSTANTIATE_TEST_SUITE_P( + Expression, EvaluatorMemorySafetyTest, + testing::Combine( + testing::ValuesIn(std::vector{ + { + "bool", + "(true && false) || x || y == 'test_str'", + {{"x", CelValue::CreateBool(false)}, + {"y", CelValue::CreateStringView("test_str")}}, + test::IsCelBool(true), + }, + { + "const_str", + "condition ? 'left_hand_string' : 'right_hand_string'", + {{"condition", CelValue::CreateBool(false)}}, + test::IsCelString("right_hand_string"), + }, + { + "long_const_string", + "condition ? 'left_hand_string' : " + "'long_right_hand_string_0123456789'", + {{"condition", CelValue::CreateBool(false)}}, + test::IsCelString("long_right_hand_string_0123456789"), + }, + { + "computed_string", + "(condition ? 'a.b' : 'b.c') + '.d.e.f'", + {{"condition", CelValue::CreateBool(false)}}, + test::IsCelString("b.c.d.e.f"), + }, + { + "regex", + R"('192.168.128.64'.matches(r'^192\.168\.[0-2]?[0-9]?[0-9]\.[0-2]?[0-9]?[0-9]') )", + {}, + test::IsCelBool(true), + }, + { + "list_create", + "[1, 2, 3, 4, 5, 6][3] == 4", + {}, + test::IsCelBool(true), + }, + { + "list_create_strings", + "['1', '2', '3', '4', '5', '6'][2] == '3'", + {}, + test::IsCelBool(true), + }, + { + "map_create", + "{'1': 'one', '2': 'two'}['2']", + {}, + test::IsCelString("two"), + }, + { + "struct_create", + R"( + AttributeContext{ + request: AttributeContext.Request{ + method: 'GET', + path: '/index' + }, + origin: AttributeContext.Peer{ + ip: '10.0.0.1' + } + } + )", + {}, + test::IsCelMessage(EqualsProto(R"pb( + request { method: "GET" path: "/index" } + origin { ip: "10.0.0.1" } + )pb")), + }, + {"extension_function", + "IsPrivate('8.8.8.8')", + {}, + test::IsCelBool(false), + /*enable_reference_resolver=*/false}, + {"namespaced_function", + "net.IsPrivate('192.168.0.1')", + {}, + test::IsCelBool(true), + /*enable_reference_resolver=*/true}, + { + "comprehension", + "['abc', 'def', 'ghi', 'jkl'].exists(el, el == 'mno')", + {}, + test::IsCelBool(false), + }, + { + "comprehension_complex", + "['a' + 'b' + 'c', 'd' + 'ef', 'g' + 'hi', 'j' + 'kl']" + ".exists(el, el.startsWith('g'))", + {}, + test::IsCelBool(true), + }}), + testing::Values(Options::kDefault, Options::kExhaustive, + Options::kFoldConstants, + Options::kFoldConstantsManagedArena)), + &TestCaseName); + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/tests/mock_cel_expression.h b/eval/tests/mock_cel_expression.h index bd04e831e..07b32b29f 100644 --- a/eval/tests/mock_cel_expression.h +++ b/eval/tests/mock_cel_expression.h @@ -3,14 +3,12 @@ #include -#include "gmock/gmock.h" -#include "eval/public/activation.h" +#include "absl/status/statusor.h" +#include "eval/public/base_activation.h" #include "eval/public/cel_expression.h" +#include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { class MockCelExpression : public CelExpression { public: @@ -36,9 +34,6 @@ class MockCelExpression : public CelExpression { (const, override)); }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ diff --git a/eval/tests/modern_benchmark_test.cc b/eval/tests/modern_benchmark_test.cc new file mode 100644 index 000000000..005f93aa5 --- /dev/null +++ b/eval/tests/modern_benchmark_test.cc @@ -0,0 +1,1335 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// General benchmarks for CEL evaluator. + +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/rpc/context/attribute_context.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_set.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "common/allocator.h" +#include "common/casting.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/tests/request_context.pb.h" +#include "extensions/comprehensions_v2_functions.h" +#include "extensions/comprehensions_v2_macros.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "extensions/protobuf/value.h" +#include "internal/benchmark.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/constant_folding.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +ABSL_FLAG(bool, enable_recursive_planning, false, "enable recursive planning"); + +namespace cel { + +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::cel::extensions::ProtobufRuntimeAdapter; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::google::api::expr::parser::EnrichedParse; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::RequestContext; +using ::google::rpc::context::AttributeContext; + +RuntimeOptions GetOptions() { + RuntimeOptions options; + + if (absl::GetFlag(FLAGS_enable_recursive_planning)) { + options.max_recursion_depth = -1; + } + + return options; +} + +enum class ConstFoldingEnabled { kNo, kYes }; + +std::unique_ptr StandardRuntimeOrDie( + const cel::RuntimeOptions& options, google::protobuf::Arena* arena = nullptr, + ConstFoldingEnabled const_folding = ConstFoldingEnabled::kNo) { + auto builder = CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options); + ABSL_CHECK_OK(builder.status()); + + switch (const_folding) { + case ConstFoldingEnabled::kNo: + break; + case ConstFoldingEnabled::kYes: + ABSL_CHECK(arena != nullptr); + ABSL_CHECK_OK(extensions::EnableConstantFolding(*builder)); + break; + } + + auto runtime = std::move(builder).value().Build(); + ABSL_CHECK_OK(runtime.status()); + return std::move(runtime).value(); +} + +template +Value WrapMessageOrDie(const T& message, google::protobuf::Arena* absl_nonnull arena) { + auto value = extensions::ProtoMessageToValue( + message, internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), arena); + ABSL_CHECK_OK(value.status()); + return std::move(value).value(); +} + +// Benchmark test +// Evaluates cel expression: +// '1 + 1 + 1 .... +1' +static void BM_Eval(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + int len = state.range(0); + + Expr root_expr; + Expr* cur_expr = &root_expr; + + for (int i = 0; i < len; i++) { + Expr::Call* call = cur_expr->mutable_call_expr(); + call->set_function("_+_"); + call->add_args()->mutable_const_expr()->set_int64_value(1); + cur_expr = call->add_args(); + } + + cur_expr->mutable_const_expr()->set_int64_value(1); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, root_expr)); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_TRUE(Cast(result) == len + 1); + } +} + +BENCHMARK(BM_Eval)->Range(1, 10000); + +absl::Status EmptyCallback(int64_t expr_id, const Value&, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { + return absl::OkStatus(); +} + +// Benchmark test +// Traces cel expression with an empty callback: +// '1 + 1 + 1 .... +1' +static void BM_Eval_Trace(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + options.enable_recursive_tracing = true; + + auto runtime = StandardRuntimeOrDie(options); + + int len = state.range(0); + + Expr root_expr; + Expr* cur_expr = &root_expr; + + for (int i = 0; i < len; i++) { + Expr::Call* call = cur_expr->mutable_call_expr(); + call->set_function("_+_"); + call->add_args()->mutable_const_expr()->set_int64_value(1); + cur_expr = call->add_args(); + } + + cur_expr->mutable_const_expr()->set_int64_value(1); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, root_expr)); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_TRUE(Cast(result) == len + 1); + } +} + +// A number higher than 10k leads to a stack overflow due to the recursive +// nature of the proto to native type conversion. +BENCHMARK(BM_Eval_Trace)->Range(1, 10000); + +// Benchmark test +// Evaluates cel expression: +// '"a" + "a" + "a" .... + "a"' +static void BM_EvalString(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + + auto runtime = StandardRuntimeOrDie(options); + + int len = state.range(0); + + Expr root_expr; + Expr* cur_expr = &root_expr; + + for (int i = 0; i < len; i++) { + Expr::Call* call = cur_expr->mutable_call_expr(); + call->set_function("_+_"); + call->add_args()->mutable_const_expr()->set_string_value("a"); + cur_expr = call->add_args(); + } + + cur_expr->mutable_const_expr()->set_string_value("a"); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, root_expr)); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_TRUE(Cast(result).Size() == len + 1); + } +} + +// A number higher than 10k leads to a stack overflow due to the recursive +// nature of the proto to native type conversion. +BENCHMARK(BM_EvalString)->Range(1, 10000); + +// Benchmark test +// Traces cel expression with an empty callback: +// '"a" + "a" + "a" .... + "a"' +static void BM_EvalString_Trace(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + options.enable_recursive_tracing = true; + + auto runtime = StandardRuntimeOrDie(options); + + int len = state.range(0); + + Expr root_expr; + Expr* cur_expr = &root_expr; + + for (int i = 0; i < len; i++) { + Expr::Call* call = cur_expr->mutable_call_expr(); + call->set_function("_+_"); + call->add_args()->mutable_const_expr()->set_string_value("a"); + cur_expr = call->add_args(); + } + + cur_expr->mutable_const_expr()->set_string_value("a"); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, root_expr)); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_TRUE(Cast(result).Size() == len + 1); + } +} + +// A number higher than 10k leads to a stack overflow due to the recursive +// nature of the proto to native type conversion. +BENCHMARK(BM_EvalString_Trace)->Range(1, 10000); + +const char kIP[] = "10.0.1.2"; +const char kPath[] = "/admin/edit"; +const char kToken[] = "admin"; + +ABSL_ATTRIBUTE_NOINLINE +bool NativeCheck(absl::btree_map& attributes, + const absl::flat_hash_set& denylists, + const absl::flat_hash_set& allowlists) { + auto& ip = attributes["ip"]; + auto& path = attributes["path"]; + auto& token = attributes["token"]; + if (denylists.find(ip) != denylists.end()) { + return false; + } + if (absl::StartsWith(path, "v1")) { + if (token == "v1" || token == "v2" || token == "admin") { + return true; + } + } else if (absl::StartsWith(path, "v2")) { + if (token == "v2" || token == "admin") { + return true; + } + } else if (absl::StartsWith(path, "/admin")) { + if (token == "admin") { + if (allowlists.find(ip) != allowlists.end()) { + return true; + } + } + } + return false; +} + +void BM_PolicyNative(benchmark::State& state) { + const auto denylists = + absl::flat_hash_set{"10.0.1.4", "10.0.1.5", "10.0.1.6"}; + const auto allowlists = + absl::flat_hash_set{"10.0.1.1", "10.0.1.2", "10.0.1.3"}; + auto attributes = absl::btree_map{ + {"ip", kIP}, {"token", kToken}, {"path", kPath}}; + for (auto _ : state) { + auto result = NativeCheck(attributes, denylists, allowlists); + ASSERT_TRUE(result); + } +} + +BENCHMARK(BM_PolicyNative); + +void BM_PolicySymbolic(benchmark::State& state) { + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + !(ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && + ((path.startsWith("v1") && token in ["v1", "v2", "admin"]) || + (path.startsWith("v2") && token in ["v2", "admin"]) || + (path.startsWith("/admin") && token == "admin" && ip in [ + "10.0.1.1", "10.0.1.2", "10.0.1.3" + ]) + ))cel")); + + RuntimeOptions options = GetOptions(); + auto runtime = + StandardRuntimeOrDie(options, &arena, ConstFoldingEnabled::kYes); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + Activation activation; + activation.InsertOrAssignValue("ip", StringValue(&arena, kIP)); + activation.InsertOrAssignValue("path", StringValue(&arena, kPath)); + activation.InsertOrAssignValue("token", StringValue(&arena, kToken)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + auto result_bool = As(result); + ASSERT_TRUE(result_bool && result_bool->NativeValue()); + } +} + +BENCHMARK(BM_PolicySymbolic); + +class RequestMapImpl : public CustomMapValueInterface { + public: + size_t Size() const override { return 3; } + + absl::Status ListKeys( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + ListValue* absl_nonnull result) const override { + return absl::UnimplementedError("Unsupported"); + } + + absl::StatusOr NewIterator() const override { + return absl::UnimplementedError("Unsupported"); + } + + std::string DebugString() const override { return "RequestMapImpl"; } + + absl::Status ConvertToJsonObject( + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Message* absl_nonnull) const override { + return absl::UnimplementedError("Unsupported"); + } + + CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { + return CustomMapValue(google::protobuf::Arena::Create(arena), arena); + } + + protected: + // Called by `Find` after performing various argument checks. + absl::StatusOr Find( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override { + auto string_value = As(key); + if (!string_value) { + return false; + } + if (string_value->Equals("ip")) { + *result = StringValue(kIP); + } else if (string_value->Equals("path")) { + *result = StringValue(kPath); + } else if (string_value->Equals("token")) { + *result = StringValue(kToken); + } else { + return false; + } + return true; + } + + // Called by `Has` after performing various argument checks. + absl::StatusOr Has( + const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + return absl::UnimplementedError("Unsupported."); + } + + private: + NativeTypeId GetNativeTypeId() const override { + return NativeTypeId::For(); + } +}; + +// Uses a lazily constructed map container for "ip", "path", and "token". +void BM_PolicySymbolicMap(benchmark::State& state) { + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && + ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || + (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || + (request.path.startsWith("/admin") && request.token == "admin" && + request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) + ))cel")); + + RuntimeOptions options = GetOptions(); + + auto runtime = StandardRuntimeOrDie(options); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + Activation activation; + CustomMapValue map_value(google::protobuf::Arena::Create(&arena), + &arena); + + activation.InsertOrAssignValue("request", std::move(map_value)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_PolicySymbolicMap); + +// Uses a protobuf container for "ip", "path", and "token". +void BM_PolicySymbolicProto(benchmark::State& state) { + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && + ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || + (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || + (request.path.startsWith("/admin") && request.token == "admin" && + request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) + ))cel")); + + RuntimeOptions options = GetOptions(); + + auto runtime = StandardRuntimeOrDie(options); + + SourceInfo source_info; + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + Activation activation; + RequestContext request; + request.set_ip(kIP); + request.set_path(kPath); + request.set_token(kToken); + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_PolicySymbolicProto); + +// This expression has no equivalent CEL +constexpr char kListSum[] = R"( +id: 1 +comprehension_expr: < + accu_var: "__result__" + iter_var: "x" + iter_range: < + id: 2 + ident_expr: < + name: "list_var" + > + > + accu_init: < + id: 3 + const_expr: < + int64_value: 0 + > + > + loop_step: < + id: 4 + call_expr: < + function: "_+_" + args: < + id: 5 + ident_expr: < + name: "__result__" + > + > + args: < + id: 6 + ident_expr: < + name: "x" + > + > + > + > + loop_condition: < + id: 7 + const_expr: < + bool_value: true + > + > + result: < + id: 8 + ident_expr: < + name: "__result__" + > + > +>)"; + +void BM_Comprehension(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + options.comprehension_max_iterations = 10000000; + auto runtime = StandardRuntimeOrDie(options); + + Expr expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + int len = state.range(0); + list_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + } + + activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_EQ(Cast(result), len); + } +} + +BENCHMARK(BM_Comprehension)->Range(1, 1 << 20); + +void BM_Comprehension_Trace(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + options.enable_recursive_tracing = true; + + options.comprehension_max_iterations = 10000000; + auto runtime = StandardRuntimeOrDie(options); + google::protobuf::Arena arena; + Expr expr; + Activation activation; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + auto list_builder = cel::NewListValueBuilder(&arena); + + int len = state.range(0); + list_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + } + activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_EQ(Cast(result), len); + } +} + +BENCHMARK(BM_Comprehension_Trace)->Range(1, 1 << 20); + +void BM_HasMap(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("has(request.path) && !has(request.ip)")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + auto map_builder = cel::NewMapValueBuilder(&arena); + + ASSERT_THAT( + map_builder->Put(cel::StringValue("path"), cel::StringValue("path")), + IsOk()); + + activation.InsertOrAssignValue("request", std::move(*map_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_HasMap); + +void BM_HasProto(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("has(request.path) && !has(request.ip)")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + RequestContext request; + request.set_path(kPath); + request.set_token(kToken); + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_HasProto); + +void BM_HasProtoMap(benchmark::State& state) { + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("has(request.headers.create_time) && " + "!has(request.headers.update_time)")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + RequestContext request; + request.mutable_headers()->insert({"create_time", "2021-01-01"}); + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_HasProtoMap); + +void BM_ReadProtoMap(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + request.headers.create_time == "2021-01-01" + )cel")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + RequestContext request; + request.mutable_headers()->insert({"create_time", "2021-01-01"}); + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_ReadProtoMap); + +void BM_NestedProtoFieldRead(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + !request.a.b.c.d.e + )cel")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + RequestContext request; + request.mutable_a()->mutable_b()->mutable_c()->mutable_d()->set_e(false); + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_NestedProtoFieldRead); + +void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + !request.a.b.c.d.e + )cel")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + RequestContext request; + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_NestedProtoFieldReadDefaults); + +void BM_ProtoStructAccess(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + has(request.auth.claims.iss) && request.auth.claims.iss == 'accounts.google.com' + )cel")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + AttributeContext::Request request; + auto* auth = request.mutable_auth(); + (*auth->mutable_claims()->mutable_fields())["iss"].set_string_value( + "accounts.google.com"); + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_ProtoStructAccess); + +void BM_ProtoListAccess(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( + "//.../accessLevels/MY_LEVEL_4" in request.auth.access_levels + )cel")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + AttributeContext::Request request; + auto* auth = request.mutable_auth(); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_0"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_1"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_2"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_3"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_4"); + activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result) && + Cast(result).NativeValue()); + } +} + +BENCHMARK(BM_ProtoListAccess); + +// This expression has no equivalent CEL expression. +// Sum a square with a nested comprehension +constexpr char kNestedListSum[] = R"( +id: 1 +comprehension_expr: < + accu_var: "__result__" + iter_var: "x" + iter_range: < + id: 2 + ident_expr: < + name: "list_var" + > + > + accu_init: < + id: 3 + const_expr: < + int64_value: 0 + > + > + loop_step: < + id: 4 + call_expr: < + function: "_+_" + args: < + id: 5 + ident_expr: < + name: "__result__" + > + > + args: < + id: 6 + comprehension_expr: < + accu_var: "__result__" + iter_var: "x" + iter_range: < + id: 9 + ident_expr: < + name: "list_var" + > + > + accu_init: < + id: 10 + const_expr: < + int64_value: 0 + > + > + loop_step: < + id: 11 + call_expr: < + function: "_+_" + args: < + id: 12 + ident_expr: < + name: "__result__" + > + > + args: < + id: 13 + ident_expr: < + name: "x" + > + > + > + > + loop_condition: < + id: 14 + const_expr: < + bool_value: true + > + > + result: < + id: 15 + ident_expr: < + name: "__result__" + > + > + > + > + > + > + loop_condition: < + id: 7 + const_expr: < + bool_value: true + > + > + result: < + id: 8 + ident_expr: < + name: "__result__" + > + > +>)"; + +void BM_NestedComprehension(benchmark::State& state) { + Expr expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); + + RuntimeOptions options = GetOptions(); + options.comprehension_max_iterations = 10000000; + auto runtime = StandardRuntimeOrDie(options); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + int len = state.range(0); + list_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + } + + activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_EQ(Cast(result), len * len); + } +} + +BENCHMARK(BM_NestedComprehension)->Range(1, 1 << 10); + +void BM_NestedComprehension_Trace(benchmark::State& state) { + Expr expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); + + RuntimeOptions options = GetOptions(); + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + options.enable_recursive_tracing = true; + + auto runtime = StandardRuntimeOrDie(options); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + int len = state.range(0); + list_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + } + + activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, &EmptyCallback)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_EQ(Cast(result), len * len); + } +} + +BENCHMARK(BM_NestedComprehension_Trace)->Range(1, 1 << 10); + +void BM_ListComprehension(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("list_var.map(x, x * 2)")); + + RuntimeOptions options = GetOptions(); + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + int len = state.range(0); + list_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + } + + activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); + } +} + +BENCHMARK(BM_ListComprehension)->Range(1, 1 << 16); + +void BM_ListComprehension_Trace(benchmark::State& state) { + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("list_var.map(x, x * 2)")); + + RuntimeOptions options = GetOptions(); + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + options.enable_recursive_tracing = true; + + auto runtime = StandardRuntimeOrDie(options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + int len = state.range(0); + list_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + } + + activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); + } +} + +BENCHMARK(BM_ListComprehension_Trace)->Range(1, 1 << 16); + +void BM_ExistsComprehensionBestCase(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("my_int_list.exists(x, x == 1)")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + + activation.InsertOrAssignValue("my_int_list", + std::move(*list_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.GetBool().NativeValue()); + } +} + +BENCHMARK(BM_ExistsComprehensionBestCase); + +void BM_ExistsComprehensionWorstCase(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("my_int_list.exists(x, x == -1)")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + int len = state.range(0); + list_builder->Reserve(len); + + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(i)), IsOk()); + } + + activation.InsertOrAssignValue("my_int_list", + std::move(*list_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.GetBool().NativeValue()); + } +} + +BENCHMARK(BM_ExistsComprehensionWorstCase)->Range(1, 1 << 10); + +void BM_AllComprehensionBestCase(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("my_int_list.exists(x, x != 1)")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + + activation.InsertOrAssignValue("my_int_list", + std::move(*list_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.GetBool().NativeValue()); + } +} + +BENCHMARK(BM_AllComprehensionBestCase); + +void BM_AllComprehensionWorstCase(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("my_int_list.all(x, x != -1)")); + + RuntimeOptions options = GetOptions(); + auto runtime = StandardRuntimeOrDie(options); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + int len = state.range(0); + list_builder->Reserve(len); + + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(i)), IsOk()); + } + + activation.InsertOrAssignValue("my_int_list", + std::move(*list_builder).Build()); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.GetBool().NativeValue()); + } +} + +BENCHMARK(BM_AllComprehensionWorstCase)->Range(1, 1 << 10); + +void BM_ListComprehension_Opt(benchmark::State& state) { + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("list_var.map(x, x * 2)")); + + RuntimeOptions options = GetOptions(); + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + auto runtime = + StandardRuntimeOrDie(options, &arena, ConstFoldingEnabled::kYes); + + Activation activation; + + auto list_builder = cel::NewListValueBuilder(&arena); + + int len = state.range(0); + list_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); + } + + activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); + } +} + +BENCHMARK(BM_ListComprehension_Opt)->Range(1, 1 << 16); + +void BM_ComprehensionCpp(benchmark::State& state) { + Activation activation; + + std::vector list; + + int len = state.range(0); + list.reserve(len); + for (int i = 0; i < len; i++) { + list.push_back(IntValue(1)); + } + + auto op = [&list]() { + int sum = 0; + for (const auto& value : list) { + sum += Cast(value).NativeValue(); + } + return sum; + }; + for (auto _ : state) { + int result = op(); + ASSERT_EQ(result, len); + } +} + +BENCHMARK(BM_ComprehensionCpp)->Range(1, 1 << 20); + +void BM_MapTransformComprehension(benchmark::State& state) { + ASSERT_OK_AND_ASSIGN(auto source, + NewSource("map_var.transformMapEntry(k, v, {v:k})")); + + MacroRegistry registry; + ASSERT_THAT( + extensions::RegisterComprehensionsV2Macros(registry, ParserOptions()), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + EnrichedParse(*source, registry, ParserOptions())); + + RuntimeOptions options = GetOptions(); + options.comprehension_max_iterations = 10000000; + + // This is a critical optimization: it allows the comprehension to accumulate + // results in a mutable map instead of cloning and augmenting an unmodifiable + // map on every iteration. + options.enable_comprehension_mutable_map = true; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + ASSERT_THAT(extensions::RegisterComprehensionsV2Functions( + builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + google::protobuf::Arena arena; + Activation activation; + + auto map_builder = cel::NewMapValueBuilder(&arena); + + int len = state.range(0); + map_builder->Reserve(len); + for (int i = 0; i < len; i++) { + ASSERT_THAT(map_builder->Put(IntValue(i), IntValue(i)), IsOk()); + } + + activation.InsertOrAssignValue("map_var", std::move(*map_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr.parsed_expr())); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Evaluate(&arena, activation)); + ASSERT_TRUE(InstanceOf(result)); + ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); + } +} + +BENCHMARK(BM_MapTransformComprehension)->Range(1, 1 << 16); + +} // namespace + +} // namespace cel diff --git a/eval/tests/request_context.proto b/eval/tests/request_context.proto index 9e771cfed..446cd2df2 100644 --- a/eval/tests/request_context.proto +++ b/eval/tests/request_context.proto @@ -6,7 +6,22 @@ option cc_enable_arenas = true; // Message representing a sample request context message RequestContext { + // Example for deeply nested messages. + message D { + bool e = 1; + } + message C { + D d = 1; + } + message B { + C c = 1; + } + message A { + B b = 1; + } string ip = 1; string path = 2; string token = 3; + map headers = 4; + A a = 5; } diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc index 672846534..71ffe652c 100644 --- a/eval/tests/unknowns_end_to_end_test.cc +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -4,16 +4,21 @@ // the unknowns is particular to the runtime. #include +#include +#include +#include +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" -#include "google/protobuf/arena.h" -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "absl/container/btree_map.h" +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "eval/eval/evaluator_core.h" +#include "base/attribute.h" +#include "base/function_result.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" @@ -22,11 +27,16 @@ #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/unknown_set.h" -#include "base/status_macros.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/internal/activation_attribute_matcher_access.h" +#include "runtime/internal/attribute_matcher.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" namespace google { namespace api { @@ -34,76 +44,36 @@ namespace expr { namespace runtime { namespace { -using google::api::expr::v1alpha1::Expr; +using ::absl_testing::IsOk; +using ::cel::runtime_internal::ActivationAttributeMatcherAccess; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; using ::google::protobuf::Arena; -using testing::ElementsAre; - -// var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') -constexpr char kExprTextproto[] = R"pb( - id: 13 - call_expr { - function: "_||_" - args { - id: 6 - call_expr { - function: "_&&_" - args { - id: 2 - call_expr { - function: "_>_" - args { - id: 1 - ident_expr { name: "var1" } - } - args { - id: 3 - const_expr { int64_value: 3 } - } - } - } - args { - id: 4 - call_expr { - function: "F1" - args { - id: 5 - const_expr { string_value: "arg1" } - } - } - } - } - } - args { - id: 12 - call_expr { - function: "_&&_" - args { - id: 8 - call_expr { - function: "_>_" - args { - id: 7 - ident_expr { name: "var2" } - } - args { - id: 9 - const_expr { int64_value: 3 } - } - } - } - args { - id: 10 - call_expr { - function: "F2" - args { - id: 11 - const_expr { string_value: "arg2" } - } - } - } - } - } - })pb"; +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +absl::StatusOr MakeCelMap(absl::string_view expr, + google::protobuf::Arena* arena) { + static CelExpressionBuilder* builder = []() { + return CreateCelExpressionBuilder(InterpreterOptions()).release(); + }(); + static absl::NoDestructor activation; + + CEL_ASSIGN_OR_RETURN(ParsedExpr parsed_expr, Parse(expr)); + + CEL_ASSIGN_OR_RETURN(auto plan, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + absl::StatusOr result = plan->Evaluate(*activation, arena); + if (!result.ok()) { + return result.status(); + } + if (!result->IsMap()) { + return absl::FailedPreconditionError( + absl::StrCat("expression did not evaluate to a map: ", expr)); + } + return result; +} enum class FunctionResponse { kUnknown, kTrue, kFalse }; @@ -146,32 +116,29 @@ class UnknownsTest : public testing::Test { InterpreterOptions options; options.unknown_processing = opts; builder_ = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder_->GetRegistry())); - ASSERT_OK( - builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F1"))); - ASSERT_OK( - builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F2"))); - ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kExprTextproto, &expr_)) - << "error parsing expr"; + ASSERT_THAT(RegisterBuiltinFunctions(builder_->GetRegistry()), IsOk()); + ASSERT_THAT( + builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F1")), + IsOk()); + ASSERT_THAT( + builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F2")), + IsOk()); } protected: Arena arena_; Activation activation_; std::unique_ptr builder_; - google::api::expr::v1alpha1::Expr expr_; }; -MATCHER_P2(FunctionCallIs, fn_name, fn_arg, "") { - const UnknownFunctionResult* result = arg; - return result->arguments().size() == 1 && result->arguments()[0].IsString() && - result->arguments()[0].StringOrDie().value() == fn_arg && - result->descriptor().name() == fn_name; +MATCHER_P(FunctionCallIs, fn_name, "") { + const cel::FunctionResult& result = arg; + return result.descriptor().name() == fn_name; } MATCHER_P(AttributeIs, attr, "") { - const CelAttribute* result = arg; - return result->variable().ident_expr().name() == attr; + const cel::Attribute& result = arg; + return result.AsString().value_or("") == attr; } TEST_F(UnknownsTest, NoUnknowns) { @@ -179,20 +146,23 @@ TEST_F(UnknownsTest, NoUnknowns) { activation_.InsertValue("var1", CelValue::CreateInt64(3)); activation_.InsertValue("var2", CelValue::CreateInt64(5)); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F1", FunctionResponse::kFalse))); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F2", FunctionResponse::kTrue))); - - // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') - auto plan = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(plan); - - auto maybe_response = plan.value()->Evaluate(activation_, &arena_); - ASSERT_OK(maybe_response); - CelValue response = maybe_response.value(); - - ASSERT_TRUE(response.IsBool()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F1", FunctionResponse::kFalse)), + IsOk()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F2", FunctionResponse::kTrue)), + IsOk()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); + auto plan = builder_->CreateExpression(&expr.expr(), nullptr); + ASSERT_THAT(plan, IsOk()); + + ASSERT_OK_AND_ASSIGN(CelValue response, + plan.value()->Evaluate(activation_, &arena_)); + + ASSERT_TRUE(response.IsBool()) << response.DebugString(); EXPECT_TRUE(response.BoolOrDie()); } @@ -200,21 +170,24 @@ TEST_F(UnknownsTest, UnknownAttributes) { PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); activation_.InsertValue("var2", CelValue::CreateInt64(3)); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F1", FunctionResponse::kTrue))); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F2", FunctionResponse::kFalse))); - - // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') - auto plan = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(plan); - - auto maybe_response = plan.value()->Evaluate(activation_, &arena_); - ASSERT_OK(maybe_response); - CelValue response = maybe_response.value(); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F1", FunctionResponse::kTrue)), + IsOk()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F2", FunctionResponse::kFalse)), + IsOk()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); + auto plan = builder_->CreateExpression(&expr.expr(), nullptr); + ASSERT_THAT(plan, IsOk()); + + ASSERT_OK_AND_ASSIGN(CelValue response, + plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsUnknownSet()); - EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes(), ElementsAre(AttributeIs("var1"))); } @@ -222,39 +195,88 @@ TEST_F(UnknownsTest, UnknownAttributesPruning) { PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); activation_.InsertValue("var2", CelValue::CreateInt64(5)); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F1", FunctionResponse::kTrue))); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F2", FunctionResponse::kTrue))); - - // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') - auto plan = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(plan); - - auto maybe_response = plan.value()->Evaluate(activation_, &arena_); - ASSERT_OK(maybe_response); - CelValue response = maybe_response.value(); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F1", FunctionResponse::kTrue)), + IsOk()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F2", FunctionResponse::kTrue)), + IsOk()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); + auto plan = builder_->CreateExpression(&expr.expr(), nullptr); + ASSERT_THAT(plan, IsOk()); + + ASSERT_OK_AND_ASSIGN(CelValue response, + plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsBool()); EXPECT_TRUE(response.BoolOrDie()); } +class CustomMatcher : public cel::runtime_internal::AttributeMatcher { + public: + MatchResult CheckForUnknown(const cel::Attribute& attr) const override { + // Rendering to a string just for ease of testing. + std::string name = attr.AsString().value_or(""); + if (name == "var1") { + return MatchResult::PARTIAL; + } else if (name == "var1.foo") { + return MatchResult::FULL; + } + return MatchResult::NONE; + } +}; + +TEST_F(UnknownsTest, UnknownAttributesCustomMatcher) { + PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); + + ASSERT_OK_AND_ASSIGN(auto var1, MakeCelMap("{'bar': 1}", &arena_)); + activation_.InsertValue("var1", var1); + CustomMatcher matcher; + ActivationAttributeMatcherAccess::SetAttributeMatcher(activation_, &matcher); + + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F1", FunctionResponse::kTrue, CelValue::Type::kMap)), + IsOk()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F2", FunctionResponse::kTrue)), + IsOk()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("F1(var1) || var1.foo || var1.bar")); + auto plan = builder_->CreateExpression(&expr.expr(), nullptr); + ASSERT_THAT(plan, IsOk()); + + ASSERT_OK_AND_ASSIGN(CelValue response, + plan.value()->Evaluate(activation_, &arena_)); + + ASSERT_TRUE(response.IsUnknownSet()) << response.DebugString(); + EXPECT_THAT( + response.UnknownSetOrDie()->unknown_attributes(), + UnorderedElementsAre(AttributeIs("var1"), AttributeIs("var1.foo"))); +} + TEST_F(UnknownsTest, UnknownFunctionsWithoutOptionError) { PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); activation_.InsertValue("var1", CelValue::CreateInt64(5)); activation_.InsertValue("var2", CelValue::CreateInt64(3)); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F1", FunctionResponse::kUnknown))); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F2", FunctionResponse::kFalse))); - - // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') - auto plan = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(plan); - - auto maybe_response = plan.value()->Evaluate(activation_, &arena_); - ASSERT_OK(maybe_response); - CelValue response = maybe_response.value(); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F1", FunctionResponse::kUnknown)), + IsOk()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F2", FunctionResponse::kFalse)), + IsOk()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); + auto plan = builder_->CreateExpression(&expr.expr(), nullptr); + ASSERT_THAT(plan, IsOk()); + + ASSERT_OK_AND_ASSIGN(CelValue response, + plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsError()); EXPECT_EQ(response.ErrorOrDie()->code(), absl::StatusCode::kUnavailable); @@ -264,24 +286,25 @@ TEST_F(UnknownsTest, UnknownFunctions) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); activation_.InsertValue("var1", CelValue::CreateInt64(5)); activation_.InsertValue("var2", CelValue::CreateInt64(5)); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F1", FunctionResponse::kUnknown))); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F2", FunctionResponse::kFalse))); - - // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') - auto plan = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(plan); - - auto maybe_response = plan.value()->Evaluate(activation_, &arena_); - ASSERT_OK(maybe_response); - CelValue response = maybe_response.value(); - - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - ElementsAre(FunctionCallIs("F1", "arg1"))); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F1", FunctionResponse::kUnknown)), + IsOk()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F2", FunctionResponse::kFalse)), + IsOk()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); + auto plan = builder_->CreateExpression(&expr.expr(), nullptr); + ASSERT_THAT(plan, IsOk()); + + ASSERT_OK_AND_ASSIGN(CelValue response, + plan.value()->Evaluate(activation_, &arena_)); + + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), + ElementsAre(FunctionCallIs("F1"))); } TEST_F(UnknownsTest, UnknownsMerge) { @@ -289,25 +312,26 @@ TEST_F(UnknownsTest, UnknownsMerge) { activation_.InsertValue("var1", CelValue::CreateInt64(5)); activation_.set_unknown_attribute_patterns({CelAttributePattern("var2", {})}); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F1", FunctionResponse::kUnknown))); - ASSERT_OK(activation_.InsertFunction( - std::make_unique("F2", FunctionResponse::kTrue))); - - // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') - auto plan = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(plan); - - auto maybe_response = plan.value()->Evaluate(activation_, &arena_); - ASSERT_OK(maybe_response); - CelValue response = maybe_response.value(); - - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - ElementsAre(FunctionCallIs("F1", "arg1"))); - EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F1", FunctionResponse::kUnknown)), + IsOk()); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "F2", FunctionResponse::kTrue)), + IsOk()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); + auto plan = builder_->CreateExpression(&expr.expr(), nullptr); + ASSERT_THAT(plan, IsOk()); + + ASSERT_OK_AND_ASSIGN(CelValue response, + plan.value()->Evaluate(activation_, &arena_)); + + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), + ElementsAre(FunctionCallIs("F1"))); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes(), ElementsAre(AttributeIs("var2"))); } @@ -425,9 +449,10 @@ class UnknownsCompTest : public testing::Test { InterpreterOptions options; options.unknown_processing = opts; builder_ = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder_->GetRegistry())); - ASSERT_OK(builder_->GetRegistry()->RegisterLazyFunction( - CreateDescriptor("Fn", CelValue::Type::kInt64))); + ASSERT_THAT(RegisterBuiltinFunctions(builder_->GetRegistry()), IsOk()); + ASSERT_THAT(builder_->GetRegistry()->RegisterLazyFunction( + CreateDescriptor("Fn", CelValue::Type::kInt64)), + IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsExpr, &expr_)) << "error parsing expr"; @@ -443,22 +468,21 @@ class UnknownsCompTest : public testing::Test { TEST_F(UnknownsCompTest, UnknownsMerge) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); - ASSERT_OK(activation_.InsertFunction(std::make_unique( - "Fn", FunctionResponse::kUnknown, CelValue::Type::kInt64))); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "Fn", FunctionResponse::kUnknown, CelValue::Type::kInt64)), + IsOk()); // [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].exists(x, Fn(x) > 5) auto build_status = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(build_status); + ASSERT_THAT(build_status, IsOk()); auto eval_status = build_status.value()->Evaluate(activation_, &arena_); - ASSERT_OK(eval_status); + ASSERT_THAT(eval_status, IsOk()); CelValue response = eval_status.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - testing::SizeIs(10)); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), + testing::SizeIs(1)); } constexpr char kListCompCondExpr[] = R"pb( @@ -561,9 +585,10 @@ class UnknownsCompCondTest : public testing::Test { InterpreterOptions options; options.unknown_processing = opts; builder_ = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder_->GetRegistry())); - ASSERT_OK(builder_->GetRegistry()->RegisterLazyFunction( - CreateDescriptor("Fn", CelValue::Type::kInt64))); + ASSERT_THAT(RegisterBuiltinFunctions(builder_->GetRegistry()), IsOk()); + ASSERT_THAT(builder_->GetRegistry()->RegisterLazyFunction( + CreateDescriptor("Fn", CelValue::Type::kInt64)), + IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListCompCondExpr, &expr_)) << "error parsing expr"; } @@ -578,37 +603,36 @@ class UnknownsCompCondTest : public testing::Test { TEST_F(UnknownsCompCondTest, UnknownConditionReturned) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); - ASSERT_OK(activation_.InsertFunction(std::make_unique( - "Fn", FunctionResponse::kUnknown, CelValue::Type::kInt64))); + ASSERT_THAT(activation_.InsertFunction(std::make_unique( + "Fn", FunctionResponse::kUnknown, CelValue::Type::kInt64)), + IsOk()); // [1, 2, 3].exists_one(x, Fn(x)) auto build_status = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(build_status); + ASSERT_THAT(build_status, IsOk()); auto eval_status = build_status.value()->Evaluate(activation_, &arena_); - ASSERT_OK(eval_status); + ASSERT_THAT(eval_status, IsOk()); CelValue response = eval_status.value(); - ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); // The comprehension ends on the first non-bool condition, so we only get one // call captured in the UnknownSet. - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), testing::SizeIs(1)); } TEST_F(UnknownsCompCondTest, ErrorConditionReturned) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); - // No implementation for Fn(int64_t) provided in activation -- this turns into a + // No implementation for Fn(int64) provided in activation -- this turns into a // CelError. // [1, 2, 3].exists_one(x, Fn(x)) auto build_status = builder_->CreateExpression(&expr_, nullptr); - ASSERT_OK(build_status); + ASSERT_THAT(build_status, IsOk()); auto eval_status = build_status.value()->Evaluate(activation_, &arena_); - ASSERT_OK(eval_status); + ASSERT_THAT(eval_status, IsOk()); CelValue response = eval_status.value(); ASSERT_TRUE(response.IsError()) << CelValue::TypeName(response.type()); @@ -687,9 +711,10 @@ TEST(UnknownsIterAttrTest, IterAttributeTrail) { options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); - ASSERT_OK(builder->GetRegistry()->RegisterLazyFunction( - CreateDescriptor("Fn", CelValue::Type::kMap))); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); + ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( + CreateDescriptor("Fn", CelValue::Type::kMap)), + IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsWithAttrExpr, &expr)) << "error parsing expr"; @@ -702,24 +727,24 @@ TEST(UnknownsIterAttrTest, IterAttributeTrail) { // var[1]['elem1'] is unknown activation.set_unknown_attribute_patterns({CelAttributePattern( "var", { - CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)), - CelAttributeQualifierPattern::Create( + CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), + CreateCelAttributeQualifierPattern( CelValue::CreateStringView("elem1")), })}); - ASSERT_OK(activation.InsertFunction(std::make_unique( - "Fn", FunctionResponse::kFalse, CelValue::Type::kMap))); + ASSERT_THAT(activation.InsertFunction(std::make_unique( + "Fn", FunctionResponse::kFalse, CelValue::Type::kMap)), + IsOk()); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1]' is partially unknown when we make the function call so we treat it // as unknown. ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 1); @@ -732,7 +757,7 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypes) { Arena arena; UnknownSet unknown_set; - CelError error; + CelError error = absl::CancelledError(); std::vector> backing; @@ -742,13 +767,14 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypes) { {CelValue::CreateError(&error), CelValue::CreateBool(false)}); backing.push_back({CelValue::CreateBool(true), CelValue::CreateBool(false)}); - auto map_impl = CreateContainerBackedMap(absl::MakeSpan(backing)); + auto map_impl = CreateContainerBackedMap(absl::MakeSpan(backing)).value(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); - ASSERT_OK(builder->GetRegistry()->RegisterLazyFunction( - CreateDescriptor("Fn", CelValue::Type::kBool))); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); + ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( + CreateDescriptor("Fn", CelValue::Type::kBool)), + IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsWithAttrExpr, &expr)) << "error parsing expr"; @@ -758,13 +784,14 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypes) { activation.InsertValue("var", CelValue::CreateMap(map_impl.get())); - ASSERT_OK(activation.InsertFunction(std::make_unique( - "Fn", FunctionResponse::kFalse, CelValue::Type::kBool))); + ASSERT_THAT(activation.InsertFunction(std::make_unique( + "Fn", FunctionResponse::kFalse, CelValue::Type::kBool)), + IsOk()); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ(response.UnknownSetOrDie(), &unknown_set); + ASSERT_EQ(*response.UnknownSetOrDie(), unknown_set); } TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypesShortcutted) { @@ -774,7 +801,7 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypesShortcutted) { Arena arena; UnknownSet unknown_set; - CelError error; + CelError error = absl::CancelledError(); std::vector> backing; @@ -784,13 +811,14 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypesShortcutted) { {CelValue::CreateError(&error), CelValue::CreateBool(false)}); backing.push_back({CelValue::CreateBool(true), CelValue::CreateBool(false)}); - auto map_impl = CreateContainerBackedMap(absl::MakeSpan(backing)); + auto map_impl = CreateContainerBackedMap(absl::MakeSpan(backing)).value(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); - ASSERT_OK(builder->GetRegistry()->RegisterLazyFunction( - CreateDescriptor("Fn", CelValue::Type::kBool))); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); + ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( + CreateDescriptor("Fn", CelValue::Type::kBool)), + IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsWithAttrExpr, &expr)) << "error parsing expr"; @@ -800,8 +828,9 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypesShortcutted) { activation.InsertValue("var", CelValue::CreateMap(map_impl.get())); - ASSERT_OK(activation.InsertFunction(std::make_unique( - "Fn", FunctionResponse::kTrue, CelValue::Type::kBool))); + ASSERT_THAT(activation.InsertFunction(std::make_unique( + "Fn", FunctionResponse::kTrue, CelValue::Type::kBool)), + IsOk()); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsBool()) << CelValue::TypeName(response.type()); @@ -882,35 +911,36 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMap) { options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); - ASSERT_OK(builder->GetRegistry()->RegisterLazyFunction( - CreateDescriptor("Fn", CelValue::Type::kDouble))); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); + ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( + CreateDescriptor("Fn", CelValue::Type::kDouble)), + IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kMapElementsComp, &expr)) << "error parsing expr"; activation.InsertValue("var", CelProtoWrapper::CreateMessage(&list, &arena)); // var[1]['key'] is unknown activation.set_unknown_attribute_patterns({CelAttributePattern( - "var", { - CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)), - CelAttributeQualifierPattern::Create( - CelValue::CreateStringView("key")), - })}); + "var", + { + CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), + CreateCelAttributeQualifierPattern(CelValue::CreateStringView("key")), + })}); - ASSERT_OK(activation.InsertFunction(std::make_unique( - "Fn", FunctionResponse::kFalse, CelValue::Type::kDouble))); + ASSERT_THAT(activation.InsertFunction(std::make_unique( + "Fn", FunctionResponse::kFalse, CelValue::Type::kDouble)), + IsOk()); auto plan = builder->CreateExpression(&expr, nullptr).value(); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1].key' is unknown when we make the Fn function call. // comprehension is: ((([] + false) + unk) + false) -> unk ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 2); @@ -984,6 +1014,52 @@ constexpr char kFilterElementsComp[] = R"pb( } })pb"; +TEST(UnknownsIterAttrTest, IterAttributeTrailExact) { + InterpreterOptions options; + Activation activation; + Arena arena; + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("list_var.exists(x, x)")); + + protobuf::Value element; + element.set_bool_value(false); + protobuf::ListValue list; + *list.add_values() = element; + *list.add_values() = element; + *list.add_values() = element; + + (*list.mutable_values())[0].set_bool_value(true); + + options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); + activation.InsertValue("list_var", + CelProtoWrapper::CreateMessage(&list, &arena)); + + // list_var[0] + std::vector unknown_attribute_patterns; + unknown_attribute_patterns.push_back(CelAttributePattern( + "list_var", + {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(0))})); + activation.set_unknown_attribute_patterns( + std::move(unknown_attribute_patterns)); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + CelValue response = plan->Evaluate(activation, &arena).value(); + + ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); + + ASSERT_EQ(response.UnknownSetOrDie() + ->unknown_attributes() + .begin() + ->qualifier_path() + .size(), + 1); +} + TEST(UnknownsIterAttrTest, IterAttributeTrailFilterValues) { InterpreterOptions options; Expr expr; @@ -1005,7 +1081,7 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterValues) { options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kFilterElementsComp, &expr)) << "error parsing expr"; activation.InsertValue("var", CelProtoWrapper::CreateMessage(&list, &arena)); @@ -1013,8 +1089,8 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterValues) { // var[1]['value_key'] is unknown activation.set_unknown_attribute_patterns({CelAttributePattern( "var", { - CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)), - CelAttributeQualifierPattern::Create( + CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), + CreateCelAttributeQualifierPattern( CelValue::CreateStringView("value_key")), })}); @@ -1022,13 +1098,12 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterValues) { CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1].value_key' is unknown when we make the cons function call. // comprehension is: ((([] + [1]) + unk) + [1]) -> unk ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 2); @@ -1055,7 +1130,7 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterConditions) { options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kFilterElementsComp, &expr)) << "error parsing expr"; activation.InsertValue("var", CelProtoWrapper::CreateMessage(&list, &arena)); @@ -1065,15 +1140,15 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterConditions) { {CelAttributePattern( "var", { - CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)), - CelAttributeQualifierPattern::Create( + CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), + CreateCelAttributeQualifierPattern( CelValue::CreateStringView("filter_key")), }), CelAttributePattern( "var", { - CelAttributeQualifierPattern::Create(CelValue::CreateInt64(0)), - CelAttributeQualifierPattern::Create( + CreateCelAttributeQualifierPattern(CelValue::CreateInt64(0)), + CreateCelAttributeQualifierPattern( CelValue::CreateStringView("filter_key")), })}); @@ -1088,11 +1163,10 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterConditions) { // loop2: (true)? unk{1} + [1] : unk{1} -> unk{1} // result: unk{1} ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 2); diff --git a/eval/testutil/BUILD b/eval/testutil/BUILD index 268e225b1..cb35e6752 100644 --- a/eval/testutil/BUILD +++ b/eval/testutil/BUILD @@ -1,10 +1,13 @@ +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") + # This package contains testing utility code package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) proto_library( - name = "test_message_protos", + name = "test_message_proto", srcs = [ "test_message.proto", ], @@ -19,5 +22,18 @@ proto_library( cc_proto_library( name = "test_message_cc_proto", - deps = [":test_message_protos"], + deps = [":test_message_proto"], +) + +proto_library( + name = "test_extensions_proto", + srcs = [ + "test_extensions.proto", + ], + deps = ["@com_google_protobuf//:wrappers_proto"], +) + +cc_proto_library( + name = "test_extensions_cc_proto", + deps = [":test_extensions_proto"], ) diff --git a/eval/testutil/args.proto b/eval/testutil/args.proto deleted file mode 100644 index f4ec6991e..000000000 --- a/eval/testutil/args.proto +++ /dev/null @@ -1,47 +0,0 @@ -syntax = "proto3"; - -package google.api.expr.runtime; -option cc_enable_arenas = true; - -// Message representing errors -// during CEL evaluation. -message Argument { - oneof arg_kind { - bool bool_value = 1; - int64 int64_value = 2; - uint64 uint64_value = 3; - - float float_value = 4; - double double_value = 5; - - string string_value = 6; - bytes bytes_value = 7; - - google.protobuf.Duration duration = 8; - google.protobuf.Timestamp timestamp = 9; - } - - TestMessage message_value = 12; - - repeated int32 int32_list = 101; - repeated int64 int64_list = 102; - - repeated uint32 uint32_list = 103; - repeated uint64 uint64_list = 104; - - repeated float float_list = 105; - repeated double double_list = 106; - - repeated string string_list = 107; - repeated string cord_list = 108 [ctype = CORD]; - repeated bytes bytes_list = 109; - - repeated bool bool_list = 110; - - repeated TestEnum enum_list = 111; - repeated TestMessage message_list = 112; - - map int64_int32_map = 201; - map uint64_int32_map = 202; - map string_int32_map = 203; -} diff --git a/eval/testutil/test_extensions.proto b/eval/testutil/test_extensions.proto new file mode 100644 index 000000000..4a422c62b --- /dev/null +++ b/eval/testutil/test_extensions.proto @@ -0,0 +1,38 @@ +syntax = "proto2"; + +package google.api.expr.runtime; + +import "google/protobuf/wrappers.proto"; + +option cc_enable_arenas = true; +option java_multiple_files = true; + +enum TestExtEnum { + TEST_EXT_UNSPECIFIED = 0; + TEST_EXT_1 = 10; + TEST_EXT_2 = 20; + TEST_EXT_3 = 30; +} + +// This proto is used to show how extensions are tracked as fields +// with fully qualified names. +message TestExtensions { + optional string name = 1; + + extensions 100 to max; +} + +// Package scoped extensions. +extend TestExtensions { + optional TestExtensions nested_ext = 100; + optional int32 int32_ext = 101; + optional google.protobuf.Int32Value int32_wrapper_ext = 102; +} + +// Message scoped extensions. +message TestMessageExtensions { + extend TestExtensions { + repeated string repeated_string_exts = 103; + optional TestExtEnum enum_ext = 104; + } +} \ No newline at end of file diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index 22fb71c70..b59d9bc19 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -43,27 +43,33 @@ message TestMessage { TestMessage message_value = 12; + reserved 99; + repeated int32 int32_list = 101; repeated int64 int64_list = 102; - repeated uint32 uint32_list = 103; repeated uint64 uint64_list = 104; - repeated float float_list = 105; repeated double double_list = 106; - repeated string string_list = 107; repeated string cord_list = 108 [ctype = CORD]; repeated bytes bytes_list = 109; - repeated bool bool_list = 110; - repeated TestEnum enum_list = 111; repeated TestMessage message_list = 112; + repeated google.protobuf.Timestamp timestamp_list = 113; map int64_int32_map = 201; map uint64_int32_map = 202; map string_int32_map = 203; + map bool_int32_map = 204; + map int32_int32_map = 205; + map uint32_uint32_map = 206; + map int32_float_map = 207; + map int64_enum_map = 208; + map string_timestamp_map = 209; + map string_message_map = 210; + map int64_timestamp_map = 211; // Well-known types. google.protobuf.Any any_value = 300; diff --git a/extensions/BUILD b/extensions/BUILD new file mode 100644 index 000000000..05104a4a5 --- /dev/null +++ b/extensions/BUILD @@ -0,0 +1,860 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "encoders", + srcs = ["encoders.cc"], + hdrs = ["encoders.h"], + deps = [ + "//checker:type_checker_builder", + "//common:decl", + "//common:type", + "//common:value", + "//compiler", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "encoders_test", + srcs = ["encoders_test.cc"], + deps = [ + ":encoders", + "//checker:standard_library", + "//checker:validation_result", + "//compiler", + "//compiler:compiler_factory", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "proto_ext", + srcs = ["proto_ext.cc"], + hdrs = ["proto_ext.h"], + deps = [ + "//common:expr", + "//compiler", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "math_ext", + srcs = ["math_ext.cc"], + hdrs = ["math_ext.h"], + deps = [ + ":math_ext_decls", + "//common:casting", + "//common:value", + "//eval/public:cel_function_registry", + "//eval/public:cel_number", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "math_ext_macros", + srcs = ["math_ext_macros.cc"], + hdrs = ["math_ext_macros.h"], + deps = [ + "//common:ast", + "//common:constant", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "//parser:options", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "math_ext_decls", + srcs = ["math_ext_decls.cc"], + hdrs = ["math_ext_decls.h"], + deps = [ + ":math_ext_macros", + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", + "//common:type_kind", + "//compiler", + "//internal:status_macros", + "//parser:parser_interface", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "math_ext_test", + srcs = ["math_ext_test.cc"], + deps = [ + ":math_ext", + ":math_ext_decls", + ":math_ext_macros", + "//checker:standard_library", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:decl", + "//common:function_descriptor", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/testing:matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +# New users should use ":regex_ext" instead. +cc_library( + name = "regex_functions", + srcs = ["regex_functions.cc"], + hdrs = ["regex_functions.h"], + deps = [ + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", + "//common:value", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:re2_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_library( + name = "bindings_ext", + srcs = ["bindings_ext.cc"], + hdrs = ["bindings_ext.h"], + deps = [ + "//checker:type_checker_builder", + "//common:decl", + "//common:expr", + "//common:type", + "//compiler", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "regex_functions_test", + srcs = [ + "regex_functions_test.cc", + ], + deps = [ + ":regex_functions", + "//checker:standard_library", + "//checker:validation_result", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//extensions/protobuf:runtime_adapter", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "bindings_ext_test", + srcs = ["bindings_ext_test.cc"], + deps = [ + ":bindings_ext", + "//base:attributes", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function", + "//eval/public:cel_function_adapter", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public/structs:cel_proto_wrapper", + "//eval/public/testing:matchers", + "//internal:testing", + "//parser", + "//parser:macro", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "bindings_ext_benchmark_test", + srcs = ["bindings_ext_benchmark_test.cc"], + tags = ["benchmark"], + deps = [ + ":bindings_ext", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//internal:benchmark", + "//internal:testing", + "//parser", + "//parser:macro", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "select_optimization", + srcs = ["select_optimization.cc"], + hdrs = ["select_optimization.h"], + deps = [ + "//base:attributes", + "//base:builtins", + "//common:ast", + "//common:ast_rewrite", + "//common:casting", + "//common:constant", + "//common:expr", + "//common:function_descriptor", + "//common:kind", + "//common:native_type", + "//common:type", + "//common:value", + "//eval/compiler:flat_expr_builder", + "//eval/compiler:flat_expr_builder_extensions", + "//eval/eval:attribute_trail", + "//eval/eval:direct_expression_step", + "//eval/eval:evaluator_core", + "//eval/eval:expression_step_base", + "//internal:casts", + "//internal:number", + "//internal:status_macros", + "//runtime:runtime_builder", + "//runtime/internal:errors", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "select_optimization_test", + srcs = ["select_optimization_test.cc"], + deps = [ + ":select_optimization", + "//base:ast", + "//base:attributes", + "//base:builtins", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:decl", + "//common:decl_proto", + "//common:expr", + "//common:kind", + "//common:memory", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//eval/compiler:flat_expr_builder", + "//eval/compiler:flat_expr_builder_extensions", + "//eval/compiler:resolver", + "//eval/eval:evaluator_core", + "//eval/internal:interop", + "//eval/public:cel_type_registry", + "//eval/public:cel_value", + "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_info_apis", + "//extensions/protobuf:ast_converters", + "//internal:number", + "//internal:status_macros", + "//internal:testing", + "//parser", + "//runtime:activation", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_issue", + "//runtime:runtime_options", + "//runtime:type_registry", + "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:empty_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "lists_functions", + srcs = ["lists_functions.cc"], + hdrs = ["lists_functions.h"], + deps = [ + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:expr", + "//common:operators", + "//common:type", + "//common:value", + "//common:value_kind", + "//compiler", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "//parser:options", + "//parser:parser_interface", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "lists_functions_test", + srcs = ["lists_functions_test.cc"], + deps = [ + ":lists_functions", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:source", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", + "//parser:options", + "//parser:standard_macros", + "//runtime", + "//runtime:activation", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "sets_functions", + srcs = ["sets_functions.cc"], + hdrs = ["sets_functions.h"], + deps = [ + "//base:function_adapter", + "//checker:type_checker_builder", + "//common:decl", + "//common:type", + "//common:value", + "//compiler", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "sets_functions_test", + srcs = ["sets_functions_test.cc"], + deps = [ + ":sets_functions", + "//checker:standard_library", + "//checker:validation_result", + "//common:ast_proto", + "//common:minimal_descriptor_pool", + "//compiler:compiler_factory", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function_adapter", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//internal:testing", + "//runtime:runtime_options", + "@com_google_absl//absl/status:status_matchers", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "sets_functions_benchmark_test", + srcs = ["sets_functions_benchmark_test.cc"], + tags = ["benchmark"], + deps = [ + ":sets_functions", + "//common:value", + "//eval/internal:interop", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//internal:benchmark", + "//internal:status_macros", + "//internal:testing", + "//parser", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "strings", + srcs = ["strings.cc"], + hdrs = ["strings.h"], + deps = [ + ":formatting", + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", + "//common:value", + "//compiler", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "strings_test", + srcs = ["strings_test.cc"], + deps = [ + ":strings", + "//checker:standard_library", + "//checker:type_check_issue", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:decl", + "//common:type", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:options", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "//testutil:baseline_tests", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "comprehensions_v2_functions", + srcs = ["comprehensions_v2_functions.cc"], + hdrs = ["comprehensions_v2_functions.h"], + deps = [ + "//common:value", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "comprehensions_v2_macros", + srcs = ["comprehensions_v2_macros.cc"], + hdrs = ["comprehensions_v2_macros.h"], + deps = [ + "//common:expr", + "//common:operators", + "//compiler", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "comprehensions_v2", + srcs = ["comprehensions_v2.cc"], + hdrs = ["comprehensions_v2.h"], + deps = [ + ":comprehensions_v2_functions", + ":comprehensions_v2_macros", + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", + "//compiler", + "//internal:status_macros", + "//parser:macro_registry", + "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "comprehensions_v2_test", + srcs = ["comprehensions_v2_test.cc"], + deps = [ + ":bindings_ext", + ":comprehensions_v2", + ":comprehensions_v2_functions", + ":strings", + "//checker:standard_library", + "//checker:validation_result", + "//common:value", + "//common:value_testing", + "//compiler:compiler_factory", + "//compiler:optional", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "formatting", + srcs = ["formatting.cc"], + hdrs = ["formatting.h"], + deps = [ + "//common:value", + "//common:value_kind", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "regex_ext", + srcs = ["regex_ext.cc"], + hdrs = ["regex_ext.h"], + deps = [ + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", + "//common:value", + "//compiler", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:casts", + "//internal:re2_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_builder", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "//validator", + "//validator:regex_validator", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "regex_ext_test", + srcs = ["regex_ext_test.cc"], + deps = [ + ":regex_ext", + "//checker:standard_library", + "//checker:validation_result", + "//common:kind", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//eval/public:activation", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//extensions/protobuf:runtime_adapter", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:reference_resolver", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "//validator", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "formatting_test", + srcs = ["formatting_test.cc"], + deps = [ + ":formatting", + "//common:value", + "//extensions/protobuf:runtime_adapter", + "//internal:parse_text_proto", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//parser", + "//parser:options", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc new file mode 100644 index 000000000..4823c077c --- /dev/null +++ b/extensions/bindings_ext.cc @@ -0,0 +1,103 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/bindings_ext.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/parser_interface.h" + +namespace cel::extensions { + +namespace { + +static constexpr char kCelNamespace[] = "cel"; +static constexpr char kBind[] = "bind"; +static constexpr char kBlock[] = "cel.@block"; +static constexpr char kBlockOverloadId[] = "cel_block_list"; +static constexpr char kUnusedIterVar[] = "#unused"; + +bool IsTargetNamespace(const Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == kCelNamespace; +} + +inline absl::Status ConfigureParser(ParserBuilder& parser_builder) { + for (const Macro& macro : bindings_macros()) { + CEL_RETURN_IF_ERROR(parser_builder.AddMacro(macro)); + } + return absl::OkStatus(); +} + +absl::Status ConfigureChecker(int version, + TypeCheckerBuilder& type_checker_builder) { + if (version < 1) { + return absl::OkStatus(); + } + static Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl(kBlock, MakeOverloadDecl(kBlockOverloadId, kParam, + ListType(), kParam))); + return type_checker_builder.AddFunction(std::move(decl)); +} + +} // namespace + +std::vector bindings_macros() { + absl::StatusOr cel_bind = Macro::Receiver( + kBind, 3, + [](MacroExprFactory& factory, Expr& target, + absl::Span args) -> absl::optional { + if (!IsTargetNamespace(target)) { + return std::nullopt; + } + if (!args[0].has_ident_expr()) { + return factory.ReportErrorAt( + args[0], "cel.bind() variable name must be a simple identifier"); + } + auto var_name = args[0].ident_expr().name(); + return factory.NewComprehension(kUnusedIterVar, factory.NewList(), + std::move(var_name), std::move(args[1]), + factory.NewBoolConst(false), + std::move(args[0]), std::move(args[2])); + }); + return {*cel_bind}; +} + +CompilerLibrary BindingsCompilerLibrary(int version) { + return CompilerLibrary( + "cel.lib.ext.bindings", &ConfigureParser, + [version](auto& b) { return ConfigureChecker(version, b); }); +} + +CheckerLibrary BindingsCheckerLibrary(int version) { + return CheckerLibrary{"cel.lib.ext.bindings", [version](auto& b) { + return ConfigureChecker(version, b); + }}; +} + +} // namespace cel::extensions diff --git a/extensions/bindings_ext.h b/extensions/bindings_ext.h new file mode 100644 index 000000000..40b83a37f --- /dev/null +++ b/extensions/bindings_ext.h @@ -0,0 +1,46 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_BINDINGS_EXT_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_BINDINGS_EXT_H_ + +#include + +#include "absl/status/status.h" +#include "compiler/compiler.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +constexpr int kBindingsVersionLatest = 1; +// bindings_macros() returns a macro for cel.bind() which can be used to support +// local variable bindings within expressions. +std::vector bindings_macros(); + +inline absl::Status RegisterBindingsMacros(MacroRegistry& registry, + const ParserOptions&) { + return registry.RegisterMacros(bindings_macros()); +} + +// Declarations for the bindings extension library. +CompilerLibrary BindingsCompilerLibrary(int version = kBindingsVersionLatest); + +// Declarations for the bindings extension library. +CheckerLibrary BindingsCheckerLibrary(int version = kBindingsVersionLatest); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_BINDINGS_EXT_H_ diff --git a/extensions/bindings_ext_benchmark_test.cc b/extensions/bindings_ext_benchmark_test.cc new file mode 100644 index 000000000..52203d810 --- /dev/null +++ b/extensions/bindings_ext_benchmark_test.cc @@ -0,0 +1,252 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/testing/matchers.h" +#include "extensions/bindings_ext.h" +#include "internal/benchmark.h" +#include "internal/testing.h" +#include "parser/macro.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::google::api::expr::parser::ParseWithMacros; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::test::CelValueMatcher; +using ::google::api::expr::runtime::test::IsCelBool; +using ::google::api::expr::runtime::test::IsCelString; + +struct BenchmarkCase { + std::string name; + std::string expression; + CelValueMatcher matcher; +}; + +const std::vector& BenchmarkCases() { + static absl::NoDestructor> cases( + std::vector{ + {"simple", R"(cel.bind(x, "ab", x))", IsCelString("ab")}, + {"multiple_references", R"(cel.bind(x, "ab", x + x + x + x))", + IsCelString("abababab")}, + {"nested", + R"( + cel.bind( + x, + "ab", + cel.bind( + y, + "cd", + x + y + "ef")))", + IsCelString("abcdef")}, + {"nested_defintion", + R"( + cel.bind( + x, + "ab", + cel.bind( + y, + x + "cd", + y + "ef" + )))", + IsCelString("abcdef")}, + {"bind_outside_loop", + R"( + cel.bind( + outer_value, + [1, 2, 3], + [3, 2, 1].all( + value, + value in outer_value) + ))", + IsCelBool(true)}, + {"bind_inside_loop", + R"( + [3, 2, 1].all( + x, + cel.bind(value, x * x, value < 16) + ))", + IsCelBool(true)}, + {"bind_loop_bind", + R"( + cel.bind( + outer_value, + {1: 2, 2: 3, 3: 4}, + outer_value.all( + key, + cel.bind( + value, + outer_value[key], + value == key + 1 + ) + )))", + IsCelBool(true)}, + {"ternary_depends_on_bind", + R"( + cel.bind( + a, + "ab", + (true && a.startsWith("c")) ? a : "cd" + ))", + IsCelString("cd")}, + {"ternary_does_not_depend_on_bind", + R"( + cel.bind( + a, + "ab", + (false && a.startsWith("c")) ? a : "cd" + ))", + IsCelString("cd")}, + {"twice_nested_defintion", + R"( + cel.bind( + x, + "ab", + cel.bind( + y, + x + "cd", + cel.bind( + z, + y + "ef", + z))) + )", + IsCelString("abcdef")}, + }); + + return *cases; +} + +class BindingsBenchmarkTest : public ::testing::TestWithParam { + protected: + google::protobuf::Arena arena_; +}; + +TEST_P(BindingsBenchmarkTest, CheckBenchmarkCaseWorks) { + const BenchmarkCase& benchmark = GetParam(); + + std::vector all_macros = Macro::AllMacros(); + std::vector bindings_macros = cel::extensions::bindings_macros(); + all_macros.insert(all_macros.end(), bindings_macros.begin(), + bindings_macros.end()); + ASSERT_OK_AND_ASSIGN( + auto expr, ParseWithMacros(benchmark.expression, all_macros, "")); + + InterpreterOptions options; + auto builder = + google::api::expr::runtime::CreateCelExpressionBuilder(options); + + ASSERT_OK(google::api::expr::runtime::RegisterBuiltinFunctions( + builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression( + &expr.expr(), &expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, program->Evaluate(activation, &arena)); + + EXPECT_THAT(result, benchmark.matcher); +} + +void RunBenchmark(const BenchmarkCase& benchmark, benchmark::State& state) { + std::vector all_macros = Macro::AllMacros(); + std::vector bindings_macros = cel::extensions::bindings_macros(); + all_macros.insert(all_macros.end(), bindings_macros.begin(), + bindings_macros.end()); + ASSERT_OK_AND_ASSIGN( + auto expr, ParseWithMacros(benchmark.expression, all_macros, "")); + + InterpreterOptions options; + auto builder = + google::api::expr::runtime::CreateCelExpressionBuilder(options); + + ASSERT_OK(google::api::expr::runtime::RegisterBuiltinFunctions( + builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression( + &expr.expr(), &expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + for (auto _ : state) { + auto result = program->Evaluate(activation, &arena); + benchmark::DoNotOptimize(result); + ABSL_DCHECK_OK(result); + ABSL_DCHECK(benchmark.matcher.Matches(*result)); + } +} + +void BM_Simple(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[0], state); +} +void BM_MultipleReferences(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[1], state); +} +void BM_Nested(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[2], state); +} +void BM_NestedDefinition(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[3], state); +} +void BM_BindOusideLoop(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[4], state); +} +void BM_BindInsideLoop(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[5], state); +} +void BM_BindLoopBind(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[6], state); +} +void BM_TernaryDependsOnBind(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[7], state); +} +void BM_TernaryDoesNotDependOnBind(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[8], state); +} +void BM_TwiceNestedDefinition(benchmark::State& state) { + RunBenchmark(BenchmarkCases()[9], state); +} + +BENCHMARK(BM_Simple); +BENCHMARK(BM_MultipleReferences); +BENCHMARK(BM_Nested); +BENCHMARK(BM_NestedDefinition); +BENCHMARK(BM_BindOusideLoop); +BENCHMARK(BM_BindInsideLoop); +BENCHMARK(BM_BindLoopBind); +BENCHMARK(BM_TernaryDependsOnBind); +BENCHMARK(BM_TernaryDoesNotDependOnBind); +BENCHMARK(BM_TwiceNestedDefinition); + +INSTANTIATE_TEST_SUITE_P(BindingsBenchmarkTest, BindingsBenchmarkTest, + ::testing::ValuesIn(BenchmarkCases())); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/bindings_ext_test.cc b/extensions/bindings_ext_test.cc new file mode 100644 index 000000000..c8b12c24a --- /dev/null +++ b/extensions/bindings_ext_test.cc @@ -0,0 +1,872 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/bindings_ext.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "base/attribute.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_adapter.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "parser/macro.h" +#include "parser/parser.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto2::NestedTestAllTypes; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::google::api::expr::parser::ParseWithMacros; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelFunction; +using ::google::api::expr::runtime::CelFunctionDescriptor; +using ::google::api::expr::runtime::CelProtoWrapper; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::FunctionAdapter; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +using ::google::api::expr::runtime::UnknownProcessingOptions; +using ::google::api::expr::runtime::test::IsCelInt64; +using ::google::protobuf::Arena; +using ::google::protobuf::TextFormat; +using ::testing::Contains; +using ::testing::HasSubstr; +using ::testing::Pair; + +struct TestInfo { + std::string expr; + std::string err = ""; +}; + +class TestFunction : public CelFunction { + public: + explicit TestFunction(absl::string_view name) + : CelFunction(CelFunctionDescriptor( + name, true, + {CelValue::Type::kBool, CelValue::Type::kBool, + CelValue::Type::kBool, CelValue::Type::kBool})) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + Arena* arena) const override { + *result = CelValue::CreateBool(true); + return absl::OkStatus(); + } +}; + +// Test function used to test macro collision and non-expansion. +constexpr absl::string_view kBind = "bind"; +std::unique_ptr CreateBindFunction() { + return std::make_unique(kBind); +} + +class BindingsExtTest + : public testing::TestWithParam> { + protected: + const TestInfo& GetTestInfo() { return std::get<0>(GetParam()); } + bool GetEnableConstantFolding() { return std::get<1>(GetParam()); } + bool GetEnableRecursivePlan() { return std::get<2>(GetParam()); } +}; + +TEST_P(BindingsExtTest, Default) { + const TestInfo& test_info = GetTestInfo(); + Arena arena; + std::vector all_macros = Macro::AllMacros(); + std::vector bindings_macros = cel::extensions::bindings_macros(); + all_macros.insert(all_macros.end(), bindings_macros.begin(), + bindings_macros.end()); + auto result = ParseWithMacros(test_info.expr, all_macros, ""); + if (!test_info.err.empty()) { + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_info.err))); + return; + } + EXPECT_THAT(result, IsOk()); + + ParsedExpr parsed_expr = *result; + Expr expr = parsed_expr.expr(); + SourceInfo source_info = parsed_expr.source_info(); + + // Obtain CEL Expression builder. + InterpreterOptions options; + options.enable_heterogeneous_equality = true; + options.enable_empty_wrapper_null_unboxing = true; + options.constant_folding = GetEnableConstantFolding(); + options.constant_arena = &arena; + options.max_recursion_depth = GetEnableRecursivePlan() ? -1 : 0; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, &source_info)); + Activation activation; + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsBool()) << out.DebugString(); + EXPECT_EQ(out.BoolOrDie(), true); +} + +TEST_P(BindingsExtTest, Tracing) { + const TestInfo& test_info = GetTestInfo(); + Arena arena; + std::vector all_macros = Macro::AllMacros(); + std::vector bindings_macros = cel::extensions::bindings_macros(); + all_macros.insert(all_macros.end(), bindings_macros.begin(), + bindings_macros.end()); + auto result = ParseWithMacros(test_info.expr, all_macros, ""); + if (!test_info.err.empty()) { + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_info.err))); + return; + } + EXPECT_THAT(result, IsOk()); + + ParsedExpr parsed_expr = *result; + Expr expr = parsed_expr.expr(); + SourceInfo source_info = parsed_expr.source_info(); + + // Obtain CEL Expression builder. + InterpreterOptions options; + options.enable_heterogeneous_equality = true; + options.enable_empty_wrapper_null_unboxing = true; + options.constant_folding = GetEnableConstantFolding(); + options.constant_arena = &arena; + options.max_recursion_depth = GetEnableRecursivePlan() ? -1 : 0; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&expr, &source_info)); + Activation activation; + // Run evaluation. + ASSERT_OK_AND_ASSIGN( + CelValue out, + cel_expr->Trace(activation, &arena, + [](int64_t, const CelValue&, google::protobuf::Arena*) { + return absl::OkStatus(); + })); + ASSERT_TRUE(out.IsBool()) << out.DebugString(); + EXPECT_EQ(out.BoolOrDie(), true); +} + +INSTANTIATE_TEST_SUITE_P( + CelBindingsExtTest, BindingsExtTest, + testing::Combine( + testing::ValuesIn( + {{"cel.bind(t, true, t)"}, + {"cel.bind(msg, \"hello\", msg + msg + msg) == " + "\"hellohellohello\""}, + {"cel.bind(t1, true, cel.bind(t2, true, t1 && t2))"}, + {"cel.bind(valid_elems, [1, 2, 3], " + "[3, 4, 5].exists(e, e in valid_elems))"}, + {"cel.bind(valid_elems, [1, 2, 3], " + "![4, 5].exists(e, e in valid_elems))"}, + // Implementation detail: bind variables and comprehension + // variables get mapped to an int index in the same space. Check + // that mixing them works. + {R"( + cel.bind( + my_list, + ['a', 'b', 'c'].map(x, x + '_'), + [0, 1, 2].map(y, my_list[y] + string(y))) == + ['a_0', 'b_1', 'c_2'])"}, + // Check scoping rules. + {"cel.bind(x, 1, " + " cel.bind(x, x + 1, x)) == 2"}, + // Testing a bound function with the same macro name, but non-cel + // namespace. The function mirrors the macro signature, but just + // returns true. + {"false.bind(false, false, false)"}, + // Error case where the variable name is not a simple identifier. + {"cel.bind(bad.name, true, bad.name)", + "variable name must be a simple identifier"}}), + /*constant_folding*/ testing::Bool(), + /*recursive_plan*/ testing::Bool())); + +constexpr absl::string_view kTraceExpr = R"pb( + expr: { + id: 11 + comprehension_expr: { + iter_var: "#unused" + iter_range: { + id: 8 + list_expr: {} + } + accu_var: "x" + accu_init: { + id: 4 + const_expr: { int64_value: 20 } + } + loop_condition: { + id: 9 + const_expr: { bool_value: false } + } + loop_step: { + id: 10 + ident_expr: { name: "x" } + } + result: { + id: 6 + call_expr: { + function: "_*_" + args: { + id: 5 + ident_expr: { name: "x" } + } + args: { + id: 7 + ident_expr: { name: "x" } + } + } + } + } + })pb"; + +TEST(BindingsExtTest, TraceSupport) { + ParsedExpr expr; + ASSERT_TRUE(TextFormat::ParseFromString(kTraceExpr, &expr)); + InterpreterOptions options; + options.enable_heterogeneous_equality = true; + options.enable_empty_wrapper_null_unboxing = true; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + Activation activation; + google::protobuf::Arena arena; + absl::flat_hash_map ids; + ASSERT_OK_AND_ASSIGN( + auto result, + plan->Trace(activation, &arena, + [&](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { + ids[id] = value; + return absl::OkStatus(); + })); + + EXPECT_TRUE(result.IsInt64() && result.Int64OrDie() == 400) + << result.DebugString(); + + EXPECT_THAT(ids, Contains(Pair(4, IsCelInt64(20)))); + EXPECT_THAT(ids, Contains(Pair(7, IsCelInt64(20)))); +} + +// Test bind expression with nested field selection. +// +// cel.bind(submsg, +// msg.child.child, +// (false) ? +// TestAllTypes{single_int64: -42}.single_int64 : +// submsg.payload.single_int64) +constexpr absl::string_view kFieldSelectTestExpr = R"pb( + reference_map: { + key: 4 + value: { name: "msg" } + } + reference_map: { + key: 8 + value: { overload_id: "conditional" } + } + reference_map: { + key: 9 + value: { name: "cel.expr.conformance.proto2.TestAllTypes" } + } + reference_map: { + key: 13 + value: { name: "submsg" } + } + reference_map: { + key: 18 + value: { name: "submsg" } + } + type_map: { + key: 4 + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } + } + type_map: { + key: 5 + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } + } + type_map: { + key: 6 + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } + } + type_map: { + key: 7 + value: { primitive: BOOL } + } + type_map: { + key: 8 + value: { primitive: INT64 } + } + type_map: { + key: 9 + value: { message_type: "cel.expr.conformance.proto2.TestAllTypes" } + } + type_map: { + key: 11 + value: { primitive: INT64 } + } + type_map: { + key: 12 + value: { primitive: INT64 } + } + type_map: { + key: 13 + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } + } + type_map: { + key: 14 + value: { message_type: "cel.expr.conformance.proto2.TestAllTypes" } + } + type_map: { + key: 15 + value: { primitive: INT64 } + } + type_map: { + key: 16 + value: { list_type: { elem_type: { dyn: {} } } } + } + type_map: { + key: 17 + value: { primitive: BOOL } + } + type_map: { + key: 18 + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } + } + type_map: { + key: 19 + value: { primitive: INT64 } + } + source_info: { + location: "" + line_offsets: 120 + positions: { key: 1 value: 0 } + positions: { key: 2 value: 8 } + positions: { key: 3 value: 9 } + positions: { key: 4 value: 17 } + positions: { key: 5 value: 20 } + positions: { key: 6 value: 26 } + positions: { key: 7 value: 35 } + positions: { key: 8 value: 42 } + positions: { key: 9 value: 56 } + positions: { key: 10 value: 69 } + positions: { key: 11 value: 71 } + positions: { key: 12 value: 75 } + positions: { key: 13 value: 91 } + positions: { key: 14 value: 97 } + positions: { key: 15 value: 105 } + positions: { key: 16 value: 8 } + positions: { key: 17 value: 8 } + positions: { key: 18 value: 8 } + positions: { key: 19 value: 8 } + macro_calls: { + key: 19 + value: { + call_expr: { + target: { + id: 1 + ident_expr: { name: "cel" } + } + function: "bind" + args: { + id: 3 + ident_expr: { name: "submsg" } + } + args: { + id: 6 + select_expr: { + operand: { + id: 5 + select_expr: { + operand: { + id: 4 + ident_expr: { name: "msg" } + } + field: "child" + } + } + field: "child" + } + } + args: { + id: 8 + call_expr: { + function: "_?_:_" + args: { + id: 7 + const_expr: { bool_value: false } + } + args: { + id: 12 + select_expr: { + operand: { + id: 9 + struct_expr: { + message_name: "cel.expr.conformance.proto2.TestAllTypes" + entries: { + id: 10 + field_key: "single_int64" + value: { + id: 11 + const_expr: { int64_value: -42 } + } + } + } + } + field: "single_int64" + } + } + args: { + id: 15 + select_expr: { + operand: { + id: 14 + select_expr: { + operand: { + id: 13 + ident_expr: { name: "submsg" } + } + field: "payload" + } + } + field: "single_int64" + } + } + } + } + } + } + } + } + expr: { + id: 19 + comprehension_expr: { + iter_var: "#unused" + iter_range: { + id: 16 + list_expr: {} + } + accu_var: "submsg" + accu_init: { + id: 6 + select_expr: { + operand: { + id: 5 + select_expr: { + operand: { + id: 4 + ident_expr: { name: "msg" } + } + field: "child" + } + } + field: "child" + } + } + loop_condition: { + id: 17 + const_expr: { bool_value: false } + } + loop_step: { + id: 18 + ident_expr: { name: "submsg" } + } + result: { + id: 8 + call_expr: { + function: "_?_:_" + args: { + id: 7 + const_expr: { bool_value: false } + } + args: { + id: 12 + select_expr: { + operand: { + id: 9 + struct_expr: { + message_name: "cel.expr.conformance.proto2.TestAllTypes" + entries: { + id: 10 + field_key: "single_int64" + value: { + id: 11 + const_expr: { int64_value: -42 } + } + } + } + } + field: "single_int64" + } + } + args: { + id: 15 + select_expr: { + operand: { + id: 14 + select_expr: { + operand: { + id: 13 + ident_expr: { name: "submsg" } + } + field: "payload" + } + } + field: "single_int64" + } + } + } + } + } + })pb"; + +class BindingsExtInteractionsTest : public testing::TestWithParam { + protected: + bool GetEnableSelectOptimization() { return GetParam(); } +}; + +TEST_P(BindingsExtInteractionsTest, SelectOptimization) { + CheckedExpr expr; + ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); + InterpreterOptions options; + options.enable_empty_wrapper_null_unboxing = true; + options.enable_select_optimization = GetEnableSelectOptimization(); + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); + Arena arena; + Activation activation; + + NestedTestAllTypes msg; + msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); + + activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); + + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsInt64()); + EXPECT_EQ(out.Int64OrDie(), 42); +} + +TEST_P(BindingsExtInteractionsTest, UnknownAttributesSelectOptimization) { + CheckedExpr expr; + ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); + InterpreterOptions options; + options.enable_empty_wrapper_null_unboxing = true; + options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; + options.enable_select_optimization = GetEnableSelectOptimization(); + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); + Arena arena; + Activation activation; + activation.set_unknown_attribute_patterns({AttributePattern( + "msg", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("child")})}); + + NestedTestAllTypes msg; + msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); + + activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); + + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsUnknownSet()); + EXPECT_THAT(out.UnknownSetOrDie()->unknown_attributes(), + testing::ElementsAre( + Attribute("msg", {AttributeQualifier::OfString("child"), + AttributeQualifier::OfString("child")}))); +} + +TEST_P(BindingsExtInteractionsTest, + UnknownAttributeSelectOptimizationReturnValue) { + CheckedExpr expr; + ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); + InterpreterOptions options; + options.enable_empty_wrapper_null_unboxing = true; + options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; + options.enable_select_optimization = GetEnableSelectOptimization(); + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); + Arena arena; + Activation activation; + activation.set_unknown_attribute_patterns({AttributePattern( + "msg", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("payload"), + AttributeQualifierPattern::OfString("single_int64")})}); + + NestedTestAllTypes msg; + msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); + + activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); + + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsUnknownSet()) << out.DebugString(); + EXPECT_THAT(out.UnknownSetOrDie()->unknown_attributes(), + testing::ElementsAre(Attribute( + "msg", {AttributeQualifier::OfString("child"), + AttributeQualifier::OfString("child"), + AttributeQualifier::OfString("payload"), + AttributeQualifier::OfString("single_int64")}))); +} + +TEST_P(BindingsExtInteractionsTest, MissingAttributesSelectOptimization) { + CheckedExpr expr; + ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); + InterpreterOptions options; + options.enable_empty_wrapper_null_unboxing = true; + options.enable_missing_attribute_errors = true; + options.enable_select_optimization = GetEnableSelectOptimization(); + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); + Arena arena; + Activation activation; + activation.set_missing_attribute_patterns({AttributePattern( + "msg", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("payload"), + AttributeQualifierPattern::OfString("single_int64")})}); + + NestedTestAllTypes msg; + msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); + + activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); + + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsError()) << out.DebugString(); + EXPECT_THAT(out.ErrorOrDie()->ToString(), + HasSubstr("msg.child.child.payload.single_int64")); +} + +TEST_P(BindingsExtInteractionsTest, UnknownAttribute) { + std::vector all_macros = Macro::AllMacros(); + std::vector bindings_macros = cel::extensions::bindings_macros(); + all_macros.insert(all_macros.end(), bindings_macros.begin(), + bindings_macros.end()); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithMacros( + R"( + cel.bind( + x, + msg.child.payload.single_int64, + x < 42 || 1 == 1))", + all_macros)); + + InterpreterOptions options; + options.enable_empty_wrapper_null_unboxing = true; + options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; + options.enable_select_optimization = GetEnableSelectOptimization(); + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( + &expr.expr(), &expr.source_info())); + Arena arena; + Activation activation; + activation.set_unknown_attribute_patterns({AttributePattern( + "msg", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("payload"), + AttributeQualifierPattern::OfString("single_int64")})}); + + NestedTestAllTypes msg; + msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); + + activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); + + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsBool()) << out.DebugString(); + EXPECT_TRUE(out.BoolOrDie()); +} + +TEST_P(BindingsExtInteractionsTest, UnknownAttributeReturnValue) { + std::vector all_macros = Macro::AllMacros(); + std::vector bindings_macros = cel::extensions::bindings_macros(); + all_macros.insert(all_macros.end(), bindings_macros.begin(), + bindings_macros.end()); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithMacros( + R"( + cel.bind( + x, + msg.child.payload.single_int64, + x))", + all_macros)); + + InterpreterOptions options; + options.enable_empty_wrapper_null_unboxing = true; + options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; + options.enable_select_optimization = GetEnableSelectOptimization(); + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( + &expr.expr(), &expr.source_info())); + Arena arena; + Activation activation; + activation.set_unknown_attribute_patterns({AttributePattern( + "msg", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("payload"), + AttributeQualifierPattern::OfString("single_int64")})}); + + NestedTestAllTypes msg; + msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); + + activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); + + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsUnknownSet()) << out.DebugString(); + EXPECT_THAT(out.UnknownSetOrDie()->unknown_attributes(), + testing::ElementsAre(Attribute( + "msg", {AttributeQualifier::OfString("child"), + AttributeQualifier::OfString("payload"), + AttributeQualifier::OfString("single_int64")}))); +} + +TEST_P(BindingsExtInteractionsTest, MissingAttribute) { + std::vector all_macros = Macro::AllMacros(); + std::vector bindings_macros = cel::extensions::bindings_macros(); + all_macros.insert(all_macros.end(), bindings_macros.begin(), + bindings_macros.end()); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithMacros( + R"( + cel.bind( + x, + msg.child.payload.single_int64, + x < 42 || 1 == 2))", + all_macros)); + + InterpreterOptions options; + options.enable_empty_wrapper_null_unboxing = true; + options.enable_missing_attribute_errors = true; + options.enable_select_optimization = GetEnableSelectOptimization(); + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + + ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); + + // Register builtins and configure the execution environment. + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( + &expr.expr(), &expr.source_info())); + Arena arena; + Activation activation; + activation.set_missing_attribute_patterns({AttributePattern( + "msg", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("payload"), + AttributeQualifierPattern::OfString("single_int64")})}); + + NestedTestAllTypes msg; + msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); + + activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); + + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsError()) << out.DebugString(); + EXPECT_THAT(out.ErrorOrDie()->ToString(), + HasSubstr("msg.child.payload.single_int64")); +} + +INSTANTIATE_TEST_SUITE_P(BindingsExtInteractionsTest, + BindingsExtInteractionsTest, + /*enable_select_optimization=*/testing::Bool()); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2.cc b/extensions/comprehensions_v2.cc new file mode 100644 index 000000000..486369c1e --- /dev/null +++ b/extensions/comprehensions_v2.cc @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2.h" + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "extensions/comprehensions_v2_macros.h" +#include "internal/status_macros.h" +#include "parser/parser_interface.h" + +using ::cel::checker_internal::BuiltinsArena; + +namespace cel::extensions { + +namespace { + +// Arbitrary type parameter name A. +TypeParamType TypeParamA() { return TypeParamType("A"); } + +// Arbitrary type parameter name B. +TypeParamType TypeParamB() { return TypeParamType("B"); } + +Type MapOfAB() { + static absl::NoDestructor kInstance( + MapType(BuiltinsArena(), TypeParamA(), TypeParamB())); + return *kInstance; +} + +absl::Status AddComprehensionsV2Functions(TypeCheckerBuilder& builder) { + FunctionDecl map_insert; + map_insert.set_name("cel.@mapInsert"); + CEL_RETURN_IF_ERROR(map_insert.AddOverload( + MakeOverloadDecl("@mapInsert_map_key_value", MapOfAB(), MapOfAB(), + TypeParamA(), TypeParamB()))); + CEL_RETURN_IF_ERROR(map_insert.AddOverload( + MakeOverloadDecl("@mapInsert_map_map", MapOfAB(), MapOfAB(), MapOfAB()))); + return builder.AddFunction(map_insert); +} + +absl::Status ConfigureParser(ParserBuilder& parser_builder) { + return RegisterComprehensionsV2Macros(parser_builder); +} + +} // namespace + +CompilerLibrary ComprehensionsV2CompilerLibrary() { + return CompilerLibrary("cel.lib.ext.comprev2", &ConfigureParser, + &AddComprehensionsV2Functions); +} + +CheckerLibrary ComprehensionsV2CheckerLibrary() { + return CheckerLibrary{"cel.lib.ext.comprev2", &AddComprehensionsV2Functions}; +} + +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2.h b/extensions/comprehensions_v2.h new file mode 100644 index 000000000..94f984708 --- /dev/null +++ b/extensions/comprehensions_v2.h @@ -0,0 +1,39 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_H_ + +#include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "extensions/comprehensions_v2_functions.h" // IWYU pragma: export +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, + const ParserOptions& options); + +// Declarations for the comprehensions v2 extension library. +CompilerLibrary ComprehensionsV2CompilerLibrary(); + +// Declarations for the comprehensions v2 extension library. +CheckerLibrary ComprehensionsV2CheckerLibrary(); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_H_ diff --git a/extensions/comprehensions_v2_functions.cc b/extensions/comprehensions_v2_functions.cc new file mode 100644 index 000000000..bf23780c0 --- /dev/null +++ b/extensions/comprehensions_v2_functions.cc @@ -0,0 +1,148 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_functions.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "common/values/map_value_builder.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +namespace { + +absl::StatusOr MapInsertKeyValue( + const MapValue& map, const Value& key, const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (auto mutable_map_value = common_internal::AsMutableMapValue(map); + mutable_map_value) { + // Fast path, runtime has given us a mutable map. We can mutate it directly + // and return it. + CEL_RETURN_IF_ERROR(mutable_map_value->Put(key, value)) + .With(ErrorValueReturn()); + return map; + } + // Slow path, we have to make a copy. + auto builder = NewMapValueBuilder(arena); + if (auto size = map.Size(); size.ok()) { + builder->Reserve(*size + 1); + } else { + size.IgnoreError(); + } + CEL_RETURN_IF_ERROR( + map.ForEach( + [&builder](const Value& key, + const Value& value) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(builder->Put(key, value)); + return true; + }, + descriptor_pool, message_factory, arena)) + .With(ErrorValueReturn()); + CEL_RETURN_IF_ERROR(builder->Put(key, value)).With(ErrorValueReturn()); + return std::move(*builder).Build(); +} + +absl::StatusOr MapInsertMap( + const MapValue& map, const MapValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (auto mutable_map_value = common_internal::AsMutableMapValue(map); + mutable_map_value) { + // Fast path, runtime has given us a mutable map. We can mutate it directly + // and return it. + CEL_RETURN_IF_ERROR( + value.ForEach( + [&mutable_map_value](const Value& key, + const Value& value) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(mutable_map_value->Put(key, value)); + return true; + }, + descriptor_pool, message_factory, arena)) + .With(ErrorValueReturn()); + return map; + } + // Slow path, we have to make a copy. + auto builder = NewMapValueBuilder(arena); + if (auto size = map.Size(); size.ok()) { + builder->Reserve(*size + 1); + } else { + size.IgnoreError(); + } + CEL_RETURN_IF_ERROR( + map.ForEach( + [&builder](const Value& key, + const Value& value) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(builder->Put(key, value)); + return true; + }, + descriptor_pool, message_factory, arena)) + .With(ErrorValueReturn()); + CEL_RETURN_IF_ERROR( + value.ForEach( + [&builder](const Value& key, + const Value& value) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(builder->Put(key, value)); + return true; + }, + descriptor_pool, message_factory, arena)) + .With(ErrorValueReturn()); + return std::move(*builder).Build(); +} + +} // namespace + +absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(registry.Register( + TernaryFunctionAdapter, MapValue, Value, + Value>::CreateDescriptor("cel.@mapInsert", + /*receiver_style=*/false), + TernaryFunctionAdapter, MapValue, Value, + Value>::WrapFunction(&MapInsertKeyValue))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, MapValue, MapValue>:: + CreateDescriptor("cel.@mapInsert", + /*receiver_style=*/false), + BinaryFunctionAdapter, MapValue, + MapValue>::WrapFunction(&MapInsertMap))); + + return absl::OkStatus(); +} + +absl::Status RegisterComprehensionsV2Functions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options) { + return RegisterComprehensionsV2Functions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options)); +} + +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_functions.h b/extensions/comprehensions_v2_functions.h new file mode 100644 index 000000000..8f99780a2 --- /dev/null +++ b/extensions/comprehensions_v2_functions.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Register comprehension v2 functions. +absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, + const RuntimeOptions& options); +absl::Status RegisterComprehensionsV2Functions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc new file mode 100644 index 000000000..a054626f9 --- /dev/null +++ b/extensions/comprehensions_v2_macros.cc @@ -0,0 +1,571 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_macros.h" + +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser_interface.h" + +namespace cel::extensions { + +namespace { + +using ::google::api::expr::common::CelOperator; + +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + +absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("all() requires 3 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "all() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], "all() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "all() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("all() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(true); + auto condition = + factory.NewCall(CelOperator::NOT_STRICTLY_FALSE, factory.NewAccuIdent()); + auto step = factory.NewCall(CelOperator::LOGICAL_AND, factory.NewAccuIdent(), + std::move(args[2])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), factory.AccuVarName(), std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeAllMacro2() { + auto status_or_macro = Macro::Receiver(CelOperator::ALL, 3, ExpandAllMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("exists() requires 3 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "exists() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], "exists() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "exists() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("exists() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("exists() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(false); + auto condition = factory.NewCall( + CelOperator::NOT_STRICTLY_FALSE, + factory.NewCall(CelOperator::LOGICAL_NOT, factory.NewAccuIdent())); + auto step = factory.NewCall(CelOperator::LOGICAL_OR, factory.NewAccuIdent(), + std::move(args[2])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), factory.AccuVarName(), std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeExistsMacro2() { + auto status_or_macro = + Macro::Receiver(CelOperator::EXISTS, 3, ExpandExistsMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("existsOne() requires 3 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "existsOne() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], + "existsOne() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "existsOne() second variable must be different " + "from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("existsOne() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("existsOne() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewIntConst(0); + auto condition = factory.NewBoolConst(true); + auto step = + factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + factory.NewCall(CelOperator::ADD, factory.NewAccuIdent(), + factory.NewIntConst(1)), + factory.NewAccuIdent()); + auto result = factory.NewCall(CelOperator::EQUALS, factory.NewAccuIdent(), + factory.NewIntConst(1)); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), factory.AccuVarName(), std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeExistsOneMacro2() { + auto status_or_macro = Macro::Receiver("existsOne", 3, ExpandExistsOneMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("transformList() requires 3 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], + "transformList() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], + "transformList() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformList() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformList() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformList() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + std::string iter_var = args[0].ident_expr().name(); + std::string iter_var2 = args[1].ident_expr().name(); + Expr step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[2])))); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), factory.AccuVarName(), + factory.NewList(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformList3Macro() { + auto status_or_macro = + Macro::Receiver("transformList", 3, ExpandTransformList3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 4) { + return factory.ReportError("transformList() requires 4 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], + "transformList() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], + "transformList() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformList() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformList() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformList() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + std::string iter_var = args[0].ident_expr().name(); + std::string iter_var2 = args[1].ident_expr().name(); + Expr step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[3])))); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), factory.AccuVarName(), + factory.NewList(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformList4Macro() { + auto status_or_macro = + Macro::Receiver("transformList", 4, ExpandTransformList4Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("transformMap() requires 3 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], + "transformMap() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], + "transformMap() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMap() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformMap() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformMap() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + std::string iter_var = args[0].ident_expr().name(); + std::string iter_var2 = args[1].ident_expr().name(); + Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[0]), std::move(args[2])); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), factory.AccuVarName(), + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMap3Macro() { + auto status_or_macro = + Macro::Receiver("transformMap", 3, ExpandTransformMap3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 4) { + return factory.ReportError("transformMap() requires 4 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], + "transformMap() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], + "transformMap() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMap() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformMap() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformMap() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + std::string iter_var = args[0].ident_expr().name(); + std::string iter_var2 = args[1].ident_expr().name(); + Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[0]), std::move(args[3])); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), factory.AccuVarName(), + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMap4Macro() { + auto status_or_macro = + Macro::Receiver("transformMap", 4, ExpandTransformMap4Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("transformMapEntry() requires 3 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], + "transformMapEntry() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], + "transformMapEntry() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMapEntry() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], + absl::StrCat("transformMapEntry() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], + absl::StrCat("transformMapEntry() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + std::string iter_var = args[0].ident_expr().name(); + std::string iter_var2 = args[1].ident_expr().name(); + Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[2])); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), factory.AccuVarName(), + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMap3EntryMacro() { + auto status_or_macro = + Macro::Receiver("transformMapEntry", 3, ExpandTransformMapEntry3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 4) { + return factory.ReportError("transformMapEntry() requires 4 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], + "transformMapEntry() first variable name must be a simple identifier"); + } + if (!IsSimpleIdentifier(args[1])) { + return factory.ReportErrorAt( + args[1], + "transformMapEntry() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMapEntry() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], + absl::StrCat("transformMapEntry() first variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], + absl::StrCat("transformMapEntry() second variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + std::string iter_var = args[0].ident_expr().name(); + std::string iter_var2 = args[1].ident_expr().name(); + Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[3])); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), factory.AccuVarName(), + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMapEntry4Macro() { + auto status_or_macro = + Macro::Receiver("transformMapEntry", 4, ExpandTransformMapEntry4Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +const Macro& AllMacro2() { + static const absl::NoDestructor macro(MakeAllMacro2()); + return *macro; +} + +const Macro& ExistsMacro2() { + static const absl::NoDestructor macro(MakeExistsMacro2()); + return *macro; +} + +const Macro& ExistsOneMacro2() { + static const absl::NoDestructor macro(MakeExistsOneMacro2()); + return *macro; +} + +const Macro& TransformList3Macro() { + static const absl::NoDestructor macro(MakeTransformList3Macro()); + return *macro; +} + +const Macro& TransformList4Macro() { + static const absl::NoDestructor macro(MakeTransformList4Macro()); + return *macro; +} + +const Macro& TransformMap3Macro() { + static const absl::NoDestructor macro(MakeTransformMap3Macro()); + return *macro; +} + +const Macro& TransformMap4Macro() { + static const absl::NoDestructor macro(MakeTransformMap4Macro()); + return *macro; +} + +const Macro& TransformMapEntry3Macro() { + static const absl::NoDestructor macro(MakeTransformMap3EntryMacro()); + return *macro; +} + +const Macro& TransformMapEntry4Macro() { + static const absl::NoDestructor macro(MakeTransformMapEntry4Macro()); + return *macro; +} + +} // namespace + +std::vector AllMacros() { + return {AllMacro2(), + ExistsMacro2(), + ExistsOneMacro2(), + TransformList3Macro(), + TransformList4Macro(), + TransformMap3Macro(), + TransformMap4Macro(), + TransformMapEntry3Macro(), + TransformMapEntry4Macro()}; +} + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, + const ParserOptions&) { + for (const Macro& macro : AllMacros()) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(macro)); + } + + return absl::OkStatus(); +} + +absl::Status RegisterComprehensionsV2Macros(ParserBuilder& parser_builder) { + for (const Macro& macro : AllMacros()) { + CEL_RETURN_IF_ERROR(parser_builder.AddMacro(macro)); + } + + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_macros.h b/extensions/comprehensions_v2_macros.h new file mode 100644 index 000000000..fed6e9284 --- /dev/null +++ b/extensions/comprehensions_v2_macros.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ + +#include "absl/status/status.h" +#include "compiler/compiler.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, + const ParserOptions& options); + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(ParserBuilder& parser_builder); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ diff --git a/extensions/comprehensions_v2_test.cc b/extensions/comprehensions_v2_test.cc new file mode 100644 index 000000000..25645af5c --- /dev/null +++ b/extensions/comprehensions_v2_test.cc @@ -0,0 +1,575 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" +#include "common/value_testing.h" +#include "common/values/list_value_builder.h" +#include "common/values/map_value_builder.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2_functions.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::ErrorValueIs; +using ::testing::HasSubstr; +using ::testing::TestWithParam; + +absl::StatusOr> CreateProgram( + const std::string& expression, bool enable_mutable_accumulator, + int max_recursion_depth) { + // Configure the compiler + CEL_ASSIGN_OR_RETURN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(StandardCheckerLibrary())); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(OptionalCompilerLibrary())); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(BindingsCompilerLibrary())); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(StringsCompilerLibrary())); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary( + extensions::ComprehensionsV2CompilerLibrary())); + + CEL_ASSIGN_OR_RETURN(auto compiler, std::move(*compiler_builder).Build()); + + // Configure the runtime + cel::RuntimeOptions options; + options.enable_qualified_type_identifiers = true; + options.enable_comprehension_list_append = enable_mutable_accumulator; + options.enable_comprehension_mutable_map = enable_mutable_accumulator; + options.max_recursion_depth = max_recursion_depth; + + CEL_ASSIGN_OR_RETURN(auto runtime_builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + CEL_RETURN_IF_ERROR(EnableOptionalTypes(runtime_builder)); + CEL_RETURN_IF_ERROR( + RegisterStringsFunctions(runtime_builder.function_registry(), options)); + CEL_RETURN_IF_ERROR(RegisterComprehensionsV2Functions( + runtime_builder.function_registry(), options)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + + CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expression)); + if (!result.IsValid()) { + return absl::Status(absl::StatusCode::kInvalidArgument, + result.FormatError()); + } + return runtime->CreateProgram(*result.ReleaseAst()); +} + +struct TestOptions { + bool enable_mutable_accumulator; + int max_recursion_depth; +}; + +struct ComprehensionsV2TestCase { + std::string expression; + absl::StatusCode expected_status_code = absl::StatusCode::kOk; + std::string expected_error; +}; + +class ComprehensionsV2Test + : public TestWithParam> { +}; + +TEST_P(ComprehensionsV2Test, Basic) { + const ComprehensionsV2TestCase& test_case = std::get<0>(GetParam()); + const TestOptions& options = std::get<1>(GetParam()); + + absl::StatusOr> program = + CreateProgram(test_case.expression, options.enable_mutable_accumulator, + options.max_recursion_depth); + + if (!program.ok()) { + EXPECT_THAT(program, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.expected_error))); + // The error is expected. Nothing more to do in this test case + return; + } + + ASSERT_THAT(program, IsOk()); + + google::protobuf::Arena arena; + Activation activation; + + if (test_case.expected_status_code == absl::StatusCode::kOk) { + EXPECT_THAT(program.value()->Evaluate(&arena, activation), + IsOkAndHolds(BoolValueIs(true))) + << test_case.expression; + } else { + EXPECT_THAT(program.value()->Evaluate(&arena, activation), + IsOkAndHolds(ErrorValueIs(StatusIs( + test_case.expected_status_code, test_case.expected_error)))) + << test_case.expression; + } +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionsV2Test, ComprehensionsV2Test, + ::testing::Combine( + ::testing::ValuesIn({ + // list.all() + {.expression = "[1, 2, 3, 4].all(i, v, i < 5 && v > 0)"}, + {.expression = "[1, 2, 3, 4].all(i, v, i < v)"}, + {.expression = "[1, 2, 3, 4].all(i, v, i > v) == false"}, + { + .expression = + R"cel(cel.bind(listA, [1, 2, 3, 4], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))))cel", + }, + { + .expression = + R"cel(cel.bind(listA, [1, 2, 3, 4, 5, 6], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))) == false)cel", + }, + { + .expression = "[].all(__result__, v, v == 0)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].all(__result__, v, v == 0)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].all(i, __result__, i == 0)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].all(e, e, e == e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "[].all(foo.bar, e, true)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "[].all(e, foo.bar, true)", + .expected_error = + "second variable name must be a simple identifier", + }, + + // list.exists() + { + .expression = + R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.exists(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", + }, + { + .expression = "[].exists(__result__, v, v == 0)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(i, __result__, i == 0)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(e, e, e == e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "[].exists(foo.bar, e, true)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "[].exists(e, foo.bar, true)", + .expected_error = + "second variable name must be a simple identifier", + }, + // list.existsOne() + { + .expression = + R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", + }, + { + .expression = + R"cel(cel.bind(l, ['hello', 'goodbye', 'hello!', 'goodbye'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next == 'goodbye').orValue(false))) == false)cel", + }, + { + .expression = "[].existsOne(__result__, v, v == 0)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].existsOne(i, __result__, i == 0)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].existsOne(e, e, e == e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "[].existsOne(foo.bar, e, true)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "[].existsOne(e, foo.bar, true)", + .expected_error = + "second variable name must be a simple identifier", + }, + // list.transformList() + { + .expression = + R"cel(['Hello', 'world'].transformList(i, v, '[' + string(i) + ']' + v.lowerAscii()) == ['[0]hello', '[1]world'])cel", + }, + { + .expression = + R"cel(['hello', 'world'].transformList(i, v, v.startsWith('greeting'), '[' + string(i) + ']' + v) == [])cel", + }, + { + .expression = + R"cel([1, 2, 3].transformList(indexVar, valueVar, (indexVar * valueVar) + valueVar) == [1, 4, 9])cel", + }, + { + .expression = + R"cel([1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == [1, 9])cel", + }, + { + .expression = "[].transformList(__result__, v, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(i, __result__, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(e, e, e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "[].transformList(foo.bar, e, e)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "[].transformList(e, foo.bar, e)", + .expected_error = + "second variable name must be a simple identifier", + }, + { + .expression = "[].transformList(__result__, v, v == 0, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(i, __result__, i == 0, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(e, e, e == e, e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "[].transformList(foo.bar, e, true, e)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "[].transformList(e, foo.bar, true, e)", + .expected_error = + "second variable name must be a simple identifier", + }, + // list.transformMap() + { + .expression = + R"cel(['Hello', 'world'].transformMap(i, v, [v.lowerAscii()]) == {0: ['hello'], 1: ['world']})cel", + }, + { + .expression = + R"cel([1, 2, 3].transformMap(indexVar, valueVar, (indexVar * valueVar) + valueVar) == {0: 1, 1: 4, 2: 9})cel", + }, + { + .expression = + R"cel([1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == {0: 1, 2: 9})cel", + }, + // map.all() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'world'}.all(k, v, k.startsWith('hello') && v == 'world'))cel", + }, + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.all(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", + }, + // map.exists() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.exists(k, v, k.startsWith('hello') && v.endsWith('world')))cel", + }, + // map.existsOne() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')))cel", + }, + { + .expression = + R"cel({'hello': 'world', 'hello!': 'wow, world'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", + }, + // map.transformList() + { + .expression = + R"cel({'Hello': 'world'}.transformList(k, v, k.lowerAscii() + "=" + v) == ['hello=world'])cel", + }, + { + .expression = + R"cel({'hello': 'world'}.transformList(k, v, k.startsWith('greeting'), k + "=" + v) == [])cel", + }, + { + .expression = + R"cel(cel.bind(m, {'farewell': 'goodbye', 'greeting': 'hello'}.transformList(k, _, k), m == ['farewell', 'greeting'] || m == ['greeting', 'farewell']))cel", + }, + { + .expression = + R"cel(cel.bind(m, {'greeting': 'hello', 'farewell': 'goodbye'}.transformList(_, v, v), m == ['goodbye', 'hello'] || m == ['hello', 'goodbye']))cel", + }, + // map.transformMap() + { + .expression = + R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, k + ', ' + v + '!') == {'hello': 'hello, world!', 'goodbye': 'goodbye, cruel world!'})cel", + }, + { + .expression = + R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, v.startsWith('world'), k + ", " + v + "!") == {'hello': 'hello, world!'})cel", + }, + { + .expression = "{}.transformMap(__result__, v, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(k, __result__, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(e, e, e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMap(foo.bar, e, e)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(e, foo.bar, e)", + .expected_error = + "second variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(__result__, v, v == 0, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(k, __result__, k == 0, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(e, e, e == e, e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMap(foo.bar, e, true, e)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(e, foo.bar, true, e)", + .expected_error = + "second variable name must be a simple identifier", + }, + // map.transformMapEntry + { + .expression = + R"cel({'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, {v: k}) == {'world': 'hello', 'tacocat': 'greetings'})cel", + }, + { + .expression = + R"cel({'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, {}) == {})cel", + }, + { + .expression = + R"cel({'a': 'same', 'c': 'same'}.transformMapEntry(k, v, {v: k}))cel", + .expected_status_code = absl::StatusCode::kAlreadyExists, + .expected_error = "duplicate key in map", + }, + { + .expression = "{}.transformMapEntry(__result__, v, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMapEntry(k, __result__, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMapEntry(e, e, e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMapEntry(foo.bar, e, e)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMapEntry(e, foo.bar, e)", + .expected_error = + "second variable name must be a simple identifier", + }, + // transformMapEntry(k, v, filter, expr) + { + .expression = + R"cel({'hello': 'world', 'same': 'same'}.transformMapEntry(k, v, k != v, {v: k}) == {'world': 'hello'})cel", + }, + { + .expression = "{}.transformMapEntry(__result__, v, v == 0, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMapEntry(k, __result__, k == 0, v)", + .expected_error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMapEntry(e, e, e == e, e)", + .expected_error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMapEntry(foo.bar, e, true, e)", + .expected_error = + "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMapEntry(e, foo.bar, true, e)", + .expected_error = + "second variable name must be a simple identifier", + }, + // list.transformMapEntry + { + .expression = + R"cel(['one', 'two'].transformMapEntry(k, v, {k + 1: 'is ' + v}) == {1: 'is one', 2: 'is two'})cel", + }, + }), + ::testing::ValuesIn({ + { + .enable_mutable_accumulator = true, + .max_recursion_depth = 0, + }, + { + .enable_mutable_accumulator = false, + .max_recursion_depth = 0, + }, + { + .enable_mutable_accumulator = true, + .max_recursion_depth = -1, + }, + { + .enable_mutable_accumulator = false, + .max_recursion_depth = -1, + }, + }))); + +class ComprehensionsV2TestMutableAccumulator + : public TestWithParam> { +}; + +TEST_P(ComprehensionsV2TestMutableAccumulator, MutableAccumulator) { + const ComprehensionsV2TestCase& test_case = std::get<0>(GetParam()); + const TestOptions& options = std::get<1>(GetParam()); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr program, + CreateProgram(test_case.expression, options.enable_mutable_accumulator, + options.max_recursion_depth)); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(&arena, activation)); + bool is_mutable_accumulator = common_internal::IsMutableListValue(result) || + common_internal::IsMutableMapValue(result); + EXPECT_EQ(is_mutable_accumulator, options.enable_mutable_accumulator); +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionsV2Test, ComprehensionsV2TestMutableAccumulator, + ::testing::Combine( + ::testing::ValuesIn({ + {.expression = + R"cel(['Hello', 'world'].transformList(i, v, i))cel"}, + { + .expression = + R"cel({'hello': 'world'}.transformMap(k, v, k + v))cel", + }, + { + .expression = + R"cel(['hello', 'world'].transformMap(k, v, v))cel", + }, + { + .expression = + R"cel({'hello': 'world'}.transformMapEntry(k, v, {v: k}))cel", + }, + { + .expression = + R"cel(['hello', 'world'].transformMapEntry(k, v, {v: k}))cel", + }, + }), + ::testing::ValuesIn({ + { + .enable_mutable_accumulator = true, + .max_recursion_depth = 0, + }, + { + .enable_mutable_accumulator = false, + .max_recursion_depth = 0, + }, + { + .enable_mutable_accumulator = true, + .max_recursion_depth = -1, + }, + { + .enable_mutable_accumulator = false, + .max_recursion_depth = -1, + }, + }))); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/encoders.cc b/extensions/encoders.cc new file mode 100644 index 000000000..66431b30b --- /dev/null +++ b/extensions/encoders.cc @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/encoders.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +namespace { + +absl::StatusOr Base64Decode( + const StringValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string in; + std::string out; + if (!absl::Base64Unescape(value.NativeString(in), &out)) { + return ErrorValue{absl::InvalidArgumentError("invalid base64 data")}; + } + return BytesValue(arena, std::move(out)); +} + +absl::StatusOr Base64Encode( + const BytesValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string in; + std::string out; + out = absl::Base64Escape(value.NativeString(in)); + return StringValue(arena, std::move(out)); +} + +absl::Status RegisterEncodersDecls(TypeCheckerBuilder& builder) { + CEL_ASSIGN_OR_RETURN( + auto base64_decode_decl, + MakeFunctionDecl( + "base64.decode", + MakeOverloadDecl("base64_decode_string", BytesType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto base64_encode_decl, + MakeFunctionDecl( + "base64.encode", + MakeOverloadDecl("base64_encode_bytes", StringType(), BytesType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(base64_decode_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(base64_encode_decl)); + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterEncodersFunctions(FunctionRegistry& registry, + const RuntimeOptions&) { + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, + StringValue>::CreateDescriptor("base64.decode", + false), + UnaryFunctionAdapter, StringValue>::WrapFunction( + &Base64Decode))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, BytesValue>::CreateDescriptor( + "base64.encode", false), + UnaryFunctionAdapter, BytesValue>::WrapFunction( + &Base64Encode))); + return absl::OkStatus(); +} + +absl::Status RegisterEncodersFunctions( + google::api::expr::runtime::CelFunctionRegistry* absl_nonnull registry, + const google::api::expr::runtime::InterpreterOptions& options) { + return RegisterEncodersFunctions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options)); +} + +CheckerLibrary EncodersCheckerLibrary() { + return {"cel.lib.ext.encoders", &RegisterEncodersDecls}; +} + +CompilerLibrary EncodersCompilerLibrary() { + return CompilerLibrary::FromCheckerLibrary(EncodersCheckerLibrary()); +} + +} // namespace cel::extensions diff --git a/extensions/encoders.h b/extensions/encoders.h new file mode 100644 index 000000000..2187f7fc6 --- /dev/null +++ b/extensions/encoders.h @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_ENCODERS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_ENCODERS_H_ + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Register encoders functions. +absl::Status RegisterEncodersFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +absl::Status RegisterEncodersFunctions( + google::api::expr::runtime::CelFunctionRegistry* absl_nonnull registry, + const google::api::expr::runtime::InterpreterOptions& options); + +// Declarations for the encoders extension library. +CheckerLibrary EncodersCheckerLibrary(); + +// Compiler library for the encoders extension. +CompilerLibrary EncodersCompilerLibrary(); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_ENCODERS_H_ diff --git a/extensions/encoders_test.cc b/extensions/encoders_test.cc new file mode 100644 index 000000000..c95588e29 --- /dev/null +++ b/extensions/encoders_test.cc @@ -0,0 +1,91 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/encoders.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; + +struct TestCase { + std::string expr; +}; + +class EncodersTest : public ::testing::TestWithParam {}; + +TEST_P(EncodersTest, ParseCheckEval) { + const TestCase& test_case = GetParam(); + + // Configure the compiler. + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT( + compiler_builder->AddLibrary(extensions::EncodersCheckerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(*compiler_builder).Build()); + + // Configure the runtime. + cel::RuntimeOptions runtime_options; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), + runtime_options)); + ASSERT_THAT(RegisterEncodersFunctions(runtime_builder.function_registry(), + runtime_options), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + + // Compile, plan, evaluate. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(test_case.expr)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto program, + runtime->CreateProgram(*result.ReleaseAst())); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.IsBool()); + ASSERT_TRUE(value.GetBool()); +} + +INSTANTIATE_TEST_SUITE_P( + EncodersTest, EncodersTest, + testing::Values(TestCase{"base64.encode(b'hello') == 'aGVsbG8='"}, + TestCase{"base64.decode('aGVsbG8=') == b'hello'"})); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/formatting.cc b/extensions/formatting.cc new file mode 100644 index 000000000..252fdc7bd --- /dev/null +++ b/extensions/formatting.cc @@ -0,0 +1,570 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/formatting.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/btree_map.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +namespace { + +static constexpr int32_t kNanosPerMillisecond = 1000000; +static constexpr int32_t kNanosPerMicrosecond = 1000; +static constexpr int kMaxPrecision = 1000; + +absl::StatusOr FormatString( + const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); + +absl::StatusOr>> ParsePrecision( + absl::string_view format, int max_precision) { + if (format.empty() || format[0] != '.') return std::pair{0, std::nullopt}; + + int64_t i = 1; + while (i < format.size() && absl::ascii_isdigit(format[i])) { + ++i; + } + if (i == format.size()) { + return absl::InvalidArgumentError( + "unable to find end of precision specifier"); + } + int precision; + if (!absl::SimpleAtoi(format.substr(1, i - 1), &precision)) { + return absl::InvalidArgumentError( + "unable to convert precision specifier to integer"); + } + if (precision > max_precision) { + return absl::InvalidArgumentError( + absl::StrCat("precision specifier exceeds maximum of ", max_precision)); + } + return std::pair{i, precision}; +} + +absl::StatusOr FormatDuration( + const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + absl::Duration duration = value.GetDuration(); + if (duration == absl::ZeroDuration()) { + return "0s"; + } + if (duration < absl::ZeroDuration()) { + scratch.append("-"); + duration = absl::AbsDuration(duration); + } + int64_t seconds = absl::ToInt64Seconds(duration); + absl::StrAppend(&scratch, seconds); + int64_t nanos = absl::ToInt64Nanoseconds(duration - absl::Seconds(seconds)); + if (nanos != 0) { + scratch.append("."); + if (nanos % kNanosPerMillisecond == 0) { + scratch.append(absl::StrFormat("%03d", nanos / kNanosPerMillisecond)); + } else if (nanos % kNanosPerMicrosecond == 0) { + scratch.append(absl::StrFormat("%06d", nanos / kNanosPerMicrosecond)); + } else { + scratch.append(absl::StrFormat("%09d", nanos)); + } + } + scratch.append("s"); + return scratch; +} + +absl::StatusOr FormatDouble( + double value, std::optional precision, bool use_scientific_notation, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + static constexpr int kDefaultPrecision = 6; + if (std::isnan(value)) { + return "NaN"; + } else if (value == std::numeric_limits::infinity()) { + return "Infinity"; + } else if (value == -std::numeric_limits::infinity()) { + return "-Infinity"; + } + auto format = absl::StrCat("%.", precision.value_or(kDefaultPrecision), + use_scientific_notation ? "e" : "f"); + if (use_scientific_notation) { + scratch = absl::StrFormat(*absl::ParsedFormat<'e'>::New(format), value); + } else { + scratch = absl::StrFormat(*absl::ParsedFormat<'f'>::New(format), value); + } + return scratch; +} + +absl::StatusOr FormatList( + const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto it, value.GetList().NewIterator()); + scratch.clear(); + scratch.push_back('['); + std::string value_scratch; + + while (it->HasNext()) { + CEL_ASSIGN_OR_RETURN(auto next, + it->Next(descriptor_pool, message_factory, arena)); + absl::string_view next_str; + value_scratch.clear(); + CEL_ASSIGN_OR_RETURN( + next_str, FormatString(next, descriptor_pool, message_factory, arena, + value_scratch)); + absl::StrAppend(&scratch, next_str); + absl::StrAppend(&scratch, ", "); + } + if (scratch.size() > 1) { + scratch.resize(scratch.size() - 2); + } + scratch.push_back(']'); + return scratch; +} + +absl::StatusOr FormatMap( + const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + absl::btree_map value_map; + std::string value_scratch; + CEL_RETURN_IF_ERROR(value.GetMap().ForEach( + [&](const Value& key, const Value& value) -> absl::StatusOr { + if (key.kind() != ValueKind::kString && + key.kind() != ValueKind::kBool && key.kind() != ValueKind::kInt && + key.kind() != ValueKind::kUint) { + return absl::InvalidArgumentError( + absl::StrCat("map keys must be strings, booleans, integers, or " + "unsigned integers, was given ", + key.GetTypeName())); + } + value_scratch.clear(); + CEL_ASSIGN_OR_RETURN(auto key_str, + FormatString(key, descriptor_pool, message_factory, + arena, value_scratch)); + value_map.emplace(key_str, value); + return true; + }, + descriptor_pool, message_factory, arena)); + + scratch.clear(); + scratch.push_back('{'); + for (const auto& [key, value] : value_map) { + value_scratch.clear(); + CEL_ASSIGN_OR_RETURN(auto value_str, + FormatString(value, descriptor_pool, message_factory, + arena, value_scratch)); + absl::StrAppend(&scratch, key, ": "); + absl::StrAppend(&scratch, value_str); + absl::StrAppend(&scratch, ", "); + } + if (scratch.size() > 1) { + scratch.resize(scratch.size() - 2); + } + scratch.push_back('}'); + return scratch; +} + +absl::StatusOr FormatString( + const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (value.kind()) { + case ValueKind::kList: + return FormatList(value, descriptor_pool, message_factory, arena, + scratch); + case ValueKind::kMap: + return FormatMap(value, descriptor_pool, message_factory, arena, scratch); + case ValueKind::kString: + return value.GetString().NativeString(scratch); + case ValueKind::kBytes: + return value.GetBytes().NativeString(scratch); + case ValueKind::kNull: + return "null"; + case ValueKind::kInt: + absl::StrAppend(&scratch, value.GetInt().NativeValue()); + return scratch; + case ValueKind::kUint: + absl::StrAppend(&scratch, value.GetUint().NativeValue()); + return scratch; + case ValueKind::kDouble: { + auto number = value.GetDouble().NativeValue(); + if (std::isnan(number)) { + return "NaN"; + } + if (number == std::numeric_limits::infinity()) { + return "Infinity"; + } + if (number == -std::numeric_limits::infinity()) { + return "-Infinity"; + } + absl::StrAppend(&scratch, number); + return scratch; + } + case ValueKind::kTimestamp: + absl::StrAppend(&scratch, value.DebugString()); + return scratch; + case ValueKind::kDuration: + return FormatDuration(value, scratch); + case ValueKind::kBool: + if (value.GetBool().NativeValue()) { + return "true"; + } + return "false"; + case ValueKind::kType: + return value.GetType().name(); + default: + return absl::InvalidArgumentError(absl::StrFormat( + "could not convert argument %s to string", value.GetTypeName())); + } +} + +absl::StatusOr FormatDecimal( + const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + scratch.clear(); + switch (value.kind()) { + case ValueKind::kInt: + absl::StrAppend(&scratch, value.GetInt().NativeValue()); + return scratch; + case ValueKind::kUint: + absl::StrAppend(&scratch, value.GetUint().NativeValue()); + return scratch; + case ValueKind::kDouble: + return FormatDouble(value.GetDouble().NativeValue(), + /*precision=*/std::nullopt, + /*use_scientific_notation=*/false, scratch); + default: + return absl::InvalidArgumentError( + absl::StrCat("decimal clause can only be used on numbers, was given ", + value.GetTypeName())); + } +} + +absl::StatusOr FormatBinary( + const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + decltype(value.GetUint().NativeValue()) unsigned_value; + bool sign_bit = false; + switch (value.kind()) { + case ValueKind::kInt: { + auto tmp = value.GetInt().NativeValue(); + if (tmp < 0) { + sign_bit = true; + // Negating min int is undefined behavior, so we need to use unsigned + // arithmetic. + using unsigned_type = std::make_unsigned::type; + unsigned_value = -static_cast(tmp); + } else { + unsigned_value = tmp; + } + break; + } + case ValueKind::kUint: + unsigned_value = value.GetUint().NativeValue(); + break; + case ValueKind::kBool: + if (value.GetBool().NativeValue()) { + return "1"; + } + return "0"; + default: + return absl::InvalidArgumentError(absl::StrCat( + "binary clause can only be used on integers and bools, was given ", + value.GetTypeName())); + } + + if (unsigned_value == 0) { + return "0"; + } + + int size = absl::bit_width(unsigned_value) + sign_bit; + scratch.resize(size); + for (int i = size - 1; i >= 0; --i) { + if (unsigned_value & 1) { + scratch[i] = '1'; + } else { + scratch[i] = '0'; + } + unsigned_value >>= 1; + } + if (sign_bit) { + scratch[0] = '-'; + } + return scratch; +} + +absl::StatusOr FormatHex( + const Value& value, bool use_upper_case, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (value.kind()) { + case ValueKind::kString: + scratch = absl::BytesToHexString(value.GetString().NativeString(scratch)); + break; + case ValueKind::kBytes: + scratch = absl::BytesToHexString(value.GetBytes().NativeString(scratch)); + break; + case ValueKind::kInt: { + // Golang supports signed hex, but absl::StrFormat does not. To be + // compatible, we need to add a leading '-' if the value is negative. + auto tmp = value.GetInt().NativeValue(); + if (tmp < 0) { + // Negating min int is undefined behavior, so we need to use unsigned + // arithmetic. + using unsigned_type = std::make_unsigned::type; + scratch = absl::StrFormat("-%x", -static_cast(tmp)); + } else { + scratch = absl::StrFormat("%x", tmp); + } + break; + } + case ValueKind::kUint: + scratch = absl::StrFormat("%x", value.GetUint().NativeValue()); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("hex clause can only be used on integers, byte buffers, " + "and strings, was given ", + value.GetTypeName())); + } + if (use_upper_case) { + absl::AsciiStrToUpper(&scratch); + } + return scratch; +} + +absl::StatusOr FormatOctal( + const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (value.kind()) { + case ValueKind::kInt: { + // Golang supports signed octals, but absl::StrFormat does not. To be + // compatible, we need to add a leading '-' if the value is negative. + auto tmp = value.GetInt().NativeValue(); + if (tmp < 0) { + // Negating min int is undefined behavior, so we need to use unsigned + // arithmetic. + using unsigned_type = std::make_unsigned::type; + scratch = absl::StrFormat("-%o", -static_cast(tmp)); + } else { + scratch = absl::StrFormat("%o", tmp); + } + return scratch; + } + case ValueKind::kUint: + scratch = absl::StrFormat("%o", value.GetUint().NativeValue()); + return scratch; + default: + return absl::InvalidArgumentError( + absl::StrCat("octal clause can only be used on integers, was given ", + value.GetTypeName())); + } +} + +absl::StatusOr GetDouble(const Value& value, std::string& scratch) { + if (value.kind() == ValueKind::kString) { + auto str = value.GetString().NativeString(scratch); + if (str == "NaN") { + return std::nan(""); + } else if (str == "Infinity") { + return std::numeric_limits::infinity(); + } else if (str == "-Infinity") { + return -std::numeric_limits::infinity(); + } else { + return absl::InvalidArgumentError( + absl::StrCat("only \"NaN\", \"Infinity\", and \"-Infinity\" are " + "supported for conversion to double: ", + str)); + } + } + if (value.kind() == ValueKind::kInt) { + return static_cast(value.GetInt().NativeValue()); + } + if (value.kind() == ValueKind::kUint) { + return static_cast(value.GetUint().NativeValue()); + } + if (value.kind() != ValueKind::kDouble) { + return absl::InvalidArgumentError( + absl::StrCat("expected a double but got a ", value.GetTypeName())); + } + return value.GetDouble().NativeValue(); +} + +absl::StatusOr FormatFixed( + const Value& value, std::optional precision, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto number, GetDouble(value, scratch)); + return FormatDouble(number, precision, + /*use_scientific_notation=*/false, scratch); +} + +absl::StatusOr FormatScientific( + const Value& value, std::optional precision, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto number, GetDouble(value, scratch)); + return FormatDouble(number, precision, + /*use_scientific_notation=*/true, scratch); +} + +absl::StatusOr> ParseAndFormatClause( + absl::string_view format, const Value& value, int max_precision, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + CEL_ASSIGN_OR_RETURN(auto precision_pair, + ParsePrecision(format, max_precision)); + auto [read, precision] = precision_pair; + switch (format[read]) { + case 's': { + CEL_ASSIGN_OR_RETURN(auto result, + FormatString(value, descriptor_pool, message_factory, + arena, scratch)); + return std::pair{read, result}; + } + case 'd': { + CEL_ASSIGN_OR_RETURN(auto result, FormatDecimal(value, scratch)); + return std::pair{read, result}; + } + case 'f': { + CEL_ASSIGN_OR_RETURN(auto result, FormatFixed(value, precision, scratch)); + return std::pair{read, result}; + } + case 'e': { + CEL_ASSIGN_OR_RETURN(auto result, + FormatScientific(value, precision, scratch)); + return std::pair{read, result}; + } + case 'b': { + CEL_ASSIGN_OR_RETURN(auto result, FormatBinary(value, scratch)); + return std::pair{read, result}; + } + case 'x': + case 'X': { + CEL_ASSIGN_OR_RETURN( + auto result, + FormatHex(value, + /*use_upper_case=*/format[read] == 'X', scratch)); + return std::pair{read, result}; + } + case 'o': { + CEL_ASSIGN_OR_RETURN(auto result, FormatOctal(value, scratch)); + return std::pair{read, result}; + } + default: + return absl::InvalidArgumentError(absl::StrFormat( + "unrecognized formatting clause \"%c\"", format[read])); + } +} + +absl::StatusOr Format( + const StringValue& format_value, const ListValue& args, int max_precision, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string format_scratch, clause_scratch; + absl::string_view format = format_value.NativeString(format_scratch); + std::string result; + result.reserve(format.size()); + int64_t arg_index = 0; + CEL_ASSIGN_OR_RETURN(int64_t args_size, args.Size()); + for (int64_t i = 0; i < format.size(); ++i) { + clause_scratch.clear(); + if (format[i] != '%') { + result.push_back(format[i]); + continue; + } + ++i; + if (i >= format.size()) { + return ErrorValue( + absl::InvalidArgumentError("unexpected end of format string")); + } + if (format[i] == '%') { + result.push_back('%'); + continue; + } + if (arg_index >= args_size) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("index %d out of range", arg_index))); + } + CEL_ASSIGN_OR_RETURN(auto value, args.Get(arg_index++, descriptor_pool, + message_factory, arena)); + + auto clause = ParseAndFormatClause(format.substr(i), value, max_precision, + descriptor_pool, message_factory, arena, + clause_scratch); + if (!clause.ok()) { + return ErrorValue(std::move(clause).status()); + } + absl::StrAppend(&result, clause->second); + i += clause->first; + } + return StringValue::From(std::move(result), arena); +} + +} // namespace + +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options) { + const int max_precision = + std::clamp(format_options.max_precision, 0, kMaxPrecision); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, StringValue, ListValue>:: + CreateDescriptor("format", /*receiver_style=*/true), + BinaryFunctionAdapter, StringValue, ListValue>:: + WrapFunction( + [max_precision]( + const StringValue& format, const ListValue& args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return Format(format, args, max_precision, descriptor_pool, + message_factory, arena); + }))); + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/formatting.h b/extensions/formatting.h new file mode 100644 index 000000000..88954857b --- /dev/null +++ b/extensions/formatting.h @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +struct StringsExtensionFormatOptions { + // The maximum precision to permit for formatting floating-point numbers. + int max_precision = 1000; +}; + +// Register extension functions for string formatting. +// +// This implements (string).format([args...]) in the strings extension. Most +// users should add these functions via `extensions/strings.h` instead. +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options = {}); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc new file mode 100644 index 000000000..6a7fb300b --- /dev/null +++ b/extensions/formatting_test.cc @@ -0,0 +1,980 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/formatting.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::parser::ParserOptions; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +using StringFormatLimitsTest = TestWithParam; + +// Check that formatted floating points are reversible. +TEST_P(StringFormatLimitsTest, FormatLimits) { + google::protobuf::Arena arena; + const RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + RegisterStringFormattingFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(GetParam(), "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + static_assert(std::numeric_limits::min_exponent == -1021); + for (double x : { + 0x1p-1021, + 0x3p-1021, + std::numeric_limits::epsilon() * 0x1p-3, + std::numeric_limits::epsilon() * 0x7p-3, + 1.1 / 7.0 * 1e-101, + 1.2 / 7.0 * 1e-101, + }) { + activation.InsertOrAssignValue("x", DoubleValue(x)); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); + } +} + +TEST(StringFormatLimitsTest, MaxPrecisionOption) { + google::protobuf::Arena arena; + const RuntimeOptions options; + StringsExtensionFormatOptions format_options; + format_options.max_precision = 99; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT(RegisterStringFormattingFunctions(builder.function_registry(), + options, format_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'%.100f'.format([1.123])", + "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus().message(), + HasSubstr("precision specifier exceeds maximum of 99")); +} + +INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, + ValuesIn({ + "double('%.326f'.format([x])) == x", + "double('%.17e'.format([x])) == x", + })); + +struct FormattingTestCase { + std::string name; + std::string format; + std::string format_args; + absl::flat_hash_map> + dyn_args; + std::string expected; + std::optional error = std::nullopt; +}; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +template +ParsedMessageValue MakeMessage(absl::string_view text) { + return ParsedMessageValue( + internal::DynamicParseTextProto(GetTestArena(), text, + internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory()), + GetTestArena()); +} + +using StringFormatTest = TestWithParam; +TEST_P(StringFormatTest, TestStringFormatting) { + const FormattingTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + const RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + auto registration_status = + RegisterStringFormattingFunctions(builder.function_registry(), options); + if (test_case.error.has_value() && !registration_status.ok()) { + EXPECT_THAT(registration_status.message(), HasSubstr(*test_case.error)); + return; + } else { + ASSERT_THAT(registration_status, IsOk()); + } + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + auto expr_str = absl::StrFormat("'''%s'''.format([%s])", test_case.format, + test_case.format_args); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(expr_str, "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + for (const auto& [name, value] : test_case.dyn_args) { + if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, + StringValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, BoolValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, IntValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, IntValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, + UintValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, + DoubleValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue( + name, DurationValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue( + name, TimestampValue{std::get(value)}); + } else if (std::holds_alternative(value)) { + activation.InsertOrAssignValue(name, std::get(value)); + } + } + auto result = program->Evaluate(&arena, activation); + if (test_case.error.has_value()) { + if (result.ok()) { + EXPECT_THAT(result->DebugString(), HasSubstr(*test_case.error)); + } else { + EXPECT_THAT(result.status().message(), HasSubstr(*test_case.error)); + } + } else { + if (!result.ok()) { + // Make it easier to debug the test case. + ASSERT_THAT(result.status().message(), ""); + // Make sure test case stops here. + ASSERT_TRUE(result.ok()); + } + ASSERT_TRUE(result->Is()); + EXPECT_THAT(result->GetString().ToString(), test_case.expected); + } +} + +INSTANTIATE_TEST_SUITE_P( + TestStringFormatting, StringFormatTest, + ValuesIn({ + { + .name = "Basic", + .format = "%s %s!", + .format_args = "'hello', 'world'", + .expected = "hello world!", + }, + { + .name = "EscapedPercentSign", + .format = "Percent sign %%!", + .format_args = "'hello', 'world'", + .expected = "Percent sign %!", + }, + { + .name = "IncompleteCase", + .format = "%", + .format_args = "'hello'", + .error = "unexpected end of format string", + }, + { + .name = "MissingFormatArg", + .format = "%s", + .format_args = "", + .error = "index 0 out of range", + }, + { + .name = "MissingFormatArg2", + .format = "%s, %s", + .format_args = "'hello'", + .error = "index 1 out of range", + }, + { + .name = "InvalidPrecision", + .format = "%.6", + .format_args = "'hello'", + .error = "unable to find end of precision specifier", + }, + { + .name = "InvalidPrecision2", + .format = "%.f", + .format_args = "'hello'", + .error = "unable to convert precision specifier to integer", + }, + { + .name = "InvalidPrecision3", + .format = "%.", + .format_args = "'hello'", + .error = "unable to find end of precision specifier", + }, + { + .name = "InvalidPrecisionOutOfRange", + .format = "%.1001f", + .format_args = "1.2345", + .error = "precision specifier exceeds maximum of 100", + }, + { + .name = "DecimalFormatingClause", + .format = "int %d, uint %d", + .format_args = "-1, uint(2)", + .expected = R"(int -1, uint 2)", + }, + { + .name = "OctalFormatingClause", + .format = "int %o, uint %o", + .format_args = "-10, uint(20)", + .expected = R"(int -12, uint 24)", + }, + { + .name = "OctalDoesNotWorkWithDouble", + .format = "double %o", + .format_args = "double(\"-Inf\")", + .error = + "octal clause can only be used on integers, was given double", + }, + { + .name = "HexFormatingClause", + .format = "int %x, uint %X, string %x, bytes %X", + .format_args = "-10, uint(255), 'hello', b'world'", + .expected = "int -a, uint FF, string 68656c6c6f, bytes 776F726C64", + }, + { + .name = "HexFormatingClauseLeadingZero", + .format = "string: %x", + .format_args = R"(b'\x00\x00hello\x00')", + .expected = "string: 000068656c6c6f00", + }, + { + .name = "HexDoesNotWorkWithDouble", + .format = "double %x", + .format_args = "double(\"-Inf\")", + .error = "hex clause can only be used on integers, byte buffers, " + "and strings, was given double", + }, + { + .name = "BinaryFormatingClause", + .format = "int %b, uint %b, bool %b, bool %b", + .format_args = "-32, uint(20), false, true", + .expected = "int -100000, uint 10100, bool 0, bool 1", + }, + { + .name = "BinaryFormatingClauseLimits", + .format = "min_int %b, max_int %b, max_uint %b", + .format_args = + absl::StrCat(std::numeric_limits::min(), ",", + std::numeric_limits::max(), ",", + std::numeric_limits::max(), "u"), + .expected = "min_int " + "-10000000000000000000000000000000000000000000000000000" + "00000000000, max_int " + "111111111111111111111111111111111111111111111111111111" + "111111111, max_uint " + "111111111111111111111111111111111111111111111111111111" + "1111111111", + }, + { + .name = "BinaryFormatingClauseZero", + .format = "zero %b", + .format_args = "0", + .expected = "zero 0", + }, + { + .name = "HexFormatingClauseLimits", + .format = "min_int %x, max_int %x, max_uint %x", + .format_args = + absl::StrCat(std::numeric_limits::min(), ",", + std::numeric_limits::max(), ",", + std::numeric_limits::max(), "u"), + .expected = "min_int -8000000000000000, max_int 7fffffffffffffff, " + "max_uint ffffffffffffffff", + }, + { + .name = "OctalFormatingClauseLimits", + .format = "min_int %o, max_int %o, max_uint %o", + .format_args = + absl::StrCat(std::numeric_limits::min(), ",", + std::numeric_limits::max(), ",", + std::numeric_limits::max(), "u"), + .expected = + "min_int -1000000000000000000000, max_int " + "777777777777777777777, max_uint 1777777777777777777777", + }, + { + .name = "FixedClauseFormatting", + .format = "%f", + .format_args = "10000.1234", + .expected = "10000.123400", + }, + { + .name = "FixedClauseFormattingWithPrecision", + .format = "%.2f", + .format_args = "10000.1234", + .expected = "10000.12", + }, + { + .name = "ListSupportForStringWithQuotes", + .format = "%s", + .format_args = R"(["a\"b","a\\b"])", + .expected = "[a\"b, a\\b]", + }, + { + .name = "ListSupportForStringWithDouble", + .format = "%s", + .format_args = + R"([double("NaN"),double("Infinity"), double("-Infinity")])", + .expected = "[NaN, Infinity, -Infinity]", + }, + FormattingTestCase{ + .name = "FixedClauseFormattingWithDynArgs", + .format = "%.2f %d", + .format_args = "arg, message.single_int32", + .dyn_args = + { + {"arg", 10000.1234}, + {"message", + MakeMessage(R"pb(single_int32: 42)pb")}, + }, + .expected = "10000.12 42", + }, + { + .name = "NoOp", + .format = "no substitution", + .expected = "no substitution", + }, + { + .name = "MidStringSubstitution", + .format = "str is %s and some more", + .format_args = "'filler'", + .expected = "str is filler and some more", + }, + { + .name = "PercentEscaping", + .format = "%% and also %%", + .expected = "% and also %", + }, + { + .name = "SubstitutionInsideEscapedPercentSigns", + .format = "%%%s%%", + .format_args = "'text'", + .expected = "%text%", + }, + { + .name = "SubstitutionWithOneEscapedPercentSignOnTheRight", + .format = "%s%%", + .format_args = "'percent on the right'", + .expected = "percent on the right%", + }, + { + .name = "SubstitutionWithOneEscapedPercentSignOnTheLeft", + .format = "%%%s", + .format_args = "'percent on the left'", + .expected = "%percent on the left", + }, + { + .name = "MultipleSubstitutions", + .format = "%d %d %d, %s %s %s, %d %d %d, %s %s %s", + .format_args = "1, 2, 3, 'A', 'B', 'C', 4, 5, 6, 'D', 'E', 'F'", + .expected = "1 2 3, A B C, 4 5 6, D E F", + }, + { + .name = "PercentSignEscapeSequenceSupport", + .format = "\u0025\u0025escaped \u0025s\u0025\u0025", + .format_args = "'percent'", + .expected = "%escaped percent%", + }, + { + .name = "FixedPointFormattingClause", + .format = "%.3f", + .format_args = "1.2345", + .expected = "1.234", + }, + { + .name = "BinaryFormattingClause", + .format = "this is 5 in binary: %b", + .format_args = "5", + .expected = "this is 5 in binary: 101", + }, + { + .name = "UintSupportForBinaryFormatting", + .format = "unsigned 64 in binary: %b", + .format_args = "uint(64)", + .expected = "unsigned 64 in binary: 1000000", + }, + { + .name = "BoolSupportForBinaryFormatting", + .format = "bit set from bool: %b", + .format_args = "true", + .expected = "bit set from bool: 1", + }, + { + .name = "OctalFormattingClause", + .format = "%o", + .format_args = "11", + .expected = "13", + }, + { + .name = "UintSupportForOctalFormattingClause", + .format = "this is an unsigned octal: %o", + .format_args = "uint(65535)", + .expected = "this is an unsigned octal: 177777", + }, + { + .name = "LowercaseHexadecimalFormattingClause", + .format = "%x is 20 in hexadecimal", + .format_args = "30", + .expected = "1e is 20 in hexadecimal", + }, + { + .name = "UppercaseHexadecimalFormattingClause", + .format = "%X is 20 in hexadecimal", + .format_args = "30", + .expected = "1E is 20 in hexadecimal", + }, + { + .name = "UnsignedSupportForHexadecimalFormattingClause", + .format = "%X is 6000 in hexadecimal", + .format_args = "uint(6000)", + .expected = "1770 is 6000 in hexadecimal", + }, + { + .name = "StringSupportWithHexadecimalFormattingClause", + .format = "%x", + .format_args = R"("Hello world!")", + .expected = "48656c6c6f20776f726c6421", + }, + { + .name = "StringSupportWithUppercaseHexadecimalFormattingClause", + .format = "%X", + .format_args = R"("Hello world!")", + .expected = "48656C6C6F20776F726C6421", + }, + { + .name = "ByteSupportWithHexadecimalFormattingClause", + .format = "%x", + .format_args = R"(b"byte string")", + .expected = "6279746520737472696e67", + }, + { + .name = "ByteSupportWithUppercaseHexadecimalFormattingClause", + .format = "%X", + .format_args = R"(b"byte string")", + .expected = "6279746520737472696E67", + }, + { + .name = "ScientificNotationFormattingClause", + .format = "%.6e", + .format_args = "1052.032911275", + .expected = "1.052033e+03", + }, + { + .name = "ScientificNotationFormattingClause2", + .format = "%e", + .format_args = "1234.0", + .expected = "1.234000e+03", + }, + { + .name = "DefaultPrecisionForFixedPointClause", + .format = "%f", + .format_args = "2.71828", + .expected = "2.718280", + }, + { + .name = "DefaultPrecisionForScientificNotation", + .format = "%e", + .format_args = "2.71828", + .expected = "2.718280e+00", + }, + { + .name = "FixedPointClauseWithInt", + .format = "%f", + .format_args = "3", + .expected = "3.000000", + }, + { + .name = "ScientificNotationWithUint", + .format = "%e", + .format_args = "uint(3)", + .expected = "3.000000e+00", + }, + { + .name = "NaNSupportForFixedPoint", + .format = "%f", + .format_args = "\"NaN\"", + .expected = "NaN", + }, + { + .name = "PositiveInfinitySupportForFixedPoint", + .format = "%f", + .format_args = "\"Infinity\"", + .expected = "Infinity", + }, + { + .name = "NegativeInfinitySupportForFixedPoint", + .format = "%f", + .format_args = "\"-Infinity\"", + .expected = "-Infinity", + }, + { + .name = "UintSupportForDecimalClause", + .format = "%d", + .format_args = "uint(64)", + .expected = "64", + }, + { + .name = "NullSupportForString", + .format = "null: %s", + .format_args = "null", + .expected = "null: null", + }, + { + .name = "IntSupportForString", + .format = "%s", + .format_args = "999999999999", + .expected = "999999999999", + }, + { + .name = "BytesSupportForString", + .format = "some bytes: %s", + .format_args = "b\"xyz\"", + .expected = "some bytes: xyz", + }, + { + .name = "TypeSupportForString", + .format = "type is %s", + .format_args = "type(\"test string\")", + .expected = "type is string", + }, + { + .name = "TimestampSupportForString", + .format = "%s", + .format_args = "timestamp(\"2023-02-03T23:31:20+00:00\")", + .expected = "2023-02-03T23:31:20Z", + }, + { + .name = "DurationSupportForString", + .format = "%s", + .format_args = "duration(\"1h45m47s\")", + .expected = "6347s", + }, + { + .name = "ListSupportForString", + .format = "%s", + .format_args = + R"(["abc", 3.14, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")])", + .expected = + R"([abc, 3.14, null, [9, 8, 7, 6], 2023-02-03T23:31:20Z])", + }, + { + .name = "MapSupportForString", + .format = "%s", + .format_args = + R"({"key1": b"xyz", "key5": null, "key2": duration("7200s"), "key4": true, "key3": 2.71828})", + .expected = + R"({key1: xyz, key2: 7200s, key3: 2.71828, key4: true, key5: null})", + }, + { + .name = "MapSupportAllKeyTypes", + .format = "map with multiple key types: %s", + .format_args = + R"({1: "value1", uint(2): "value2", true: double("NaN")})", + .expected = "map with multiple key types: {1: value1, 2: value2, " + "true: NaN}", + }, + { + .name = "MapAfterDecimalFormatting", + .format = "%d %s", + .format_args = R"(42, {"key": 1})", + .expected = "42 {key: 1}", + }, + { + .name = "BooleanSupportForString", + .format = "true bool: %s, false bool: %s", + .format_args = "true, false", + .expected = "true bool: true, false bool: false", + }, + FormattingTestCase{ + .name = "DynTypeSupportForStringFormattingClause", + .format = "Dynamic String: %s", + .format_args = R"(dynStr)", + .dyn_args = {{"dynStr", std::string("a string")}}, + .expected = "Dynamic String: a string", + }, + FormattingTestCase{ + .name = "DynTypeSupportForNumbersWithStringFormattingClause", + .format = "Dynamic Int Str: %s Dynamic Double Str: %s", + .format_args = R"(dynIntStr, dynDoubleStr)", + .dyn_args = + { + {"dynIntStr", 32}, + {"dynDoubleStr", 56.8}, + }, + .expected = "Dynamic Int Str: 32 Dynamic Double Str: 56.8", + }, + FormattingTestCase{ + .name = "DynTypeSupportForIntegerFormattingClause", + .format = "Dynamic Int: %d", + .format_args = R"(dynInt)", + .dyn_args = {{"dynInt", 128}}, + .expected = "Dynamic Int: 128", + }, + FormattingTestCase{ + .name = "DynTypeSupportForIntegerFormattingClauseUnsigned", + .format = "Dynamic Unsigned Int: %d", + .format_args = R"(dynUnsignedInt)", + .dyn_args = {{"dynUnsignedInt", uint64_t{256}}}, + .expected = "Dynamic Unsigned Int: 256", + }, + FormattingTestCase{ + .name = "DynTypeSupportForHexFormattingClause", + .format = "Dynamic Hex Int: %x", + .format_args = R"(dynHexInt)", + .dyn_args = {{"dynHexInt", 22}}, + .expected = "Dynamic Hex Int: 16", + }, + FormattingTestCase{ + .name = "DynTypeSupportForHexFormattingClauseUppercase", + .format = "Dynamic Hex Int: %X (uppercase)", + .format_args = R"(dynHexInt)", + .dyn_args = {{"dynHexInt", 26}}, + .expected = "Dynamic Hex Int: 1A (uppercase)", + }, + FormattingTestCase{ + .name = "DynTypeSupportForUnsignedHexFormattingClause", + .format = "Dynamic Hex Int: %x (unsigned)", + .format_args = R"(dynUnsignedHexInt)", + .dyn_args = {{"dynUnsignedHexInt", uint64_t{500}}}, + .expected = "Dynamic Hex Int: 1f4 (unsigned)", + }, + FormattingTestCase{ + .name = "DynTypeSupportForFixedPointFormattingClause", + .format = "Dynamic Double: %.3f", + .format_args = R"(dynDouble)", + .dyn_args = {{"dynDouble", 4.5}}, + .expected = "Dynamic Double: 4.500", + }, + FormattingTestCase{ + .name = "DynTypeSupportForFixedPointFormattingClauseCommaSeparatorL" + "ocale", + .format = "Dynamic Double: %f", + .format_args = R"(dynDouble)", + .dyn_args = {{"dynDouble", 4.5}}, + .expected = "Dynamic Double: 4.500000", + }, + FormattingTestCase{ + .name = "DynTypeSupportForScientificNotation", + .format = "(Dynamic Type) E: %e", + .format_args = R"(dynE)", + .dyn_args = {{"dynE", 2.71828}}, + .expected = "(Dynamic Type) E: 2.718280e+00", + }, + FormattingTestCase{ + .name = "DynTypeNaNInfinitySupportForFixedPoint", + .format = "NaN: %f, Infinity: %f", + .format_args = R"(dynNaN, dynInf)", + .dyn_args = {{"dynNaN", std::nan("")}, + {"dynInf", std::numeric_limits::infinity()}}, + .expected = "NaN: NaN, Infinity: Infinity", + }, + FormattingTestCase{ + .name = "DynTypeSupportForTimestamp", + .format = "Dynamic Type Timestamp: %s", + .format_args = R"(dynTime)", + .dyn_args = {{"dynTime", absl::FromUnixSeconds(1257894000)}}, + .expected = "Dynamic Type Timestamp: 2009-11-10T23:00:00Z", + }, + FormattingTestCase{ + .name = "DynTypeSupportForDuration", + .format = "Dynamic Type Duration: %s", + .format_args = R"(dynDuration)", + .dyn_args = {{"dynDuration", absl::Hours(2) + absl::Minutes(25) + + absl::Seconds(47)}}, + .expected = "Dynamic Type Duration: 8747s", + }, + FormattingTestCase{ + .name = "DynTypeSupportForMaps", + .format = "Dynamic Type Map with Duration: %s", + .format_args = R"({6:dyn(duration("422s"))})", + .expected = "Dynamic Type Map with Duration: {6: 422s}", + }, + FormattingTestCase{ + .name = "DurationsWithSubseconds", + .format = "Durations with subseconds: %s", + .format_args = + R"([duration("422s"), duration("2s123ms"), duration("1us"), duration("1ns"), duration("-1000000ns")])", + .expected = "Durations with subseconds: [422s, 2.123s, 0.000001s, " + "0.000000001s, -0.001s]", + }, + { + .name = "UnrecognizedFormattingClause", + .format = "%a", + .format_args = "1", + .error = "unrecognized formatting clause \"a\"", + }, + { + .name = "OutOfBoundsArgIndex", + .format = "%d %d %d", + .format_args = "0, 1", + .error = "index 2 out of range", + }, + { + .name = "StringSubstitutionIsNotAllowedWithBinaryClause", + .format = "string is %b", + .format_args = "\"abc\"", + .error = "binary clause can only be used on integers and bools, " + "was given string", + }, + { + .name = "DurationSubstitutionIsNotAllowedWithDecimalClause", + .format = "%d", + .format_args = "duration(\"30m2s\")", + .error = "decimal clause can only be used on numbers, was given " + "google.protobuf.Duration", + }, + { + .name = "StringSubstitutionIsNotAllowedWithOctalClause", + .format = "octal: %o", + .format_args = "\"a string\"", + .error = + "octal clause can only be used on integers, was given string", + }, + { + .name = "DoubleSubstitutionIsNotAllowedWithHexClause", + .format = "double is %x", + .format_args = "0.5", + .error = "hex clause can only be used on integers, byte buffers, " + "and strings, was given double", + }, + { + .name = "UppercaseIsNotAllowedForScientificClause", + .format = "double is %E", + .format_args = "0.5", + .error = "unrecognized formatting clause \"E\"", + }, + { + .name = "ObjectIsNotAllowed", + .format = "object is %s", + .format_args = "cel.expr.conformance.proto3.TestAllTypes{}", + .error = "could not convert argument " + "cel.expr.conformance.proto3.TestAllTypes to string", + }, + { + .name = "ObjectInsideList", + .format = "%s", + .format_args = "[1, 2, cel.expr.conformance.proto3.TestAllTypes{}]", + .error = "could not convert argument " + "cel.expr.conformance.proto3.TestAllTypes to string", + }, + { + .name = "ObjectInsideMap", + .format = "%s", + .format_args = + "{1: \"a\", 2: cel.expr.conformance.proto3.TestAllTypes{}}", + .error = "could not convert argument " + "cel.expr.conformance.proto3.TestAllTypes to string", + }, + { + .name = "NullNotAllowedForDecimalClause", + .format = "null: %d", + .format_args = "null", + .error = "decimal clause can only be used on numbers, was given " + "null_type", + }, + { + .name = "NullNotAllowedForScientificNotationClause", + .format = "null: %e", + .format_args = "null", + .error = "expected a double but got a null_type", + }, + { + .name = "NullNotAllowedForFixedPointClause", + .format = "null: %f", + .format_args = "null", + .error = "expected a double but got a null_type", + }, + { + .name = "NullNotAllowedForHexadecimalClause", + .format = "null: %x", + .format_args = "null", + .error = "hex clause can only be used on integers, byte buffers, " + "and strings, was given null_type", + }, + { + .name = "NullNotAllowedForUppercaseHexadecimalClause", + .format = "null: %X", + .format_args = "null", + .error = "hex clause can only be used on integers, byte buffers, " + "and strings, was given null_type", + }, + { + .name = "NullNotAllowedForBinaryClause", + .format = "null: %b", + .format_args = "null", + .error = "binary clause can only be used on integers and bools, " + "was given null_type", + }, + { + .name = "NullNotAllowedForOctalClause", + .format = "null: %o", + .format_args = "null", + .error = "octal clause can only be used on integers, was given " + "null_type", + }, + { + .name = "NegativeBinaryFormattingClause", + .format = "this is -5 in binary: %b", + .format_args = "-5", + .expected = "this is -5 in binary: -101", + }, + { + .name = "NegativeOctalFormattingClause", + .format = "%o", + .format_args = "-11", + .expected = "-13", + }, + { + .name = "NegativeHexadecimalFormattingClause", + .format = "%x is -30 in hexadecimal", + .format_args = "-30", + .expected = "-1e is -30 in hexadecimal", + }, + { + .name = "DefaultPrecisionForString", + .format = "%s", + .format_args = "2.71", + .expected = "2.71", + }, + { + .name = "DefaultListPrecisionForString", + .format = "%s", + .format_args = "[2.71]", + .expected = + "[2.71]", // Different from Golang (2.710000) consistent with + // the precision of a double outside of a list. + }, + { + .name = "AutomaticRoundingForString", + .format = "%s", + .format_args = "10002.71", + .expected = "10002.7", // Different from Golang (10002.71) which + // does not round. + }, + { + .name = "DefaultScientificNotationForString", + .format = "%s", + .format_args = "0.000000002", + .expected = "2e-09", + }, + { + .name = "DefaultListScientificNotationForString", + .format = "%s", + .format_args = "[0.000000002]", + .expected = + "[2e-09]", // Different from Golang (0.000000) consistent with + // the notation of a double outside of a list. + }, + { + .name = "NaNSupportForString", + .format = "%s", + .format_args = R"(double("NaN"))", + .expected = "NaN", + }, + { + .name = "PositiveInfinitySupportForString", + .format = "%s", + .format_args = R"(double("Inf"))", + .expected = "Infinity", + }, + { + .name = "NegativeInfinitySupportForString", + .format = "%s", + .format_args = R"(double("-Inf"))", + .expected = "-Infinity", + }, + { + .name = "InfinityListSupportForString", + .format = "%s", + .format_args = R"([double("NaN"), double("+Inf"), double("-Inf")])", + .expected = "[NaN, Infinity, -Infinity]", + }, + { + .name = "SmallDurationSupportForString", + .format = "%s", + .format_args = R"(duration("2ns"))", + .expected = "0.000000002s", + }, + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/lists_functions.cc b/extensions/lists_functions.cc new file mode 100644 index 000000000..bfe05d887 --- /dev/null +++ b/extensions/lists_functions.cc @@ -0,0 +1,702 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/lists_functions.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/operators.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser_interface.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { +namespace { + +using ::cel::checker_internal::BuiltinsArena; + +absl::Span SortableTypes() { + static const Type kTypes[]{cel::IntType(), cel::UintType(), + cel::DoubleType(), cel::BoolType(), + cel::DurationType(), cel::TimestampType(), + cel::StringType(), cel::BytesType()}; + + return kTypes; +} + +// Slow distinct() implementation that uses Equal() to compare values in O(n^2). +absl::Status ListDistinctHeterogeneousImpl( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValueBuilder* absl_nonnull builder, + int64_t start_index = 0, std::vector seen = {}) { + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + for (int64_t i = start_index; i < size; ++i) { + CEL_ASSIGN_OR_RETURN(Value value, + list.Get(i, descriptor_pool, message_factory, arena)); + bool is_distinct = true; + for (const Value& seen_value : seen) { + CEL_ASSIGN_OR_RETURN(Value equal, value.Equal(seen_value, descriptor_pool, + message_factory, arena)); + if (equal.IsTrue()) { + is_distinct = false; + break; + } + } + if (is_distinct) { + seen.push_back(value); + CEL_RETURN_IF_ERROR(builder->Add(value)); + } + } + return absl::OkStatus(); +} + +// Fast distinct() implementation for homogeneous hashable types. Falls back to +// the slow implementation if the list is not actually homogeneous. +template +absl::Status ListDistinctHomogeneousHashableImpl( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValueBuilder* absl_nonnull builder) { + absl::flat_hash_set seen; + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + for (int64_t i = 0; i < size; ++i) { + CEL_ASSIGN_OR_RETURN(Value value, + list.Get(i, descriptor_pool, message_factory, arena)); + if (auto typed_value = value.As(); typed_value.has_value()) { + if (seen.contains(*typed_value)) { + continue; + } + seen.insert(*typed_value); + CEL_RETURN_IF_ERROR(builder->Add(value)); + } else { + // List is not homogeneous, fall back to the slow implementation. + // Keep the existing list builder, which already constructed the list of + // all the distinct values (that were homogeneous so far) up to index i. + // Pass the seen values as a vector to the slow implementation. + std::vector seen_values{seen.begin(), seen.end()}; + return ListDistinctHeterogeneousImpl(list, descriptor_pool, + message_factory, arena, builder, i, + std::move(seen_values)); + } + } + return absl::OkStatus(); +} + +absl::StatusOr ListDistinct( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + // If the list is empty or has a single element, we can return it as is. + if (size < 2) { + return list; + } + + // We need a set to keep track of the seen values. + // + // By default, for unhashable types, this set is implemented as a vector of + // all the seen values, which means that we will perform O(n^2) comparisons + // between the values. + // + // For efficiency purposes, if the first element of the list is hashable, we + // will use a specialized implementation that is faster for homogeneous lists + // of hashable types. + // If the list is not homogeneous, we will fall back to the slow + // implementation. + // + // The total runtime cost is O(n) for homogeneous lists of hashable types, and + // O(n^2) for all other cases. + auto builder = NewListValueBuilder(arena); + CEL_ASSIGN_OR_RETURN(Value first, + list.Get(0, descriptor_pool, message_factory, arena)); + switch (first.kind()) { + case ValueKind::kInt: { + CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( + list, descriptor_pool, message_factory, arena, builder.get())); + break; + } + case ValueKind::kUint: { + CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( + list, descriptor_pool, message_factory, arena, builder.get())); + break; + } + case ValueKind::kBool: { + CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( + list, descriptor_pool, message_factory, arena, builder.get())); + break; + } + case ValueKind::kString: { + CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( + list, descriptor_pool, message_factory, arena, builder.get())); + break; + } + default: { + CEL_RETURN_IF_ERROR(ListDistinctHeterogeneousImpl( + list, descriptor_pool, message_factory, arena, builder.get())); + break; + } + } + return std::move(*builder).Build(); +} + +absl::Status ListFlattenImpl( + const ListValue& list, int64_t remaining_depth, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, ListValueBuilder* absl_nonnull builder) { + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + for (int64_t i = 0; i < size; ++i) { + CEL_ASSIGN_OR_RETURN(Value value, + list.Get(i, descriptor_pool, message_factory, arena)); + if (absl::optional list_value = value.AsList(); + list_value.has_value() && remaining_depth > 0) { + CEL_RETURN_IF_ERROR(ListFlattenImpl(*list_value, remaining_depth - 1, + descriptor_pool, message_factory, + arena, builder)); + } else { + CEL_RETURN_IF_ERROR(builder->Add(std::move(value))); + } + } + return absl::OkStatus(); +} + +absl::StatusOr ListFlatten( + const ListValue& list, int64_t depth, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (depth < 0) { + return ErrorValue( + absl::InvalidArgumentError("flatten(): level must be non-negative")); + } + auto builder = NewListValueBuilder(arena); + CEL_RETURN_IF_ERROR(ListFlattenImpl(list, depth, descriptor_pool, + message_factory, arena, builder.get())); + return std::move(*builder).Build(); +} + +absl::StatusOr ListRange( + int64_t end, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + auto builder = NewListValueBuilder(arena); + builder->Reserve(end); + for (int64_t i = 0; i < end; ++i) { + CEL_RETURN_IF_ERROR(builder->Add(IntValue(i))); + } + return std::move(*builder).Build(); +} + +absl::StatusOr ListReverse( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + auto builder = NewListValueBuilder(arena); + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + for (ptrdiff_t i = size - 1; i >= 0; --i) { + CEL_ASSIGN_OR_RETURN(Value value, + list.Get(i, descriptor_pool, message_factory, arena)); + CEL_RETURN_IF_ERROR(builder->Add(value)); + } + return std::move(*builder).Build(); +} + +absl::StatusOr ListSlice( + const ListValue& list, int64_t start, int64_t end, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + if (start < 0 || end < 0) { + return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( + "cannot slice(%d, %d), negative indexes not supported", start, end))); + } + if (start > end) { + return cel::ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("cannot slice(%d, %d), start index must be less than " + "or equal to end index", + start, end))); + } + if (size < end) { + return cel::ErrorValue(absl::InvalidArgumentError(absl::StrFormat( + "cannot slice(%d, %d), list is length %d", start, end, size))); + } + auto builder = NewListValueBuilder(arena); + for (int64_t i = start; i < end; ++i) { + CEL_ASSIGN_OR_RETURN(Value val, + list.Get(i, descriptor_pool, message_factory, arena)); + CEL_RETURN_IF_ERROR(builder->Add(val)); + } + return std::move(*builder).Build(); +} + +template +absl::StatusOr ListSortByAssociatedKeysNative( + const ListValue& list, const ListValue& keys, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + // If the list is empty or has a single element, we can return it as is. + if (size < 2) { + return list; + } + std::vector keys_vec; + absl::Status status = keys.ForEach( + [&keys_vec](const Value& value) -> absl::StatusOr { + if (auto typed_value = value.As(); typed_value.has_value()) { + keys_vec.push_back(*typed_value); + } else { + return absl::InvalidArgumentError( + "sort(): list elements must have the same type"); + } + return true; + }, + descriptor_pool, message_factory, arena); + if (!status.ok()) { + return ErrorValue(status); + } + ABSL_ASSERT(keys_vec.size() == size); // Already checked by the caller. + std::vector sorted_indices(keys_vec.size()); + std::iota(sorted_indices.begin(), sorted_indices.end(), 0); + std::sort( + sorted_indices.begin(), sorted_indices.end(), + [&](int64_t a, int64_t b) -> bool { return keys_vec[a] < keys_vec[b]; }); + + // Now sorted_indices contains the indices of the keys in sorted order. + // We can use it to build the sorted list. + auto builder = NewListValueBuilder(arena); + for (const auto& index : sorted_indices) { + CEL_ASSIGN_OR_RETURN( + Value value, list.Get(index, descriptor_pool, message_factory, arena)); + CEL_RETURN_IF_ERROR(builder->Add(value)); + } + return std::move(*builder).Build(); +} + +// Internal function used for the implementation of sort() and sortBy(). +// +// Sorts a list of arbitrary elements, according to the order produced by +// sorting another list of comparable elements. If the element type of the keys +// is not comparable or the element types are not the same, the function will +// produce an error. +// +// .@sortByAssociatedKeys() -> +// U in {int, uint, double, bool, duration, timestamp, string, bytes} +// +// Example: +// +// ["foo", "bar", "baz"].@sortByAssociatedKeys([3, 1, 2]) +// -> returns ["bar", "baz", "foo"] +absl::StatusOr ListSortByAssociatedKeys( + const ListValue& list, const ListValue& keys, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(size_t list_size, list.Size()); + CEL_ASSIGN_OR_RETURN(size_t keys_size, keys.Size()); + if (list_size != keys_size) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("@sortByAssociatedKeys() expected a list of the same " + "size as the associated keys list, but got %d and %d " + "elements respectively.", + list_size, keys_size))); + } + // Empty lists are already sorted. + // We don't check for size == 1 because the list could contain a single + // element of a type that is not supported by this function. + if (list_size == 0) { + return list; + } + CEL_ASSIGN_OR_RETURN(Value first, + keys.Get(0, descriptor_pool, message_factory, arena)); + switch (first.kind()) { + case ValueKind::kInt: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + case ValueKind::kUint: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + case ValueKind::kDouble: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + case ValueKind::kBool: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + case ValueKind::kString: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + case ValueKind::kTimestamp: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + case ValueKind::kDuration: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + case ValueKind::kBytes: + return ListSortByAssociatedKeysNative( + list, keys, descriptor_pool, message_factory, arena); + default: + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("sort(): unsupported type %s", first.GetTypeName()))); + } +} + +// Create an expression equivalent to: +// target.map(varIdent, mapExpr) +absl::optional MakeMapComprehension(MacroExprFactory& factory, + Expr target, Expr var_ident, + Expr map_expr) { + auto step = factory.NewCall( + google::api::expr::common::CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(map_expr)))); + auto var_name = var_ident.ident_expr().name(); + return factory.NewComprehension(std::move(var_name), std::move(target), + factory.AccuVarName(), factory.NewList(), + factory.NewBoolConst(true), std::move(step), + factory.NewAccuIdent()); +} + +// Create an expression equivalent to: +// cel.bind(varIdent, varExpr, call_expr) +absl::optional MakeBindComprehension(MacroExprFactory& factory, + Expr var_ident, Expr var_expr, + Expr call_expr) { + auto var_name = var_ident.ident_expr().name(); + return factory.NewComprehension( + "#unused", factory.NewList(), std::move(var_name), std::move(var_expr), + factory.NewBoolConst(false), std::move(var_ident), std::move(call_expr)); +} + +// This macro transforms an expression like: +// +// mylistExpr.sortBy(e, -math.abs(e)) +// +// into something equivalent to: +// +// cel.bind( +// @__sortBy_input__, +// myListExpr, +// @__sortBy_input__.@sortByAssociatedKeys( +// @__sortBy_input__.map(e, -math.abs(e) +// ) +// ) +Macro ListSortByMacro() { + absl::StatusOr sortby_macro = Macro::Receiver( + "sortBy", 2, + [](MacroExprFactory& factory, Expr& target, + absl::Span args) -> absl::optional { + if (!target.has_ident_expr() && !target.has_select_expr() && + !target.has_list_expr() && !target.has_comprehension_expr() && + !target.has_call_expr()) { + return factory.ReportErrorAt( + target, + "sortBy can only be applied to a list, identifier, " + "comprehension, call or select expression"); + } + + auto sortby_input_ident = factory.NewIdent("@__sortBy_input__"); + auto sortby_input_expr = std::move(target); + auto key_ident = std::move(args[0]); + auto key_expr = std::move(args[1]); + + // Build the map expression: + // map_compr := @__sortBy_input__.map(key_ident, key_expr) + auto map_compr = + MakeMapComprehension(factory, factory.Copy(sortby_input_ident), + std::move(key_ident), std::move(key_expr)); + if (!map_compr.has_value()) { + return std::nullopt; + } + + // Build the call expression: + // call_expr := @__sortBy_input__.@sortByAssociatedKeys(map_compr) + std::vector call_args; + call_args.push_back(std::move(*map_compr)); + auto call_expr = factory.NewMemberCall("@sortByAssociatedKeys", + std::move(sortby_input_ident), + absl::MakeSpan(call_args)); + + // Build the returned bind expression: + // cel.bind(@__sortBy_input__, target, call_expr) + auto var_ident = factory.NewIdent("@__sortBy_input__"); + auto var_expr = std::move(sortby_input_expr); + auto bind_compr = + MakeBindComprehension(factory, std::move(var_ident), + std::move(var_expr), std::move(call_expr)); + return bind_compr; + }); + return *sortby_macro; +} + +absl::StatusOr ListSort( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return ListSortByAssociatedKeys(list, list, descriptor_pool, message_factory, + arena); +} + +absl::Status RegisterListDistinctFunction(FunctionRegistry& registry) { + return UnaryFunctionAdapter, const ListValue&>:: + RegisterMemberOverload("distinct", &ListDistinct, registry); +} + +absl::Status RegisterListFlattenFunction(FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter, const ListValue&, + int64_t>::RegisterMemberOverload("flatten", + &ListFlatten, + registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter, const ListValue&>:: + RegisterMemberOverload( + "flatten", + [](const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return ListFlatten(list, 1, descriptor_pool, message_factory, + arena); + }, + registry))); + return absl::OkStatus(); +} + +absl::Status RegisterListRangeFunction(FunctionRegistry& registry) { + return UnaryFunctionAdapter, + int64_t>::RegisterGlobalOverload("lists.range", + &ListRange, + registry); +} + +absl::Status RegisterListReverseFunction(FunctionRegistry& registry) { + return UnaryFunctionAdapter, const ListValue&>:: + RegisterMemberOverload("reverse", &ListReverse, registry); +} + +absl::Status RegisterListSliceFunction(FunctionRegistry& registry) { + return TernaryFunctionAdapter, const ListValue&, + int64_t, + int64_t>::RegisterMemberOverload("slice", + &ListSlice, + registry); +} + +absl::Status RegisterListSortFunction(FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter, const ListValue&>:: + RegisterMemberOverload("sort", &ListSort, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter< + absl::StatusOr, const ListValue&, + const ListValue&>::RegisterMemberOverload("@sortByAssociatedKeys", + &ListSortByAssociatedKeys, + registry))); + return absl::OkStatus(); +} + +const Type& ListIntType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), IntType())); + return *kInstance; +} + +const Type& ListTypeParamType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), TypeParamType("T"))); + return *kInstance; +} + +absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder, + int version) { + CEL_ASSIGN_OR_RETURN( + FunctionDecl distinct_decl, + MakeFunctionDecl("distinct", MakeMemberOverloadDecl( + "list_distinct", ListTypeParamType(), + ListTypeParamType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl flatten_decl, + MakeFunctionDecl( + "flatten", + MakeMemberOverloadDecl("list_flatten_int", ListType(), ListType(), + IntType()), + MakeMemberOverloadDecl("list_flatten", ListType(), ListType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl range_decl, + MakeFunctionDecl( + "lists.range", + MakeOverloadDecl("list_range", ListIntType(), IntType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl reverse_decl, + MakeFunctionDecl( + "reverse", MakeMemberOverloadDecl("list_reverse", ListTypeParamType(), + ListTypeParamType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl slice_decl, + MakeFunctionDecl( + "slice", + MakeMemberOverloadDecl("list_slice", ListTypeParamType(), + ListTypeParamType(), IntType(), IntType()))); + + static const absl::NoDestructor> kSortableListTypes([] { + std::vector instance; + instance.reserve(SortableTypes().size()); + for (const Type& type : SortableTypes()) { + instance.push_back(ListType(BuiltinsArena(), type)); + } + return instance; + }()); + + FunctionDecl sort_decl; + sort_decl.set_name("sort"); + FunctionDecl sort_by_key_decl; + sort_by_key_decl.set_name("@sortByAssociatedKeys"); + + for (const Type& list_type : *kSortableListTypes) { + std::string elem_type_name(list_type.AsList()->GetElement().name()); + + CEL_RETURN_IF_ERROR(sort_decl.AddOverload(MakeMemberOverloadDecl( + absl::StrCat("list_", elem_type_name, "_sort"), list_type, list_type))); + CEL_RETURN_IF_ERROR(sort_by_key_decl.AddOverload(MakeMemberOverloadDecl( + absl::StrCat("list_", elem_type_name, "_sortByAssociatedKeys"), + ListTypeParamType(), ListTypeParamType(), list_type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_by_key_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(distinct_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(range_decl))); + // MergeFunction is used to combine with the reverse function + // defined in strings extension. + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); + return absl::OkStatus(); +} + +std::vector lists_macros(int version) { + switch (version) { + case 0: + return {}; + case 1: + return {}; + case 2: + default: + return {ListSortByMacro()}; + }; +} + +absl::Status ConfigureParser(ParserBuilder& builder, int version) { + for (const Macro& macro : lists_macros(version)) { + CEL_RETURN_IF_ERROR(builder.AddMacro(macro)); + } + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterListsFunctions(FunctionRegistry& registry, + const RuntimeOptions& options, + int version) { + CEL_RETURN_IF_ERROR(RegisterListSliceFunction(registry)); + if (version == 0) { + return absl::OkStatus(); + } + + // Since version 1 + CEL_RETURN_IF_ERROR(RegisterListFlattenFunction(registry)); + if (version == 1) { + return absl::OkStatus(); + } + + // Since version 2 + CEL_RETURN_IF_ERROR(RegisterListDistinctFunction(registry)); + CEL_RETURN_IF_ERROR(RegisterListRangeFunction(registry)); + CEL_RETURN_IF_ERROR(RegisterListReverseFunction(registry)); + CEL_RETURN_IF_ERROR(RegisterListSortFunction(registry)); + return absl::OkStatus(); +} + +absl::Status RegisterListsMacros(MacroRegistry& registry, const ParserOptions&, + int version) { + return registry.RegisterMacros(lists_macros(version)); +} + +CheckerLibrary ListsCheckerLibrary(int version) { + return {.id = "cel.lib.ext.lists", + .configure = [version](TypeCheckerBuilder& builder) { + return RegisterListsCheckerDecls(builder, version); + }}; +} + +CompilerLibrary ListsCompilerLibrary(int version) { + auto lib = CompilerLibrary::FromCheckerLibrary(ListsCheckerLibrary(version)); + lib.configure_parser = [version](ParserBuilder& builder) { + return ConfigureParser(builder, version); + }; + return lib; +} + +} // namespace cel::extensions diff --git a/extensions/lists_functions.h b/extensions/lists_functions.h new file mode 100644 index 000000000..0b057170f --- /dev/null +++ b/extensions/lists_functions.h @@ -0,0 +1,103 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_LISTS_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_LISTS_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +constexpr int kListsExtensionLatestVersion = 2; + +// Register implementations for list extension functions. +// +// === Since version 0 === +// .slice(start: int, end: int) -> list(T) +// +// === Since version 1 === +// .flatten() -> list(dyn) +// .flatten(limit: int) -> list(dyn) +// +// === Since version 2 === +// lists.range(n: int) -> list(int) +// +// .distinct() -> list(T) +// +// .reverse() -> list(T) +// +// .sort() -> list(T) +// +absl::Status RegisterListsFunctions(FunctionRegistry& registry, + const RuntimeOptions& options, + int version = kListsExtensionLatestVersion); + +// Register list macros. +// +// === Since version 2 === +// +// .sortBy(, ) +absl::Status RegisterListsMacros(MacroRegistry& registry, + const ParserOptions& options, + int version = kListsExtensionLatestVersion); + +// Type check declarations for the lists extension library. +// Provides decls for the following functions: +// +// === Since version 0 === +// .slice(start: int, end: int) -> list(T) +// +// === Since version 1 === +// .flatten() -> list(dyn) +// .flatten(limit: int) -> list(dyn) +// +// === Since version 2 === +// lists.range(n: int) -> list(int) +// +// .distinct() -> list(T) +// +// .reverse() -> list(T) +// +// .sort() -> list(T_) where T_ is partially orderable +CheckerLibrary ListsCheckerLibrary(int version = kListsExtensionLatestVersion); + +// Provides decls for the following functions: +// +// === Since version 0 === +// .slice(start: int, end: int) -> list(T) +// +// === Since version 1 === +// .flatten() -> list(dyn) +// .flatten(limit: int) -> list(dyn) +// +// === Since version 2 === +// lists.range(n: int) -> list(int) +// +// .distinct() -> list(T) +// +// .reverse() -> list(T) +// +// .sort() -> list(T_) where T_ is partially orderable +CompilerLibrary ListsCompilerLibrary( + int version = kListsExtensionLatestVersion); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ diff --git a/extensions/lists_functions_test.cc b/extensions/lists_functions_test.cc new file mode 100644 index 000000000..8e9a3c3f5 --- /dev/null +++ b/extensions/lists_functions_test.cc @@ -0,0 +1,461 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/lists_functions.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/source.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/standard_macros.h" +#include "runtime/activation.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::ErrorValueIs; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::testing::Contains; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestInfo { + std::string expr; + std::string err = ""; +}; + +class ListsFunctionsTest : public testing::TestWithParam {}; + +TEST_P(ListsFunctionsTest, EndToEnd) { + const TestInfo& test_info = GetParam(); + RecordProperty("cel_expression", test_info.expr); + if (!test_info.err.empty()) { + RecordProperty("cel_expected_error", test_info.err); + } + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_info.expr, "")); + + MacroRegistry macro_registry; + ParserOptions parser_options{.add_macro_calls = true}; + ASSERT_THAT(RegisterStandardMacros(macro_registry, parser_options), IsOk()); + ASSERT_THAT(RegisterListsMacros(macro_registry, parser_options), IsOk()); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + google::api::expr::parser::Parse(*source, macro_registry, + parser_options)); + Expr expr = parsed_expr.expr(); + SourceInfo source_info = parsed_expr.source_info(); + + google::protobuf::Arena arena; + const auto options = RuntimeOptions{}; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + // Needed to resolve namespaced functions when evaluating a ParsedExpr. + ASSERT_THAT(cel::EnableReferenceResolver( + builder, cel::ReferenceResolverEnabled::kAlways), + IsOk()); + EXPECT_THAT(RegisterListsFunctions(builder.function_registry(), options), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + if (!test_info.err.empty()) { + EXPECT_THAT(result, + ErrorValueIs(StatusIs(testing::_, HasSubstr(test_info.err)))); + return; + } + ASSERT_TRUE(result.IsBool()) + << test_info.expr << " -> " << result.DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()) + << test_info.expr << " -> " << result.DebugString(); +} + +INSTANTIATE_TEST_SUITE_P( + ListsFunctionsTest, ListsFunctionsTest, + testing::ValuesIn({ + // lists.range() + {R"cel(lists.range(4) == [0,1,2,3])cel"}, + {R"cel(lists.range(0) == [])cel"}, + + // .reverse() + {R"cel([5,1,2,3].reverse() == [3,2,1,5])cel"}, + {R"cel([] == [])cel"}, + {R"cel([1] == [1])cel"}, + {R"cel( + ['are', 'you', 'as', 'bored', 'as', 'I', 'am'].reverse() + == ['am', 'I', 'as', 'bored', 'as', 'you', 'are'] + )cel"}, + {R"cel( + [false, true, true].reverse().reverse() == [false, true, true] + )cel"}, + + // .slice() + {R"cel([1,2,3,4].slice(0, 4) == [1,2,3,4])cel"}, + {R"cel([1,2,3,4].slice(0, 0) == [])cel"}, + {R"cel([1,2,3,4].slice(1, 1) == [])cel"}, + {R"cel([1,2,3,4].slice(4, 4) == [])cel"}, + {R"cel([1,2,3,4].slice(1, 3) == [2, 3])cel"}, + {R"cel([1,2,3,4].slice(3, 0))cel", + "cannot slice(3, 0), start index must be less than or equal to end " + "index"}, + {R"cel([1,2,3,4].slice(0, 10))cel", + "cannot slice(0, 10), list is length 4"}, + {R"cel([1,2,3,4].slice(-5, 10))cel", + "cannot slice(-5, 10), negative indexes not supported"}, + {R"cel([1,2,3,4].slice(-5, -3))cel", + "cannot slice(-5, -3), negative indexes not supported"}, + + // .flatten() + {R"cel(dyn([]).flatten() == [])cel"}, + {R"cel(dyn([1,2,3,4]).flatten() == [1,2,3,4])cel"}, + {R"cel([1,[2,[3,4]]].flatten() == [1,2,[3,4]])cel"}, + {R"cel([1,2,[],[],[3,4]].flatten() == [1,2,3,4])cel"}, + {R"cel([1,[2,[3,4]]].flatten(2) == [1,2,3,4])cel"}, + {R"cel([1,[2,[3,[4]]]].flatten(-1))cel", "level must be non-negative"}, + + // .sort() + {R"cel([].sort() == [])cel"}, + {R"cel([1].sort() == [1])cel"}, + {R"cel([4, 3, 2, 1].sort() == [1, 2, 3, 4])cel"}, + {R"cel(["d", "a", "b", "c"].sort() == ["a", "b", "c", "d"])cel"}, + {R"cel([b"d", b"a", b"aa"].sort() == [b"a", b"aa", b"d"])cel"}, + {R"cel( + [1.0, -1.5, 2.0, 1.0, -1.5, -1.5].sort() + == [-1.5, -1.5, -1.5, 1.0, 1.0, 2.0] + )cel"}, + {R"cel( + [42u, 3u, 1337u, 42u, 1337u, 3u, 42u].sort() + == [3u, 3u, 42u, 42u, 42u, 1337u, 1337u] + )cel"}, + {R"cel([false, true, false].sort() == [false, false, true])cel"}, + {R"cel( + [ + timestamp('2024-01-03T00:00:00Z'), + timestamp('2024-01-01T00:00:00Z'), + timestamp('2024-01-02T00:00:00Z'), + ].sort() == [ + timestamp('2024-01-01T00:00:00Z'), + timestamp('2024-01-02T00:00:00Z'), + timestamp('2024-01-03T00:00:00Z'), + ] + )cel"}, + {R"cel( + [duration('1m'), duration('2s'), duration('3h')].sort() + == [duration('2s'), duration('1m'), duration('3h')] + )cel"}, + {R"cel(["d", 3, 2, "c"].sort())cel", + "list elements must have the same type"}, + {R"cel([google.api.expr.runtime.TestMessage{}].sort())cel", + "unsupported type google.api.expr.runtime.TestMessage"}, + {R"cel([[1], [2]].sort())cel", "unsupported type list"}, + + // .sortBy() + {R"cel([].sortBy(e, e) == [])cel"}, + {R"cel(["a"].sortBy(e, e) == ["a"])cel"}, + {R"cel( + [-3, 1, -5, -2, 4].sortBy(e, -(e * e)) == [-5, 4, -3, -2, 1] + )cel"}, + {R"cel( + [-3, 1, -5, -2, 4].map(e, e * 2).sortBy(e, -(e * e)) + == [-10, 8, -6, -4, 2] + )cel"}, + {R"cel(lists.range(3).sortBy(e, -e) == [2, 1, 0])cel"}, + {R"cel( + ["a", "c", "b", "first"].sortBy(e, e == "first" ? "" : e) + == ["first", "a", "b", "c"] + )cel"}, + {R"cel( + [ + google.api.expr.runtime.TestMessage{string_value: 'foo'}, + google.api.expr.runtime.TestMessage{string_value: 'bar'}, + google.api.expr.runtime.TestMessage{string_value: 'baz'} + ].sortBy(e, e.string_value) == [ + google.api.expr.runtime.TestMessage{string_value: 'bar'}, + google.api.expr.runtime.TestMessage{string_value: 'baz'}, + google.api.expr.runtime.TestMessage{string_value: 'foo'} + ] + )cel"}, + {R"cel([[2], [1], [3]].sortBy(e, e[0]) == [[1], [2], [3]])cel"}, + {R"cel([[1], ["a"]].sortBy(e, e[0]))cel", + "list elements must have the same type"}, + {R"cel([[1], [2]].sortBy(e, e))cel", "unsupported type list"}, + {R"cel([google.api.expr.runtime.TestMessage{}].sortBy(e, e))cel", + "unsupported type google.api.expr.runtime.TestMessage"}, + + // .distinct() + {R"cel([].distinct() == [])cel"}, + {R"cel([1].distinct() == [1])cel"}, + {R"cel([-2, 5, -2, 1, 1, 5, -2, 1].distinct() == [-2, 5, 1])cel"}, + {R"cel( + [2u, 5u, 100u, 1u, 1u, 5u, 2u, 1u].distinct() == [2u, 5u, 100u, 1u] + )cel"}, + {R"cel([false, true, true, false].distinct() == [false, true])cel"}, + {R"cel( + ['c', 'a', 'a', 'b', 'a', 'b', 'c', 'c'].distinct() + == ['c', 'a', 'b'] + )cel"}, + {R"cel([1, 2.0, "c", 3, "c", 1].distinct() == [1, 2.0, "c", 3])cel"}, + {R"cel([1, 1.0, 2].distinct() == [1, 2])cel"}, + {R"cel([1, 1u].distinct() == [1])cel"}, + {R"cel([[1], [1], [2]].distinct() == [[1], [2]])cel"}, + {R"cel( + [ + google.api.expr.runtime.TestMessage{string_value: 'a'}, + google.api.expr.runtime.TestMessage{string_value: 'b'}, + google.api.expr.runtime.TestMessage{string_value: 'a'} + ].distinct() == [ + google.api.expr.runtime.TestMessage{string_value: 'a'}, + google.api.expr.runtime.TestMessage{string_value: 'b'} + ] + )cel"}, + {R"cel( + [ + google.api.expr.runtime.TestMessage{string_value: 'a'}, + 1, + 42.0, + [1, 2, 3], + false, + ].distinct() == [ + google.api.expr.runtime.TestMessage{string_value: 'a'}, + 1, + 42.0, + [1, 2, 3], + false, + ] + )cel"}, + })); + +TEST(ListsFunctionsTest, ListSortByMacroParseError) { + ASSERT_OK_AND_ASSIGN(auto source, + cel::NewSource("100.sortBy(e, e)", "")); + MacroRegistry macro_registry; + ParserOptions parser_options{.add_macro_calls = true}; + ASSERT_THAT(RegisterListsMacros(macro_registry, parser_options), IsOk()); + EXPECT_THAT( + google::api::expr::parser::Parse(*source, macro_registry, parser_options), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("sortBy can only be applied to"))); +} + +struct ListCheckerTestCase { + std::string expr; + std::string error_substr; +}; + +class ListsCheckerLibraryTest + : public ::testing::TestWithParam { + public: + void SetUp() override { + // Arrange: Configure the compiler. + // Add the lists checker library to the compiler builder. + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), + IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(ListsCompilerLibrary()), IsOk()); + compiler_builder->GetCheckerBuilder().set_container( + "cel.expr.conformance.proto3"); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); + } + + std::unique_ptr compiler_; +}; + +TEST_P(ListsCheckerLibraryTest, ListsFunctionsTypeCheckerSuccess) { + // Act & Assert: Compile the expression and validate the result. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler_->Compile(GetParam().expr)); + absl::string_view error_substr = GetParam().error_substr; + EXPECT_EQ(result.IsValid(), error_substr.empty()); + + if (!error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(error_substr)); + } +} + +// Returns a vector of test cases for the ListsCheckerLibraryTest. +// Returns both positive and negative test cases for the lists functions. +std::vector createListsCheckerParams() { + return { + // lists.distinct() + {R"([1,2,3,4,4].distinct() == [1,2,3,4])"}, + {R"('abc'.distinct() == [1,2,3,4])", + "no matching overload for 'distinct'"}, + {R"([1,2,3,4,4].distinct() == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", "undeclared reference"}, + // lists.flatten() + {R"([1,2,3,4].flatten() == [1,2,3,4])"}, + {R"([1,2,3,4].flatten(1) == [1,2,3,4])"}, + {R"('abc'.flatten() == [1,2,3,4])", "no matching overload for 'flatten'"}, + {R"([1,2,3,4].flatten() == 'abc')", "no matching overload for '_==_'"}, + {R"('abc'.flatten(1) == [1,2,3,4])", + "no matching overload for 'flatten'"}, + {R"([1,2,3,4].flatten('abc') == [1,2,3,4])", + "no matching overload for 'flatten'"}, + {R"([1,2,3,4].flatten(1) == 'abc')", "no matching overload"}, + // lists.range() + {R"(lists.range(4) == [0,1,2,3])"}, + {R"(lists.range('abc') == [])", "no matching overload for 'lists.range'"}, + {R"(lists.range(4) == 'abc')", "no matching overload for '_==_'"}, + {R"(lists.range(4, 4) == [0,1,2,3])", "undeclared reference"}, + // lists.reverse() + {R"([1,2,3,4].reverse() == [4,3,2,1])"}, + {R"('abc'.reverse() == [])", "no matching overload for 'reverse'"}, + {R"([1,2,3,4].reverse() == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4].reverse(1) == [4,3,2,1])", "undeclared reference"}, + // lists.slice() + {R"([1,2,3,4].slice(0, 4) == [1,2,3,4])"}, + {R"('abc'.slice(0, 4) == [1,2,3,4])", "no matching overload for 'slice'"}, + {R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])", + "no matching overload for 'slice'"}, + {R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])", + "no matching overload for 'slice'"}, + {R"([1,2,3,4].slice(0, 4) == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", "undeclared reference"}, + // lists.sort() + {R"([1,2,3,4].sort() == [1,2,3,4])"}, + {R"([TestAllTypes{}, TestAllTypes{}].sort() == [])", + "no matching overload for 'sort'"}, + {R"('abc'.sort() == [])", "no matching overload for 'sort'"}, + {R"([1,2,3,4].sort() == 'abc')", "no matching overload for '_==_'"}, + {R"([1,2,3,4].sort(2) == [1,2,3,4])", "undeclared reference"}, + // sortBy macro + {R"([1,2,3,4].sortBy(x, -x) == [4,3,2,1])"}, + {R"([TestAllTypes{}, TestAllTypes{}].sortBy(x, x) == [])", + "no matching overload for '@sortByAssociatedKeys'"}, + {R"( + [TestAllTypes{single_int64: 2}, TestAllTypes{single_int64: 1}] + .sortBy(x, x.single_int64) == + [TestAllTypes{single_int64: 1}, TestAllTypes{single_int64: 2}])"}, + }; +} + +INSTANTIATE_TEST_SUITE_P(ListsCheckerLibraryTest, ListsCheckerLibraryTest, + ValuesIn(createListsCheckerParams())); + +struct ListsExtensionVersionTestCase { + std::string expr; + std::vector expected_supported_versions; +}; + +class ListsExtensionVersionTest + : public ::testing::TestWithParam {}; + +TEST_P(ListsExtensionVersionTest, ListsExtensionVersions) { + const ListsExtensionVersionTestCase& test_case = GetParam(); + for (int version = 0; + version <= cel::extensions::kListsExtensionLatestVersion; ++version) { + CompilerLibrary compiler_library = ListsCompilerLibrary(version); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), + CompilerOptions())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(test_case.expr)); + if (absl::c_contains(test_case.expected_supported_versions, version)) { + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << "Expected no issues for expr: " << test_case.expr + << " at version: " << version << " but got: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference")))); + } + } +}; + +std::vector CreateListsExtensionVersionParams() { + return { + ListsExtensionVersionTestCase{ + .expr = "[0,1,2,3].slice(0, 2)", + .expected_supported_versions = {0, 1, 2}, + }, + ListsExtensionVersionTestCase{ + .expr = "[[0]].flatten()", + .expected_supported_versions = {1, 2}, + }, + ListsExtensionVersionTestCase{ + .expr = "[[0]].flatten(1)", + .expected_supported_versions = {1, 2}, + }, + ListsExtensionVersionTestCase{ + .expr = "[1,2,3,4].sort()", + .expected_supported_versions = {2}, + }, + ListsExtensionVersionTestCase{ + .expr = "[1,2,3,4].sortBy(x, x)", + .expected_supported_versions = {2}, + }, + ListsExtensionVersionTestCase{ + .expr = "[1,2,3,4].distinct()", + .expected_supported_versions = {2}, + }, + ListsExtensionVersionTestCase{ + .expr = "lists.range(4)", + .expected_supported_versions = {2}, + }, + ListsExtensionVersionTestCase{ + .expr = "[1,2,3,4].reverse()", + .expected_supported_versions = {2}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(ListsExtensionVersionTest, ListsExtensionVersionTest, + ValuesIn(CreateListsExtensionVersionParams())); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc new file mode 100644 index 000000000..a7773da19 --- /dev/null +++ b/extensions/math_ext.cc @@ -0,0 +1,479 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/math_ext.h" + +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/casting.h" +#include "common/value.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_number.h" +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +namespace { + +using ::google::api::expr::runtime::CelFunctionRegistry; +using ::google::api::expr::runtime::CelNumber; +using ::google::api::expr::runtime::InterpreterOptions; + +static constexpr char kMathMin[] = "math.@min"; +static constexpr char kMathMax[] = "math.@max"; + +struct ToValueVisitor { + Value operator()(uint64_t v) const { return UintValue{v}; } + Value operator()(int64_t v) const { return IntValue{v}; } + Value operator()(double v) const { return DoubleValue{v}; } +}; + +Value NumberToValue(CelNumber number) { + return number.visit(ToValueVisitor{}); +} + +absl::StatusOr ValueToNumber(const Value& value, + absl::string_view function) { + if (auto int_value = As(value); int_value) { + return CelNumber::FromInt64(int_value->NativeValue()); + } + if (auto uint_value = As(value); uint_value) { + return CelNumber::FromUint64(uint_value->NativeValue()); + } + if (auto double_value = As(value); double_value) { + return CelNumber::FromDouble(double_value->NativeValue()); + } + return absl::InvalidArgumentError( + absl::StrCat(function, " arguments must be numeric")); +} + +CelNumber MinNumber(CelNumber v1, CelNumber v2) { + if (v2 < v1) { + return v2; + } + return v1; +} + +Value MinValue(CelNumber v1, CelNumber v2) { + return NumberToValue(MinNumber(v1, v2)); +} + +template +Value Identity(T v1) { + return NumberToValue(CelNumber(v1)); +} + +template +Value Min(T v1, U v2) { + return MinValue(CelNumber(v1), CelNumber(v2)); +} + +absl::StatusOr MinList( + const ListValue& values, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(auto iterator, values.NewIterator()); + if (!iterator->HasNext()) { + return ErrorValue( + absl::InvalidArgumentError("math.@min argument must not be empty")); + } + Value value; + CEL_RETURN_IF_ERROR( + iterator->Next(descriptor_pool, message_factory, arena, &value)); + absl::StatusOr current = ValueToNumber(value, kMathMin); + if (!current.ok()) { + return ErrorValue{current.status()}; + } + CelNumber min = *current; + while (iterator->HasNext()) { + CEL_RETURN_IF_ERROR( + iterator->Next(descriptor_pool, message_factory, arena, &value)); + absl::StatusOr other = ValueToNumber(value, kMathMin); + if (!other.ok()) { + return ErrorValue{other.status()}; + } + min = MinNumber(min, *other); + } + return NumberToValue(min); +} + +CelNumber MaxNumber(CelNumber v1, CelNumber v2) { + if (v2 > v1) { + return v2; + } + return v1; +} + +Value MaxValue(CelNumber v1, CelNumber v2) { + return NumberToValue(MaxNumber(v1, v2)); +} + +template +Value Max(T v1, U v2) { + return MaxValue(CelNumber(v1), CelNumber(v2)); +} + +absl::StatusOr MaxList( + const ListValue& values, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(auto iterator, values.NewIterator()); + if (!iterator->HasNext()) { + return ErrorValue( + absl::InvalidArgumentError("math.@max argument must not be empty")); + } + Value value; + CEL_RETURN_IF_ERROR( + iterator->Next(descriptor_pool, message_factory, arena, &value)); + absl::StatusOr current = ValueToNumber(value, kMathMax); + if (!current.ok()) { + return ErrorValue{current.status()}; + } + CelNumber min = *current; + while (iterator->HasNext()) { + CEL_RETURN_IF_ERROR( + iterator->Next(descriptor_pool, message_factory, arena, &value)); + absl::StatusOr other = ValueToNumber(value, kMathMax); + if (!other.ok()) { + return ErrorValue{other.status()}; + } + min = MaxNumber(min, *other); + } + return NumberToValue(min); +} + +template +absl::Status RegisterCrossNumericMin(FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); + + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); + + return absl::OkStatus(); +} + +template +absl::Status RegisterCrossNumericMax(FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); + + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); + + return absl::OkStatus(); +} + +double CeilDouble(double value) { return std::ceil(value); } + +double FloorDouble(double value) { return std::floor(value); } + +double RoundDouble(double value) { return std::round(value); } + +double TruncDouble(double value) { return std::trunc(value); } + +double SqrtDouble(double value) { return std::sqrt(value); } + +double SqrtInt(int64_t value) { return std::sqrt(value); } + +double SqrtUint(uint64_t value) { return std::sqrt(value); } + +bool IsInfDouble(double value) { return std::isinf(value); } + +bool IsNaNDouble(double value) { return std::isnan(value); } + +bool IsFiniteDouble(double value) { return std::isfinite(value); } + +double AbsDouble(double value) { return std::fabs(value); } + +Value AbsInt(int64_t value) { + if (ABSL_PREDICT_FALSE(value == std::numeric_limits::min())) { + return ErrorValue(absl::InvalidArgumentError("integer overflow")); + } + return IntValue(value < 0 ? -value : value); +} + +uint64_t AbsUint(uint64_t value) { return value; } + +double SignDouble(double value) { + if (std::isnan(value)) { + return value; + } + if (value == 0.0) { + return 0.0; + } + return std::signbit(value) ? -1.0 : 1.0; +} + +int64_t SignInt(int64_t value) { return value < 0 ? -1 : value > 0 ? 1 : 0; } + +uint64_t SignUint(uint64_t value) { return value == 0 ? 0 : 1; } + +int64_t BitAndInt(int64_t lhs, int64_t rhs) { return lhs & rhs; } + +uint64_t BitAndUint(uint64_t lhs, uint64_t rhs) { return lhs & rhs; } + +int64_t BitOrInt(int64_t lhs, int64_t rhs) { return lhs | rhs; } + +uint64_t BitOrUint(uint64_t lhs, uint64_t rhs) { return lhs | rhs; } + +int64_t BitXorInt(int64_t lhs, int64_t rhs) { return lhs ^ rhs; } + +uint64_t BitXorUint(uint64_t lhs, uint64_t rhs) { return lhs ^ rhs; } + +int64_t BitNotInt(int64_t value) { return ~value; } + +uint64_t BitNotUint(uint64_t value) { return ~value; } + +Value BitShiftLeftInt(int64_t lhs, int64_t rhs) { + if (ABSL_PREDICT_FALSE(rhs < 0)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("math.bitShiftLeft() invalid negative shift: ", rhs))); + } + if (rhs > 63) { + return IntValue(0); + } + // Shift in the unsigned domain to avoid undefined behaviour when lhs is + // negative or the shift moves bits into the sign bit, matching the bit + // pattern semantics already used by bitShiftRight. + return IntValue(absl::bit_cast(absl::bit_cast(lhs) + << static_cast(rhs))); +} + +Value BitShiftLeftUint(uint64_t lhs, int64_t rhs) { + if (ABSL_PREDICT_FALSE(rhs < 0)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("math.bitShiftLeft() invalid negative shift: ", rhs))); + } + if (rhs > 63) { + return UintValue(0); + } + return UintValue(lhs << static_cast(rhs)); +} + +Value BitShiftRightInt(int64_t lhs, int64_t rhs) { + if (ABSL_PREDICT_FALSE(rhs < 0)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("math.bitShiftRight() invalid negative shift: ", rhs))); + } + if (rhs > 63) { + return IntValue(0); + } + // We do not perform a sign extension shift, per the spec we just do the same + // thing as uint. + return IntValue(absl::bit_cast(absl::bit_cast(lhs) >> + static_cast(rhs))); +} + +Value BitShiftRightUint(uint64_t lhs, int64_t rhs) { + if (ABSL_PREDICT_FALSE(rhs < 0)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("math.bitShiftRight() invalid negative shift: ", rhs))); + } + if (rhs > 63) { + return UintValue(0); + } + return UintValue(lhs >> static_cast(rhs)); +} + +} // namespace + +absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, + const RuntimeOptions& options, + int version) { + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Identity, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); + CEL_RETURN_IF_ERROR(( + UnaryFunctionAdapter, + ListValue>::RegisterGlobalOverload(kMathMin, MinList, + registry))); + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Identity, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); + CEL_RETURN_IF_ERROR(( + UnaryFunctionAdapter, + ListValue>::RegisterGlobalOverload(kMathMax, MaxList, + registry))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.ceil", CeilDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.floor", FloorDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.round", RoundDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.trunc", TruncDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.isInf", IsInfDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.isNaN", IsNaNDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.isFinite", IsFiniteDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.abs", AbsDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.abs", AbsInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.abs", AbsUint, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sign", SignDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sign", SignInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sign", SignUint, registry))); + + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitAnd", BitAndInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitAnd", + BitAndUint, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitOr", BitOrInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitOr", + BitOrUint, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitXor", BitXorInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitXor", + BitXorUint, + registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.bitNot", BitNotInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.bitNot", BitNotUint, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftLeft", BitShiftLeftInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftLeft", BitShiftLeftUint, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftRight", BitShiftRightInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftRight", BitShiftRightUint, registry))); + + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtUint, registry))); + + return absl::OkStatus(); +} + +absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + return RegisterMathExtensionFunctions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options)); +} + +} // namespace cel::extensions diff --git a/extensions/math_ext.h b/extensions/math_ext.h new file mode 100644 index 000000000..fe000e476 --- /dev/null +++ b/extensions/math_ext.h @@ -0,0 +1,39 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "extensions/math_ext_decls.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Register extension functions for supporting mathematical operations above +// and beyond the set defined in the CEL standard environment. +absl::Status RegisterMathExtensionFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kMathExtensionLatestVersion); + +absl::Status RegisterMathExtensionFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ diff --git a/extensions/math_ext_decls.cc b/extensions/math_ext_decls.cc new file mode 100644 index 000000000..a7091cef6 --- /dev/null +++ b/extensions/math_ext_decls.cc @@ -0,0 +1,335 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/math_ext_decls.h" + +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "extensions/math_ext_macros.h" +#include "internal/status_macros.h" +#include "parser/parser_interface.h" + +namespace cel::extensions { +namespace { + +constexpr char kMathExtensionName[] = "cel.lib.ext.math"; + +const Type& ListIntType() { + static absl::NoDestructor kInstance( + ListType(checker_internal::BuiltinsArena(), IntType())); + return *kInstance; +} + +const Type& ListDoubleType() { + static absl::NoDestructor kInstance( + ListType(checker_internal::BuiltinsArena(), DoubleType())); + return *kInstance; +} + +const Type& ListUintType() { + static absl::NoDestructor kInstance( + ListType(checker_internal::BuiltinsArena(), UintType())); + return *kInstance; +} + +std::string OverloadTypeName(const Type& type) { + switch (type.kind()) { + case cel::TypeKind::kInt: + return "int"; + case TypeKind::kDouble: + return "double"; + case TypeKind::kUint: + return "uint"; + case TypeKind::kList: + return absl::StrCat("list_", + OverloadTypeName(type.AsList()->GetElement())); + default: + return "unsupported"; + } +} + +absl::Status AddMinMaxDecls(TypeCheckerBuilder& builder) { + const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; + const Type kListNumerics[] = {ListIntType(), ListDoubleType(), + ListUintType()}; + + constexpr char kMinOverloadPrefix[] = "math_@min_"; + constexpr char kMaxOverloadPrefix[] = "math_@max_"; + + FunctionDecl min_decl; + min_decl.set_name("math.@min"); + + FunctionDecl max_decl; + max_decl.set_name("math.@max"); + + for (const Type& type : kNumerics) { + // Unary overloads + CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type)), type, type))); + + CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type)), type, type))); + + // Pairwise overloads + for (const Type& other_type : kNumerics) { + Type out_type = DynType(); + if (type.kind() == other_type.kind()) { + out_type = type; + } + CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type), "_", + OverloadTypeName(other_type)), + out_type, type, other_type))); + + CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type), "_", + OverloadTypeName(other_type)), + out_type, type, other_type))); + } + } + + // List overloads + for (const Type& type : kListNumerics) { + CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type)), + type.AsList()->GetElement(), type))); + + CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type)), + type.AsList()->GetElement(), type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(min_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(max_decl)); + + return absl::OkStatus(); +} + +absl::Status AddSignednessDecls(TypeCheckerBuilder& builder) { + const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; + + FunctionDecl sign_decl; + sign_decl.set_name("math.sign"); + + FunctionDecl abs_decl; + abs_decl.set_name("math.abs"); + + for (const Type& type : kNumerics) { + CEL_RETURN_IF_ERROR(sign_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_sign_", OverloadTypeName(type)), type, type))); + CEL_RETURN_IF_ERROR(abs_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_abs_", OverloadTypeName(type)), type, type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(sign_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(abs_decl)); + + return absl::OkStatus(); +} + +absl::Status AddSqrtDecls(TypeCheckerBuilder& builder) { + const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; + + FunctionDecl sqrt_decl; + sqrt_decl.set_name("math.sqrt"); + + for (const Type& type : kNumerics) { + CEL_RETURN_IF_ERROR(sqrt_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_sqrt_", OverloadTypeName(type)), + DoubleType(), type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(sqrt_decl)); + + return absl::OkStatus(); +} + +absl::Status AddFloatingPointDecls(TypeCheckerBuilder& builder) { + // Rounding + CEL_ASSIGN_OR_RETURN( + auto ceil_decl, + MakeFunctionDecl( + "math.ceil", + MakeOverloadDecl("math_ceil_double", DoubleType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto floor_decl, + MakeFunctionDecl( + "math.floor", + MakeOverloadDecl("math_floor_double", DoubleType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto round_decl, + MakeFunctionDecl( + "math.round", + MakeOverloadDecl("math_round_double", DoubleType(), DoubleType()))); + CEL_ASSIGN_OR_RETURN( + auto trunc_decl, + MakeFunctionDecl( + "math.trunc", + MakeOverloadDecl("math_trunc_double", DoubleType(), DoubleType()))); + + // FP helpers + CEL_ASSIGN_OR_RETURN( + auto is_inf_decl, + MakeFunctionDecl( + "math.isInf", + MakeOverloadDecl("math_isInf_double", BoolType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto is_nan_decl, + MakeFunctionDecl( + "math.isNaN", + MakeOverloadDecl("math_isNaN_double", BoolType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto is_finite_decl, + MakeFunctionDecl( + "math.isFinite", + MakeOverloadDecl("math_isFinite_double", BoolType(), DoubleType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(ceil_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(floor_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(round_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(trunc_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(is_inf_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(is_nan_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(is_finite_decl)); + + return absl::OkStatus(); +} + +absl::Status AddBitwiseDecls(TypeCheckerBuilder& builder) { + const Type kBitwiseTypes[] = {IntType(), UintType()}; + + FunctionDecl bit_and_decl; + bit_and_decl.set_name("math.bitAnd"); + + FunctionDecl bit_or_decl; + bit_or_decl.set_name("math.bitOr"); + + FunctionDecl bit_xor_decl; + bit_xor_decl.set_name("math.bitXor"); + + FunctionDecl bit_not_decl; + bit_not_decl.set_name("math.bitNot"); + + FunctionDecl bit_lshift_decl; + bit_lshift_decl.set_name("math.bitShiftLeft"); + + FunctionDecl bit_rshift_decl; + bit_rshift_decl.set_name("math.bitShiftRight"); + + for (const Type& type : kBitwiseTypes) { + CEL_RETURN_IF_ERROR(bit_and_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitAnd_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type, type))); + + CEL_RETURN_IF_ERROR(bit_or_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitOr_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type, type))); + + CEL_RETURN_IF_ERROR(bit_xor_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitXor_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type, type))); + + CEL_RETURN_IF_ERROR(bit_not_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitNot_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type))); + + CEL_RETURN_IF_ERROR(bit_lshift_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_bitShiftLeft_", OverloadTypeName(type), "_int"), + type, type, IntType()))); + + CEL_RETURN_IF_ERROR(bit_rshift_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_bitShiftRight_", OverloadTypeName(type), "_int"), + type, type, IntType()))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_and_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_or_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_xor_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_not_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_lshift_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_rshift_decl)); + + return absl::OkStatus(); +} + +absl::Status AddMathExtensionDeclarations(TypeCheckerBuilder& builder, + int version) { + CEL_RETURN_IF_ERROR(AddMinMaxDecls(builder)); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(AddSignednessDecls(builder)); + CEL_RETURN_IF_ERROR(AddFloatingPointDecls(builder)); + CEL_RETURN_IF_ERROR(AddBitwiseDecls(builder)); + if (version == 1) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(AddSqrtDecls(builder)); + + return absl::OkStatus(); +} + +absl::Status AddMathExtensionMacros(ParserBuilder& builder, int version) { + for (const auto& m : math_macros()) { + // At the moment, all macros are supported in all versions. When we add a + // new macro, we must add a version check here. + CEL_RETURN_IF_ERROR(builder.AddMacro(m)); + } + return absl::OkStatus(); +} + +} // namespace + +// Configuration for cel::Compiler to enable the math extension declarations. +CompilerLibrary MathCompilerLibrary(int version) { + return CompilerLibrary( + kMathExtensionName, + [version](ParserBuilder& builder) { + return AddMathExtensionMacros(builder, version); + }, + [version](TypeCheckerBuilder& builder) { + return AddMathExtensionDeclarations(builder, version); + }); +} + +// Configuration for cel::TypeChecker to enable the math extension declarations. +CheckerLibrary MathCheckerLibrary(int version) { + return { + .id = kMathExtensionName, + .configure = + [version](TypeCheckerBuilder& builder) { + return AddMathExtensionDeclarations(builder, version); + }, + }; +} + +} // namespace cel::extensions diff --git a/extensions/math_ext_decls.h b/extensions/math_ext_decls.h new file mode 100644 index 000000000..624649a39 --- /dev/null +++ b/extensions/math_ext_decls.h @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ + +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" + +namespace cel::extensions { + +constexpr int kMathExtensionLatestVersion = 2; + +// Configuration for cel::Compiler to enable the math extension declarations. +CompilerLibrary MathCompilerLibrary(int version = kMathExtensionLatestVersion); + +// Configuration for cel::TypeChecker to enable the math extension declarations. +CheckerLibrary MathCheckerLibrary(int version = kMathExtensionLatestVersion); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ diff --git a/extensions/math_ext_macros.cc b/extensions/math_ext_macros.cc new file mode 100644 index 000000000..08b163132 --- /dev/null +++ b/extensions/math_ext_macros.cc @@ -0,0 +1,192 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/math_ext_macros.h" + +#include +#include + +#include "absl/functional/overload.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "common/ast.h" +#include "common/constant.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" + +namespace cel::extensions { + +namespace { + +static constexpr absl::string_view kMathNamespace = "math"; +static constexpr absl::string_view kLeast = "least"; +static constexpr absl::string_view kGreatest = "greatest"; + +static constexpr char kMathMin[] = "math.@min"; +static constexpr char kMathMax[] = "math.@max"; + +bool IsTargetNamespace(const Expr &target) { + return target.has_ident_expr() && + target.ident_expr().name() == kMathNamespace; +} + +bool IsValidArgType(const Expr &arg) { + return absl::visit( + absl::Overload([](const UnspecifiedExpr &) -> bool { return false; }, + [](const Constant &const_expr) -> bool { + return const_expr.has_double_value() || + const_expr.has_int_value() || + const_expr.has_uint_value(); + }, + [](const ListExpr &) -> bool { return false; }, + [](const StructExpr &) -> bool { return false; }, + [](const MapExpr &) -> bool { return false; }, + // This is intended for call and select expressions. + [](const auto &) -> bool { return true; }), + arg.kind()); +} + +absl::optional CheckInvalidArgs(MacroExprFactory &factory, + absl::string_view macro, + absl::Span arguments) { + for (const auto &argument : arguments) { + if (!IsValidArgType(argument)) { + return factory.ReportErrorAt( + argument, + absl::StrCat(macro, " simple literal arguments must be numeric")); + } + } + + return std::nullopt; +} + +bool IsListLiteralWithValidArgs(const Expr &arg) { + if (const auto *list_expr = arg.has_list_expr() ? &arg.list_expr() : nullptr; + list_expr) { + if (list_expr->elements().empty()) { + return false; + } + for (const auto &element : list_expr->elements()) { + if (!IsValidArgType(element.expr())) { + return false; + } + } + return true; + } + return false; +} + +} // namespace + +std::vector math_macros() { + absl::StatusOr least = Macro::ReceiverVarArg( + kLeast, + [](MacroExprFactory &factory, Expr &target, + absl::Span arguments) -> absl::optional { + if (!IsTargetNamespace(target)) { + return std::nullopt; + } + + switch (arguments.size()) { + case 0: + return factory.ReportErrorAt( + target, "math.least() requires at least one argument."); + case 1: { + if (!IsListLiteralWithValidArgs(arguments[0]) && + !IsValidArgType(arguments[0])) { + return factory.ReportErrorAt( + arguments[0], "math.least() invalid single argument value."); + } + + return factory.NewCall(kMathMin, arguments); + } + case 2: { + if (auto error = + CheckInvalidArgs(factory, "math.least()", arguments); + error) { + return std::move(*error); + } + return factory.NewCall(kMathMin, arguments); + } + default: + if (auto error = + CheckInvalidArgs(factory, "math.least()", arguments); + error) { + return std::move(*error); + } + std::vector elements; + elements.reserve(arguments.size()); + for (auto &argument : arguments) { + elements.push_back(factory.NewListElement(std::move(argument))); + } + return factory.NewCall(kMathMin, + factory.NewList(std::move(elements))); + } + }); + absl::StatusOr greatest = Macro::ReceiverVarArg( + kGreatest, + [](MacroExprFactory &factory, Expr &target, + absl::Span arguments) -> absl::optional { + if (!IsTargetNamespace(target)) { + return std::nullopt; + } + + switch (arguments.size()) { + case 0: { + return factory.ReportErrorAt( + target, "math.greatest() requires at least one argument."); + } + case 1: { + if (!IsListLiteralWithValidArgs(arguments[0]) && + !IsValidArgType(arguments[0])) { + return factory.ReportErrorAt( + arguments[0], + "math.greatest() invalid single argument value."); + } + + return factory.NewCall(kMathMax, arguments); + } + case 2: { + if (auto error = + CheckInvalidArgs(factory, "math.greatest()", arguments); + error) { + return std::move(*error); + } + return factory.NewCall(kMathMax, arguments); + } + default: { + if (auto error = + CheckInvalidArgs(factory, "math.greatest()", arguments); + error) { + return std::move(*error); + } + std::vector elements; + elements.reserve(arguments.size()); + for (auto &argument : arguments) { + elements.push_back(factory.NewListElement(std::move(argument))); + } + return factory.NewCall(kMathMax, + factory.NewList(std::move(elements))); + } + } + }); + + return {*least, *greatest}; +} + +} // namespace cel::extensions diff --git a/extensions/math_ext_macros.h b/extensions/math_ext_macros.h new file mode 100644 index 000000000..0c482e49f --- /dev/null +++ b/extensions/math_ext_macros.h @@ -0,0 +1,38 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_MACROS_H_ + +#include + +#include "absl/status/status.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +// math_macros() returns the namespaced helper macros for math.least() and +// math.greatest(). +std::vector math_macros(); + +inline absl::Status RegisterMathMacros(MacroRegistry& registry, + const ParserOptions&) { + return registry.RegisterMacros(math_macros()); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_MACROS_H_ diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc new file mode 100644 index 000000000..ce05ae6ed --- /dev/null +++ b/extensions/math_ext_test.cc @@ -0,0 +1,691 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/math_ext.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/standard_library.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/function_descriptor.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/testing/matchers.h" +#include "extensions/math_ext_decls.h" +#include "extensions/math_ext_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::google::api::expr::parser::ParseWithMacros; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelFunction; +using ::google::api::expr::runtime::CelFunctionDescriptor; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::ContainerBackedListImpl; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; +using ::google::api::expr::runtime::test::EqualsCelValue; +using ::google::protobuf::Arena; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +constexpr absl::string_view kMathMin = "math.@min"; +constexpr absl::string_view kMathMax = "math.@max"; + +struct TestCase { + absl::string_view operation; + CelValue arg1; + absl::optional arg2; + CelValue result; +}; + +TestCase MinCase(CelValue v1, CelValue v2, CelValue result) { + return TestCase{kMathMin, v1, v2, result}; +} + +TestCase MinCase(CelValue list, CelValue result) { + return TestCase{kMathMin, list, std::nullopt, result}; +} + +TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { + return TestCase{kMathMax, v1, v2, result}; +} + +TestCase MaxCase(CelValue list, CelValue result) { + return TestCase{kMathMax, list, std::nullopt, result}; +} + +struct MacroTestCase { + absl::string_view expr; + absl::string_view err = ""; +}; + +class TestFunction : public CelFunction { + public: + explicit TestFunction(absl::string_view name) + : CelFunction(MakeDescriptor(name)) {} + + static FunctionDescriptor MakeDescriptor(absl::string_view name) { + return FunctionDescriptor(name, true, + {CelValue::Type::kBool, CelValue::Type::kInt64, + CelValue::Type::kInt64}); + } + + absl::Status Evaluate(absl::Span args, CelValue* result, + Arena* arena) const override { + *result = CelValue::CreateBool(true); + return absl::OkStatus(); + } +}; + +// Test function used to test macro collision and non-expansion. +constexpr absl::string_view kGreatest = "greatest"; +std::unique_ptr CreateGreatestFunction() { + return std::make_unique(kGreatest); +} + +constexpr absl::string_view kLeast = "least"; +std::unique_ptr CreateLeastFunction() { + return std::make_unique(kLeast); +} + +Expr CallExprOneArg(absl::string_view operation) { + Expr expr; + auto call = expr.mutable_call_expr(); + call->set_function(operation); + + auto arg = call->add_args(); + auto ident = arg->mutable_ident_expr(); + ident->set_name("a"); + return expr; +} + +Expr CallExprTwoArgs(absl::string_view operation) { + Expr expr; + auto call = expr.mutable_call_expr(); + call->set_function(operation); + + auto arg = call->add_args(); + auto ident = arg->mutable_ident_expr(); + ident->set_name("a"); + + arg = call->add_args(); + ident = arg->mutable_ident_expr(); + ident->set_name("b"); + return expr; +} + +void ExpectResult(const TestCase& test_case) { + Expr expr; + Activation activation; + activation.InsertValue("a", test_case.arg1); + if (test_case.arg2.has_value()) { + activation.InsertValue("b", *test_case.arg2); + expr = CallExprTwoArgs(test_case.operation); + } else { + expr = CallExprOneArg(test_case.operation); + } + + SourceInfo source_info; + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterMathExtensionFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr, &source_info)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena)); + if (!test_case.result.IsError()) { + EXPECT_THAT(value, EqualsCelValue(test_case.result)); + } else { + auto expected = test_case.result.ErrorOrDie(); + EXPECT_THAT(*value.ErrorOrDie(), + StatusIs(expected->code(), HasSubstr(expected->message()))); + } +} + +using MathExtParamsTest = testing::TestWithParam; +TEST_P(MathExtParamsTest, MinMaxTests) { ExpectResult(GetParam()); } + +INSTANTIATE_TEST_SUITE_P( + MathExtParamsTest, MathExtParamsTest, + testing::ValuesIn({ + MinCase(CelValue::CreateInt64(3L), CelValue::CreateInt64(2L), + CelValue::CreateInt64(2L)), + MinCase(CelValue::CreateInt64(-1L), CelValue::CreateUint64(2u), + CelValue::CreateInt64(-1L)), + MinCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.1), + CelValue::CreateDouble(-1.1)), + MinCase(CelValue::CreateDouble(-2.0), CelValue::CreateDouble(-1.1), + CelValue::CreateDouble(-2.0)), + MinCase(CelValue::CreateDouble(3.1), CelValue::CreateInt64(2), + CelValue::CreateInt64(2)), + MinCase(CelValue::CreateDouble(2.5), CelValue::CreateUint64(2u), + CelValue::CreateUint64(2u)), + MinCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(-1.1), + CelValue::CreateDouble(-1.1)), + MinCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(20), + CelValue::CreateUint64(3u)), + MinCase(CelValue::CreateUint64(4u), CelValue::CreateUint64(2u), + CelValue::CreateUint64(2u)), + MinCase(CelValue::CreateInt64(2L), CelValue::CreateUint64(2u), + CelValue::CreateInt64(2L)), + MinCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.0), + CelValue::CreateInt64(-1L)), + MinCase(CelValue::CreateDouble(2.0), CelValue::CreateInt64(2), + CelValue::CreateDouble(2.0)), + MinCase(CelValue::CreateDouble(2.0), CelValue::CreateUint64(2u), + CelValue::CreateDouble(2.0)), + MinCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(2.0), + CelValue::CreateUint64(2u)), + MinCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(3), + CelValue::CreateUint64(3u)), + + MaxCase(CelValue::CreateInt64(3L), CelValue::CreateInt64(2L), + CelValue::CreateInt64(3L)), + MaxCase(CelValue::CreateInt64(-1L), CelValue::CreateUint64(2u), + CelValue::CreateUint64(2u)), + MaxCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.1), + CelValue::CreateInt64(-1L)), + MaxCase(CelValue::CreateDouble(-2.0), CelValue::CreateDouble(-1.1), + CelValue::CreateDouble(-1.1)), + MaxCase(CelValue::CreateDouble(3.1), CelValue::CreateInt64(2), + CelValue::CreateDouble(3.1)), + MaxCase(CelValue::CreateDouble(2.5), CelValue::CreateUint64(2u), + CelValue::CreateDouble(2.5)), + MaxCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(-1.1), + CelValue::CreateUint64(2u)), + MaxCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(20), + CelValue::CreateInt64(20)), + MaxCase(CelValue::CreateUint64(4u), CelValue::CreateUint64(2u), + CelValue::CreateUint64(4u)), + MaxCase(CelValue::CreateInt64(2L), CelValue::CreateUint64(2u), + CelValue::CreateInt64(2L)), + MaxCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.0), + CelValue::CreateInt64(-1L)), + MaxCase(CelValue::CreateDouble(2.0), CelValue::CreateInt64(2), + CelValue::CreateDouble(2.0)), + MaxCase(CelValue::CreateDouble(2.0), CelValue::CreateUint64(2u), + CelValue::CreateDouble(2.0)), + MaxCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(2.0), + CelValue::CreateUint64(2u)), + MaxCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(3), + CelValue::CreateUint64(3u)), + })); + +TEST(MathExtTest, MinMaxList) { + ContainerBackedListImpl single_item_list({CelValue::CreateInt64(1)}); + ExpectResult(MinCase(CelValue::CreateList(&single_item_list), + CelValue::CreateInt64(1))); + ExpectResult(MaxCase(CelValue::CreateList(&single_item_list), + CelValue::CreateInt64(1))); + + ContainerBackedListImpl list({CelValue::CreateInt64(1), + CelValue::CreateUint64(2u), + CelValue::CreateDouble(-1.1)}); + ExpectResult( + MinCase(CelValue::CreateList(&list), CelValue::CreateDouble(-1.1))); + ExpectResult( + MaxCase(CelValue::CreateList(&list), CelValue::CreateUint64(2u))); + + absl::Status empty_list_err = + absl::InvalidArgumentError("argument must not be empty"); + CelValue err_value = CelValue::CreateError(&empty_list_err); + ContainerBackedListImpl empty_list({}); + ExpectResult(MinCase(CelValue::CreateList(&empty_list), err_value)); + ExpectResult(MaxCase(CelValue::CreateList(&empty_list), err_value)); + + absl::Status bad_arg_err = + absl::InvalidArgumentError("arguments must be numeric"); + err_value = CelValue::CreateError(&bad_arg_err); + + ContainerBackedListImpl bad_single_item({CelValue::CreateBool(true)}); + ExpectResult(MinCase(CelValue::CreateList(&bad_single_item), err_value)); + ExpectResult(MaxCase(CelValue::CreateList(&bad_single_item), err_value)); + + ContainerBackedListImpl bad_middle_item({CelValue::CreateInt64(1), + CelValue::CreateBool(false), + CelValue::CreateDouble(-1.1)}); + ExpectResult(MinCase(CelValue::CreateList(&bad_middle_item), err_value)); + ExpectResult(MaxCase(CelValue::CreateList(&bad_middle_item), err_value)); +} + +using MathExtMacroParamsTest = testing::TestWithParam; +TEST_P(MathExtMacroParamsTest, ParserTests) { + const MacroTestCase& test_case = GetParam(); + auto result = ParseWithMacros(test_case.expr, cel::extensions::math_macros(), + ""); + if (!test_case.err.empty()) { + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.err))); + return; + } + ASSERT_OK(result); + + ParsedExpr parsed_expr = *result; + Expr expr = parsed_expr.expr(); + SourceInfo source_info = parsed_expr.source_info(); + InterpreterOptions options; + options.enable_qualified_identifier_rewrites = true; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(builder->GetRegistry()->Register(CreateGreatestFunction())); + ASSERT_OK(builder->GetRegistry()->Register(CreateLeastFunction())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK(RegisterMathExtensionFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr, &source_info)); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena)); + + ASSERT_TRUE(value.IsBool()); + EXPECT_EQ(value.BoolOrDie(), true); +} + +TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { + const MacroTestCase& test_case = GetParam(); + CompilerOptions compile_opts; + compile_opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN(auto compiler_builder, + cel::NewCompilerBuilder( + internal::GetTestingDescriptorPool(), compile_opts)); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(MathCompilerLibrary()), IsOk()); + + // Add test functions that check macro (non-)expansion. + ASSERT_OK_AND_ASSIGN( + auto least_decl, + MakeFunctionDecl("least", MakeMemberOverloadDecl("bool_least_int_int", + /*result*/ BoolType(), + /*receiver*/ BoolType(), + IntType(), IntType()))); + ASSERT_OK_AND_ASSIGN(auto greatest_decl, + MakeFunctionDecl("greatest", MakeMemberOverloadDecl( + "bool_greatest_int_int", + /*result*/ BoolType(), + /*receiver*/ BoolType(), + IntType(), IntType()))); + + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(least_decl), + IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(greatest_decl), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile(test_case.expr, "")); + + if (!test_case.err.empty()) { + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.err)); + return; + } + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT( + RegisterMathExtensionFunctions(runtime_builder.function_registry(), opts), + IsOk()); + + ASSERT_THAT( + runtime_builder.function_registry().Register( + TestFunction::MakeDescriptor(kGreatest), CreateGreatestFunction()), + IsOk()); + ASSERT_THAT( + runtime_builder.function_registry().Register( + TestFunction::MakeDescriptor(kLeast), CreateGreatestFunction()), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.IsBool()); + EXPECT_EQ(value.GetBool(), true); +} + +INSTANTIATE_TEST_SUITE_P( + MathExtMacrosParamsTest, MathExtMacroParamsTest, + testing::ValuesIn( + {// Tests for math.least + {"math.least(-0.5) == -0.5"}, + {"math.least(-1) == -1"}, + {"math.least(1u) == 1u"}, + {"math.least(42.0, -0.5) == -0.5"}, + {"math.least(-1, 0) == -1"}, + {"math.least(-1, -1) == -1"}, + {"math.least(1u, 42u) == 1u"}, + {"math.least(42.0, -0.5, -0.25) == -0.5"}, + {"math.least(-1, 0, 1) == -1"}, + {"math.least(-1, -1, -1) == -1"}, + {"math.least(1u, 42u, 0u) == 0u"}, + // math.least two arg overloads across type. + {"math.least(1, 1.0) == 1"}, + {"math.least(1, -2.0) == -2.0"}, + {"math.least(2, 1u) == 1u"}, + {"math.least(1.5, 2) == 1.5"}, + {"math.least(1.5, -2) == -2"}, + {"math.least(2.5, 1u) == 1u"}, + {"math.least(1u, 2) == 1u"}, + {"math.least(1u, -2) == -2"}, + {"math.least(2u, 2.5) == 2u"}, + // math.least with dynamic values across type. + {"math.least(1u, dyn(42)) == 1"}, + {"math.least(1u, dyn(42), dyn(0.0)) == 0u"}, + // math.least with a list literal. + {"math.least([1u, 42u, 0u]) == 0u"}, + // math.least errors + { + "math.least()", + "math.least() requires at least one argument.", + }, + { + "math.least('hello')", + "math.least() invalid single argument value.", + }, + { + "math.least({})", + "math.least() invalid single argument value", + }, + { + "math.least([])", + "math.least() invalid single argument value", + }, + { + "math.least([1, true])", + "math.least() invalid single argument value", + }, + { + "math.least(1, true)", + "math.least() simple literal arguments must be numeric", + }, + { + "math.least(1, 2, true)", + "math.least() simple literal arguments must be numeric", + }, + + // Tests for math.greatest + {"math.greatest(-0.5) == -0.5"}, + {"math.greatest(-1) == -1"}, + {"math.greatest(1u) == 1u"}, + {"math.greatest(42.0, -0.5) == 42.0"}, + {"math.greatest(-1, 0) == 0"}, + {"math.greatest(-1, -1) == -1"}, + {"math.greatest(1u, 42u) == 42u"}, + {"math.greatest(42.0, -0.5, -0.25) == 42.0"}, + {"math.greatest(-1, 0, 1) == 1"}, + {"math.greatest(-1, -1, -1) == -1"}, + {"math.greatest(1u, 42u, 0u) == 42u"}, + // math.least two arg overloads across type. + {"math.greatest(1, 1.0) == 1"}, + {"math.greatest(1, -2.0) == 1"}, + {"math.greatest(2, 1u) == 2"}, + {"math.greatest(1.5, 2) == 2"}, + {"math.greatest(1.5, -2) == 1.5"}, + {"math.greatest(2.5, 1u) == 2.5"}, + {"math.greatest(1u, 2) == 2"}, + {"math.greatest(1u, -2) == 1u"}, + {"math.greatest(2u, 2.5) == 2.5"}, + // math.greatest with dynamic values across type. + {"math.greatest(1u, dyn(42)) == 42.0"}, + {"math.greatest(1u, dyn(0.0), 0u) == 1"}, + // math.greatest with a list literal + {"math.greatest([1u, dyn(0.0), 0u]) == 1"}, + // math.greatest errors + { + "math.greatest()", + "math.greatest() requires at least one argument.", + }, + { + "math.greatest('hello')", + "math.greatest() invalid single argument value.", + }, + { + "math.greatest({})", + "math.greatest() invalid single argument value", + }, + { + "math.greatest([])", + "math.greatest() invalid single argument value", + }, + { + "math.greatest([1, true])", + "math.greatest() invalid single argument value", + }, + { + "math.greatest(1, true)", + "math.greatest() simple literal arguments must be numeric", + }, + { + "math.greatest(1, 2, true)", + "math.greatest() simple literal arguments must be numeric", + }, + // Call signatures which trigger macro expansion, but which do not + // get expanded. The function just returns true. + { + "false.greatest(1,2)", + }, + { + "true.least(1,2)", + }, + // Basic coverage for function definitions. Behavior is tested in the + // conformance tests. + {"math.sign(-12) == -1"}, + {"math.sign(0u) == 0u"}, + {"math.sign(42.01) == 1.0"}, + {"math.abs(-12) == 12"}, + {"math.abs(0u) == 0u"}, + {"math.abs(42.01) == 42.01"}, + {"math.ceil(42.01) == 43.0"}, + {"math.floor(42.01) == 42.0"}, + {"math.round(42.5) == 43.0"}, + {"math.sqrt(49.0) == 7.0"}, + {"math.sqrt(0) == 0.0"}, + {"math.sqrt(1) == 1.0"}, + {"math.sqrt(25u) == 5.0"}, + {"math.sqrt(38.44) == 6.2"}, + {"math.isNaN(math.sqrt(-15)) == true"}, + {"math.trunc(42.0) == 42.0"}, + {"math.isInf(42.0 / 0.0) == true"}, + {"math.isNaN(double('nan')) == true"}, + {"math.isFinite(42.1) == true"}, + {"math.bitAnd(3, 1) == 1"}, + {"math.bitAnd(3u, 1u) == 1u"}, + {"math.bitOr(2, 1) == 3"}, + {"math.bitOr(2u, 1u) == 3u"}, + {"math.bitXor(3, 1) == 2"}, + {"math.bitXor(3u, 1u) == 2u"}, + {"math.bitNot(2) == -3"}, + {"math.bitAnd(math.bitNot(0x3u), 0xFFu) == 0xFCu"}, + {"math.bitShiftLeft(1, 1) == 2"}, + {"math.bitShiftLeft(-1, 1) == -2"}, + {"math.bitShiftLeft(-4, 2) == -16"}, + {"math.bitShiftLeft(1u, 1) == 2u"}, + {"math.bitShiftRight(4, 1) == 2"}, + {"math.bitShiftRight(4u, 1) == 2u"}})); + +struct MathExtensionVersionTestCase { + std::string expr; + std::vector expected_supported_versions; +}; + +class MathExtensionVersionTest + : public ::testing::TestWithParam {}; + +TEST_P(MathExtensionVersionTest, MathExtensionVersions) { + const MathExtensionVersionTestCase& test_case = GetParam(); + for (int version = 0; version <= cel::extensions::kMathExtensionLatestVersion; + ++version) { + CompilerLibrary compiler_library = MathCompilerLibrary(version); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), + CompilerOptions())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(test_case.expr)); + if (absl::c_contains(test_case.expected_supported_versions, version)) { + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << "Expected no issues for expr: " << test_case.expr + << " at version: " << version << " but got: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference")))) + << "Expected undeclared reference for expr: " << test_case.expr + << " at version: " << version; + } + } +}; + +std::vector CreateMathExtensionVersionParams() { + return { + MathExtensionVersionTestCase{ + .expr = "math.least([0,1,2,3])", + .expected_supported_versions = {0, 1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.greatest([0,1,2,3])", + .expected_supported_versions = {0, 1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.ceil(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.floor(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.round(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.trunc(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.isInf(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.isNaN(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.isFinite(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.abs(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.sign(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitAnd(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitOr(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitXor(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitNot(1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitShiftLeft(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitShiftRight(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.sqrt(1.5)", + .expected_supported_versions = {2}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(MathExtensionVersionTest, MathExtensionVersionTest, + ValuesIn(CreateMathExtensionVersionParams())); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/proto_ext.cc b/extensions/proto_ext.cc new file mode 100644 index 000000000..48618f7ae --- /dev/null +++ b/extensions/proto_ext.cc @@ -0,0 +1,128 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/proto_ext.h" + +#include +#include +#include + +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "common/expr.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/parser_interface.h" + +namespace cel::extensions { + +namespace { + +static constexpr char kProtoNamespace[] = "proto"; +static constexpr char kGetExt[] = "getExt"; +static constexpr char kHasExt[] = "hasExt"; + +absl::optional ValidateExtensionIdentifier(const Expr& expr) { + return absl::visit( + absl::Overload( + [](const SelectExpr& select_expr) -> absl::optional { + if (select_expr.test_only()) { + return std::nullopt; + } + auto op_name = ValidateExtensionIdentifier(select_expr.operand()); + if (!op_name.has_value()) { + return std::nullopt; + } + return absl::StrCat(*op_name, ".", select_expr.field()); + }, + [](const IdentExpr& ident_expr) -> absl::optional { + return ident_expr.name(); + }, + [](const auto&) -> absl::optional { + return std::nullopt; + }), + expr.kind()); +} + +absl::optional GetExtensionFieldName(const Expr& expr) { + if (const auto* select_expr = + expr.has_select_expr() ? &expr.select_expr() : nullptr; + select_expr) { + return ValidateExtensionIdentifier(expr); + } + return std::nullopt; +} + +bool IsExtensionCall(const Expr& target) { + if (const auto* ident_expr = + target.has_ident_expr() ? &target.ident_expr() : nullptr; + ident_expr) { + return ident_expr->name() == kProtoNamespace; + } + return false; +} + +absl::Status ConfigureParser(ParserBuilder& builder) { + for (const auto& macro : proto_macros()) { + CEL_RETURN_IF_ERROR(builder.AddMacro(macro)); + } + return absl::OkStatus(); +} + +} // namespace + +std::vector proto_macros() { + absl::StatusOr getExt = Macro::Receiver( + kGetExt, 2, + [](MacroExprFactory& factory, Expr& target, + absl::Span arguments) -> absl::optional { + if (!IsExtensionCall(target)) { + return std::nullopt; + } + auto extFieldName = GetExtensionFieldName(arguments[1]); + if (!extFieldName.has_value()) { + return factory.ReportErrorAt(arguments[1], "invalid extension field"); + } + return factory.NewSelect(std::move(arguments[0]), + std::move(*extFieldName)); + }); + absl::StatusOr hasExt = Macro::Receiver( + kHasExt, 2, + [](MacroExprFactory& factory, Expr& target, + absl::Span arguments) -> absl::optional { + if (!IsExtensionCall(target)) { + return std::nullopt; + } + auto extFieldName = GetExtensionFieldName(arguments[1]); + if (!extFieldName.has_value()) { + return factory.ReportErrorAt(arguments[1], "invalid extension field"); + } + return factory.NewPresenceTest(std::move(arguments[0]), + std::move(*extFieldName)); + }); + return {*hasExt, *getExt}; +} + +CompilerLibrary ProtoExtCompilerLibrary() { + return CompilerLibrary("cel.lib.ext.proto", ConfigureParser); +} + +} // namespace cel::extensions diff --git a/extensions/proto_ext.h b/extensions/proto_ext.h new file mode 100644 index 000000000..82e086aba --- /dev/null +++ b/extensions/proto_ext.h @@ -0,0 +1,42 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTO_EXT_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTO_EXT_H_ + +#include + +#include "absl/status/status.h" +#include "compiler/compiler.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +// proto_macros returns the macros which are useful for working with protobuf +// objects in CEL. Specifically, the proto.getExt() and proto.hasExt() macros. +std::vector proto_macros(); + +// Library for the proto extensions. +CompilerLibrary ProtoExtCompilerLibrary(); + +inline absl::Status RegisterProtoMacros(MacroRegistry& registry, + const ParserOptions&) { + return registry.RegisterMacros(proto_macros()); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTO_EXT_H_ diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD new file mode 100644 index 000000000..3f4081b09 --- /dev/null +++ b/extensions/protobuf/BUILD @@ -0,0 +1,224 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "memory_manager", + srcs = ["memory_manager.cc"], + hdrs = ["memory_manager.h"], + deps = [ + "//common:memory", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "memory_manager_test", + srcs = ["memory_manager_test.cc"], + deps = [ + ":memory_manager", + "//common:memory", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "ast_converters", + hdrs = ["ast_converters.h"], + deps = [ + "//common:ast", + "//common:ast_proto", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_library( + name = "runtime_adapter", + srcs = ["runtime_adapter.cc"], + hdrs = ["runtime_adapter.h"], + deps = [ + ":ast_converters", + "//internal:status_macros", + "//runtime", + "//runtime:runtime_builder", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "enum_adapter", + srcs = ["enum_adapter.cc"], + hdrs = ["enum_adapter.h"], + deps = [ + "//runtime:type_registry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "value", + hdrs = [ + "value.h", + ], + deps = [ + "//common:memory", + "//common:type", + "//common:value", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_test( + name = "value_test", + srcs = [ + "value_test.cc", + ], + deps = [ + ":value", + "//base:attributes", + "//common:casting", + "//common:value", + "//common:value_kind", + "//common:value_testing", + "//internal:testing", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_test( + name = "value_end_to_end_test", + srcs = ["value_end_to_end_test.cc"], + deps = [ + ":runtime_adapter", + "//common:value", + "//common:value_testing", + "//internal:testing", + "//parser", + "//runtime", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "bind_proto_to_activation", + srcs = ["bind_proto_to_activation.cc"], + hdrs = ["bind_proto_to_activation.h"], + deps = [ + ":value", + "//common:casting", + "//common:value", + "//internal:status_macros", + "//runtime:activation", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "bind_proto_to_activation_test", + srcs = ["bind_proto_to_activation_test.cc"], + deps = [ + ":bind_proto_to_activation", + "//common:casting", + "//common:value", + "//common:value_testing", + "//internal:testing", + "//runtime:activation", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_library( + name = "value_testing", + testonly = True, + hdrs = ["value_testing.h"], + deps = [ + ":value", + "//common:value", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "value_testing_test", + srcs = ["value_testing_test.cc"], + deps = [ + ":value", + ":value_testing", + "//common:value", + "//common:value_testing", + "//internal:proto_matchers", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + ], +) diff --git a/extensions/protobuf/ast_converters.h b/extensions/protobuf/ast_converters.h new file mode 100644 index 000000000..a8295c552 --- /dev/null +++ b/extensions/protobuf/ast_converters.h @@ -0,0 +1,56 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_AST_CONVERTERS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_AST_CONVERTERS_H_ + +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/ast_proto.h" + +namespace cel::extensions { + +// Creates a runtime AST from a parsed-only protobuf AST. +// May return a non-ok Status if the AST is malformed (e.g. unset required +// fields). +ABSL_DEPRECATED("Use cel::CreateAstFromParsedExpr instead.") +inline absl::StatusOr> CreateAstFromParsedExpr( + const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info = nullptr) { + return cel::CreateAstFromParsedExpr(expr, source_info); +} + +ABSL_DEPRECATED("Use cel::CreateAstFromParsedExpr instead.") +inline absl::StatusOr> CreateAstFromParsedExpr( + const cel::expr::ParsedExpr& parsed_expr) { + return cel::CreateAstFromParsedExpr(parsed_expr); +} + +// Creates a runtime AST from a checked protobuf AST. +// May return a non-ok Status if the AST is malformed (e.g. unset required +// fields). +ABSL_DEPRECATED("Use cel::CreateAstFromCheckedExpr instead.") +inline absl::StatusOr> CreateAstFromCheckedExpr( + const cel::expr::CheckedExpr& checked_expr) { + return cel::CreateAstFromCheckedExpr(checked_expr); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_AST_CONVERTERS_H_ diff --git a/extensions/protobuf/bind_proto_to_activation.cc b/extensions/protobuf/bind_proto_to_activation.cc new file mode 100644 index 000000000..aa151cb85 --- /dev/null +++ b/extensions/protobuf/bind_proto_to_activation.cc @@ -0,0 +1,92 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/bind_proto_to_activation.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "runtime/activation.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions::protobuf_internal { + +namespace { + +using ::google::protobuf::Descriptor; + +absl::StatusOr ShouldBindField( + const google::protobuf::FieldDescriptor* field_desc, const StructValue& struct_value, + BindProtoUnsetFieldBehavior unset_field_behavior) { + if (unset_field_behavior == BindProtoUnsetFieldBehavior::kBindDefaultValue || + field_desc->is_repeated()) { + return true; + } + return struct_value.HasFieldByNumber(field_desc->number()); +} + +absl::StatusOr GetFieldValue( + const google::protobuf::FieldDescriptor* field_desc, const StructValue& struct_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + // Special case unset any. + if (field_desc->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + field_desc->message_type()->well_known_type() == + Descriptor::WELLKNOWNTYPE_ANY) { + CEL_ASSIGN_OR_RETURN(bool present, + struct_value.HasFieldByNumber(field_desc->number())); + if (!present) { + return NullValue(); + } + } + + return struct_value.GetFieldByNumber(field_desc->number(), descriptor_pool, + message_factory, arena); +} + +} // namespace + +absl::Status BindProtoToActivation( + const Descriptor& descriptor, const StructValue& struct_value, + BindProtoUnsetFieldBehavior unset_field_behavior, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation) { + for (int i = 0; i < descriptor.field_count(); i++) { + const google::protobuf::FieldDescriptor* field_desc = descriptor.field(i); + CEL_ASSIGN_OR_RETURN( + bool should_bind, + ShouldBindField(field_desc, struct_value, unset_field_behavior)); + if (!should_bind) { + continue; + } + + CEL_ASSIGN_OR_RETURN( + Value field, GetFieldValue(field_desc, struct_value, descriptor_pool, + message_factory, arena)); + + activation->InsertOrAssignValue(field_desc->name(), std::move(field)); + } + + return absl::OkStatus(); +} + +} // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/bind_proto_to_activation.h b/extensions/protobuf/bind_proto_to_activation.h new file mode 100644 index 000000000..61f43c13d --- /dev/null +++ b/extensions/protobuf/bind_proto_to_activation.h @@ -0,0 +1,130 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "common/casting.h" +#include "common/value.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "runtime/activation.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +// Option for handling unset fields on the context proto. +enum class BindProtoUnsetFieldBehavior { + // Bind the message defined default or zero value. + kBindDefaultValue, + // Skip binding unset fields, no value is bound for the corresponding + // variable. + kSkip +}; + +namespace protobuf_internal { + +// Implements binding provided the context message has already +// been adapted to a suitable struct value. +absl::Status BindProtoToActivation( + const google::protobuf::Descriptor& descriptor, const StructValue& struct_value, + BindProtoUnsetFieldBehavior unset_field_behavior, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation); + +} // namespace protobuf_internal + +// Utility method, that takes a protobuf Message and interprets it as a +// namespace, binding its fields to Activation. This is often referred to as a +// context message. +// +// Field names and values become respective names and values of parameters +// bound to the Activation object. +// Example: +// Assume we have a protobuf message of type: +// message Person { +// int age = 1; +// string name = 2; +// } +// +// The sample code snippet will look as follows: +// +// Person person; +// person.set_name("John Doe"); +// person.age(42); +// +// CEL_RETURN_IF_ERROR(BindProtoToActivation(person, value_factory, +// activation)); +// +// After this snippet, activation will have two parameters bound: +// "name", with string value of "John Doe" +// "age", with int value of 42. +// +// The default behavior for unset fields is to skip them. E.g. if the name field +// is not set on the Person message, it will not be bound in to the activation. +// BindProtoUnsetFieldBehavior::kBindDefault, will bind the cc proto api default +// for the field (either an explicit default value or a type specific default). +// +// For repeated fields, an unset field is bound as an empty list. +template +absl::Status BindProtoToActivation( + const T& context, BindProtoUnsetFieldBehavior unset_field_behavior, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation) { + static_assert(std::is_base_of_v); + // TODO(uncreated-issue/68): for simplicity, just convert the whole message to a + // struct value. For performance, may be better to convert members as needed. + CEL_ASSIGN_OR_RETURN( + Value parent, + ProtoMessageToValue(context, descriptor_pool, message_factory, arena)); + + if (!InstanceOf(parent)) { + return absl::InvalidArgumentError( + absl::StrCat("context is a well-known type: ", context.GetTypeName())); + } + const StructValue& struct_value = Cast(parent); + + const google::protobuf::Descriptor* descriptor = context.GetDescriptor(); + + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("context missing descriptor: ", context.GetTypeName())); + } + + return protobuf_internal::BindProtoToActivation( + *descriptor, struct_value, unset_field_behavior, descriptor_pool, + message_factory, arena, activation); +} +template +absl::Status BindProtoToActivation( + const T& context, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation) { + return BindProtoToActivation(context, BindProtoUnsetFieldBehavior::kSkip, + descriptor_pool, message_factory, arena, + activation); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ diff --git a/extensions/protobuf/bind_proto_to_activation_test.cc b/extensions/protobuf/bind_proto_to_activation_test.cc new file mode 100644 index 000000000..680b4b353 --- /dev/null +++ b/extensions/protobuf/bind_proto_to_activation_test.cc @@ -0,0 +1,245 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/bind_proto_to_activation.h" + +#include "google/protobuf/wrappers.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/types/optional.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto2::TestAllTypes; +using ::cel::test::IntValueIs; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Optional; + +using BindProtoToActivationTest = common_internal::ValueTest<>; + +TEST_F(BindProtoToActivationTest, BindProtoToActivation) { + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("single_int64", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IntValueIs(123)))); +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationWktUnsupported) { + google::protobuf::Int64Value int64_value; + int64_value.set_value(123); + Activation activation; + + EXPECT_THAT(BindProtoToActivation(int64_value, descriptor_pool(), + message_factory(), arena(), &activation), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("google.protobuf.Int64Value"))); +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationSkip) { + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("single_int32", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(std::nullopt))); + EXPECT_THAT(activation.FindVariable("single_sint32", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(std::nullopt))); +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationDefault) { + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_THAT( + BindProtoToActivation( + test_all_types, BindProtoUnsetFieldBehavior::kBindDefaultValue, + descriptor_pool(), message_factory(), arena(), &activation), + IsOk()); + + // from test_all_types.proto + // optional int32 single_int32 = 1 [default = -32]; + EXPECT_THAT(activation.FindVariable("single_int32", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IntValueIs(-32)))); + EXPECT_THAT(activation.FindVariable("single_sint32", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IntValueIs(0)))); +} + +// Special case any fields. Mirrors go evaluator behavior. +TEST_F(BindProtoToActivationTest, BindProtoToActivationDefaultAny) { + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_THAT( + BindProtoToActivation( + test_all_types, BindProtoUnsetFieldBehavior::kBindDefaultValue, + descriptor_pool(), message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("single_any", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(test::IsNullValue()))); +} + +MATCHER_P(IsListValueOfSize, size, "") { + const Value& v = arg; + + auto value = As(v); + if (!value) { + return false; + } + auto s = value->Size(); + return s.ok() && *s == size; +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationRepeated) { + TestAllTypes test_all_types; + test_all_types.add_repeated_int64(123); + test_all_types.add_repeated_int64(456); + test_all_types.add_repeated_int64(789); + + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("repeated_int64", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsListValueOfSize(3)))); +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationRepeatedEmpty) { + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("repeated_int32", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsListValueOfSize(0)))); +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationRepeatedComplex) { + TestAllTypes test_all_types; + auto* nested = test_all_types.add_repeated_nested_message(); + nested->set_bb(123); + nested = test_all_types.add_repeated_nested_message(); + nested->set_bb(456); + nested = test_all_types.add_repeated_nested_message(); + nested->set_bb(789); + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT( + activation.FindVariable("repeated_nested_message", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsListValueOfSize(3)))); +} + +MATCHER_P(IsMapValueOfSize, size, "") { + const Value& v = arg; + + auto value = As(v); + if (!value) { + return false; + } + auto s = value->Size(); + return s.ok() && *s == size; +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationMap) { + TestAllTypes test_all_types; + (*test_all_types.mutable_map_int64_int64())[1] = 2; + (*test_all_types.mutable_map_int64_int64())[2] = 4; + + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("map_int64_int64", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsMapValueOfSize(2)))); +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationMapEmpty) { + TestAllTypes test_all_types; + test_all_types.set_single_int64(123); + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("map_int32_int32", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsMapValueOfSize(0)))); +} + +TEST_F(BindProtoToActivationTest, BindProtoToActivationMapComplex) { + TestAllTypes test_all_types; + TestAllTypes::NestedMessage value; + value.set_bb(42); + (*test_all_types.mutable_map_int64_message())[1] = value; + (*test_all_types.mutable_map_int64_message())[2] = value; + + Activation activation; + + ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), + message_factory(), arena(), &activation), + IsOk()); + + EXPECT_THAT(activation.FindVariable("map_int64_message", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsMapValueOfSize(2)))); +} + +} // namespace +} // namespace cel::extensions diff --git a/extensions/protobuf/enum_adapter.cc b/extensions/protobuf/enum_adapter.cc new file mode 100644 index 000000000..113b1e7d1 --- /dev/null +++ b/extensions/protobuf/enum_adapter.cc @@ -0,0 +1,48 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "extensions/protobuf/enum_adapter.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel::extensions { + +absl::Status RegisterProtobufEnum( + TypeRegistry& registry, const google::protobuf::EnumDescriptor* enum_descriptor) { + if (registry.resolveable_enums().contains(enum_descriptor->full_name())) { + return absl::AlreadyExistsError( + absl::StrCat(enum_descriptor->full_name(), " already registered.")); + } + + // TODO(uncreated-issue/42): the registry enum implementation runs linear lookups for + // constants since this isn't expected to happen at runtime. Consider updating + // if / when strong enum typing is implemented. + std::vector enumerators; + enumerators.reserve(enum_descriptor->value_count()); + for (int i = 0; i < enum_descriptor->value_count(); i++) { + enumerators.push_back({std::string(enum_descriptor->value(i)->name()), + enum_descriptor->value(i)->number()}); + } + registry.RegisterEnum(enum_descriptor->full_name(), std::move(enumerators)); + + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/protobuf/enum_adapter.h b/extensions/protobuf/enum_adapter.h new file mode 100644 index 000000000..c5c1c5ebf --- /dev/null +++ b/extensions/protobuf/enum_adapter.h @@ -0,0 +1,30 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_ENUM_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_ENUM_ADAPTER_H_ + +#include "absl/status/status.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel::extensions { + +// Register a resolveable enum for the given runtime builder. +absl::Status RegisterProtobufEnum( + TypeRegistry& registry, const google::protobuf::EnumDescriptor* enum_descriptor); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_ENUM_ADAPTER_H_ diff --git a/extensions/protobuf/internal/BUILD b/extensions/protobuf/internal/BUILD new file mode 100644 index 000000000..4a3a3e82b --- /dev/null +++ b/extensions/protobuf/internal/BUILD @@ -0,0 +1,58 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "map_reflection", + srcs = ["map_reflection.cc"], + hdrs = ["map_reflection.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "qualify", + srcs = ["qualify.cc"], + hdrs = ["qualify.h"], + deps = [ + ":map_reflection", + "//base:attributes", + "//base:builtins", + "//common:kind", + "//common:memory", + "//internal:status_macros", + "//runtime:runtime_options", + "//runtime/internal:errors", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/extensions/protobuf/internal/map_reflection.cc b/extensions/protobuf/internal/map_reflection.cc new file mode 100644 index 000000000..605e4437d --- /dev/null +++ b/extensions/protobuf/internal/map_reflection.cc @@ -0,0 +1,132 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/internal/map_reflection.h" + +#include "absl/base/nullability.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +namespace google::protobuf::expr { + +class CelMapReflectionFriend final { + public: + static bool LookupMapValue(const Reflection& reflection, + const Message& message, + const FieldDescriptor& field, const MapKey& key, + MapValueConstRef* value) { + return reflection.LookupMapValue(message, &field, key, value); + } + + static bool ContainsMapKey(const Reflection& reflection, + const Message& message, + const FieldDescriptor& field, const MapKey& key) { + return reflection.ContainsMapKey(message, &field, key); + } + + static int MapSize(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.MapSize(message, &field); + } + + static google::protobuf::ConstMapIterator ConstMapBegin( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapBegin(&message, &field); + } + + static google::protobuf::ConstMapIterator ConstMapEnd( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapEnd(&message, &field); + } + + static bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, + google::protobuf::Message* message, + const google::protobuf::FieldDescriptor& field, + const google::protobuf::MapKey& key, + google::protobuf::MapValueRef* value) { + return reflection.InsertOrLookupMapValue(message, &field, key, value); + } + + static bool DeleteMapValue(const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::MapKey& key) { + return reflection->DeleteMapValue(message, field, key); + } +}; + +} // namespace google::protobuf::expr + +namespace cel::extensions::protobuf_internal { + +bool LookupMapValue(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field, + const google::protobuf::MapKey& key, + google::protobuf::MapValueConstRef* value) { + return google::protobuf::expr::CelMapReflectionFriend::LookupMapValue( + reflection, message, field, key, value); +} + +bool ContainsMapKey(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field, + const google::protobuf::MapKey& key) { + return google::protobuf::expr::CelMapReflectionFriend::ContainsMapKey( + reflection, message, field, key); +} + +int MapSize(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::MapSize(reflection, message, + field); +} + +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapBegin(reflection, + message, field); +} + +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapEnd(reflection, message, + field); +} + +bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, + google::protobuf::Message* message, + const google::protobuf::FieldDescriptor& field, + const google::protobuf::MapKey& key, + google::protobuf::MapValueRef* value) { + return google::protobuf::expr::CelMapReflectionFriend::InsertOrLookupMapValue( + reflection, message, field, key, value); +} + +bool DeleteMapValue(const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::MapKey& key) { + return google::protobuf::expr::CelMapReflectionFriend::DeleteMapValue( + reflection, message, field, key); +} + +} // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/internal/map_reflection.h b/extensions/protobuf/internal/map_reflection.h new file mode 100644 index 000000000..681d7693d --- /dev/null +++ b/extensions/protobuf/internal/map_reflection.h @@ -0,0 +1,67 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MAP_REFLECTION_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MAP_REFLECTION_H_ + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +#ifndef GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND +#error "protobuf library is too old, please update to version 3.15.0 or newer" +#endif + +namespace cel::extensions::protobuf_internal { + +bool LookupMapValue(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field, + const google::protobuf::MapKey& key, google::protobuf::MapValueConstRef* value) + ABSL_ATTRIBUTE_NONNULL(); + +bool ContainsMapKey(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field, + const google::protobuf::MapKey& key); + +int MapSize(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); + +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); + +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); + +bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, + google::protobuf::Message* message, + const google::protobuf::FieldDescriptor& field, + const google::protobuf::MapKey& key, + google::protobuf::MapValueRef* value) + ABSL_ATTRIBUTE_NONNULL(); + +bool DeleteMapValue(const google::protobuf::Reflection* absl_nonnull reflection, + google::protobuf::Message* absl_nonnull message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::MapKey& key); + +} // namespace cel::extensions::protobuf_internal + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MAP_REFLECTION_H_ diff --git a/extensions/protobuf/internal/qualify.cc b/extensions/protobuf/internal/qualify.cc new file mode 100644 index 000000000..37ad30011 --- /dev/null +++ b/extensions/protobuf/internal/qualify.cc @@ -0,0 +1,457 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/internal/qualify.h" + +#include +#include + +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "base/attribute.h" +#include "base/builtins.h" +#include "common/kind.h" +#include "common/memory.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/status_macros.h" +#include "runtime/internal/errors.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +#undef GetMessage + +namespace cel::extensions::protobuf_internal { + +namespace { + +const google::protobuf::FieldDescriptor* GetNormalizedFieldByNumber( + const google::protobuf::Descriptor* descriptor, const google::protobuf::Reflection* reflection, + int field_number) { + const google::protobuf::FieldDescriptor* field_desc = + descriptor->FindFieldByNumber(field_number); + if (field_desc == nullptr && reflection != nullptr) { + field_desc = reflection->FindKnownExtensionByNumber(field_number); + } + return field_desc; +} + +// JSON container types and Any have special unpacking rules. +// +// Not considered for qualify traversal for simplicity, but +// could be supported in a follow-up if needed. +bool IsUnsupportedQualifyType(const google::protobuf::Descriptor& desc) { + switch (desc.well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + return true; + default: + return false; + } +} + +constexpr int kKeyTag = 1; +constexpr int kValueTag = 2; + +bool MatchesMapKeyType(const google::protobuf::FieldDescriptor* key_desc, + const cel::AttributeQualifier& key) { + switch (key_desc->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return key.kind() == cel::Kind::kBool; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + // fall through + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return key.kind() == cel::Kind::kInt64; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + // fall through + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return key.kind() == cel::Kind::kUint64; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return key.kind() == cel::Kind::kString; + + default: + return false; + } +} + +absl::StatusOr> LookupMapValue( + const google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field_desc, + const google::protobuf::FieldDescriptor* key_desc, + const cel::AttributeQualifier& key) { + if (!MatchesMapKeyType(key_desc, key)) { + return runtime_internal::CreateInvalidMapKeyTypeError( + key_desc->cpp_type_name()); + } + + std::string proto_key_string; + google::protobuf::MapKey proto_key; + switch (key_desc->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + proto_key.SetBoolValue(*key.GetBoolKey()); + break; + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + int64_t key_value = *key.GetInt64Key(); + if (key_value > std::numeric_limits::max() || + key_value < std::numeric_limits::lowest()) { + return absl::OutOfRangeError("integer overflow"); + } + proto_key.SetInt32Value(key_value); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + proto_key.SetInt64Value(*key.GetInt64Key()); + break; + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + proto_key_string = std::string(*key.GetStringKey()); + proto_key.SetStringValue(proto_key_string); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + uint64_t key_value = *key.GetUint64Key(); + if (key_value > std::numeric_limits::max()) { + return absl::OutOfRangeError("unsigned integer overflow"); + } + proto_key.SetUInt32Value(key_value); + } break; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + proto_key.SetUInt64Value(*key.GetUint64Key()); + } break; + default: + return runtime_internal::CreateInvalidMapKeyTypeError( + key_desc->cpp_type_name()); + } + + // Look the value up + google::protobuf::MapValueConstRef value_ref; + bool found = cel::extensions::protobuf_internal::LookupMapValue( + *reflection, *message, *field_desc, proto_key, &value_ref); + if (!found) { + return std::nullopt; + } + return value_ref; +} + +bool FieldIsPresent(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_desc, + const google::protobuf::Reflection* reflection) { + if (field_desc->is_map()) { + // When the map field appears in a has(msg.map_field) expression, the map + // is considered 'present' when it is non-empty. Since maps are repeated + // fields they don't participate with standard proto presence testing + // since the repeated field is always at least empty. + return reflection->FieldSize(*message, field_desc) != 0; + } + + if (field_desc->is_repeated()) { + // When the list field appears in a has(msg.list_field) expression, the + // list is considered 'present' when it is non-empty. + return reflection->FieldSize(*message, field_desc) != 0; + } + + // Standard proto presence test for non-repeated fields. + return reflection->HasField(*message, field_desc); +} + +} // namespace + +absl::Status ProtoQualifyState::ApplySelectQualifier( + const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager) { + return absl::visit( + absl::Overload( + [&](const cel::AttributeQualifier& qualifier) -> absl::Status { + if (repeated_field_desc_ == nullptr) { + return absl::UnimplementedError( + "dynamic field access on message not supported"); + } + return ApplyAttributeQualifer(qualifier, memory_manager); + }, + [&](const cel::FieldSpecifier& field_specifier) -> absl::Status { + if (repeated_field_desc_ != nullptr) { + return absl::UnimplementedError( + "strong field access on container not supported"); + } + return ApplyFieldSpecifier(field_specifier, memory_manager); + }), + qualifier); +} + +absl::Status ProtoQualifyState::ApplyLastQualifierHas( + const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager) { + const cel::FieldSpecifier* specifier = + absl::get_if(&qualifier); + return absl::visit( + absl::Overload( + [&](const cel::AttributeQualifier& qualifier) mutable + -> absl::Status { + if (qualifier.kind() != cel::Kind::kString || + repeated_field_desc_ == nullptr || + !repeated_field_desc_->is_map()) { + SetResultFromError( + runtime_internal::CreateNoMatchingOverloadError("has"), + memory_manager); + return absl::OkStatus(); + } + return MapHas(qualifier, memory_manager); + }, + [&](const cel::FieldSpecifier& field_specifier) mutable + -> absl::Status { + const auto* field_desc = GetNormalizedFieldByNumber( + descriptor_, reflection_, specifier->number); + if (field_desc == nullptr) { + SetResultFromError( + runtime_internal::CreateNoSuchFieldError(specifier->name), + memory_manager); + return absl::OkStatus(); + } + SetResultFromBool( + FieldIsPresent(message_, field_desc, reflection_)); + return absl::OkStatus(); + }), + qualifier); +} + +absl::Status ProtoQualifyState::ApplyLastQualifierGet( + const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager) { + return absl::visit( + absl::Overload( + [&](const cel::AttributeQualifier& attr_qualifier) mutable + -> absl::Status { + if (repeated_field_desc_ == nullptr) { + return absl::UnimplementedError( + "dynamic field access on message not supported"); + } + if (repeated_field_desc_->is_map()) { + return ApplyLastQualifierGetMap(attr_qualifier, memory_manager); + } + return ApplyLastQualifierGetList(attr_qualifier, memory_manager); + }, + [&](const cel::FieldSpecifier& specifier) mutable -> absl::Status { + if (repeated_field_desc_ != nullptr) { + return absl::UnimplementedError( + "strong field access on container not supported"); + } + return ApplyLastQualifierMessageGet(specifier, memory_manager); + }), + qualifier); +} + +absl::Status ProtoQualifyState::ApplyFieldSpecifier( + const cel::FieldSpecifier& field_specifier, + MemoryManagerRef memory_manager) { + const google::protobuf::FieldDescriptor* field_desc = GetNormalizedFieldByNumber( + descriptor_, reflection_, field_specifier.number); + if (field_desc == nullptr) { + SetResultFromError( + runtime_internal::CreateNoSuchFieldError(field_specifier.name), + memory_manager); + return absl::OkStatus(); + } + + if (field_desc->is_repeated()) { + repeated_field_desc_ = field_desc; + return absl::OkStatus(); + } + + if (field_desc->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE || + IsUnsupportedQualifyType(*field_desc->message_type())) { + CEL_RETURN_IF_ERROR(SetResultFromField(message_, field_desc, + ProtoWrapperTypeOptions::kUnsetNull, + memory_manager)); + return absl::OkStatus(); + } + + message_ = &reflection_->GetMessage(*message_, field_desc); + descriptor_ = message_->GetDescriptor(); + reflection_ = message_->GetReflection(); + return absl::OkStatus(); +} + +absl::StatusOr ProtoQualifyState::CheckListIndex( + const cel::AttributeQualifier& qualifier) const { + if (qualifier.kind() != cel::Kind::kInt64) { + return runtime_internal::CreateNoMatchingOverloadError( + cel::builtin::kIndex); + } + + int index = *qualifier.GetInt64Key(); + int size = reflection_->FieldSize(*message_, repeated_field_desc_); + if (index < 0 || index >= size) { + return absl::InvalidArgumentError( + absl::StrCat("index out of bounds: index=", index, " size=", size)); + } + return index; +} + +absl::Status ProtoQualifyState::ApplyAttributeQualifierList( + const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { + ABSL_DCHECK_NE(repeated_field_desc_, nullptr); + ABSL_DCHECK(!repeated_field_desc_->is_map()); + ABSL_DCHECK_EQ(repeated_field_desc_->cpp_type(), + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); + + auto index_or = CheckListIndex(qualifier); + if (!index_or.ok()) { + SetResultFromError(std::move(index_or).status(), memory_manager); + return absl::OkStatus(); + } + + if (IsUnsupportedQualifyType(*repeated_field_desc_->message_type())) { + CEL_RETURN_IF_ERROR(SetResultFromRepeatedField( + message_, repeated_field_desc_, *index_or, memory_manager)); + return absl::OkStatus(); + } + + message_ = &reflection_->GetRepeatedMessage(*message_, repeated_field_desc_, + *index_or); + descriptor_ = message_->GetDescriptor(); + reflection_ = message_->GetReflection(); + repeated_field_desc_ = nullptr; + return absl::OkStatus(); +} + +absl::StatusOr ProtoQualifyState::CheckMapIndex( + const cel::AttributeQualifier& qualifier) const { + const auto* key_desc = + repeated_field_desc_->message_type()->FindFieldByNumber(kKeyTag); + + CEL_ASSIGN_OR_RETURN( + absl::optional value_ref, + LookupMapValue(message_, reflection_, repeated_field_desc_, key_desc, + qualifier)); + + if (!value_ref.has_value()) { + std::string key_string; + absl::StatusOr key_string_or = qualifier.AsString(); + if (key_string_or.ok()) { + key_string = *key_string_or; + } + return runtime_internal::CreateNoSuchKeyError(key_string); + } + return std::move(value_ref).value(); +} + +absl::Status ProtoQualifyState::ApplyAttributeQualifierMap( + const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { + ABSL_DCHECK_NE(repeated_field_desc_, nullptr); + ABSL_DCHECK(repeated_field_desc_->is_map()); + ABSL_DCHECK_EQ(repeated_field_desc_->cpp_type(), + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); + + absl::StatusOr value_ref = CheckMapIndex(qualifier); + if (!value_ref.ok()) { + SetResultFromError(std::move(value_ref).status(), memory_manager); + return absl::OkStatus(); + } + + const auto* value_desc = + repeated_field_desc_->message_type()->FindFieldByNumber(kValueTag); + + if (value_desc->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE || + IsUnsupportedQualifyType(*value_desc->message_type())) { + CEL_RETURN_IF_ERROR(SetResultFromMapField(message_, value_desc, *value_ref, + memory_manager)); + return absl::OkStatus(); + } + + message_ = &(value_ref->GetMessageValue()); + descriptor_ = message_->GetDescriptor(); + reflection_ = message_->GetReflection(); + repeated_field_desc_ = nullptr; + return absl::OkStatus(); +} + +absl::Status ProtoQualifyState::ApplyAttributeQualifer( + const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { + ABSL_DCHECK_NE(repeated_field_desc_, nullptr); + if (repeated_field_desc_->cpp_type() != + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + return absl::InternalError("Unexpected qualify intermediate type"); + } + if (repeated_field_desc_->is_map()) { + return ApplyAttributeQualifierMap(qualifier, memory_manager); + } // else simple repeated + return ApplyAttributeQualifierList(qualifier, memory_manager); +} + +absl::Status ProtoQualifyState::MapHas(const cel::AttributeQualifier& key, + MemoryManagerRef memory_manager) { + const auto* key_desc = + repeated_field_desc_->message_type()->FindFieldByNumber(kKeyTag); + + absl::StatusOr> value_ref = + LookupMapValue(message_, reflection_, repeated_field_desc_, key_desc, + key); + + if (!value_ref.ok()) { + SetResultFromError(std::move(value_ref).status(), memory_manager); + return absl::OkStatus(); + } + + SetResultFromBool(value_ref->has_value()); + return absl::OkStatus(); +} + +absl::Status ProtoQualifyState::ApplyLastQualifierMessageGet( + const cel::FieldSpecifier& specifier, MemoryManagerRef memory_manager) { + const auto* field_desc = + GetNormalizedFieldByNumber(descriptor_, reflection_, specifier.number); + if (field_desc == nullptr) { + SetResultFromError(runtime_internal::CreateNoSuchFieldError(specifier.name), + memory_manager); + return absl::OkStatus(); + } + return SetResultFromField(message_, field_desc, + ProtoWrapperTypeOptions::kUnsetNull, + memory_manager); +} + +absl::Status ProtoQualifyState::ApplyLastQualifierGetList( + const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { + ABSL_DCHECK(!repeated_field_desc_->is_map()); + + absl::StatusOr index = CheckListIndex(qualifier); + if (!index.ok()) { + SetResultFromError(std::move(index).status(), memory_manager); + return absl::OkStatus(); + } + return SetResultFromRepeatedField(message_, repeated_field_desc_, *index, + memory_manager); +} + +absl::Status ProtoQualifyState::ApplyLastQualifierGetMap( + const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { + ABSL_DCHECK(repeated_field_desc_->is_map()); + + absl::StatusOr value_ref = CheckMapIndex(qualifier); + + if (!value_ref.ok()) { + SetResultFromError(std::move(value_ref).status(), memory_manager); + return absl::OkStatus(); + } + + const auto* value_desc = + repeated_field_desc_->message_type()->FindFieldByNumber(kValueTag); + return SetResultFromMapField(message_, value_desc, *value_ref, + memory_manager); +} + +} // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/internal/qualify.h b/extensions/protobuf/internal/qualify.h new file mode 100644 index 000000000..39b5120f5 --- /dev/null +++ b/extensions/protobuf/internal/qualify.h @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_QUALIFY_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_QUALIFY_H_ + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/attribute.h" +#include "common/memory.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::extensions::protobuf_internal { + +class ProtoQualifyState { + public: + ProtoQualifyState(const google::protobuf::Message* absl_nonnull message, + const google::protobuf::Descriptor* absl_nonnull descriptor, + const google::protobuf::Reflection* absl_nonnull reflection) + : message_(message), + descriptor_(descriptor), + reflection_(reflection), + repeated_field_desc_(nullptr) {} + + virtual ~ProtoQualifyState() = default; + + ProtoQualifyState(const ProtoQualifyState&) = delete; + ProtoQualifyState& operator=(const ProtoQualifyState&) = delete; + + absl::Status ApplySelectQualifier(const cel::SelectQualifier& qualifier, + MemoryManagerRef memory_manager); + + absl::Status ApplyLastQualifierHas(const cel::SelectQualifier& qualifier, + MemoryManagerRef memory_manager); + + absl::Status ApplyLastQualifierGet(const cel::SelectQualifier& qualifier, + MemoryManagerRef memory_manager); + + private: + virtual void SetResultFromError(absl::Status status, + MemoryManagerRef memory_manager) = 0; + + virtual void SetResultFromBool(bool value) = 0; + + virtual absl::Status SetResultFromField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + ProtoWrapperTypeOptions unboxing_option, + MemoryManagerRef memory_manager) = 0; + + virtual absl::Status SetResultFromRepeatedField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + int index, MemoryManagerRef memory_manager) = 0; + + virtual absl::Status SetResultFromMapField( + const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, + const google::protobuf::MapValueConstRef& value, + MemoryManagerRef memory_manager) = 0; + + absl::Status ApplyFieldSpecifier(const cel::FieldSpecifier& field_specifier, + MemoryManagerRef memory_manager); + + absl::StatusOr CheckListIndex( + const cel::AttributeQualifier& qualifier) const; + + absl::Status ApplyAttributeQualifierList( + const cel::AttributeQualifier& qualifier, + MemoryManagerRef memory_manager); + + absl::StatusOr CheckMapIndex( + const cel::AttributeQualifier& qualifier) const; + + absl::Status ApplyAttributeQualifierMap( + const cel::AttributeQualifier& qualifier, + MemoryManagerRef memory_manager); + + absl::Status ApplyAttributeQualifer(const cel::AttributeQualifier& qualifier, + MemoryManagerRef memory_manager); + + absl::Status MapHas(const cel::AttributeQualifier& key, + MemoryManagerRef memory_manager); + + absl::Status ApplyLastQualifierMessageGet( + const cel::FieldSpecifier& specifier, MemoryManagerRef memory_manager); + + absl::Status ApplyLastQualifierGetList( + const cel::AttributeQualifier& qualifier, + MemoryManagerRef memory_manager); + + absl::Status ApplyLastQualifierGetMap( + const cel::AttributeQualifier& qualifier, + MemoryManagerRef memory_manager); + + const google::protobuf::Message* absl_nonnull message_; + const google::protobuf::Descriptor* absl_nonnull descriptor_; + const google::protobuf::Reflection* absl_nonnull reflection_; + const google::protobuf::FieldDescriptor* absl_nullable repeated_field_desc_; +}; + +} // namespace cel::extensions::protobuf_internal + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_QUALIFY_H_ diff --git a/extensions/protobuf/memory_manager.cc b/extensions/protobuf/memory_manager.cc new file mode 100644 index 000000000..5b3e6e74b --- /dev/null +++ b/extensions/protobuf/memory_manager.cc @@ -0,0 +1,37 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/memory_manager.h" + +#include "absl/base/nullability.h" +#include "common/memory.h" +#include "google/protobuf/arena.h" + +namespace cel { + +namespace extensions { + +MemoryManagerRef ProtoMemoryManager(google::protobuf::Arena* arena) { + return arena != nullptr ? MemoryManagerRef::Pooling(arena) + : MemoryManagerRef::ReferenceCounting(); +} + +google::protobuf::Arena* absl_nullable ProtoMemoryManagerArena( + MemoryManager memory_manager) { + return memory_manager.arena(); +} + +} // namespace extensions + +} // namespace cel diff --git a/extensions/protobuf/memory_manager.h b/extensions/protobuf/memory_manager.h new file mode 100644 index 000000000..08c1204db --- /dev/null +++ b/extensions/protobuf/memory_manager.h @@ -0,0 +1,56 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "common/memory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { + +// Returns an appropriate `MemoryManagerRef` wrapping `google::protobuf::Arena`. The +// lifetime of objects creating using the resulting `MemoryManagerRef` is tied +// to that of `google::protobuf::Arena`. +// +// IMPORTANT: Passing `nullptr` here will result in getting +// `MemoryManagerRef::ReferenceCounting()`. +MemoryManager ProtoMemoryManager(google::protobuf::Arena* arena); +inline MemoryManager ProtoMemoryManagerRef(google::protobuf::Arena* arena) { + return ProtoMemoryManager(arena); +} + +// Gets the underlying `google::protobuf::Arena`. If `MemoryManager` was not created using +// either `ProtoMemoryManagerRef` or `ProtoMemoryManager`, this returns +// `nullptr`. +google::protobuf::Arena* absl_nullable ProtoMemoryManagerArena( + MemoryManager memory_manager); +// Allocate and construct `T` using the `ProtoMemoryManager` provided as +// `memory_manager`. `memory_manager` must be `ProtoMemoryManager` or behavior +// is undefined. Unlike `MemoryManager::New`, this method supports arena-enabled +// messages. +template +ABSL_MUST_USE_RESULT T* NewInProtoArena(MemoryManager memory_manager, + Args&&... args) { + return google::protobuf::Arena::Create(ProtoMemoryManagerArena(memory_manager), + std::forward(args)...); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ diff --git a/extensions/protobuf/memory_manager_test.cc b/extensions/protobuf/memory_manager_test.cc new file mode 100644 index 000000000..ddab4cf32 --- /dev/null +++ b/extensions/protobuf/memory_manager_test.cc @@ -0,0 +1,58 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/memory_manager.h" + +#include "common/memory.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::NotNull; + +TEST(ProtoMemoryManager, MemoryManagement) { + google::protobuf::Arena arena; + auto memory_manager = ProtoMemoryManager(&arena); + EXPECT_EQ(memory_manager.memory_management(), MemoryManagement::kPooling); +} + +TEST(ProtoMemoryManager, Arena) { + google::protobuf::Arena arena; + auto memory_manager = ProtoMemoryManager(&arena); + EXPECT_THAT(ProtoMemoryManagerArena(memory_manager), NotNull()); +} + +TEST(ProtoMemoryManagerRef, MemoryManagement) { + google::protobuf::Arena arena; + auto memory_manager = ProtoMemoryManagerRef(&arena); + EXPECT_EQ(memory_manager.memory_management(), MemoryManagement::kPooling); + memory_manager = ProtoMemoryManagerRef(nullptr); + EXPECT_EQ(memory_manager.memory_management(), + MemoryManagement::kReferenceCounting); +} + +TEST(ProtoMemoryManagerRef, Arena) { + google::protobuf::Arena arena; + auto memory_manager = ProtoMemoryManagerRef(&arena); + EXPECT_THAT(ProtoMemoryManagerArena(memory_manager), Eq(&arena)); + memory_manager = ProtoMemoryManagerRef(nullptr); + EXPECT_THAT(ProtoMemoryManagerArena(memory_manager), IsNull()); +} + +} // namespace +} // namespace cel::extensions diff --git a/extensions/protobuf/runtime_adapter.cc b/extensions/protobuf/runtime_adapter.cc new file mode 100644 index 000000000..ca9f9354a --- /dev/null +++ b/extensions/protobuf/runtime_adapter.cc @@ -0,0 +1,54 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/runtime_adapter.h" + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/status/statusor.h" +#include "extensions/protobuf/ast_converters.h" +#include "internal/status_macros.h" +#include "runtime/runtime.h" + +namespace cel::extensions { + +absl::StatusOr> +ProtobufRuntimeAdapter::CreateProgram( + const Runtime& runtime, const cel::expr::CheckedExpr& expr, + const Runtime::CreateProgramOptions options) { + CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromCheckedExpr(expr)); + return runtime.CreateTraceableProgram(std::move(ast), options); +} + +absl::StatusOr> +ProtobufRuntimeAdapter::CreateProgram( + const Runtime& runtime, const cel::expr::ParsedExpr& expr, + const Runtime::CreateProgramOptions options) { + CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromParsedExpr(expr)); + return runtime.CreateTraceableProgram(std::move(ast), options); +} + +absl::StatusOr> +ProtobufRuntimeAdapter::CreateProgram( + const Runtime& runtime, const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info, + const Runtime::CreateProgramOptions options) { + CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromParsedExpr(expr, source_info)); + return runtime.CreateTraceableProgram(std::move(ast), options); +} + +} // namespace cel::extensions diff --git a/extensions/protobuf/runtime_adapter.h b/extensions/protobuf/runtime_adapter.h new file mode 100644 index 000000000..49af58a07 --- /dev/null +++ b/extensions/protobuf/runtime_adapter.h @@ -0,0 +1,51 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_RUNTIME_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_RUNTIME_ADAPTER_H_ + +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "google/protobuf/descriptor.h" + +namespace cel::extensions { + +// Helper class for cel::Runtime that converts the pb serialization format for +// expressions to the internal AST format. +class ProtobufRuntimeAdapter { + public: + // Only to be used for static member functions. + ProtobufRuntimeAdapter() = delete; + + static absl::StatusOr> CreateProgram( + const Runtime& runtime, const cel::expr::CheckedExpr& expr, + const Runtime::CreateProgramOptions options = {}); + static absl::StatusOr> CreateProgram( + const Runtime& runtime, const cel::expr::ParsedExpr& expr, + const Runtime::CreateProgramOptions options = {}); + static absl::StatusOr> CreateProgram( + const Runtime& runtime, const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info = nullptr, + const Runtime::CreateProgramOptions options = {}); +}; + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_RUNTIME_ADAPTER_H_ diff --git a/extensions/protobuf/value.h b/extensions/protobuf/value.h new file mode 100644 index 000000000..b7a654064 --- /dev/null +++ b/extensions/protobuf/value.h @@ -0,0 +1,98 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Utilities for wrapping and unwrapping cel::Values representing protobuf +// message types. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ + +#include +#include + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +// Adapt a protobuf message to a cel::Value. +// +// Handles unwrapping message types with special meanings in CEL (WKTs). +// +// T value must be a protobuf message class. +template +std::enable_if_t>, + absl::StatusOr> +ProtoMessageToValue(T&& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return Value::FromMessage(std::forward(value), descriptor_pool, + message_factory, arena); +} + +inline absl::Status ProtoMessageFromValue(const Value& value, + google::protobuf::Message& dest_message) { + const auto* dest_descriptor = dest_message.GetDescriptor(); + const google::protobuf::Message* src_message = nullptr; + if (auto legacy_struct_value = + cel::common_internal::AsLegacyStructValue(value); + legacy_struct_value) { + src_message = legacy_struct_value->message_ptr(); + } + if (auto parsed_message_value = value.AsParsedMessage(); + parsed_message_value) { + src_message = cel::to_address(*parsed_message_value); + } + if (src_message != nullptr) { + const auto* src_descriptor = src_message->GetDescriptor(); + if (dest_descriptor == src_descriptor) { + dest_message.CopyFrom(*src_message); + return absl::OkStatus(); + } + if (dest_descriptor->full_name() == src_descriptor->full_name()) { + absl::Cord serialized; + if (!src_message->SerializePartialToCord(&serialized)) { + return absl::UnknownError(absl::StrCat("failed to serialize message: ", + src_descriptor->full_name())); + } + if (!dest_message.ParsePartialFromCord(serialized)) { + return absl::UnknownError(absl::StrCat("failed to parse message: ", + dest_descriptor->full_name())); + } + return absl::OkStatus(); + } + } + return TypeConversionError(value.GetRuntimeType(), + MessageType(dest_descriptor)) + .NativeValue(); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ diff --git a/extensions/protobuf/value_end_to_end_test.cc b/extensions/protobuf/value_end_to_end_test.cc new file mode 100644 index 000000000..69a59bc19 --- /dev/null +++ b/extensions/protobuf/value_end_to_end_test.cc @@ -0,0 +1,1087 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Functional tests for protobuf backed CEL structs in the default runtime. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::test::BoolValueIs; +using ::cel::test::BytesValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::DurationValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::IsNullValue; +using ::cel::test::ListValueIs; +using ::cel::test::MapValueIs; +using ::cel::test::StringValueIs; +using ::cel::test::StructValueIs; +using ::cel::test::TimestampValueIs; +using ::cel::test::UintValueIs; +using ::cel::test::ValueMatcher; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::_; +using ::testing::AnyOf; +using ::testing::HasSubstr; +using ::testing::TestWithParam; + +struct TestCase { + std::string name; + std::string expr; + std::string msg_textproto; + ValueMatcher matcher; + + template + friend void AbslStringify(S& sink, const TestCase& tc) { + sink.Append(tc.name); + } +}; + +class ProtobufValueEndToEndTest : public TestWithParam { + public: + ProtobufValueEndToEndTest() = default; + + protected: + const TestCase& test_case() const { return GetParam(); } + + google::protobuf::Arena arena_; +}; + +TEST_P(ProtobufValueEndToEndTest, Runner) { + TestAllTypes message; + + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(test_case().msg_textproto, &message)); + + Activation activation; + activation.InsertOrAssignValue( + "msg", + Value::FromMessage(message, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), &arena_)); + + RuntimeOptions opts; + opts.enable_empty_wrapper_null_unboxing = true; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(test_case().expr)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena_, activation)); + + EXPECT_THAT(result, test_case().matcher); +} + +INSTANTIATE_TEST_SUITE_P( + Singular, ProtobufValueEndToEndTest, + testing::ValuesIn(std::vector{ + {"single_int64", "msg.single_int64", + R"pb( + single_int64: 42 + )pb", + IntValueIs(42)}, + {"single_int64_has", "has(msg.single_int64)", + R"pb( + single_int64: 42 + )pb", + BoolValueIs(true)}, + {"single_int64_has_false", "has(msg.single_int64)", "", + BoolValueIs(false)}, + {"single_int32", "msg.single_int32", + R"pb( + single_int32: 42 + )pb", + IntValueIs(42)}, + {"single_uint64", "msg.single_uint64", + R"pb( + single_uint64: 42 + )pb", + UintValueIs(42)}, + {"single_uint32", "msg.single_uint32", + R"pb( + single_uint32: 42 + )pb", + UintValueIs(42)}, + {"single_sint64", "msg.single_sint64", + R"pb( + single_sint64: 42 + )pb", + IntValueIs(42)}, + {"single_sint32", "msg.single_sint32", + R"pb( + single_sint32: 42 + )pb", + IntValueIs(42)}, + {"single_fixed64", "msg.single_fixed64", + R"pb( + single_fixed64: 42 + )pb", + UintValueIs(42)}, + {"single_fixed32", "msg.single_fixed32", + R"pb( + single_fixed32: 42 + )pb", + UintValueIs(42)}, + {"single_sfixed64", "msg.single_sfixed64", + R"pb( + single_sfixed64: 42 + )pb", + IntValueIs(42)}, + {"single_sfixed32", "msg.single_sfixed32", + R"pb( + single_sfixed32: 42 + )pb", + IntValueIs(42)}, + {"single_float", "msg.single_float", + R"pb( + single_float: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"single_double", "msg.single_double", + R"pb( + single_double: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"single_bool", "msg.single_bool", + R"pb( + single_bool: true + )pb", + BoolValueIs(true)}, + {"single_string", "msg.single_string", + R"pb( + single_string: "Hello 😀" + )pb", + StringValueIs("Hello 😀")}, + {"single_bytes", "msg.single_bytes", + R"pb( + single_bytes: "Hello" + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.single_duration", + R"pb( + single_duration { seconds: 10 } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_duration_default", "msg.single_duration", "", + DurationValueIs(absl::Seconds(0))}, + {"wkt_timestamp", "msg.single_timestamp", + R"pb( + single_timestamp { seconds: 10 } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_timestamp_default", "msg.single_timestamp", "", + TimestampValueIs(absl::UnixEpoch())}, + {"wkt_int64", "msg.single_int64_wrapper", + R"pb( + single_int64_wrapper { value: -20 } + )pb", + IntValueIs(-20)}, + {"wkt_int64_default", "msg.single_int64_wrapper", "", IsNullValue()}, + {"wkt_int32", "msg.single_int32_wrapper", + R"pb( + single_int32_wrapper { value: -10 } + )pb", + IntValueIs(-10)}, + {"wkt_int32_default", "msg.single_int32_wrapper", "", IsNullValue()}, + {"wkt_uint64", "msg.single_uint64_wrapper", + R"pb( + single_uint64_wrapper { value: 10 } + )pb", + UintValueIs(10)}, + {"wkt_uint64_default", "msg.single_uint64_wrapper", "", IsNullValue()}, + {"wkt_uint32", "msg.single_uint32_wrapper", + R"pb( + single_uint32_wrapper { value: 11 } + )pb", + UintValueIs(11)}, + {"wkt_uint32_default", "msg.single_uint32_wrapper", "", IsNullValue()}, + {"wkt_float", "msg.single_float_wrapper", + R"pb( + single_float_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_float_default", "msg.single_float_wrapper", "", IsNullValue()}, + {"wkt_double", "msg.single_double_wrapper", + R"pb( + single_double_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double_default", "msg.single_double_wrapper", "", IsNullValue()}, + {"wkt_bool", "msg.single_bool_wrapper", + R"pb( + single_bool_wrapper { value: false } + )pb", + BoolValueIs(false)}, + {"wkt_bool_default", "msg.single_bool_wrapper", "", IsNullValue()}, + {"wkt_string", "msg.single_string_wrapper", + R"pb( + single_string_wrapper { value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"wkt_string_default", "msg.single_string_wrapper", "", IsNullValue()}, + {"wkt_bytes", "msg.single_bytes_wrapper", + R"pb( + single_bytes_wrapper { value: "abcd" } + )pb", + BytesValueIs("abcd")}, + {"wkt_bytes_default", "msg.single_bytes_wrapper", "", IsNullValue()}, + {"wkt_null", "msg.null_value", + R"pb( + null_value: NULL_VALUE + )pb", + IsNullValue()}, + {"message_field", "msg.standalone_message", + R"pb( + standalone_message { bb: 2 } + )pb", + StructValueIs(_)}, + {"message_field_has", "has(msg.standalone_message)", + R"pb( + standalone_message { bb: 2 } + )pb", + BoolValueIs(true)}, + {"message_field_has_false", "has(msg.standalone_message)", "", + BoolValueIs(false)}, + {"single_enum", "msg.standalone_enum", + R"pb( + standalone_enum: BAR + )pb", + // BAR + IntValueIs(1)}})); + +INSTANTIATE_TEST_SUITE_P( + Repeated, ProtobufValueEndToEndTest, + testing::ValuesIn(std::vector{ + {"repeated_int64", "msg.repeated_int64[0]", + R"pb( + repeated_int64: 42 + )pb", + IntValueIs(42)}, + {"repeated_int64_has", "has(msg.repeated_int64)", + R"pb( + repeated_int64: 42 + )pb", + BoolValueIs(true)}, + {"repeated_int64_has_false", "has(msg.repeated_int64)", "", + BoolValueIs(false)}, + {"repeated_int32", "msg.repeated_int32[0]", + R"pb( + repeated_int32: 42 + )pb", + IntValueIs(42)}, + {"repeated_uint64", "msg.repeated_uint64[0]", + R"pb( + repeated_uint64: 42 + )pb", + UintValueIs(42)}, + {"repeated_uint32", "msg.repeated_uint32[0]", + R"pb( + repeated_uint32: 42 + )pb", + UintValueIs(42)}, + {"repeated_sint64", "msg.repeated_sint64[0]", + R"pb( + repeated_sint64: 42 + )pb", + IntValueIs(42)}, + {"repeated_sint32", "msg.repeated_sint32[0]", + R"pb( + repeated_sint32: 42 + )pb", + IntValueIs(42)}, + {"repeated_fixed64", "msg.repeated_fixed64[0]", + R"pb( + repeated_fixed64: 42 + )pb", + UintValueIs(42)}, + {"repeated_fixed32", "msg.repeated_fixed32[0]", + R"pb( + repeated_fixed32: 42 + )pb", + UintValueIs(42)}, + {"repeated_sfixed64", "msg.repeated_sfixed64[0]", + R"pb( + repeated_sfixed64: 42 + )pb", + IntValueIs(42)}, + {"repeated_sfixed32", "msg.repeated_sfixed32[0]", + R"pb( + repeated_sfixed32: 42 + )pb", + IntValueIs(42)}, + {"repeated_float", "msg.repeated_float[0]", + R"pb( + repeated_float: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"repeated_double", "msg.repeated_double[0]", + R"pb( + repeated_double: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"repeated_bool", "msg.repeated_bool[0]", + R"pb( + repeated_bool: true + )pb", + BoolValueIs(true)}, + {"repeated_string", "msg.repeated_string[0]", + R"pb( + repeated_string: "Hello 😀" + )pb", + StringValueIs("Hello 😀")}, + {"repeated_bytes", "msg.repeated_bytes[0]", + R"pb( + repeated_bytes: "Hello" + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.repeated_duration[0]", + R"pb( + repeated_duration { seconds: 10 } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_timestamp", "msg.repeated_timestamp[0]", + R"pb( + repeated_timestamp { seconds: 10 } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_int64", "msg.repeated_int64_wrapper[0]", + R"pb( + repeated_int64_wrapper { value: -20 } + )pb", + IntValueIs(-20)}, + {"wkt_int32", "msg.repeated_int32_wrapper[0]", + R"pb( + repeated_int32_wrapper { value: -10 } + )pb", + IntValueIs(-10)}, + {"wkt_uint64", "msg.repeated_uint64_wrapper[0]", + R"pb( + repeated_uint64_wrapper { value: 10 } + )pb", + UintValueIs(10)}, + {"wkt_uint32", "msg.repeated_uint32_wrapper[0]", + R"pb( + repeated_uint32_wrapper { value: 11 } + )pb", + UintValueIs(11)}, + {"wkt_float", "msg.repeated_float_wrapper[0]", + R"pb( + repeated_float_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double", "msg.repeated_double_wrapper[0]", + R"pb( + repeated_double_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_bool", "msg.repeated_bool_wrapper[0]", + R"pb( + + repeated_bool_wrapper { value: false } + )pb", + BoolValueIs(false)}, + {"wkt_string", "msg.repeated_string_wrapper[0]", + R"pb( + repeated_string_wrapper { value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"wkt_bytes", "msg.repeated_bytes_wrapper[0]", + R"pb( + repeated_bytes_wrapper { value: "abcd" } + )pb", + BytesValueIs("abcd")}, + {"wkt_null", "msg.repeated_null_value[0]", + R"pb( + repeated_null_value: NULL_VALUE + )pb", + IsNullValue()}, + {"message_field", "msg.repeated_nested_message[0]", + R"pb( + repeated_nested_message { bb: 42 } + )pb", + StructValueIs(_)}, + {"repeated_enum", "msg.repeated_nested_enum[0]", + R"pb( + repeated_nested_enum: BAR + )pb", + // BAR + IntValueIs(1)}, + // Implements CEL list interface + {"repeated_size", "msg.repeated_int64.size()", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + IntValueIs(2)}, + {"in_repeated", "42 in msg.repeated_int64", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(true)}, + {"in_repeated_false", "44 in msg.repeated_int64", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(false)}, + {"repeated_compre_exists", "msg.repeated_int64.exists(x, x > 42)", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(true)}, + {"repeated_compre_map", "msg.repeated_int64.map(x, x * 2)[0]", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + IntValueIs(84)}, + })); + +INSTANTIATE_TEST_SUITE_P( + Maps, ProtobufValueEndToEndTest, + testing::ValuesIn(std::vector{ + {"map_bool_int64", "msg.map_bool_int64[false]", + R"pb( + map_bool_int64 { key: false value: 42 } + )pb", + IntValueIs(42)}, + {"map_bool_int64_has", "has(msg.map_bool_int64)", + R"pb( + map_bool_int64 { key: false value: 42 } + )pb", + BoolValueIs(true)}, + {"map_bool_int64_has_false", "has(msg.map_bool_int64)", "", + BoolValueIs(false)}, + {"map_bool_int32", "msg.map_bool_int32[false]", + R"pb( + map_bool_int32 { key: false value: 42 } + )pb", + IntValueIs(42)}, + {"map_bool_uint64", "msg.map_bool_uint64[false]", + R"pb( + map_bool_uint64 { key: false value: 42 } + )pb", + UintValueIs(42)}, + {"map_bool_uint32", "msg.map_bool_uint32[false]", + R"pb( + map_bool_uint32 { key: false, value: 42 } + )pb", + UintValueIs(42)}, + {"map_bool_float", "msg.map_bool_float[false]", + R"pb( + map_bool_float { key: false value: 4.25 } + )pb", + DoubleValueIs(4.25)}, + {"map_bool_double", "msg.map_bool_double[false]", + R"pb( + map_bool_double { key: false value: 4.25 } + )pb", + DoubleValueIs(4.25)}, + {"map_bool_bool", "msg.map_bool_bool[false]", + R"pb( + map_bool_bool { key: false value: true } + )pb", + BoolValueIs(true)}, + {"map_bool_string", "msg.map_bool_string[false]", + R"pb( + map_bool_string { key: false value: "Hello 😀" } + )pb", + StringValueIs("Hello 😀")}, + {"map_bool_bytes", "msg.map_bool_bytes[false]", + R"pb( + map_bool_bytes { key: false value: "Hello" } + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.map_bool_duration[false]", + R"pb( + map_bool_duration { + key: false + value { seconds: 10 } + } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_timestamp", "msg.map_bool_timestamp[false]", + R"pb( + map_bool_timestamp { + key: false + value { seconds: 10 } + } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_int64", "msg.map_bool_int64_wrapper[false]", + R"pb( + map_bool_int64_wrapper { + key: false + value { value: -20 } + } + )pb", + IntValueIs(-20)}, + {"wkt_int32", "msg.map_bool_int32_wrapper[false]", + R"pb( + map_bool_int32_wrapper { + key: false + value { value: -10 } + } + )pb", + IntValueIs(-10)}, + {"wkt_uint64", "msg.map_bool_uint64_wrapper[false]", + R"pb( + map_bool_uint64_wrapper { + key: false + value { value: 10 } + } + )pb", + UintValueIs(10)}, + {"wkt_uint32", "msg.map_bool_uint32_wrapper[false]", + R"pb( + map_bool_uint32_wrapper { + key: false + value { value: 11 } + } + )pb", + UintValueIs(11)}, + {"wkt_float", "msg.map_bool_float_wrapper[false]", + R"pb( + map_bool_float_wrapper { + key: false + value { value: 10.25 } + } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double", "msg.map_bool_double_wrapper[false]", + R"pb( + map_bool_double_wrapper { + key: false + value { value: 10.25 } + } + )pb", + DoubleValueIs(10.25)}, + {"wkt_bool", "msg.map_bool_bool_wrapper[false]", + R"pb( + map_bool_bool_wrapper { + key: false + value { value: false } + } + )pb", + BoolValueIs(false)}, + {"wkt_string", "msg.map_bool_string_wrapper[false]", + R"pb( + map_bool_string_wrapper { + key: false + value { value: "abcd" } + } + )pb", + StringValueIs("abcd")}, + {"wkt_bytes", "msg.map_bool_bytes_wrapper[false]", + R"pb( + map_bool_bytes_wrapper { + key: false + value { value: "abcd" } + } + )pb", + BytesValueIs("abcd")}, + {"wkt_null", "msg.map_bool_null_value[false]", + R"pb( + map_bool_null_value { key: false value: NULL_VALUE } + )pb", + IsNullValue()}, + {"message_field", "msg.map_bool_message[false]", + R"pb( + map_bool_message { + key: false + value { bb: 42 } + } + )pb", + StructValueIs(_)}, + {"map_bool_enum", "msg.map_bool_enum[false]", + R"pb( + map_bool_enum { key: false value: BAR } + )pb", + // BAR + IntValueIs(1)}, + // Simplified for remaining key types. + {"map_int32_int64", "msg.map_int32_int64[42]", + R"pb( + map_int32_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_int64_int64", "msg.map_int64_int64[42]", + R"pb( + map_int64_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_uint32_int64", "msg.map_uint32_int64[42u]", + R"pb( + map_uint32_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_uint64_int64", "msg.map_uint64_int64[42u]", + R"pb( + map_uint64_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_string_int64", "msg.map_string_int64['key1']", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + IntValueIs(-42)}, + // Implements CEL map + {"in_map_int64_true", "42 in msg.map_int64_int64", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(true)}, + {"in_map_int64_false", "44 in msg.map_int64_int64", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(false)}, + {"int_map_int64_compre_exists", + "msg.map_int64_int64.exists(key, key > 42)", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(true)}, + {"int_map_int64_compre_map", + "msg.map_int64_int64.map(key, key + 20)[0]", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + + IntValueIs(AnyOf(62, 63))}, + {"map_string_key_not_found", "msg.map_string_int64['key2']", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found in map")))}, + {"map_string_select_key", "msg.map_string_int64.key1", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + IntValueIs(-42)}, + {"map_string_has_key", "has(msg.map_string_int64.key1)", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + BoolValueIs(true)}, + {"map_string_has_key_false", "has(msg.map_string_int64.key2)", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + BoolValueIs(false)}, + {"map_int32_out_of_range", "msg.map_int32_int64[0x1FFFFFFFF]", + R"pb( + map_int32_int64 { key: 10 value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found in map")))}, + {"map_uint32_out_of_range", "msg.map_uint32_int64[0x1FFFFFFFFu]", + R"pb( + map_uint32_int64 { key: 10 value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found in map")))}})); + +MATCHER_P(CelSizeIs, size, "") { + auto s = arg.Size(); + return s.ok() && *s == size; +} + +INSTANTIATE_TEST_SUITE_P( + JsonWrappers, ProtobufValueEndToEndTest, + testing::ValuesIn(std::vector{ + {"single_struct", "msg.single_struct", + R"pb( + single_struct { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_struct_null_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + )pb", + IsNullValue()}, + {"single_struct_number_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { number_value: 10.25 } + } + } + )pb", + DoubleValueIs(10.25)}, + {"single_struct_bool_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_string_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { string_value: "abcd" } + } + } + )pb", + StringValueIs("abcd")}, + {"single_struct_struct_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { + struct_value { + fields { + key: "field2", + value: { null_value: NULL_VALUE } + } + } + } + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_struct_list_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { list_value { values { null_value: NULL_VALUE } } } + } + } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_struct_select_field", "msg.single_struct.field1", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_has_field", "has(msg.single_struct.field1)", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_has_field_false", "has(msg.single_struct.field2)", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(false)}, + {"single_struct_map_size", "msg.single_struct.size()", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + IntValueIs(2)}, + {"single_struct_map_in", "'field2' in msg.single_struct", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_map_compre_exists", + "msg.single_struct.exists(key, key == 'field2')", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_map_compre_map", + "'__field1' in msg.single_struct.map(key, '__' + key)", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + BoolValueIs(true)}, + {"single_list_value", "msg.list_value", + R"pb( + list_value { values { null_value: NULL_VALUE } } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_list_value_index_null", "msg.list_value[0]", + R"pb( + list_value { values { null_value: NULL_VALUE } } + )pb", + IsNullValue()}, + {"single_list_value_index_number", "msg.list_value[0]", + R"pb( + list_value { values { number_value: 10.25 } } + )pb", + DoubleValueIs(10.25)}, + {"single_list_value_index_string", "msg.list_value[0]", + R"pb( + list_value { values { string_value: "abc" } } + )pb", + StringValueIs("abc")}, + {"single_list_value_index_bool", "msg.list_value[0]", + R"pb( + list_value { values { bool_value: false } } + )pb", + BoolValueIs(false)}, + {"single_list_value_list_size", "msg.list_value.size()", + R"pb( + list_value { + values { bool_value: false } + values { bool_value: false } + } + )pb", + IntValueIs(2)}, + {"single_list_value_list_in", "10.25 in msg.list_value", + R"pb( + list_value { + values { number_value: 10.0 } + values { number_value: 10.25 } + } + )pb", + BoolValueIs(true)}, + {"single_list_value_list_compre_exists", + "msg.list_value.exists(x, x == 10.25)", + R"pb( + list_value { + values { number_value: 10.0 } + values { number_value: 10.25 } + } + )pb", + BoolValueIs(true)}, + {"single_list_value_list_compre_map", + "msg.list_value.map(x, x + 0.5)[1]", + R"pb( + list_value { + values { number_value: 10.0 } + values { number_value: 10.25 } + } + )pb", + DoubleValueIs(10.75)}, + {"single_list_value_index_struct", "msg.list_value[0]", + R"pb( + list_value { + values { + struct_value { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_list_value_index_list", "msg.list_value[0]", + R"pb( + list_value { + values { list_value { values { null_value: NULL_VALUE } } } + } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_json_value_null", "msg.single_value", + R"pb( + single_value { null_value: NULL_VALUE } + )pb", + IsNullValue()}, + {"single_json_value_number", "msg.single_value", + R"pb( + single_value { number_value: 13.25 } + )pb", + DoubleValueIs(13.25)}, + {"single_json_value_string", "msg.single_value", + R"pb( + single_value { string_value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"single_json_value_bool", "msg.single_value", + R"pb( + single_value { bool_value: false } + )pb", + BoolValueIs(false)}, + {"single_json_value_struct", "msg.single_value", + R"pb( + single_value { struct_value {} } + )pb", + MapValueIs(CelSizeIs(0))}, + {"single_json_value_list", "msg.single_value", + R"pb( + single_value { list_value {} } + )pb", + ListValueIs(CelSizeIs(0))}, + })); + +// TODO(uncreated-issue/66): any support needs the reflection impl for looking up the +// type name and corresponding deserializer (outside of the WKTs which are +// special cased). +INSTANTIATE_TEST_SUITE_P( + Any, ProtobufValueEndToEndTest, + testing::ValuesIn(std::vector{ + {"single_any_wkt_int64", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int64Value] { value: 42 } + } + )pb", + IntValueIs(42)}, + {"single_any_wkt_int32", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int32Value] { value: 42 } + } + )pb", + IntValueIs(42)}, + {"single_any_wkt_uint64", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.UInt64Value] { value: 42 } + } + )pb", + UintValueIs(42)}, + {"single_any_wkt_uint32", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.UInt32Value] { value: 42 } + } + )pb", + UintValueIs(42)}, + {"single_any_wkt_double", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.DoubleValue] { value: 30.5 } + } + )pb", + DoubleValueIs(30.5)}, + {"single_any_wkt_string", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.StringValue] { value: "abcd" } + } + )pb", + StringValueIs("abcd")}, + + {"repeated_any_wkt_string", "msg.repeated_any[0]", + R"pb( + repeated_any { + [type.googleapis.com/google.protobuf.StringValue] { value: "abcd" } + } + )pb", + StringValueIs("abcd")}, + {"map_int64_any_wkt_string", "msg.map_int64_any[0]", + R"pb( + map_int64_any { + key: 0 + value { + [type.googleapis.com/google.protobuf.StringValue] { + value: "abcd" + } + } + } + )pb", + StringValueIs("abcd")}, + })); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/protobuf/value_test.cc b/extensions/protobuf/value_test.cc new file mode 100644 index 000000000..20d9dce2f --- /dev/null +++ b/extensions/protobuf/value_test.cc @@ -0,0 +1,800 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/value.h" + +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "base/attribute.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "google/protobuf/text_format.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto2::TestAllTypes; +using ::cel::test::BoolValueIs; +using ::cel::test::BytesValueIs; +using ::cel::test::DoubleValueIs; +using ::cel::test::DurationValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::ListValueIs; +using ::cel::test::MapValueIs; +using ::cel::test::StringValueIs; +using ::cel::test::StructValueFieldHas; +using ::cel::test::StructValueFieldIs; +using ::cel::test::StructValueIs; +using ::cel::test::TimestampValueIs; +using ::cel::test::UintValueIs; +using ::cel::test::ValueKindIs; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsTrue; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +template +T ParseTextOrDie(absl::string_view text) { + T proto; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text, &proto)); + return proto; +} + +using ProtoValueTest = common_internal::ValueTest<>; + +class ProtoValueWrapTest : public ProtoValueTest {}; + +TEST_F(ProtoValueWrapTest, ProtoBoolValueToValue) { + google::protobuf::BoolValue message; + message.set_value(true); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BoolValueIs(Eq(true)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(Eq(true)))); +} + +TEST_F(ProtoValueWrapTest, ProtoInt32ValueToValue) { + google::protobuf::Int32Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(IntValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(Eq(1)))); +} + +TEST_F(ProtoValueWrapTest, ProtoInt64ValueToValue) { + google::protobuf::Int64Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(IntValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(Eq(1)))); +} + +TEST_F(ProtoValueWrapTest, ProtoUInt32ValueToValue) { + google::protobuf::UInt32Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(UintValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(UintValueIs(Eq(1)))); +} + +TEST_F(ProtoValueWrapTest, ProtoUInt64ValueToValue) { + google::protobuf::UInt64Value message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(UintValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(UintValueIs(Eq(1)))); +} + +TEST_F(ProtoValueWrapTest, ProtoFloatValueToValue) { + google::protobuf::FloatValue message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(DoubleValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(DoubleValueIs(Eq(1)))); +} + +TEST_F(ProtoValueWrapTest, ProtoDoubleValueToValue) { + google::protobuf::DoubleValue message; + message.set_value(1); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(DoubleValueIs(Eq(1)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(DoubleValueIs(Eq(1)))); +} + +TEST_F(ProtoValueWrapTest, ProtoBytesValueToValue) { + google::protobuf::BytesValue message; + message.set_value("foo"); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(BytesValueIs(Eq("foo")))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BytesValueIs(Eq("foo")))); +} + +TEST_F(ProtoValueWrapTest, ProtoStringValueToValue) { + google::protobuf::StringValue message; + message.set_value("foo"); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(StringValueIs(Eq("foo")))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(StringValueIs(Eq("foo")))); +} + +TEST_F(ProtoValueWrapTest, ProtoDurationToValue) { + google::protobuf::Duration message; + message.set_seconds(1); + message.set_nanos(1); + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(DurationValueIs( + Eq(absl::Seconds(1) + absl::Nanoseconds(1))))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(DurationValueIs( + Eq(absl::Seconds(1) + absl::Nanoseconds(1))))); +} + +TEST_F(ProtoValueWrapTest, ProtoTimestampToValue) { + google::protobuf::Timestamp message; + message.set_seconds(1); + message.set_nanos(1); + EXPECT_THAT( + ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(TimestampValueIs( + Eq(absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1))))); + EXPECT_THAT( + ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(TimestampValueIs( + Eq(absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1))))); +} + +TEST_F(ProtoValueWrapTest, ProtoMessageToValue) { + TestAllTypes message; + EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(ValueKindIs(Eq(ValueKind::kStruct)))); + EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(ValueKindIs(Eq(ValueKind::kStruct)))); +} + +TEST_F(ProtoValueWrapTest, GetFieldByName) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 1, + single_int64: 1 + single_uint32: 1 + single_uint64: 1 + single_float: 1 + single_double: 1 + single_bool: true + single_string: "foo" + single_bytes: "foo")pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(StructValueFieldIs( + "single_int32", IntValueIs(Eq(1)), descriptor_pool(), + message_factory(), arena()))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_int32", IsTrue()))); + EXPECT_THAT(value, StructValueIs(StructValueFieldIs( + "single_int64", IntValueIs(Eq(1)), descriptor_pool(), + message_factory(), arena()))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_int64", IsTrue()))); + EXPECT_THAT(value, StructValueIs(StructValueFieldIs( + "single_uint32", UintValueIs(Eq(1)), descriptor_pool(), + message_factory(), arena()))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_uint32", IsTrue()))); + EXPECT_THAT(value, StructValueIs(StructValueFieldIs( + "single_uint64", UintValueIs(Eq(1)), descriptor_pool(), + message_factory(), arena()))); + EXPECT_THAT(value, + StructValueIs(StructValueFieldHas("single_uint64", IsTrue()))); +} + +TEST_F(ProtoValueWrapTest, GetFieldNoSuchField) { + ASSERT_OK_AND_ASSIGN( + auto value, ProtoMessageToValue( + ParseTextOrDie(R"pb(single_int32: 1)pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_THAT(value, StructValueIs(_)); + + StructValue struct_value = Cast(value); + EXPECT_THAT(struct_value.GetFieldByName("does_not_exist", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("no_such_field"))))); +} + +TEST_F(ProtoValueWrapTest, GetFieldByNumber) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 1, + single_int64: 2 + single_uint32: 3 + single_uint64: 4 + single_float: 1.25 + single_double: 1.5 + single_bool: true + single_string: "foo" + single_bytes: "foo")pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(_)); + StructValue struct_value = Cast(value); + + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleInt32FieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(1))); + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleInt64FieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(IntValueIs(2))); + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleUint32FieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(UintValueIs(3))); + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleUint64FieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(UintValueIs(4))); + + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleFloatFieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(DoubleValueIs(1.25))); + + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleDoubleFieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(DoubleValueIs(1.5))); + + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleBoolFieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleStringFieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(StringValueIs("foo"))); + + EXPECT_THAT(struct_value.GetFieldByNumber( + TestAllTypes::kSingleBytesFieldNumber, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(BytesValueIs("foo"))); +} + +TEST_F(ProtoValueWrapTest, GetFieldByNumberNoSuchField) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 1, + single_int64: 2 + single_uint32: 3 + single_uint64: 4 + single_float: 1.25 + single_double: 1.5 + single_bool: true + single_string: "foo" + single_bytes: "foo")pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(_)); + StructValue struct_value = Cast(value); + + EXPECT_THAT(struct_value.GetFieldByNumber(999, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("no_such_field"))))); + + // Out of range. + EXPECT_THAT(struct_value.GetFieldByNumber(0x1ffffffff, descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("no_such_field"))))); +} + +TEST_F(ProtoValueWrapTest, HasFieldByNumber) { + ASSERT_OK_AND_ASSIGN( + auto value, ProtoMessageToValue( + ParseTextOrDie(R"pb(single_int32: 1, + single_int64: 2)pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(_)); + StructValue struct_value = Cast(value); + + EXPECT_THAT( + struct_value.HasFieldByNumber(TestAllTypes::kSingleInt32FieldNumber), + IsOkAndHolds(BoolValue(true))); + EXPECT_THAT( + struct_value.HasFieldByNumber(TestAllTypes::kSingleInt64FieldNumber), + IsOkAndHolds(BoolValue(true))); + EXPECT_THAT( + struct_value.HasFieldByNumber(TestAllTypes::kSingleStringFieldNumber), + IsOkAndHolds(BoolValue(false))); + EXPECT_THAT( + struct_value.HasFieldByNumber(TestAllTypes::kSingleBytesFieldNumber), + IsOkAndHolds(BoolValue(false))); +} + +TEST_F(ProtoValueWrapTest, HasFieldByNumberNoSuchField) { + ASSERT_OK_AND_ASSIGN( + auto value, ProtoMessageToValue( + ParseTextOrDie(R"pb(single_int32: 1, + single_int64: 2)pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(_)); + StructValue struct_value = Cast(value); + + // Has returns a status directly instead of a CEL error as in Get. + EXPECT_THAT( + struct_value.HasFieldByNumber(999), + StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))); + EXPECT_THAT( + struct_value.HasFieldByNumber(0x1ffffffff), + StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))); +} + +TEST_F(ProtoValueWrapTest, ProtoMessageEqual) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 1, single_int64: 2 + )pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto value2, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 1, single_int64: 2 + )pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value.Equal(value, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); + EXPECT_THAT( + value2.Equal(value, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(true))); +} + +TEST_F(ProtoValueWrapTest, ProtoMessageEqualFalse) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 1, single_int64: 2 + )pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto value2, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 2, single_int64: 1 + )pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT( + value2.Equal(value, descriptor_pool(), message_factory(), arena()), + IsOkAndHolds(BoolValueIs(false))); +} + +TEST_F(ProtoValueWrapTest, ProtoMessageForEachField) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb(single_int32: 1, single_int64: 2 + )pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(_)); + StructValue struct_value = Cast(value); + + std::vector fields; + auto cb = [&fields](absl::string_view field, + const Value&) -> absl::StatusOr { + fields.push_back(std::string(field)); + return true; + }; + ASSERT_THAT(struct_value.ForEachField(cb, descriptor_pool(), + message_factory(), arena()), + IsOk()); + EXPECT_THAT(fields, UnorderedElementsAre("single_int32", "single_int64")); +} + +TEST_F(ProtoValueWrapTest, ProtoMessageQualify) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + standalone_message { bb: 42 } + )pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(_)); + StructValue struct_value = Cast(value); + + std::vector qualifiers{ + FieldSpecifier{TestAllTypes::kStandaloneMessageFieldNumber, + "standalone_message"}, + FieldSpecifier{TestAllTypes::NestedMessage::kBbFieldNumber, "bb"}}; + + Value scratch; + int count; + EXPECT_THAT( + struct_value.Qualify(qualifiers, + /*presence_test=*/false, descriptor_pool(), + message_factory(), arena(), &scratch, &count), + IsOk()); + + EXPECT_THAT(scratch, IntValueIs(42)); +} + +TEST_F(ProtoValueWrapTest, ProtoMessageQualifyHas) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + standalone_message { bb: 42 } + )pb"), + descriptor_pool(), message_factory(), arena())); + EXPECT_THAT(value, StructValueIs(_)); + StructValue struct_value = Cast(value); + + std::vector qualifiers{ + FieldSpecifier{TestAllTypes::kStandaloneMessageFieldNumber, + "standalone_message"}, + FieldSpecifier{TestAllTypes::NestedMessage::kBbFieldNumber, "bb"}}; + + Value scratch; + int count; + EXPECT_THAT( + struct_value.Qualify(qualifiers, + /*presence_test=*/true, descriptor_pool(), + message_factory(), arena(), &scratch, &count), + IsOk()); + + EXPECT_THAT(scratch, BoolValueIs(true)); +} + +TEST_F(ProtoValueWrapTest, ProtoInt64MapListKeys) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + map_int64_int64 { key: 10 value: 20 })pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN(auto map_value, Cast(value).GetFieldByName( + "map_int64_int64", descriptor_pool(), + message_factory(), arena())); + + ASSERT_THAT(map_value, MapValueIs(_)); + + ASSERT_OK_AND_ASSIGN(ListValue key_set, + Cast(map_value).ListKeys( + descriptor_pool(), message_factory(), arena())); + + EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); + + ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), + message_factory(), arena())); + + EXPECT_THAT(key0, IntValueIs(10)); +} + +TEST_F(ProtoValueWrapTest, ProtoInt32MapListKeys) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + map_int32_int64 { key: 10 value: 20 })pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN(auto map_value, Cast(value).GetFieldByName( + "map_int32_int64", descriptor_pool(), + message_factory(), arena())); + + ASSERT_THAT(map_value, MapValueIs(_)); + + ASSERT_OK_AND_ASSIGN(ListValue key_set, + Cast(map_value).ListKeys( + descriptor_pool(), message_factory(), arena())); + + EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); + + ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), + message_factory(), arena())); + + EXPECT_THAT(key0, IntValueIs(10)); +} + +TEST_F(ProtoValueWrapTest, ProtoBoolMapListKeys) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + map_bool_int64 { key: false value: 20 })pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN(auto map_value, Cast(value).GetFieldByName( + "map_bool_int64", descriptor_pool(), + message_factory(), arena())); + + ASSERT_THAT(map_value, MapValueIs(_)); + + ASSERT_OK_AND_ASSIGN(ListValue key_set, + Cast(map_value).ListKeys( + descriptor_pool(), message_factory(), arena())); + + EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); + + ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), + message_factory(), arena())); + + EXPECT_THAT(key0, BoolValueIs(false)); +} + +TEST_F(ProtoValueWrapTest, ProtoUint32MapListKeys) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + map_uint32_int64 { key: 11 value: 20 })pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto map_value, + Cast(value).GetFieldByName( + "map_uint32_int64", descriptor_pool(), message_factory(), arena())); + + ASSERT_THAT(map_value, MapValueIs(_)); + + ASSERT_OK_AND_ASSIGN(ListValue key_set, + Cast(map_value).ListKeys( + descriptor_pool(), message_factory(), arena())); + + EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); + + ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), + message_factory(), arena())); + + EXPECT_THAT(key0, UintValueIs(11)); +} + +TEST_F(ProtoValueWrapTest, ProtoUint64MapListKeys) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + map_uint64_int64 { key: 11 value: 20 })pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto map_value, + Cast(value).GetFieldByName( + "map_uint64_int64", descriptor_pool(), message_factory(), arena())); + + ASSERT_THAT(map_value, MapValueIs(_)); + + ASSERT_OK_AND_ASSIGN(ListValue key_set, + Cast(map_value).ListKeys( + descriptor_pool(), message_factory(), arena())); + + EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); + + ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), + message_factory(), arena())); + + EXPECT_THAT(key0, UintValueIs(11)); +} + +TEST_F(ProtoValueWrapTest, ProtoStringMapListKeys) { + ASSERT_OK_AND_ASSIGN( + auto value, ProtoMessageToValue( + + ParseTextOrDie( + R"pb( + map_string_int64 { key: "key1" value: 20 })pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto map_value, + Cast(value).GetFieldByName( + "map_string_int64", descriptor_pool(), message_factory(), arena())); + + ASSERT_THAT(map_value, MapValueIs(_)); + + ASSERT_OK_AND_ASSIGN(ListValue key_set, + Cast(map_value).ListKeys( + descriptor_pool(), message_factory(), arena())); + + EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); + + ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), + message_factory(), arena())); + + EXPECT_THAT(key0, StringValueIs("key1")); +} + +TEST_F(ProtoValueWrapTest, ProtoMapIterator) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + map_int64_int64 { key: 10 value: 20 } + map_int64_int64 { key: 12 value: 24 } + )pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto field_value, + Cast(value).GetFieldByName( + "map_int64_int64", descriptor_pool(), message_factory(), arena())); + + ASSERT_THAT(field_value, MapValueIs(_)); + + MapValue map_value = Cast(field_value); + + std::vector keys; + + ASSERT_OK_AND_ASSIGN(auto iter, map_value.NewIterator()); + + while (iter->HasNext()) { + ASSERT_OK_AND_ASSIGN( + keys.emplace_back(), + iter->Next(descriptor_pool(), message_factory(), arena())); + } + + EXPECT_THAT(keys, UnorderedElementsAre(IntValueIs(10), IntValueIs(12))); +} + +TEST_F(ProtoValueWrapTest, ProtoMapForEach) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + map_int64_int64 { key: 10 value: 20 } + map_int64_int64 { key: 12 value: 24 } + )pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto field_value, + Cast(value).GetFieldByName( + "map_int64_int64", descriptor_pool(), message_factory(), arena())); + + ASSERT_THAT(field_value, MapValueIs(_)); + + MapValue map_value = Cast(field_value); + + std::vector> pairs; + + auto cb = [&pairs](const Value& key, + const Value& value) -> absl::StatusOr { + pairs.push_back(std::pair(key, value)); + return true; + }; + ASSERT_THAT( + map_value.ForEach(cb, descriptor_pool(), message_factory(), arena()), + IsOk()); + + EXPECT_THAT(pairs, + UnorderedElementsAre(Pair(IntValueIs(10), IntValueIs(20)), + Pair(IntValueIs(12), IntValueIs(24)))); +} + +TEST_F(ProtoValueWrapTest, ProtoListIterator) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + repeated_int64: 1 repeated_int64: 2 + )pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto field_value, + Cast(value).GetFieldByName( + "repeated_int64", descriptor_pool(), message_factory(), arena())); + + ASSERT_THAT(field_value, ListValueIs(_)); + + ListValue list_value = Cast(field_value); + + std::vector elements; + + ASSERT_OK_AND_ASSIGN(auto iter, list_value.NewIterator()); + + while (iter->HasNext()) { + ASSERT_OK_AND_ASSIGN( + elements.emplace_back(), + iter->Next(descriptor_pool(), message_factory(), arena())); + } + + EXPECT_THAT(elements, ElementsAre(IntValueIs(1), IntValueIs(2))); +} + +TEST_F(ProtoValueWrapTest, ProtoListForEachWithIndex) { + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoMessageToValue(ParseTextOrDie( + R"pb( + repeated_int64: 1 repeated_int64: 2 + )pb"), + descriptor_pool(), message_factory(), arena())); + ASSERT_OK_AND_ASSIGN( + auto field_value, + Cast(value).GetFieldByName( + "repeated_int64", descriptor_pool(), message_factory(), arena())); + + ASSERT_THAT(field_value, ListValueIs(_)); + + ListValue list_value = Cast(field_value); + + std::vector> elements; + + auto cb = [&elements](size_t index, + const Value& value) -> absl::StatusOr { + elements.push_back(std::pair(index, value)); + return true; + }; + + ASSERT_THAT( + list_value.ForEach(cb, descriptor_pool(), message_factory(), arena()), + IsOk()); + + EXPECT_THAT(elements, + ElementsAre(Pair(0, IntValueIs(1)), Pair(1, IntValueIs(2)))); +} + +} // namespace +} // namespace cel::extensions diff --git a/extensions/protobuf/value_testing.h b/extensions/protobuf/value_testing.h new file mode 100644 index 000000000..bf1dbb95f --- /dev/null +++ b/extensions/protobuf/value_testing.h @@ -0,0 +1,78 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_TESTING_H_ + +#include +#include + +#include "absl/status/status.h" +#include "common/value.h" +#include "extensions/protobuf/value.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" + +namespace cel::extensions::test { + +template +class StructValueAsProtoMatcher { + public: + using is_gtest_matcher = void; + + explicit StructValueAsProtoMatcher(testing::Matcher&& m) + : m_(std::move(m)) {} + + bool MatchAndExplain(cel::Value v, + testing::MatchResultListener* result_listener) const { + MessageType msg; + absl::Status s = ProtoMessageFromValue(v, msg); + if (!s.ok()) { + *result_listener << "cannot convert to " + << MessageType::descriptor()->full_name() << ": " << s; + return false; + } + return m_.MatchAndExplain(msg, result_listener); + } + + void DescribeTo(std::ostream* os) const { + *os << "matches proto message " << m_; + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not match proto message " << m_; + } + + private: + testing::Matcher m_; +}; + +// Returns a matcher that matches a cel::Value against a proto message. +// +// Example usage: +// +// EXPECT_THAT(value, StructValueAsProto(EqualsProto(R"pb( +// single_int32: 1 +// single_string: "foo" +// )pb"))); +template +inline StructValueAsProtoMatcher StructValueAsProto( + testing::Matcher&& m) { + static_assert(std::is_base_of_v); + return StructValueAsProtoMatcher(std::move(m)); +} + +} // namespace cel::extensions::test + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_TESTING_H_ diff --git a/extensions/protobuf/value_testing_test.cc b/extensions/protobuf/value_testing_test.cc new file mode 100644 index 000000000..d84930349 --- /dev/null +++ b/extensions/protobuf/value_testing_test.cc @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/value_testing.h" + +#include "common/value.h" +#include "common/value_testing.h" +#include "extensions/protobuf/value.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" + +namespace cel::extensions::test { +namespace { + +using ::cel::expr::conformance::proto2::TestAllTypes; +using ::cel::extensions::ProtoMessageToValue; +using ::cel::internal::test::EqualsProto; + +using ProtoValueTestingTest = common_internal::ValueTest<>; + +TEST_F(ProtoValueTestingTest, StructValueAsProtoSimple) { + TestAllTypes test_proto; + test_proto.set_single_int32(42); + test_proto.set_single_string("foo"); + + ASSERT_OK_AND_ASSIGN(cel::Value v, + ProtoMessageToValue(test_proto, descriptor_pool(), + message_factory(), arena())); + EXPECT_THAT(v, StructValueAsProto(EqualsProto(R"pb( + single_int32: 42 + single_string: "foo" + )pb"))); +} + +} // namespace +} // namespace cel::extensions::test diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc new file mode 100644 index 000000000..9c06d90c2 --- /dev/null +++ b/extensions/regex_ext.cc @@ -0,0 +1,352 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/regex_ext.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/functional/bind_front.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/casts.h" +#include "internal/re2_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime_builder.h" +#include "validator/regex_validator.h" +#include "validator/validator.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "re2/re2.h" + +namespace cel::extensions { +namespace { + +using ::cel::checker_internal::BuiltinsArena; + +Value Extract(int regex_max_program_size, const StringValue& target, + const StringValue& regex, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string target_scratch; + std::string regex_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + RE2 re2(regex_view, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) + .With(ErrorValueReturn()); + const int group_count = re2.NumberOfCapturingGroups(); + if (group_count > 1) { + return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( + "regular expression has more than one capturing group: %s", + regex_view))); + } + + // Space for the full match (\0) and the first capture group (\1). + absl::string_view submatches[2]; + if (re2.Match(target_view, 0, target_view.length(), RE2::UNANCHORED, + submatches, 2)) { + // Return the capture group if it exists else return the full match. + const absl::string_view result_view = + (group_count == 1) ? submatches[1] : submatches[0]; + return OptionalValue::Of(StringValue::From(result_view, arena), arena); + } + + return OptionalValue::None(); +} + +Value ExtractAll(int regex_max_program_size, const StringValue& target, + const StringValue& regex, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string target_scratch; + std::string regex_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + RE2 re2(regex_view, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) + .With(ErrorValueReturn()); + const int group_count = re2.NumberOfCapturingGroups(); + if (group_count > 1) { + return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( + "regular expression has more than one capturing group: %s", + regex_view))); + } + + auto builder = NewListValueBuilder(arena); + absl::string_view temp_target = target_view; + + // Space for the full match (\0) and the first capture group (\1). + absl::string_view submatches[2]; + const int group_to_extract = (group_count == 1) ? 1 : 0; + + while (re2.Match(temp_target, 0, temp_target.length(), RE2::UNANCHORED, + submatches, group_count + 1)) { + const absl::string_view& full_match = submatches[0]; + const absl::string_view& desired_capture = submatches[group_to_extract]; + + // Avoid infinite loops on zero-length matches + if (full_match.empty()) { + if (temp_target.empty()) { + break; + } + temp_target.remove_prefix(1); + continue; + } + + if (group_count == 1 && desired_capture.empty()) { + temp_target.remove_prefix(full_match.data() - temp_target.data() + + full_match.length()); + continue; + } + + absl::Status status = + builder->Add(StringValue::From(desired_capture, arena)); + if (!status.ok()) { + return ErrorValue(status); + } + temp_target.remove_prefix(full_match.data() - temp_target.data() + + full_match.length()); + } + + return std::move(*builder).Build(); +} + +Value ReplaceAll(int regex_max_program_size, const StringValue& target, + const StringValue& regex, const StringValue& replacement, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string target_scratch; + std::string regex_scratch; + std::string replacement_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + absl::string_view replacement_view = + replacement.ToStringView(&replacement_scratch); + RE2 re2(regex_view, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) + .With(ErrorValueReturn()); + std::string error_string; + if (!re2.CheckRewriteString(replacement_view, &error_string)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("invalid replacement string: %s", error_string))); + } + + std::string output(target_view); + RE2::GlobalReplace(&output, re2, replacement_view); + + return StringValue::From(std::move(output), arena); +} + +Value ReplaceN(int regex_max_program_size, const StringValue& target, + const StringValue& regex, const StringValue& replacement, + int64_t count, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (count == 0) { + return target; + } + if (count < 0) { + return ReplaceAll(regex_max_program_size, target, regex, replacement, + descriptor_pool, message_factory, arena); + } + + std::string target_scratch; + std::string regex_scratch; + std::string replacement_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + absl::string_view replacement_view = + replacement.ToStringView(&replacement_scratch); + RE2 re2(regex_view, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) + .With(ErrorValueReturn()); + std::string error_string; + if (!re2.CheckRewriteString(replacement_view, &error_string)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("invalid replacement string: %s", error_string))); + } + + std::string output; + absl::string_view temp_target = target_view; + int replaced_count = 0; + // RE2's Rewrite only supports substitutions for groups \0 through \9. + absl::string_view match[10]; + int nmatch = std::min(9, re2.NumberOfCapturingGroups()) + 1; + + while (replaced_count < count && + re2.Match(temp_target, 0, temp_target.length(), RE2::UNANCHORED, match, + nmatch)) { + absl::string_view full_match = match[0]; + + output.append(temp_target.data(), full_match.data() - temp_target.data()); + + if (!re2.Rewrite(&output, replacement_view, match, nmatch)) { + // This should ideally not happen given CheckRewriteString passed + return ErrorValue(absl::InternalError("rewrite failed unexpectedly")); + } + + temp_target.remove_prefix(full_match.data() - temp_target.data() + + full_match.length()); + replaced_count++; + } + + output.append(temp_target.data(), temp_target.length()); + + return StringValue::From(std::move(output), arena); +} + +absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, + bool disable_extract, + int regex_max_program_size) { + if (!disable_extract) { + CEL_RETURN_IF_ERROR(( + BinaryFunctionAdapter, StringValue, StringValue>:: + RegisterGlobalOverload( + "regex.extract", + absl::bind_front(&Extract, regex_max_program_size), registry))); + } + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter, StringValue, StringValue>:: + RegisterGlobalOverload( + "regex.extractAll", + absl::bind_front(&ExtractAll, regex_max_program_size), + registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, + StringValue>::RegisterGlobalOverload("regex.replace", + absl::bind_front( + &ReplaceAll, + regex_max_program_size), + registry))); + CEL_RETURN_IF_ERROR( + (QuaternaryFunctionAdapter, StringValue, + StringValue, StringValue, int64_t>:: + RegisterGlobalOverload( + "regex.replace", + absl::bind_front(&ReplaceN, regex_max_program_size), registry))); + return absl::OkStatus(); +} + +const Type& OptionalStringType() { + static absl::NoDestructor kInstance( + OptionalType(BuiltinsArena(), StringType())); + return *kInstance; +} + +const Type& ListStringType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), StringType())); + return *kInstance; +} + +absl::Status RegisterRegexCheckerDecls(TypeCheckerBuilder& builder) { + CEL_ASSIGN_OR_RETURN( + FunctionDecl extract_decl, + MakeFunctionDecl( + "regex.extract", + MakeOverloadDecl("regex_extract_string_string", OptionalStringType(), + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl extract_all_decl, + MakeFunctionDecl( + "regex.extractAll", + MakeOverloadDecl("regex_extractAll_string_string", ListStringType(), + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl replace_decl, + MakeFunctionDecl( + "regex.replace", + MakeOverloadDecl("regex_replace_string_string_string", StringType(), + StringType(), StringType(), StringType()), + MakeOverloadDecl("regex_replace_string_string_string_int", + StringType(), StringType(), StringType(), + StringType(), IntType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(extract_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(extract_all_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(replace_decl)); + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) { + auto& runtime = cel::internal::down_cast( + runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder)); + if (!runtime.expr_builder().optional_types_enabled()) { + return absl::InvalidArgumentError( + "regex extensions requires the optional types to be enabled"); + } + if (runtime.expr_builder().options().enable_regex) { + CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions( + builder.function_registry(), + /*disable_extract=*/false, + runtime.expr_builder().options().regex_max_program_size)); + } + return absl::OkStatus(); +} + +absl::Status RegisterRegexExtensionFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options) { + if (options.enable_regex) { + return RegisterRegexExtensionFunctions(registry->InternalGetRegistry(), + /*disable_extract=*/true, + options.regex_max_program_size); + } + return absl::OkStatus(); +} + +CheckerLibrary RegexExtCheckerLibrary() { + return {.id = "cel.lib.ext.regex", .configure = RegisterRegexCheckerDecls}; +} + +CompilerLibrary RegexExtCompilerLibrary() { + return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary()); +} + +Validation RegexExtValidator() { + return RegexPatternValidator( + /*id=*/"", + {{"regex.extract", 1}, {"regex.extractAll", 1}, {"regex.replace", 1}}); +} + +} // namespace cel::extensions diff --git a/extensions/regex_ext.h b/extensions/regex_ext.h new file mode 100644 index 000000000..7b32aee00 --- /dev/null +++ b/extensions/regex_ext.h @@ -0,0 +1,131 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This extension depends on the CEL optional type. Please ensure that the +// EnableOptionalTypes is called when using regex extensions. +// +// # Replace +// +// The `regex.replace` function replaces all non-overlapping substring of a +// regex pattern in the target string with the given replacement string. +// Optionally, you can limit the number of replacements by providing a count +// argument. When the count is a negative number, the function acts as replace +// all. Only numeric (\N) capture group references are supported in the +// replacement string, with validation for correctness. Backslashed-escaped +// digits (\1 to \9) within the replacement argument can be used to insert text +// matching the corresponding parenthesized group in the regexp pattern. An +// error will be thrown for invalid regex or replace string. +// +// regex.replace(target: string, pattern: string, +// replacement: string) -> string +// regex.replace(target: string, pattern: string, +// replacement: string, count: int) -> string +// +// Examples: +// +// regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi' +// regex.replace('banana', 'a', 'x', 0) == 'banana' +// regex.replace('banana', 'a', 'x', 1) == 'bxnana' +// regex.replace('banana', 'a', 'x', -12) == 'bxnxnx' +// regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo' +// regex.replace('test', '(.)', r'\2') \\ Runtime Error invalid replace +// string regex.replace('foo bar', '(', '$2 $1') \\ Runtime Error invalid +// +// # Extract +// +// The `regex.extract` function returns the first match of a regex pattern in a +// string. If no match is found, it returns an optional none value. An error +// will be thrown for invalid regex or for multiple capture groups. +// +// regex.extract(target: string, pattern: string) -> optional +// +// Examples: +// +// regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A') +// regex.extract('HELLO', 'hello') == optional.empty() +// regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error +// multiple capture group +// +// # Extract All +// +// The `regex.extractAll` function returns a list of all matches of a regex +// pattern in a target string. If no matches are found, it returns an empty +// list. An error will be thrown for invalid regex or for multiple capture +// groups. +// +// regex.extractAll(target: string, pattern: string) -> list +// +// Examples: +// +// regex.extractAll('id:123, id:456', 'id:\\d+') == ['id:123', 'id:456'] +// regex.extractAll('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error +// multiple capture group + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ + +#include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/runtime_builder.h" +#include "validator/validator.h" + +namespace cel::extensions { + +// Register extension functions for regular expressions for +// google::api::expr::runtime::CelValue runtime. +// +// Note: CelValue does not support optional types, so regex.extract is +// unsupported. +absl::Status RegisterRegexExtensionFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); + +// Register extension functions for regular expressions. +absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder); + +// Type check declarations for the regex extension library. +// Provides decls for the following functions: +// +// regex.replace(target: str, pattern: str, replacement: str) -> str +// +// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str +// +// regex.extract(target: str, pattern: str) -> optional +// +// regex.extractAll(target: str, pattern: str) -> list +CheckerLibrary RegexExtCheckerLibrary(); + +// Provides decls for the following functions: +// +// regex.replace(target: str, pattern: str, replacement: str) -> str +// +// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str +// +// regex.extract(target: str, pattern: str) -> optional +// +// regex.extractAll(target: str, pattern: str) -> list +CompilerLibrary RegexExtCompilerLibrary(); + +// Returns a `Validation` that checks all calls to the CEL regex extension +// functions. +// +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +Validation RegexExtValidator(); + +} // namespace cel::extensions +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc new file mode 100644 index 000000000..26d9936aa --- /dev/null +++ b/extensions/regex_ext_test.cc @@ -0,0 +1,541 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/regex_ext.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" +#include "common/kind.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "validator/validator.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/extension_set.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::BoolValueIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::OptionalValueIsEmpty; +using ::cel::test::StringValueIs; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelFunctionRegistry; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +using LegacyActivation = google::api::expr::runtime::Activation; + +TEST(RegexExtTest, BuildFailsWithoutOptionalSupport) { + RuntimeOptions options; + options.enable_regex = true; + options.enable_qualified_type_identifiers = true; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + // Optional types are NOT enabled. + ASSERT_THAT(RegisterRegexExtensionFunctions(builder), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("regex extensions requires the optional types " + "to be enabled"))); +} + +TEST(RegexExtTest, LegacyRuntimeSmokeTest) { + InterpreterOptions options; + options.enable_regex = true; + options.enable_qualified_type_identifiers = true; + options.enable_qualified_identifier_rewrites = true; + + std::unique_ptr builder = CreateCelExpressionBuilder( + internal::GetTestingDescriptorPool(), nullptr, options); + + // Optional types are NOT enabled. + ASSERT_THAT(RegisterRegexExtensionFunctions(builder->GetRegistry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto expr, + Parse("regex.extractAll('hello world', 'hello (.*)')")); + LegacyActivation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression( + &expr.expr(), &expr.source_info())); + ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + ASSERT_EQ(result.ListOrDie()->size(), 1); + ASSERT_TRUE(result.ListOrDie()->Get(&arena, 0).IsString()); + EXPECT_EQ(result.ListOrDie()->Get(&arena, 0).StringOrDie().value(), "world"); +} + +TEST(RegexExtTest, DoesNotRegisterExtractForLegacy) { + InterpreterOptions options; + options.enable_regex = true; + + CelFunctionRegistry registry; + // Optional types are not usable in legacy runtime, so extract should not be + // registered. + ASSERT_THAT(RegisterRegexExtensionFunctions(®istry, options), IsOk()); + EXPECT_THAT( + registry.FindStaticOverloads("regex.extract", false, + {cel::Kind::kString, cel::Kind::kString}), + IsEmpty()); + EXPECT_THAT( + registry.FindStaticOverloads("regex.extractAll", false, + {cel::Kind::kString, cel::Kind::kString}), + SizeIs(1)); + EXPECT_THAT(registry.FindStaticOverloads( + "regex.replace", false, + {cel::Kind::kString, cel::Kind::kString, cel::Kind::kString}), + SizeIs(1)); + EXPECT_THAT( + registry.FindStaticOverloads("regex.replace", false, + {cel::Kind::kString, cel::Kind::kString, + cel::Kind::kString, cel::Kind::kInt64}), + SizeIs(1)); +} + +TEST(RegexExtTest, FollowsRegexOption) { + InterpreterOptions options; + options.enable_regex = false; + + CelFunctionRegistry registry; + ASSERT_THAT(RegisterRegexExtensionFunctions(®istry, options), IsOk()); + EXPECT_THAT( + registry.FindStaticOverloads("regex.extract", false, + {cel::Kind::kString, cel::Kind::kString}), + IsEmpty()); + EXPECT_THAT( + registry.FindStaticOverloads("regex.extractAll", false, + {cel::Kind::kString, cel::Kind::kString}), + IsEmpty()); + EXPECT_THAT(registry.FindStaticOverloads( + "regex.replace", false, + {cel::Kind::kString, cel::Kind::kString, cel::Kind::kString}), + IsEmpty()); + EXPECT_THAT( + registry.FindStaticOverloads("regex.replace", false, + {cel::Kind::kString, cel::Kind::kString, + cel::Kind::kString, cel::Kind::kInt64}), + IsEmpty()); +} + +enum class EvaluationType { + kBoolTrue, + kOptionalValue, + kOptionalNone, + kRuntimeError, + kUnknownStaticError, + kInvalidArgStaticError +}; + +struct RegexExtTestCase { + EvaluationType evaluation_type; + std::string expr; + std::string expected_result = ""; +}; + +class RegexExtTest : public TestWithParam { + public: + void SetUp() override { + RuntimeOptions options; + options.enable_regex = true; + options.enable_qualified_type_identifiers = true; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + ASSERT_THAT(RegisterRegexExtensionFunctions(builder), IsOk()); + ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); + } + + absl::StatusOr TestEvaluate(const std::string& expr_string) { + CEL_ASSIGN_OR_RETURN(auto parsed_expr, Parse(expr_string)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + cel::extensions::ProtobufRuntimeAdapter::CreateProgram( + *runtime_, parsed_expr)); + Activation activation; + return program->Evaluate(&arena_, activation); + } + + google::protobuf::Arena arena_; + std::unique_ptr runtime_; +}; + +std::vector regexTestCases() { + return { + // Tests for extract Function + {EvaluationType::kOptionalValue, + R"(regex.extract('hello world', 'hello (.*)'))", "world"}, + {EvaluationType::kOptionalValue, + R"(regex.extract('item-A, item-B', r'item-(\w+)'))", "A"}, + {EvaluationType::kOptionalValue, + R"(regex.extract('The color is red', r'The color is (\w+)'))", "red"}, + {EvaluationType::kOptionalValue, + R"(regex.extract('The color is red', r'The color is \w+'))", + "The color is red"}, + {EvaluationType::kOptionalValue, "regex.extract('brand', 'brand')", + "brand"}, + {EvaluationType::kOptionalNone, + "regex.extract('hello world', 'goodbye (.*)')"}, + {EvaluationType::kOptionalNone, "regex.extract('HELLO', 'hello')"}, + {EvaluationType::kOptionalNone, R"(regex.extract('', r'\w+'))"}, + {EvaluationType::kBoolTrue, + "regex.extract('4122345432', '22').orValue('777') == '22'"}, + {EvaluationType::kBoolTrue, + "regex.extract('4122345432', '22').or(optional.of('777')) == " + "optional.of('22')"}, + + // Tests for extractAll Function + {EvaluationType::kBoolTrue, + "regex.extractAll('id:123, id:456', 'assa') == []"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('id:123, id:456', r'id:\d+') == ['id:123','id:456'])"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('Files: f_1.txt, f_2.csv', r'f_(\d+)')==['1','2'])"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('testuser@', '(?P.*)@') == ['testuser'])"}, + {EvaluationType::kBoolTrue, + R"cel(regex.extractAll('t@gmail.com, a@y.com, 22@sdad.com', + '(?P.*)@') == ['t@gmail.com, a@y.com, 22'])cel"}, + {EvaluationType::kBoolTrue, + R"cel(regex.extractAll('t@gmail.com, a@y.com, 22@sdad.com', + r'(?P\w+)@') == ['t','a', '22'])cel"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('item:a1, topic:b2', + r'(?:item:|topic:)([a-z]\d)') == ['a1', 'b2'])"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('val=a, val=, val=c', 'val=([^,]*)')==['a','c'])"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('key=, key=, key=', 'key=([^,]*)') == []"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('a b c', r'(\S*)\s*') == ['a', 'b', 'c'])"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('abc', 'a|b*') == ['a','b']"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('abc', 'a|(b)|c*') == ['b']"}, + + // Tests for replace Function + {EvaluationType::kBoolTrue, + "regex.replace('abc', '$', '_end') == 'abc_end'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('a-b', r'\b', '|') == '|a|-|b|')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('foo bar', 'foo', r'\\') == '\\ bar')"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'ana', 'x') == 'bxna'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc', 'b(.)', r'x\1') == 'axc')"}, + {EvaluationType::kBoolTrue, + "regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('ac', 'a(b)?c', r'[\1]') == '[]')"}, + {EvaluationType::kBoolTrue, + "regex.replace('apple pie', 'p', 'X') == 'aXXle Xie'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('remove all spaces', r'\s', '') == + 'removeallspaces')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('digit:99919291992', r'\d+', '3') == 'digit:3')"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('foo bar baz', r'\w+', r'(\0)') == + '(foo) (bar) (baz)')cel"}, + {EvaluationType::kBoolTrue, "regex.replace('', 'a', 'b') == ''"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('User: Alice, Age: 30', + r'User: (?P\w+), Age: (?P\d+)', + '${name} is ${age} years old') == '${name} is ${age} years old')cel"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('User: Alice, Age: 30', + r'User: (?P\w+), Age: (?P\d+)', r'\1 is \2 years old') == + 'Alice is 30 years old')cel"}, + {EvaluationType::kBoolTrue, + "regex.replace('hello ☃', '☃', '❄') == 'hello ❄'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('id=123', r'id=(?P\d+)', r'value: \1') == + 'value: 123')"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x') == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace(regex.replace('%(foo) %(bar) %2', r'%\((\w+)\)', + r'${\1}'),r'%(\d+)', r'$\1') == '${foo} ${bar} $2')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\1') == r'\1 def')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\2') == r'\2 def')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\{word}') == '\\{word} def')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\word') == '\\word def')"}, + {EvaluationType::kBoolTrue, + "regex.replace('abc', '^', 'start_') == 'start_abc'"}, + + // Tests for replace Function with count variable + {EvaluationType::kBoolTrue, + R"(regex.replace('foofoo', 'foo', 'bar', + 9223372036854775807) == 'barbar')"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 0) == 'banana'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 1) == 'bxnana'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 2) == 'bxnxna'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 100) == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', -1) == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', -100) == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('cat-dog dog-cat cat-dog dog-cat', '(cat)-(dog)', + r'\2-\1', 1) == 'dog-cat dog-cat cat-dog dog-cat')cel"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('cat-dog dog-cat cat-dog dog-cat', '(cat)-(dog)', + r'\2-\1', 2) == 'dog-cat dog-cat dog-cat dog-cat')cel"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('a.b.c', r'\.', '-', 1) == 'a-b.c')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('a.b.c', r'\.', '-', -1) == 'a-b-c')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('123456789ABC', + '(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\w)(\\w)(\\w)','X', 1) + == 'X')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('123456789ABC', + '(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\w)(\\w)(\\w)', + r'\1-\9-X', 1) == '1-9-X')"}, + + // Static Errors + {EvaluationType::kUnknownStaticError, "regex.replace('abc', '^', 1)", + "No matching overloads found : regex.replace(string, string, int64)"}, + {EvaluationType::kUnknownStaticError, "regex.replace('abc', '^', '1','')", + "No matching overloads found : regex.replace(string, string, string, " + "string)"}, + {EvaluationType::kUnknownStaticError, "regex.extract('foo bar', 1)", + "No matching overloads found : regex.extract(string, int64)"}, + {EvaluationType::kInvalidArgStaticError, + "regex.extract('foo bar', 1, 'bar')", + "No overload found in reference resolve step for extract"}, + {EvaluationType::kInvalidArgStaticError, "regex.extractAll()", + "No overload found in reference resolve step for extractAll"}, + + // Runtime Errors + {EvaluationType::kRuntimeError, R"(regex.extract('foo', 'fo(o+)(abc'))", + "invalid regular expression: missing ): fo(o+)(abc"}, + {EvaluationType::kRuntimeError, R"(regex.extractAll('foo bar', '[a-z'))", + "invalid regular expression: missing ]: [a-z"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('foo bar', '[a-z', 'a'))", + "invalid regular expression: missing ]: [a-z"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('foo bar', '[a-z', 'a', 1))", + "invalid regular expression: missing ]: [a-z"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('id=123', r'id=(?P\d+)', r'value: \values'))", + R"(invalid replacement string: Rewrite schema error: '\' must be followed by a digit or '\'.)"}, + {EvaluationType::kRuntimeError, R"(regex.replace('test', '(t)', '\\2'))", + "invalid replacement string: Rewrite schema requests 2 matches, but " + "the regexp only has 1 parenthesized subexpressions"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('id=123', r'id=(?P\d+)', '\\', 1))", + R"(invalid replacement string: Rewrite schema error: '\' not allowed at end.)"}, + {EvaluationType::kRuntimeError, + R"(regex.extract('phone: 415-5551212', r'phone: ((\d{3})-)?'))", + R"(regular expression has more than one capturing group: phone: ((\d{3})-)?)"}, + {EvaluationType::kRuntimeError, + R"(regex.extractAll('testuser@testdomain', '(.*)@([^.]*)'))", + R"(regular expression has more than one capturing group: (.*)@([^.]*))"}, + }; +} + +TEST_P(RegexExtTest, RegexExtTests) { + const RegexExtTestCase& test_case = GetParam(); + auto result = TestEvaluate(test_case.expr); + + switch (test_case.evaluation_type) { + case EvaluationType::kRuntimeError: + EXPECT_THAT(result, IsOkAndHolds(ErrorValueIs( + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.expected_result))))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kUnknownStaticError: + EXPECT_THAT(result, IsOkAndHolds(ErrorValueIs( + StatusIs(absl::StatusCode::kUnknown, + HasSubstr(test_case.expected_result))))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kInvalidArgStaticError: + EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.expected_result))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kOptionalNone: + EXPECT_THAT(result, IsOkAndHolds(OptionalValueIsEmpty())) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kOptionalValue: + EXPECT_THAT(result, IsOkAndHolds(OptionalValueIs( + StringValueIs(test_case.expected_result)))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kBoolTrue: + EXPECT_THAT(result, IsOkAndHolds(BoolValueIs(true))) + << "Expression: " << test_case.expr; + break; + } +} + +INSTANTIATE_TEST_SUITE_P(RegexExtTest, RegexExtTest, + ValuesIn(regexTestCases())); + +struct RegexCheckerTestCase { + std::string expr_string; + std::string error_substr; +}; + +class RegexExtCheckerLibraryTest : public TestWithParam { + public: + void SetUp() override { + // Arrange: Configure the compiler. + // Add the regex checker library to the compiler builder. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, + NewCompilerBuilder(descriptor_pool_)); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(RegexExtCompilerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); + } + + const google::protobuf::DescriptorPool* descriptor_pool_ = + internal::GetTestingDescriptorPool(); + std::unique_ptr compiler_; +}; + +TEST_P(RegexExtCheckerLibraryTest, RegexExtTypeCheckerTests) { + // Act & Assert: Compile the expression and validate the result. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler_->Compile(GetParam().expr_string)); + absl::string_view error_substr = GetParam().error_substr; + EXPECT_EQ(result.IsValid(), error_substr.empty()); + + if (!error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(error_substr)); + } +} + +std::vector createRegexCheckerParams() { + return { + {"regex.replace('abc', 'a', 's') == 'sbc'"}, + {"regex.replace('abc', 'a', 's') == 121", + "found no matching overload for '_==_' applied to '(string, int)"}, + {"regex.replace('abc', 'j', '1', 2) == 9.0", + "found no matching overload for '_==_' applied to '(string, double)"}, + {"regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"}, + {"regex.extract('foo bar', 'f') == 121", + "found no matching overload for '_==_' applied to " + "'(optional_type(string), int)'"}, + }; +} + +INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest, + ValuesIn(createRegexCheckerParams())); + +absl::StatusOr> CreateRegexExtCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCheckerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(RegexExtCompilerLibrary())); + return std::move(*builder).Build(); +} + +class RegexExtValidatorTest : public TestWithParam {}; + +TEST_P(RegexExtValidatorTest, Basic) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateRegexExtCompiler()); + + Validator validator; + validator.AddValidation(RegexExtValidator()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(GetParam().expr_string)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), GetParam().error_substr.empty()) + << "Expression: " << GetParam().expr_string; + if (!GetParam().error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(GetParam().error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P(RegexExtValidatorTest, RegexExtValidatorTest, + testing::ValuesIn(std::vector{ + {"regex.extract('hello world', 'hello (.*)')"}, + {"regex.extract('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.extractAll('hello world', 'hello (.*)')"}, + {"regex.extractAll('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.replace('hello world', 'hello', 'hi')"}, + {"regex.replace('hello world', 'he([', 'hi') ", + "invalid regular expression"}, + })); +} // namespace +} // namespace cel::extensions diff --git a/extensions/regex_functions.cc b/extensions/regex_functions.cc new file mode 100644 index 000000000..005987ae4 --- /dev/null +++ b/extensions/regex_functions.cc @@ -0,0 +1,237 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/regex_functions.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/functional/bind_front.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/re2_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "re2/re2.h" + +namespace cel::extensions { +namespace { + +using ::cel::checker_internal::BuiltinsArena; +using ::google::api::expr::runtime::CelFunctionRegistry; +using ::google::api::expr::runtime::InterpreterOptions; + +// Extract matched group values from the given target string and rewrite the +// string +Value ExtractString(int regex_max_program_size, const StringValue& target, + const StringValue& regex, const StringValue& rewrite, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string regex_scratch; + std::string target_scratch; + std::string rewrite_scratch; + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view rewrite_view = rewrite.ToStringView(&rewrite_scratch); + + RE2 re2(regex_view, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) + .With(ErrorValueReturn()); + std::string output; + bool result = RE2::Extract(target_view, re2, rewrite_view, &output); + if (!result) { + return ErrorValue(absl::InvalidArgumentError( + "Unable to extract string for the given regex")); + } + return StringValue::From(std::move(output), arena); +} + +// Captures the first unnamed/named group value +// NOTE: For capturing all the groups, use CaptureStringN instead +Value CaptureString(int regex_max_program_size, const StringValue& target, + const StringValue& regex, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string regex_scratch; + std::string target_scratch; + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + absl::string_view target_view = target.ToStringView(&target_scratch); + RE2 re2(regex_view, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) + .With(ErrorValueReturn()); + std::string output; + bool result = RE2::FullMatch(target_view, re2, &output); + if (!result) { + return ErrorValue(absl::InvalidArgumentError( + "Unable to capture groups for the given regex")); + } else { + return StringValue::From(std::move(output), arena); + } +} + +// Does a FullMatchN on the given string and regex and returns a map with pairs as follows: +// a. For a named group - +// b. For an unnamed group - +absl::StatusOr CaptureStringN( + int regex_max_program_size, const StringValue& target, + const StringValue& regex, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string target_scratch; + std::string regex_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + RE2 re2(regex_view, cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) + .With(ErrorValueReturn()); + const int capturing_groups_count = re2.NumberOfCapturingGroups(); + const auto& named_capturing_groups_map = re2.CapturingGroupNames(); + if (capturing_groups_count <= 0) { + return ErrorValue(absl::InvalidArgumentError( + "Capturing groups were not found in the given regex.")); + } + std::vector captured_strings(capturing_groups_count); + std::vector captured_string_addresses(capturing_groups_count); + std::vector argv(capturing_groups_count); + for (int j = 0; j < capturing_groups_count; j++) { + captured_string_addresses[j] = &captured_strings[j]; + argv[j] = &captured_string_addresses[j]; + } + bool result = + RE2::FullMatchN(target_view, re2, argv.data(), capturing_groups_count); + if (!result) { + return ErrorValue(absl::InvalidArgumentError( + "Unable to capture groups for the given regex")); + } + auto builder = cel::NewMapValueBuilder(arena); + builder->Reserve(capturing_groups_count); + for (int index = 1; index <= capturing_groups_count; index++) { + auto it = named_capturing_groups_map.find(index); + std::string name = it != named_capturing_groups_map.end() + ? it->second + : std::to_string(index); + CEL_RETURN_IF_ERROR(builder->Put( + StringValue::From(std::move(name), arena), + StringValue::From(std::move(captured_strings[index - 1]), arena))); + } + return std::move(*builder).Build(); +} + +absl::Status RegisterRegexFunctions(FunctionRegistry& registry, + int max_regex_program_size) { + // Register Regex Extract Function + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, + StringValue>::RegisterGlobalOverload(kRegexExtract, + absl::bind_front( + &ExtractString, + max_regex_program_size), + registry))); + + // Register Regex Captures Function + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter, StringValue, StringValue>:: + RegisterGlobalOverload( + kRegexCapture, + absl::bind_front(&CaptureString, max_regex_program_size), + registry))); + + // Register Regex CaptureN Function + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter, StringValue, StringValue>:: + RegisterGlobalOverload( + kRegexCaptureN, + absl::bind_front(&CaptureStringN, max_regex_program_size), + registry))); + return absl::OkStatus(); +} + +const Type& CaptureNMapType() { + static absl::NoDestructor kInstance( + MapType(BuiltinsArena(), StringType(), StringType())); + return *kInstance; +} + +absl::Status RegisterRegexDecls(TypeCheckerBuilder& builder) { + CEL_ASSIGN_OR_RETURN( + FunctionDecl regex_extract_decl, + MakeFunctionDecl( + std::string(kRegexExtract), + MakeOverloadDecl("re_extract_string_string_string", StringType(), + StringType(), StringType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(regex_extract_decl)); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl regex_capture_decl, + MakeFunctionDecl( + std::string(kRegexCapture), + MakeOverloadDecl("re_capture_string_string", StringType(), + StringType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(regex_capture_decl)); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl regex_capture_n_decl, + MakeFunctionDecl( + std::string(kRegexCaptureN), + MakeOverloadDecl("re_captureN_string_string", CaptureNMapType(), + StringType(), StringType()))); + return builder.AddFunction(regex_capture_n_decl); +} + +} // namespace + +absl::Status RegisterRegexFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + if (options.enable_regex) { + CEL_RETURN_IF_ERROR( + RegisterRegexFunctions(registry, options.regex_max_program_size)); + } + return absl::OkStatus(); +} + +absl::Status RegisterRegexFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + CEL_RETURN_IF_ERROR(RegisterRegexFunctions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options))); + return absl::OkStatus(); +} + +CheckerLibrary RegexCheckerLibrary() { + return {.id = "cpp_regex", .configure = RegisterRegexDecls}; +} + +} // namespace cel::extensions diff --git a/extensions/regex_functions.h b/extensions/regex_functions.h new file mode 100644 index 000000000..62c83ebdd --- /dev/null +++ b/extensions/regex_functions.h @@ -0,0 +1,52 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Definitions for extension functions wrapping C++ RE2 APIs. These are +// only defined for the C++ CEL library and distinct from the regex +// extension library (supported by other implementations). + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_FUNCTIONS_H_ + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +inline constexpr absl::string_view kRegexExtract = "re.extract"; +inline constexpr absl::string_view kRegexCapture = "re.capture"; +inline constexpr absl::string_view kRegexCaptureN = "re.captureN"; + +// Register Extract and Capture Functions for RE2 +// Requires options.enable_regex to be true +// The canonical regex extensions supported by the CEL team are registered +// via the `RegisterRegexExtensionsFunctions`. This extension is deprecated. +ABSL_DEPRECATED("Use RegisterRegexExtensionsFunctions instead.") +absl::Status RegisterRegexFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); +absl::Status RegisterRegexFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +// Declarations for the regex extension library. +CheckerLibrary RegexCheckerLibrary(); + +} // namespace cel::extensions +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_FUNCTIONS_H_ diff --git a/extensions/regex_functions_test.cc b/extensions/regex_functions_test.cc new file mode 100644 index 000000000..92a4da6bb --- /dev/null +++ b/extensions/regex_functions_test.cc @@ -0,0 +1,296 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/regex_functions.h" + +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/extension_set.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::ErrorValueIs; +using ::cel::test::MapValueElements; +using ::cel::test::MapValueIs; +using ::cel::test::StringValueIs; +using ::google::api::expr::parser::Parse; +using ::testing::HasSubstr; +using ::testing::UnorderedElementsAre; +using ::testing::ValuesIn; + +struct TestCase { + const std::string expr_string; + const std::string expected_result; +}; + +class RegexFunctionsTest : public ::testing::TestWithParam { + public: + void SetUp() override { + RuntimeOptions options; + options.enable_regex = true; + options.enable_qualified_type_identifiers = true; + + ASSERT_OK_AND_ASSIGN( + RuntimeBuilder builder, + CreateStandardRuntimeBuilder(descriptor_pool_, options)); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + ASSERT_THAT(RegisterRegexFunctions(builder.function_registry(), options), + IsOk()); + ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); + } + + absl::StatusOr TestEvaluate(const std::string& expr_string) { + CEL_ASSIGN_OR_RETURN(auto parsed_expr, Parse(expr_string)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + cel::extensions::ProtobufRuntimeAdapter::CreateProgram( + *runtime_, parsed_expr)); + Activation activation; + return program->Evaluate(&arena_, activation); + } + + const google::protobuf::DescriptorPool* descriptor_pool_ = + internal::GetTestingDescriptorPool(); + google::protobuf::MessageFactory* message_factory_ = + google::protobuf::MessageFactory::generated_factory(); + google::protobuf::Arena arena_; + std::unique_ptr runtime_; +}; + +TEST_F(RegexFunctionsTest, CaptureStringSuccessWithCombinationOfGroups) { + // combination of named and unnamed groups should return a celmap + EXPECT_THAT( + TestEvaluate((R"cel( + re.captureN( + 'The user testuser belongs to testdomain', + 'The (user|domain) (?P.*) belongs to (?P.*)' + ) + )cel")), + IsOkAndHolds(MapValueIs(MapValueElements( + UnorderedElementsAre( + Pair(StringValueIs("1"), StringValueIs("user")), + Pair(StringValueIs("Username"), StringValueIs("testuser")), + Pair(StringValueIs("Domain"), StringValueIs("testdomain"))), + descriptor_pool_, message_factory_, &arena_)))); +} + +TEST_F(RegexFunctionsTest, CaptureStringSuccessWithSingleNamedGroup) { + // Regex containing one named group should return a map + EXPECT_THAT( + TestEvaluate(R"cel(re.captureN('testuser@', '(?P.*)@'))cel"), + IsOkAndHolds(MapValueIs(MapValueElements( + UnorderedElementsAre( + Pair(StringValueIs("username"), StringValueIs("testuser"))), + descriptor_pool_, message_factory_, &arena_)))); +} + +TEST_F(RegexFunctionsTest, CaptureStringSuccessWithMultipleUnamedGroups) { + // Regex containing all unnamed groups should return a map + EXPECT_THAT( + TestEvaluate( + R"cel(re.captureN('testuser@testdomain', '(.*)@([^.]*)'))cel"), + IsOkAndHolds(MapValueIs(MapValueElements( + UnorderedElementsAre( + Pair(StringValueIs("1"), StringValueIs("testuser")), + Pair(StringValueIs("2"), StringValueIs("testdomain"))), + descriptor_pool_, message_factory_, &arena_)))); +} + +// Extract String: Extract named and unnamed strings +TEST_F(RegexFunctionsTest, ExtractStringWithNamedAndUnnamedGroups) { + EXPECT_THAT(TestEvaluate(R"cel( + re.extract( + 'The user testuser belongs to testdomain', + 'The (user|domain) (?P.*) belongs to (?P.*)', + '\\3 contains \\1 \\2') + )cel"), + IsOkAndHolds(StringValueIs("testdomain contains user testuser"))); +} + +// Extract String: Extract with empty strings +TEST_F(RegexFunctionsTest, ExtractStringWithEmptyStrings) { + EXPECT_THAT(TestEvaluate(R"cel(re.extract('', '', ''))cel"), + IsOkAndHolds(StringValueIs(""))); +} + +// Extract String: Extract unnamed strings +TEST_F(RegexFunctionsTest, ExtractStringWithUnnamedGroups) { + EXPECT_THAT(TestEvaluate(R"cel( + re.extract('testuser@google.com', '(.*)@([^.]*)', '\\2!\\1') + )cel"), + IsOkAndHolds(StringValueIs("google!testuser"))); +} + +// Extract String: Extract string with no captured groups +TEST_F(RegexFunctionsTest, ExtractStringWithNoGroups) { + EXPECT_THAT(TestEvaluate(R"cel(re.extract('foo', '.*', '\'\\0\''))cel"), + IsOkAndHolds(StringValueIs("'foo'"))); +} + +// Capture String: Success with matching unnamed group +TEST_F(RegexFunctionsTest, CaptureStringWithUnnamedGroups) { + EXPECT_THAT(TestEvaluate(R"cel(re.capture('foo', 'fo(o)'))cel"), + IsOkAndHolds(StringValueIs("o"))); +} + +std::vector createParams() { + return { + {// Extract String: Fails for mismatched regex + (R"(re.extract('foo', 'f(o+)(s)', '\\1\\2'))"), + "Unable to extract string for the given regex"}, + {// Extract String: Fails when rewritten string has too many placeholders + (R"(re.extract('foo', 'f(o+)', '\\1\\2'))"), + "Unable to extract string for the given regex"}, + {// Extract String: Fails when invalid regular expression + (R"(re.extract('foo', 'f(o+)(abc', '\\1\\2'))"), + "invalid regular expression"}, + {// Capture String: Empty regex + (R"(re.capture('foo', ''))"), + "Unable to capture groups for the given regex"}, + {// Capture String: No Capturing groups + (R"(re.capture('foo', '.*'))"), + "Unable to capture groups for the given regex"}, + {// Capture String: Mismatched String + (R"(re.capture('', 'bar'))"), + "Unable to capture groups for the given regex"}, + {// Capture String: Mismatched groups + (R"(re.capture('foo', 'fo(o+)(s)'))"), + "Unable to capture groups for the given regex"}, + {// Capture String: invalid regular expression + (R"(re.capture('foo', 'fo(o+)(abc'))"), "invalid regular expression"}, + {// Capture String N: Empty regex + (R"(re.captureN('foo', ''))"), + "Capturing groups were not found in the given regex."}, + {// Capture String N: No Capturing groups + (R"(re.captureN('foo', '.*'))"), + "Capturing groups were not found in the given regex."}, + {// Capture String N: Mismatched String + (R"(re.captureN('', 'bar'))"), + "Capturing groups were not found in the given regex."}, + {// Capture String N: Mismatched groups + (R"(re.captureN('foo', 'fo(o+)(s)'))"), + "Unable to capture groups for the given regex"}, + {// Capture String N: invalid regular expression + (R"(re.captureN('foo', 'fo(o+)(abc'))"), "invalid regular expression"}, + }; +} + +TEST_P(RegexFunctionsTest, RegexFunctionsTests) { + const TestCase& test_case = GetParam(); + ABSL_LOG(INFO) << "Testing Cel Expression: " << test_case.expr_string; + EXPECT_THAT(TestEvaluate(test_case.expr_string), + IsOkAndHolds(ErrorValueIs( + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.expected_result))))); +} + +INSTANTIATE_TEST_SUITE_P(RegexFunctionsTest, RegexFunctionsTest, + ValuesIn(createParams())); + +struct RegexCheckerTestCase { + const std::string expr_string; + bool is_valid; +}; + +class RegexCheckerLibraryTest + : public ::testing::TestWithParam { + public: + void SetUp() override { + // Arrange: Configure the compiler. + // Add the regex checker library to the compiler builder. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, + NewCompilerBuilder(descriptor_pool_)); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(RegexCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); + } + + const google::protobuf::DescriptorPool* descriptor_pool_ = + internal::GetTestingDescriptorPool(); + std::unique_ptr compiler_; +}; + +TEST_P(RegexCheckerLibraryTest, RegexFunctionsTypeCheckerSuccess) { + // Act & Assert: Compile the expression and validate the result. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler_->Compile(GetParam().expr_string)); + EXPECT_EQ(result.IsValid(), GetParam().is_valid); +} + +// Returns a vector of test cases for the RegexCheckerLibraryTest. +// Returns both positive and negative test cases for the regex functions. +std::vector createRegexCheckerParams() { + return { + {R"(re.extract('testuser@google.com', '(.*)@([^.]*)', '\\2!\\1') == 'google!testuser')", + true}, + {R"(re.extract(1, '(.*)@([^.]*)', '\\2!\\1') == 'google!testuser')", + false}, + {R"(re.extract('testuser@google.com', ['1', '2'], '\\2!\\1') == 'google!testuser')", + false}, + {R"(re.extract('testuser@google.com', '(.*)@([^.]*)', false) == 'google!testuser')", + false}, + {R"(re.extract('testuser@google.com', '(.*)@([^.]*)', '\\2!\\1') == 2.2)", + false}, + {R"(re.captureN('testuser@', '(?P.*)@') == {'username': 'testuser'})", + true}, + {R"(re.captureN(['foo', 'bar'], '(?P.*)@') == {'username': 'testuser'})", + false}, + {R"(re.captureN('testuser@', 2) == {'username': 'testuser'})", false}, + {R"(re.captureN('testuser@', '(?P.*)@') == true)", false}, + {R"(re.capture('foo', 'fo(o)') == 'o')", true}, + {R"(re.capture('foo', 2) == 'o')", false}, + {R"(re.capture(true, 'fo(o)') == 'o')", false}, + {R"(re.capture('foo', 'fo(o)') == ['o'])", false}, + }; +} + +INSTANTIATE_TEST_SUITE_P(RegexCheckerLibraryTest, RegexCheckerLibraryTest, + ValuesIn(createRegexCheckerParams())); + +} // namespace + +} // namespace cel::extensions diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc new file mode 100644 index 000000000..0cc64311a --- /dev/null +++ b/extensions/select_optimization.cc @@ -0,0 +1,958 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/select_optimization.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "base/attribute.h" +#include "base/builtins.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/casting.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "internal/casts.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "runtime/internal/errors.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime_builder.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { +namespace { + +using ::cel::Ast; +using ::cel::AstRewriterBase; +using ::cel::CallExpr; +using ::cel::ConstantKind; +using ::cel::Expr; +using ::cel::ExprKind; +using ::cel::SelectExpr; +using ::google::api::expr::runtime::AttributeTrail; +using ::google::api::expr::runtime::DirectExpressionStep; +using ::google::api::expr::runtime::ExecutionFrame; +using ::google::api::expr::runtime::ExecutionFrameBase; +using ::google::api::expr::runtime::ExpressionStepBase; +using ::google::api::expr::runtime::PlannerContext; +using ::google::api::expr::runtime::ProgramOptimizer; + +// Represents a single select operation (field access or indexing). +// For struct-typed field accesses, includes the field name and the field +// number. +struct SelectInstruction { + int64_t number; + std::string name; +}; + +// Represents a single qualifier in a traversal path. +// TODO(uncreated-issue/51): support variable indexes. +using QualifierInstruction = + std::variant; + +struct SelectPath { + Expr* operand; + std::vector select_instructions; + bool test_only; + // TODO(uncreated-issue/54): support for optionals. +}; + +// Generates the AST representation of the qualification path for the optimized +// select branch. I.e., the list-typed second argument of the cel.@attribute +// call. +Expr MakeSelectPathExpr( + const std::vector& select_instructions) { + Expr result; + auto& ast_list = result.mutable_list_expr().mutable_elements(); + ast_list.reserve(select_instructions.size()); + auto visitor = absl::Overload( + [&](const SelectInstruction& instruction) { + Expr ast_instruction; + Expr field_number; + field_number.mutable_const_expr().set_int64_value(instruction.number); + Expr field_name; + field_name.mutable_const_expr().set_string_value(instruction.name); + auto& field_specifier = + ast_instruction.mutable_list_expr().mutable_elements(); + field_specifier.emplace_back().set_expr(std::move(field_number)); + field_specifier.emplace_back().set_expr(std::move(field_name)); + + ast_list.emplace_back().set_expr(std::move(ast_instruction)); + }, + [&](absl::string_view instruction) { + Expr const_expr; + const_expr.mutable_const_expr().set_string_value(instruction); + ast_list.emplace_back().set_expr(std::move(const_expr)); + }, + [&](int64_t instruction) { + Expr const_expr; + const_expr.mutable_const_expr().set_int64_value(instruction); + ast_list.emplace_back().set_expr(std::move(const_expr)); + }, + [&](uint64_t instruction) { + Expr const_expr; + const_expr.mutable_const_expr().set_uint64_value(instruction); + ast_list.emplace_back().set_expr(std::move(const_expr)); + }, + [&](bool instruction) { + Expr const_expr; + const_expr.mutable_const_expr().set_bool_value(instruction); + ast_list.emplace_back().set_expr(std::move(const_expr)); + }); + + for (const auto& instruction : select_instructions) { + absl::visit(visitor, instruction); + } + return result; +} + +// Returns a single select operation based on the inferred type of the operand +// and the field name. If the operand type doesn't define the field, returns +// nullopt. +std::optional GetSelectInstruction( + const StructType& runtime_type, PlannerContext& planner_context, + absl::string_view field_name) { + auto field_or = planner_context.type_reflector() + .FindStructTypeFieldByName(runtime_type, field_name) + .value_or(std::nullopt); + if (field_or.has_value()) { + return SelectInstruction{field_or->number(), std::string(field_or->name())}; + } + return std::nullopt; +} + +absl::StatusOr SelectQualifierFromList(const ListExpr& list) { + if (list.elements().size() != 2) { + return absl::InvalidArgumentError("Invalid cel.attribute select list"); + } + + const Expr& field_number = list.elements()[0].expr(); + const Expr& field_name = list.elements()[1].expr(); + + if (!field_number.has_const_expr() || + !field_number.const_expr().has_int64_value()) { + return absl::InvalidArgumentError( + "Invalid cel.attribute field select number"); + } + + if (!field_name.has_const_expr() || + !field_name.const_expr().has_string_value()) { + return absl::InvalidArgumentError( + "Invalid cel.attribute field select name"); + } + + return FieldSpecifier{field_number.const_expr().int64_value(), + field_name.const_expr().string_value()}; +} + +// Returns a qualifier instruction derived from a unoptimized ast. +absl::StatusOr SelectInstructionFromConstant( + const Constant& constant) { + if (constant.has_int_value()) { + return QualifierInstruction(constant.int_value()); + } else if (constant.has_uint_value()) { + return QualifierInstruction(constant.uint_value()); + } else if (constant.has_bool_value()) { + return QualifierInstruction(constant.bool_value()); + } else if (constant.has_string_value()) { + return QualifierInstruction(constant.string_value()); + } else if (constant.has_double_value()) { + cel::internal::Number number(constant.double_value()); + if (number.LosslessConvertibleToInt()) { + return QualifierInstruction(number.AsInt()); + } else if (number.LosslessConvertibleToUint()) { + return QualifierInstruction(number.AsUint()); + } + } + + return absl::InvalidArgumentError("invalid index constant for cel.attribute"); +} + +absl::StatusOr SelectQualifierFromConstant( + const Constant& constant) { + if (constant.has_int_value()) { + return AttributeQualifier::OfInt(constant.int_value()); + } else if (constant.has_uint_value()) { + return AttributeQualifier::OfUint(constant.uint_value()); + } else if (constant.has_bool_value()) { + return AttributeQualifier::OfBool(constant.bool_value()); + } else if (constant.has_string_value()) { + return AttributeQualifier::OfString(constant.string_value()); + } + // TODO(uncreated-issue/51): double keys could possibly be valid selectors, but + // the other stacks don't implement the optimization yet and we normalize the + // key to a uint or int if we do the late AST rewrite during planning. + + return absl::InvalidArgumentError("invalid cel.attribute constant"); +} + +absl::StatusOr ListIndexFromQualifier(const AttributeQualifier& qual) { + int64_t value = -1; + switch (qual.kind()) { + case Kind::kInt: + value = *qual.GetInt64Key(); + break; + default: + // TODO(uncreated-issue/51): type-checker will reject an unsigned literal, but + // should be supported as a dyn / variable. + return runtime_internal::CreateNoMatchingOverloadError( + cel::builtin::kIndex); + } + + if (value < 0) { + return absl::InvalidArgumentError("list index less than 0"); + } + + return static_cast(value); +} + +absl::StatusOr MapKeyFromQualifier(const AttributeQualifier& qual, + google::protobuf::Arena* absl_nonnull arena) { + switch (qual.kind()) { + case Kind::kInt: + return cel::IntValue(*qual.GetInt64Key()); + case Kind::kUint: + return cel::UintValue(*qual.GetUint64Key()); + case Kind::kBool: + return cel::BoolValue(*qual.GetBoolKey()); + case Kind::kString: + return StringValue::From(*qual.GetStringKey(), arena); + default: + return runtime_internal::CreateNoMatchingOverloadError( + cel::builtin::kIndex); + } +} + +absl::StatusOr ApplyQualifier( + const Value& operand, const SelectQualifier& qualifier, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return absl::visit( + absl::Overload( + [&](const FieldSpecifier& field_specifier) -> absl::StatusOr { + if (!operand.Is()) { + return cel::ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError( + "")); + } + CEL_ASSIGN_OR_RETURN( + bool present, + elem->GetStruct().HasFieldByName(field_specifier.name)); + return cel::BoolValue(present); + }, + [&](const AttributeQualifier& qualifier) -> absl::StatusOr { + if (!elem->Is() || qualifier.kind() != Kind::kString) { + return cel::ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError( + "has")); + } + + return elem->GetMap().Has( + StringValue(arena, *qualifier.GetStringKey()), + descriptor_pool, message_factory, arena); + }), + last_instruction); + } + + return ApplyQualifier(*elem, last_instruction, descriptor_pool, + message_factory, arena); +} + +absl::StatusOr> SelectInstructionsFromCall( + const CallExpr& call) { + if (call.args().size() < 2 || !call.args()[1].has_list_expr()) { + return absl::InvalidArgumentError("Invalid cel.attribute call"); + } + std::vector instructions; + const auto& ast_path = call.args()[1].list_expr().elements(); + instructions.reserve(ast_path.size()); + + for (const ListExprElement& element : ast_path) { + // Optimized field select. + if (element.has_expr()) { + const auto& element_expr = element.expr(); + if (element_expr.has_list_expr()) { + CEL_ASSIGN_OR_RETURN(instructions.emplace_back(), + SelectQualifierFromList(element_expr.list_expr())); + } else if (element_expr.has_const_expr()) { + CEL_ASSIGN_OR_RETURN( + instructions.emplace_back(), + SelectQualifierFromConstant(element_expr.const_expr())); + } else { + return absl::InvalidArgumentError("Invalid cel.attribute call"); + } + } else { + return absl::InvalidArgumentError("Invalid cel.attribute call"); + } + } + + // TODO(uncreated-issue/54): support for optionals. + + return instructions; +} + +class RewriterImpl : public AstRewriterBase { + public: + RewriterImpl(const Ast& ast, PlannerContext& planner_context) + : ast_(ast), planner_context_(planner_context) {} + + void PreVisitExpr(const Expr& expr) override { path_.push_back(&expr); } + + void PreVisitSelect(const Expr& expr, const SelectExpr& select) override { + const Expr& operand = select.operand(); + const std::string& field_name = select.field(); + // Select optimization can generalize to lists and maps, but for now only + // support message traversal. + const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id()); + + std::optional rt_type = + (checker_type.has_message_type()) + ? GetRuntimeType(checker_type.message_type().type()) + : std::nullopt; + if (rt_type.has_value() && (*rt_type).Is()) { + const StructType& runtime_type = rt_type->GetStruct(); + std::optional field_or = + GetSelectInstruction(runtime_type, planner_context_, field_name); + if (field_or.has_value()) { + candidates_[&expr] = std::move(field_or).value(); + } + } else if (checker_type.has_map_type()) { + candidates_[&expr] = QualifierInstruction(field_name); + } + // else + // TODO(uncreated-issue/54): add support for either dyn or any. Excluded to + // simplify program plan. + } + + void PreVisitCall(const Expr& expr, const CallExpr& call) override { + if (call.args().size() != 2 || call.function() != ::cel::builtin::kIndex) { + return; + } + + const auto& qualifier_expr = call.args()[1]; + if (qualifier_expr.has_const_expr()) { + auto qualifier_or = + SelectInstructionFromConstant(qualifier_expr.const_expr()); + if (!qualifier_or.ok()) { + // TODO(uncreated-issue/54): should warn, but by default warnings fail overall + // program planning. + return; + } + candidates_[&expr] = std::move(qualifier_or).value(); + } + // TODO(uncreated-issue/54): support variable indexes + } + + bool PostVisitRewrite(Expr& expr) override { + if (!progress_status_.ok()) { + return false; + } + path_.pop_back(); + auto candidate_iter = candidates_.find(&expr); + if (candidate_iter == candidates_.end()) { + return false; + } + + // On post visit, filter candidates that aren't rooted on a message or a + // select chain. + const QualifierInstruction& candidate = candidate_iter->second; + if (!HasOptimizeableRoot(&expr, candidate)) { + candidates_.erase(candidate_iter); + return false; + } + + if (!path_.empty() && candidates_.find(path_.back()) != candidates_.end()) { + // parent is optimizeable, defer rewriting until we consider the parent. + return false; + } + + SelectPath path = GetSelectPath(&expr); + + // generate the new cel.attribute call. + absl::string_view fn = path.test_only ? kCelHasField : kCelAttribute; + + Expr operand(std::move(*path.operand)); + Expr call; + call.set_id(expr.id()); + call.mutable_call_expr().set_function(std::string(fn)); + call.mutable_call_expr().mutable_args().reserve(2); + + call.mutable_call_expr().mutable_args().push_back(std::move(operand)); + call.mutable_call_expr().mutable_args().push_back( + MakeSelectPathExpr(path.select_instructions)); + + // TODO(uncreated-issue/54): support for optionals. + expr = std::move(call); + + return true; + } + + absl::Status GetProgressStatus() const { return progress_status_; } + + private: + SelectPath GetSelectPath(Expr* expr) { + SelectPath result; + result.test_only = false; + Expr* operand = expr; + auto candidate_iter = candidates_.find(operand); + while (candidate_iter != candidates_.end()) { + result.select_instructions.push_back(candidate_iter->second); + if (operand->has_select_expr()) { + if (operand->select_expr().test_only()) { + result.test_only = true; + } + operand = &(operand->mutable_select_expr().mutable_operand()); + } else { + ABSL_DCHECK(operand->has_call_expr()); + operand = &(operand->mutable_call_expr().mutable_args()[0]); + } + candidate_iter = candidates_.find(operand); + } + absl::c_reverse(result.select_instructions); + result.operand = operand; + return result; + } + + // Check whether the candidate has a message type as a root (the operand for + // the batched select operation). + // Called on post visit. + bool HasOptimizeableRoot(const Expr* expr, + const QualifierInstruction& candidate) { + if (absl::holds_alternative(candidate)) { + return true; + } + const Expr* operand = nullptr; + if (expr->has_call_expr() && expr->call_expr().args().size() == 2 && + expr->call_expr().function() == ::cel::builtin::kIndex) { + operand = &expr->call_expr().args()[0]; + } else if (expr->has_select_expr()) { + operand = &expr->select_expr().operand(); + } + + if (operand == nullptr) { + return false; + } + + return candidates_.find(operand) != candidates_.end(); + } + + std::optional GetRuntimeType(absl::string_view type_name) { + return planner_context_.type_reflector().FindType(type_name).value_or( + std::nullopt); + } + + void SetProgressStatus(const absl::Status& status) { + if (progress_status_.ok() && !status.ok()) { + progress_status_ = status; + } + } + + const Ast& ast_; + PlannerContext& planner_context_; + // ids of potentially optimizeable expr nodes. + absl::flat_hash_map candidates_; + std::vector path_; + absl::Status progress_status_; +}; + +class OptimizedSelectImpl { + public: + OptimizedSelectImpl(std::vector select_path, + std::vector qualifiers, + bool presence_test, SelectOptimizationOptions options) + : select_path_(std::move(select_path)), + qualifiers_(std::move(qualifiers)), + presence_test_(presence_test), + options_(options) + + { + ABSL_DCHECK(!select_path_.empty()); + } + + // Move constructible. + OptimizedSelectImpl(const OptimizedSelectImpl&) = delete; + OptimizedSelectImpl& operator=(const OptimizedSelectImpl&) = delete; + OptimizedSelectImpl(OptimizedSelectImpl&&) = default; + OptimizedSelectImpl& operator=(OptimizedSelectImpl&&) = delete; + + absl::StatusOr ApplySelect(ExecutionFrameBase& frame, + const StructValue& struct_value) const; + + AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; + + std::optional attribute() const { return attribute_; } + + const std::vector& qualifiers() const { + return qualifiers_; + } + + private: + std::optional attribute_; + std::vector select_path_; + std::vector qualifiers_; + bool presence_test_; + SelectOptimizationOptions options_; +}; + +// Check for unknowns or missing attributes. +absl::StatusOr> CheckForMarkedAttributes( + ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { + if (attribute_trail.empty()) { + return std::nullopt; + } + + if (frame.unknown_processing_enabled() && + frame.attribute_utility().CheckForUnknownExact(attribute_trail)) { + // Check if the inferred attribute is marked. Only matches if this attribute + // or a parent is marked unknown (use_partial = false). + // Partial matches (i.e. descendant of this attribute is marked) aren't + // considered yet in case another operation would select an unmarked + // descended attribute. + // + // TODO(uncreated-issue/51): this may return a more specific attribute than the + // declared pattern. Follow up will truncate the returned attribute to match + // the pattern. + return frame.attribute_utility().CreateUnknownSet( + attribute_trail.attribute()); + } + + if (frame.missing_attribute_errors_enabled() && + frame.attribute_utility().CheckForMissingAttribute(attribute_trail)) { + return frame.attribute_utility().CreateMissingAttributeError( + attribute_trail.attribute()); + } + + return std::nullopt; +} + +absl::StatusOr OptimizedSelectImpl::ApplySelect( + ExecutionFrameBase& frame, const StructValue& struct_value) const { + auto value_or = + (options_.force_fallback_implementation) + ? absl::UnimplementedError("Forced fallback impl") + : struct_value.Qualify(select_path_, presence_test_, + frame.descriptor_pool(), + frame.message_factory(), frame.arena()); + + if (!value_or.ok()) { + if (value_or.status().code() == absl::StatusCode::kUnimplemented) { + return FallbackSelect(struct_value, select_path_, presence_test_, + frame.descriptor_pool(), frame.message_factory(), + frame.arena()); + } + + return value_or.status(); + } + + if (value_or->second < 0 || value_or->second >= select_path_.size()) { + return std::move(value_or->first); + } + + return FallbackSelect( + value_or->first, + absl::MakeConstSpan(select_path_).subspan(value_or->second), + presence_test_, frame.descriptor_pool(), frame.message_factory(), + frame.arena()); +} + +AttributeTrail OptimizedSelectImpl::GetAttributeTrail( + const AttributeTrail& operand_trail) const { + if (operand_trail.empty()) { + return AttributeTrail(); + } + std::vector qualifiers = std::vector( + operand_trail.attribute().qualifier_path().begin(), + operand_trail.attribute().qualifier_path().end()); + qualifiers.reserve(qualifiers_.size() + qualifiers.size()); + absl::c_copy(qualifiers_, std::back_inserter(qualifiers)); + return AttributeTrail( + Attribute(std::string(operand_trail.attribute().variable_name()), + std::move(qualifiers))); +} + +class StackMachineImpl : public ExpressionStepBase { + public: + StackMachineImpl(int expr_id, OptimizedSelectImpl impl) + : ExpressionStepBase(expr_id), impl_(std::move(impl)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + // Get the effective attribute for the optimized select expression. + // Assumes the operand is the top of stack if the attribute wasn't known at + // plan time. + AttributeTrail GetAttributeTrail(ExecutionFrame* frame) const; + + OptimizedSelectImpl impl_; +}; + +AttributeTrail StackMachineImpl::GetAttributeTrail( + ExecutionFrame* frame) const { + const auto& attr = frame->value_stack().PeekAttribute(); + return impl_.GetAttributeTrail(attr); +} + +absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const { + // Default empty. + AttributeTrail attribute_trail; + // TODO(uncreated-issue/51): add support for variable qualifiers and string literal + // variable names. + constexpr size_t kStackInputs = 1; + + // For now, we expect the operand to be top of stack. + const Value& operand = frame->value_stack().Peek(); + + if (operand->Is() || operand->Is()) { + // Just forward the error which is already top of stack. + return absl::OkStatus(); + } + + if (frame->enable_attribute_tracking()) { + // Compute the attribute trail then check for any marked values. + // When possible, this is computed at plan time based on the optimized + // select arguments. + // TODO(uncreated-issue/51): add support variable qualifiers + attribute_trail = GetAttributeTrail(frame); + CEL_ASSIGN_OR_RETURN(std::optional value, + CheckForMarkedAttributes(*frame, attribute_trail)); + if (value.has_value()) { + frame->value_stack().Pop(kStackInputs); + frame->value_stack().Push(std::move(value).value(), + std::move(attribute_trail)); + return absl::OkStatus(); + } + } + + if (!operand->Is()) { + return absl::InvalidArgumentError( + "Expected struct type for select optimization."); + } + + CEL_ASSIGN_OR_RETURN(Value result, + impl_.ApplySelect(*frame, operand.GetStruct())); + + frame->value_stack().Pop(kStackInputs); + frame->value_stack().Push(std::move(result), std::move(attribute_trail)); + return absl::OkStatus(); +} + +class RecursiveImpl : public DirectExpressionStep { + public: + RecursiveImpl(int64_t expr_id, std::unique_ptr operand, + OptimizedSelectImpl impl) + : DirectExpressionStep(expr_id), + operand_(std::move(operand)), + impl_(std::move(impl)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; + + private: + // Get the effective attribute for the optimized select expression. + // Assumes the operand is the top of stack if the attribute wasn't known at + // plan time. + AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; + std::unique_ptr operand_; + OptimizedSelectImpl impl_; +}; + +AttributeTrail RecursiveImpl::GetAttributeTrail( + const AttributeTrail& operand_trail) const { + return impl_.GetAttributeTrail(operand_trail); +} + +absl::Status RecursiveImpl::Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const { + CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute)); + + if (InstanceOf(result) || InstanceOf(result)) { + // Just forward. + return absl::OkStatus(); + } + + if (frame.attribute_tracking_enabled()) { + attribute = impl_.GetAttributeTrail(attribute); + CEL_ASSIGN_OR_RETURN(auto value, + CheckForMarkedAttributes(frame, attribute)); + if (value.has_value()) { + result = std::move(value).value(); + return absl::OkStatus(); + } + } + + if (!InstanceOf(result)) { + return absl::InvalidArgumentError( + "Expected struct type for select optimization"); + } + CEL_ASSIGN_OR_RETURN(result, + impl_.ApplySelect(frame, Cast(result))); + return absl::OkStatus(); +} + +class SelectOptimizer : public ProgramOptimizer { + public: + explicit SelectOptimizer(const SelectOptimizationOptions& options) + : options_(options) {} + + absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { + return absl::OkStatus(); + } + + absl::Status OnPostVisit(PlannerContext& context, const Expr& node) override; + + private: + SelectOptimizationOptions options_; +}; + +absl::Status SelectOptimizer::OnPostVisit(PlannerContext& context, + const Expr& node) { + if (!node.has_call_expr()) { + return absl::OkStatus(); + } + + absl::string_view fn = node.call_expr().function(); + if (fn != kCelHasField && fn != kCelAttribute) { + return absl::OkStatus(); + } + + if (node.call_expr().args().size() < 2 || + node.call_expr().args().size() > 3) { + return absl::InvalidArgumentError("Invalid cel.attribute call"); + } + + if (node.call_expr().args().size() == 3) { + return absl::UnimplementedError("Optionals not yet supported"); + } + + CEL_ASSIGN_OR_RETURN(std::vector instructions, + SelectInstructionsFromCall(node.call_expr())); + + if (instructions.empty()) { + return absl::InvalidArgumentError("Invalid cel.attribute no select steps."); + } + + bool presence_test = false; + + if (fn == kCelHasField) { + presence_test = true; + } + + const Expr& operand = node.call_expr().args()[0]; + absl::string_view identifier; + if (operand.has_ident_expr()) { + identifier = operand.ident_expr().name(); + } + + if (absl::StrContains(identifier, ".")) { + return absl::UnimplementedError("qualified identifiers not supported."); + } + + std::vector qualifiers; + qualifiers.reserve(instructions.size()); + for (const auto& instruction : instructions) { + qualifiers.push_back( + absl::visit(absl::Overload( + [](const FieldSpecifier& field) { + return AttributeQualifier::OfString(field.name); + }, + [](const AttributeQualifier& q) { return q; }), + instruction)); + } + + // TODO(uncreated-issue/51): If the first argument is a string literal, the custom + // step needs to handle variable lookup. + auto* subexpression = context.program_builder().GetSubexpression(&node); + if (subexpression == nullptr || subexpression->IsFlattened()) { + // No information on the subprogram, can't optimize. + return absl::OkStatus(); + } + + OptimizedSelectImpl impl(std::move(instructions), std::move(qualifiers), + presence_test, options_); + + if (subexpression->IsRecursive()) { + auto program = subexpression->ExtractRecursiveProgram(); + auto deps = program.step->ExtractDependencies(); + if (!deps.has_value() || deps->empty()) { + return absl::InvalidArgumentError("Unexpected cel.@attribute call"); + } + subexpression->set_recursive_program( + std::make_unique(node.id(), std::move(deps->at(0)), + std::move(impl)), + program.depth); + return absl::OkStatus(); + } + + google::api::expr::runtime::ExecutionPath path; + + // else, we need to preserve the original plan for the first argument. + if (context.GetSubplan(operand).empty()) { + // Indicates another extension modified the step. Nothing to do here. + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(auto operand_subplan, context.ExtractSubplan(operand)); + absl::c_move(operand_subplan, std::back_inserter(path)); + + path.push_back( + std::make_unique(node.id(), std::move(impl))); + + return context.ReplaceSubplan(node, std::move(path)); +} + +google::api::expr::runtime::FlatExprBuilder* GetFlatExprBuilder( + RuntimeBuilder& builder) { + auto& runtime = + runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder); + if (runtime_internal::RuntimeFriendAccess::RuntimeTypeId(runtime) == + NativeTypeId::For()) { + auto& runtime_impl = + cel::internal::down_cast(runtime); + return &runtime_impl.expr_builder(); + } + return nullptr; +} + +} // namespace + +absl::Status SelectOptimizationAstUpdater::UpdateAst(PlannerContext& context, + Ast& ast) const { + RewriterImpl rewriter(ast, context); + AstRewrite(ast.mutable_root_expr(), rewriter); + return rewriter.GetProgressStatus(); +} + +google::api::expr::runtime::ProgramOptimizerFactory +CreateSelectOptimizationProgramOptimizer( + const SelectOptimizationOptions& options) { + return [=](PlannerContext& context, const Ast& ast) { + return std::make_unique(options); + }; +} + +absl::Status EnableSelectOptimization( + cel::RuntimeBuilder& builder, const SelectOptimizationOptions& options) { + auto* flat_expr_builder = GetFlatExprBuilder(builder); + if (flat_expr_builder == nullptr) { + return absl::InvalidArgumentError( + "SelectOptimization requires default runtime implementation"); + } + + flat_expr_builder->AddAstTransform( + std::make_unique()); + // Add overloads for select optimization signature. + // These are never bound, only used to prevent the builder from failing on + // the overloads check. + CEL_RETURN_IF_ERROR(builder.function_registry().RegisterLazyFunction( + FunctionDescriptor(kCelAttribute, false, {Kind::kAny, Kind::kList}))); + + CEL_RETURN_IF_ERROR(builder.function_registry().RegisterLazyFunction( + FunctionDescriptor(kCelHasField, false, {Kind::kAny, Kind::kList}))); + // Add runtime implementation. + flat_expr_builder->AddProgramOptimizer( + CreateSelectOptimizationProgramOptimizer(options)); + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/select_optimization.h b/extensions/select_optimization.h new file mode 100644 index 000000000..4de81b1b0 --- /dev/null +++ b/extensions/select_optimization.h @@ -0,0 +1,90 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_SELECT_OPTIMIZATION_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_SELECT_OPTIMIZATION_H_ + +#include "absl/status/status.h" +#include "common/ast.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "runtime/runtime_builder.h" + +namespace cel::extensions { + +constexpr char kCelAttribute[] = "cel.@attribute"; +constexpr char kCelHasField[] = "cel.@hasField"; + +// Configuration options for the select optimization. +struct SelectOptimizationOptions { + // Force the program to use the fallback implementation for the select. + // This implementation simply collapses the select operation into one program + // step and calls the normal field accessors on the Struct value. + // + // Normally, the fallback implementation is used when the Qualify operation is + // unimplemented for a given StructType. This option is exposed for testing or + // to more closely match behavior of unoptimized expressions. + bool force_fallback_implementation = false; +}; + +// Enable select optimization on the given RuntimeBuilder, replacing long +// select chains with a single operation. +// +// This assumes that the type information at check time agrees with the +// configured types at runtime. +// +// Important: The select optimization follows spec behavior for traversals. +// - `enable_empty_wrapper_null_unboxing` is ignored and optimized traversals +// always operates as though it is `true`. +// - `enable_heterogeneous_equality` is ignored and optimized traversals +// always operate as though it is `true`. +// +// This should only be called *once* on a given runtime builder. +// +// Assumes the default runtime implementation, an error with code +// InvalidArgument is returned if it is not. +// +// Note: implementation does not support optional field traversal, and will +// instead revert to the normal implementation instead of trying to optimize. +absl::Status EnableSelectOptimization( + cel::RuntimeBuilder& builder, + const SelectOptimizationOptions& options = {}); + +// =============================================================== +// Implementation details -- CEL users should not depend on these. +// Exposed here for enabling on Legacy APIs. They expose internal details +// which are not guaranteed to be stable. +// =============================================================== + +// Scans ast for optimizable select branches. +// +// In general, this should be done by a type checker but may be deferred to +// runtime. +// +// This assumes the runtime type registry has the same definitions as the one +// used by the type checker. +class SelectOptimizationAstUpdater + : public google::api::expr::runtime::AstTransform { + public: + SelectOptimizationAstUpdater() = default; + + absl::Status UpdateAst(google::api::expr::runtime::PlannerContext& context, + cel::Ast& ast) const override; +}; + +google::api::expr::runtime::ProgramOptimizerFactory +CreateSelectOptimizationProgramOptimizer( + const SelectOptimizationOptions& options = {}); + +} // namespace cel::extensions +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SELECT_OPTIMIZATION_H_ diff --git a/extensions/select_optimization_test.cc b/extensions/select_optimization_test.cc new file mode 100644 index 000000000..c14c4d461 --- /dev/null +++ b/extensions/select_optimization_test.cc @@ -0,0 +1,1957 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/select_optimization.h" + +#include +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/empty.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/ast.h" +#include "base/attribute.h" +#include "base/builtins.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/decl_proto.h" +#include "common/expr.h" +#include "common/kind.h" +#include "common/memory.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/evaluator_core.h" +#include "eval/internal/interop.h" +#include "eval/public/cel_type_registry.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "extensions/protobuf/ast_converters.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" +#include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/extension_set.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto2::NestedTestAllTypes; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::CelProtoWrapper; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::FlatExprBuilder; +using ::google::api::expr::runtime::FlatExpression; +using ::google::api::expr::runtime::LegacyTypeAccessApis; +using ::google::api::expr::runtime::LegacyTypeInfoApis; +using ::google::api::expr::runtime::LegacyTypeMutationApis; +using ::google::protobuf::Empty; +using ::testing::_; +using ::testing::AllOf; +using ::testing::AnyOf; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::Truly; + +namespace conformancepb = ::cel::expr::conformance; + +using MessageWrapper = CelValue::MessageWrapper; + +absl::Status ApplyDecl(absl::string_view decl, TypeCheckerBuilder& builder) { + cel::expr::Decl decl_proto; + + if (!google::protobuf::TextFormat::ParseFromString(decl, &decl_proto)) { + return absl::InvalidArgumentError("failed to parse decl"); + } + if (decl_proto.has_ident()) { + CEL_ASSIGN_OR_RETURN( + cel::VariableDecl d, + cel::VariableDeclFromProto(decl_proto.name(), decl_proto.ident(), + builder.descriptor_pool(), builder.arena())); + CEL_RETURN_IF_ERROR(builder.AddVariable(std::move(d))); + } else if (decl_proto.has_function()) { + CEL_ASSIGN_OR_RETURN( + cel::FunctionDecl d, + cel::FunctionDeclFromProto(decl_proto.name(), decl_proto.function(), + builder.descriptor_pool(), builder.arena())); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(d))); + } else { + return absl::InvalidArgumentError("decl has no ident or function"); + } + return absl::OkStatus(); +} + +absl::StatusOr> NewTestCompiler() { + CompilerOptions options; + options.parser_options.enable_quoted_identifiers = true; + CEL_ASSIGN_OR_RETURN(std::unique_ptr builder, + cel::NewCompilerBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCompilerLibrary())); + auto& checker_builder = builder->GetCheckerBuilder(); + google::protobuf::LinkMessageReflection(); + + checker_builder.set_container("cel.expr.conformance"); + + CEL_RETURN_IF_ERROR(ApplyDecl( + R"pb( + name: "nested_test_all_types" + ident { + type { + message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" + } + } + )pb", + checker_builder)); + CEL_RETURN_IF_ERROR(ApplyDecl( + R"pb( + name: "test_all_types" + ident { + type { message_type: "cel.expr.conformance.proto2.TestAllTypes" } + } + )pb", + checker_builder)); + CEL_RETURN_IF_ERROR(ApplyDecl( + R"pb( + name: "a" + ident { + type { + message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" + } + } + )pb", + checker_builder)); + + CEL_RETURN_IF_ERROR(ApplyDecl( + R"pb( + name: "b" + ident { + type { + message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" + } + } + )pb", + checker_builder)); + + CEL_RETURN_IF_ERROR(ApplyDecl( + R"pb( + name: "custom_predicate" + function { + overloads { + doc: "An example predicate function for checking attribute tracking for " + "the result of the optimized select chain." + overload_id: "custom_predicate_TestAllTypesNestedType" + params { + message_type: "cel.expr.conformance.proto2.TestAllTypes.NestedMessage" + } + result_type { primitive: BOOL } + } + } + )pb", + checker_builder)); + + return builder->Build(); +} + +const cel::Compiler& TestCaseCompiler() { + static const Compiler* compiler = []() { + auto compiler = NewTestCompiler(); + ABSL_CHECK_OK(compiler); + return compiler->release(); + }(); + return *compiler; +} + +absl::StatusOr> CompileForTestCase( + absl::string_view expr) { + CEL_ASSIGN_OR_RETURN(cel::ValidationResult r, + TestCaseCompiler().Compile(expr)); + if (!r.IsValid()) { + return absl::InvalidArgumentError(r.FormatError()); + } + return r.ReleaseAst(); +} + +class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { + public: + std::string DebugString( + const MessageWrapper& wrapped_message) const override { + return "MockAccessApis"; + } + + absl::string_view GetTypename( + const MessageWrapper& wrapped_message) const override { + return "MockAccessApis"; + } + + const LegacyTypeAccessApis* GetAccessApis( + const MessageWrapper& wrapped_message) const override { + return this; + } + + const LegacyTypeMutationApis* GetMutationApis( + const MessageWrapper& wrapped_message) const override { + return nullptr; + } + + std::optional< + google::api::expr::runtime::LegacyTypeInfoApis::FieldDescription> + FindFieldByName(absl::string_view field_name) const override { + return std::nullopt; + } + + MOCK_METHOD(absl::StatusOr, GetField, + (absl::string_view field_name, + const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManagerRef memory_manager), + (const, override)); + + MOCK_METHOD(absl::StatusOr, HasField, + (absl::string_view field_name, + const CelValue::MessageWrapper& value), + (const, override)); + + MOCK_METHOD(absl::StatusOr, + Qualify, + (absl::Span qualifiers, + const CelValue::MessageWrapper& instance, bool presence_test, + MemoryManagerRef memory_manager), + (const, override)); + + bool IsEqualTo( + const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const override { + return false; + } + + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override { + return {}; + } +}; + +std::pair MakeMockLegacyMessage( + google::protobuf::Arena* arena) { + auto* mock_access_apis = + google::protobuf::Arena::Create>(arena); + auto* message = google::protobuf::Arena::Create(arena); + + CelValue::MessageWrapper::Builder wrapper(message); + return {mock_access_apis, + CelValue::CreateMessageWrapper(wrapper.Build(mock_access_apis))}; +} + +absl::Status TestBindLegacyValue(absl::string_view variable, + CelValue legacy_value, google::protobuf::Arena* arena, + Activation& act) { + CEL_ASSIGN_OR_RETURN(Value value, + interop_internal::FromLegacyValue(arena, legacy_value)); + + act.InsertOrAssignValue(variable, std::move(value)); + return absl::OkStatus(); +} + +absl::Status TestBindLegacyMessage(absl::string_view variable, + const google::protobuf::Message& message, + google::protobuf::Arena* arena, cel::Activation& act) { + CelValue legacy_value = CelProtoWrapper::CreateMessage(&message, arena); + + return TestBindLegacyValue(variable, legacy_value, arena, act); +} + +class SelectOptimizationTest : public testing::Test { + public: + SelectOptimizationTest() + : env_(NewTestingRuntimeEnv()), + legacy_registry_(env_->legacy_type_registry), + type_registry_(env_->type_registry), + function_registry_(env_->function_registry), + resolver_("", function_registry_, type_registry_, + type_registry_.GetComposedTypeProvider()), + issue_collector_(RuntimeIssue::Severity::kError), + context_(env_, resolver_, runtime_options_, + type_registry_.GetComposedTypeProvider(), issue_collector_, + program_builder_, shared_arena_) { + runtime_options_.fail_on_warnings = false; + } + + void SetUp() override { + google::protobuf::LinkMessageReflection(); + ASSERT_THAT( + function_registry_.Register( + UnaryFunctionAdapter::CreateDescriptor( + "custom_predicate", false), + UnaryFunctionAdapter::WrapFunction( + [](const StructValue&) { return true; })), + IsOk()); + } + + protected: + absl_nonnull std::shared_ptr env_; + google::api::expr::runtime::CelTypeRegistry& legacy_registry_; + TypeRegistry& type_registry_; + FunctionRegistry& function_registry_; + google::protobuf::Arena arena_; + RuntimeOptions runtime_options_; + google::api::expr::runtime::Resolver resolver_; + cel::runtime_internal::IssueCollector issue_collector_; + google::api::expr::runtime::ProgramBuilder program_builder_; + std::shared_ptr shared_arena_; + google::api::expr::runtime::PlannerContext context_; +}; + +MATCHER_P2(SelectFieldEntry, id, name, "") { + const cel::Expr& entry = arg.expr(); + + if (entry.list_expr().elements().size() != 2) { + *result_listener << "want 2-tuple entry, got " + << entry.list_expr().elements().size(); + return false; + } + + int64_t got_id = + entry.list_expr().elements()[0].expr().const_expr().int64_value(); + absl::string_view got_name = + entry.list_expr().elements()[1].expr().const_expr().string_value(); + + *result_listener << "want " << id << ": '" << name << "'" << " got " << got_id + << ": '" << got_name << "'"; + + return entry.list_expr().elements()[0].expr().const_expr().int64_value() == + id && + entry.list_expr().elements()[1].expr().const_expr().string_value() == + name; +} + +std::string ToString(const AttributeQualifier& qualifier) { + switch (qualifier.kind()) { + case Kind::kInt: + return absl::StrCat(*qualifier.GetInt64Key()); + case Kind::kString: + return absl::StrCat("'", *qualifier.GetStringKey(), "'"); + case Kind::kUint: + return absl::StrCat(*qualifier.GetUint64Key()); + case Kind::kBool: + return absl::StrCat(*qualifier.GetBoolKey()); + default: + return ""; + } +} + +MATCHER_P(SelectQualifier, qualifier, + absl::StrCat("SelectQualifier: ", ToString(qualifier))) { + const cel::Expr& entry = arg.expr(); + + if (!entry.has_const_expr()) { + *result_listener << "wanted const_expr"; + return false; + } + + cel::AttributeQualifier got_qualifier; + if (entry.const_expr().has_int64_value()) { + got_qualifier = AttributeQualifier::OfInt(entry.const_expr().int64_value()); + } else if (entry.const_expr().has_string_value()) { + got_qualifier = + AttributeQualifier::OfString(entry.const_expr().string_value()); + } else if (entry.const_expr().has_bool_value()) { + got_qualifier = AttributeQualifier::OfBool(entry.const_expr().bool_value()); + } else if (entry.const_expr().has_uint64_value()) { + got_qualifier = + AttributeQualifier::OfUint(entry.const_expr().uint64_value()); + } + + *result_listener << "want " << ToString(qualifier) << " got " + << ToString(got_qualifier); + + return qualifier == got_qualifier; +} + +TEST_F(SelectOptimizationTest, AstTransformSelect) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase( + "nested_test_all_types.child.payload.standalone_message.bb")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + const auto& attr_call = ast->root_expr().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@attribute"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_EQ(attr_call.args()[0].ident_expr().name(), "nested_test_all_types"); + + EXPECT_THAT( + attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(2, "payload"), + SelectFieldEntry(23, "standalone_message"), + SelectFieldEntry(1, "bb"))); +} + +TEST_F(SelectOptimizationTest, AstTransformSelectPresence) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase( + "has(nested_test_all_types.child.payload.standalone_message.bb)")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + const auto& attr_call = ast->root_expr().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@hasField"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_EQ(attr_call.args()[0].ident_expr().name(), "nested_test_all_types"); + + EXPECT_THAT( + attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(2, "payload"), + SelectFieldEntry(23, "standalone_message"), + SelectFieldEntry(1, "bb"))); +} + +TEST_F(SelectOptimizationTest, AstTransformComplexSelect) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase( + "((false)? a.child.child : b.child).child.payload.single_int64")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + const auto& attr_call = ast->root_expr().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@attribute"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_THAT( + attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(2, "payload"), + SelectFieldEntry(2, "single_int64"))); + + const auto& operand = attr_call.args()[0]; + + EXPECT_EQ(operand.call_expr().function(), cel::builtin::kTernary); + ASSERT_THAT(operand.call_expr().args(), SizeIs(3)); + + const auto& true_branch = operand.call_expr().args()[1]; + + EXPECT_EQ(true_branch.call_expr().function(), "cel.@attribute"); + ASSERT_THAT(true_branch.call_expr().args(), SizeIs(2)); + + EXPECT_THAT( + true_branch.call_expr().args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(1, "child"))); +} + +TEST_F(SelectOptimizationTest, AstTransformMapIndexTraversal) { + // nested_test_all_types.payload.map_string_message['$not_a_field'].bb + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CompileForTestCase("nested_test_all_types.payload.map_" + "string_message['$not_a_field'].bb")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + const auto& attr_call = ast->root_expr().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@attribute"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_THAT( + attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(2, "payload"), + SelectFieldEntry(227, "map_string_message"), + SelectQualifier(AttributeQualifier::OfString("$not_a_field")), + SelectFieldEntry(1, "bb"))); + + const auto& operand = attr_call.args()[0]; + + EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); +} + +TEST_F(SelectOptimizationTest, AstTransformMapIndexUnsupportedConstant) { + // nested_test_all_types.payload.map_string_message['$not_a_field'].bb + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CompileForTestCase("nested_test_all_types.payload.map_" + "string_message['$not_a_field'].bb")); + + // Type-checker shouldn't allow a bytes key, so simulating here for + // coverage. + ast->mutable_root_expr() + .mutable_select_expr() + .mutable_operand() + .mutable_call_expr() + .mutable_args()[1] + .mutable_const_expr() + .set_bytes_value("$not_a_field"); + + // We don't fail here, but we also don't optimize past the map lookup with + // an unsupported constant key. + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + EXPECT_EQ(ast->root_expr().call_expr().function(), "cel.@attribute"); + ASSERT_THAT(ast->root_expr().call_expr().args(), SizeIs(2)); + EXPECT_EQ(ast->root_expr().call_expr().args()[0].call_expr().function(), + "_[_]"); + // cel.@attribute( + // cel.@attribute( + // nested_test_all_types, + // [payload, map_string_message])[b'$not_a_field'], + // [bb]) + EXPECT_THAT(ast->root_expr().call_expr().args()[1].list_expr().elements(), + SizeIs(1)); +} + +TEST_F(SelectOptimizationTest, AstTransformMapIndexHeterogeneousDoubleKey) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase("nested_test_all_types.payload.single_any[1.0].bb")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + EXPECT_EQ(ast->root_expr().select_expr().field(), "bb"); + // TODO(uncreated-issue/51): Right now we don't optimize past a dyn/any field + // and discard the select optimization if the root isn't a message, so we will + // consider the double as a candidate then discard it. + EXPECT_THAT(ast->root_expr().select_expr().operand().call_expr().function(), + "cel.@attribute"); + ASSERT_THAT(ast->root_expr().select_expr().operand().call_expr().args(), + SizeIs(2)); + EXPECT_THAT(ast->root_expr() + .select_expr() + .operand() + .call_expr() + .args()[1] + .list_expr() + .elements(), + SizeIs(3)); +} + +TEST_F(SelectOptimizationTest, AstTransformMapIndexHeterogeneousDoubleKeyUint) { + constexpr uint64_t kBigUint = + static_cast(internal::kMaxDoubleRepresentableAsUint); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase(absl::StrCat( + "nested_test_all_types.payload.single_any[", kBigUint, ".0].bb"))); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + EXPECT_EQ(ast->root_expr().select_expr().field(), "bb"); + // TODO(uncreated-issue/51): Right now we don't optimize past a dyn/any field + // and discard additional select steps. + EXPECT_THAT(ast->root_expr().select_expr().operand().call_expr().function(), + "cel.@attribute"); + ASSERT_THAT(ast->root_expr().select_expr().operand().call_expr().args(), + SizeIs(2)); + EXPECT_THAT(ast->root_expr() + .select_expr() + .operand() + .call_expr() + .args()[1] + .list_expr() + .elements(), + SizeIs(3)); +} + +TEST_F(SelectOptimizationTest, AstTransformFilterToMessageRoot) { + // {'field_like_key': + // nested_test_all_types}.field_like_key.payload.single_int64 + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase( + "{'field_like_key': " + "nested_test_all_types}.field_like_key.payload.single_int64")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + const auto& attr_call = ast->root_expr().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@attribute"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_THAT(attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(2, "payload"), + SelectFieldEntry(2, "single_int64"))); + + const auto& operand = attr_call.args()[0]; + + EXPECT_EQ(operand.select_expr().field(), "field_like_key"); +} + +TEST_F(SelectOptimizationTest, AstTransformMapDotTraversal) { + // nested_test_all_types.payload.map_string_message.field_like_key.bb + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CompileForTestCase("nested_test_all_types.payload.map_" + "string_message.field_like_key.bb")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + const auto& attr_call = ast->root_expr().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@attribute"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_THAT(attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(2, "payload"), + SelectFieldEntry(227, "map_string_message"), + SelectQualifier( + AttributeQualifier::OfString("field_like_key")), + SelectFieldEntry(1, "bb"))); + + const auto& operand = attr_call.args()[0]; + + EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); +} + +TEST_F(SelectOptimizationTest, AstTransformAnyDotTraversal) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase( + "nested_test_all_types.payload.single_any.single_int64")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + // When fully supported, we'd expect this to collapse to one attribute call. + const auto& attr_call = ast->root_expr().select_expr().operand().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@attribute"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_THAT(attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(2, "payload"), + SelectFieldEntry(100, "single_any"))); + + const auto& operand = attr_call.args()[0]; + + EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); +} + +TEST_F(SelectOptimizationTest, AstTransformRepeated) { + // nested_test_all_types.payload.repeated_nested_message[1].bb + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase( + "nested_test_all_types.payload.repeated_nested_message[1].bb")); + + SelectOptimizationAstUpdater updater; + EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); + + // When fully supported, we'd expect this to collapse to one attribute call. + const auto& attr_call = ast->root_expr().call_expr(); + EXPECT_EQ(attr_call.function(), "cel.@attribute"); + + ASSERT_THAT(attr_call.args(), SizeIs(2)); + + EXPECT_THAT(attr_call.args()[1].list_expr().elements(), + ElementsAre(SelectFieldEntry(2, "payload"), + SelectFieldEntry(51, "repeated_nested_message"), + SelectQualifier(AttributeQualifier::OfInt(1)), + SelectFieldEntry(1, "bb"))); + + const auto& operand = attr_call.args()[0]; + + EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); +} + +TEST_F(SelectOptimizationTest, AstTransformParseOnlyNotUpdated) { + google::protobuf::LinkMessageReflection(); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddAstTransform(std::make_unique()); + + // nested_test_all_types.payload.repeated_nested_message[1].bb + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("nested_test_all_types.payload.repeated_nested_message[1].bb")); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CreateAstFromParsedExpr(expr)); + ASSERT_OK_AND_ASSIGN(FlatExpression plan, + builder.CreateExpressionImpl(std::move(ast), nullptr)); + + NestedTestAllTypes var; + + var.mutable_payload()->add_repeated_nested_message(); + var.mutable_payload()->add_repeated_nested_message()->set_bb(42); + + cel::Activation act; + ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), + IsOk()); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + ASSERT_OK_AND_ASSIGN( + Value result, + plan.EvaluateWithCallback( + act, /*embedder_context=*/nullptr, + google::api::expr::runtime::EvaluationListener(), state)); + + ASSERT_TRUE(result->Is()) << result->DebugString(); + + EXPECT_EQ(result.GetInt().NativeValue(), 42); +} + +TEST_F(SelectOptimizationTest, ProgramOptimizerUnoptimizedAst) { + google::protobuf::LinkMessageReflection(); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + // nested_test_all_types.child.payload.standalone_message.bb + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase( + "nested_test_all_types.child.payload.standalone_message.bb")); + + ASSERT_OK_AND_ASSIGN(FlatExpression plan, + builder.CreateExpressionImpl(std::move(ast), nullptr)); + + NestedTestAllTypes var; + + var.mutable_child()->mutable_payload()->mutable_standalone_message()->set_bb( + 42); + + cel::Activation act; + ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), + IsOk()); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + ASSERT_OK_AND_ASSIGN( + Value result, + plan.EvaluateWithCallback( + act, /*embedder_context=*/nullptr, + google::api::expr::runtime::EvaluationListener(), state)); + + ASSERT_TRUE(result->Is()) << result->DebugString(); + + EXPECT_EQ(result.GetInt().NativeValue(), 42); +} + +TEST_F(SelectOptimizationTest, MissingAttributeIndependentOfUnknown) { + google::protobuf::LinkMessageReflection(); + + RuntimeOptions options = runtime_options_; + options.unknown_processing = UnknownProcessingOptions::kDisabled; + options.enable_missing_attribute_errors = true; + + FlatExprBuilder builder(env_, options); + + builder.AddAstTransform(std::make_unique()); + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase("custom_predicate(nested_test_all_types.child.payload." + "standalone_message)")); + + ASSERT_OK_AND_ASSIGN(FlatExpression plan, + builder.CreateExpressionImpl(std::move(ast), nullptr)); + + cel::Activation act; + // activation only uses a ptr to the underlying message, persist them. + NestedTestAllTypes var; + + act.SetMissingPatterns( + {AttributePattern("nested_test_all_types", + { + AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("payload"), + })}); + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + child { payload { standalone_message { bb: 20 } } } + )pb", + &var)); + ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), + IsOk()); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + ASSERT_OK_AND_ASSIGN( + Value result, + plan.EvaluateWithCallback( + act, /*embedder_context=*/nullptr, + google::api::expr::runtime::EvaluationListener(), state)); + + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("nested_test_all_types.child.payload"))); +} + +TEST_F(SelectOptimizationTest, NullUnboxingOptionHonored) { + google::protobuf::LinkMessageReflection(); + + RuntimeOptions options = runtime_options_; + options.enable_empty_wrapper_null_unboxing = true; + + FlatExprBuilder builder(env_, options); + + builder.AddAstTransform(std::make_unique()); + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + // nested_test_all_types.payload.single_int64_wrapper + ASSERT_OK_AND_ASSIGN( + std::unique_ptr ast, + CompileForTestCase("nested_test_all_types.payload.single_int64_wrapper")); + + ASSERT_OK_AND_ASSIGN(FlatExpression plan, + builder.CreateExpressionImpl(std::move(ast), nullptr)); + + cel::Activation act; + // activation only uses a ptr to the underlying message, persist them. + NestedTestAllTypes var; + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + payload {} + )pb", + &var)); + ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), + IsOk()); + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + ASSERT_OK_AND_ASSIGN( + Value result, + plan.EvaluateWithCallback( + act, /*embedder_context=*/nullptr, + google::api::expr::runtime::EvaluationListener(), state)); + + ASSERT_TRUE(result->Is()) << result->DebugString(); +} + +using ActivationSetupFn = + std::function; + +struct ProgramOptimizerTestCase { + std::string case_name; + std::string expr; + // identifier -> NestedTestAllTypes textproto + absl::flat_hash_map vars; + ActivationSetupFn setup_activation; + std::function&)> expectations; +}; + +class SelectOptimizationProgramOptimizerTest + : public SelectOptimizationTest, + public testing::WithParamInterface {}; + +TEST_P(SelectOptimizationProgramOptimizerTest, Default) { + const ProgramOptimizerTestCase& test_case = GetParam(); + google::protobuf::LinkMessageReflection(); + + RuntimeOptions options = runtime_options_; + options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; + options.enable_missing_attribute_errors = true; + + FlatExprBuilder builder(env_, options); + + builder.AddAstTransform(std::make_unique()); + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CompileForTestCase(test_case.expr)); + + ASSERT_OK_AND_ASSIGN(FlatExpression plan, + builder.CreateExpressionImpl(std::move(ast), nullptr)); + + cel::Activation act; + // activation only uses a ptr to the underlying message, persist them. + std::vector> vars; + for (const auto& entry : test_case.vars) { + vars.push_back(std::make_unique()); + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(entry.second, vars.back().get())); + ASSERT_THAT(TestBindLegacyMessage(entry.first, *vars.back(), &arena_, act), + IsOk()); + } + + if (test_case.setup_activation != nullptr) { + ASSERT_THAT(test_case.setup_activation(&arena_, act), IsOk()); + } + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + absl::StatusOr result = plan.EvaluateWithCallback( + act, /*embedder_context=*/nullptr, + google::api::expr::runtime::EvaluationListener(), state); + + ASSERT_NO_FATAL_FAILURE(test_case.expectations(result)); +} + +TEST_P(SelectOptimizationProgramOptimizerTest, ForceFallbackImpl) { + const ProgramOptimizerTestCase& test_case = GetParam(); + google::protobuf::LinkMessageReflection(); + + RuntimeOptions options = runtime_options_; + options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; + options.enable_missing_attribute_errors = true; + + FlatExprBuilder builder(env_, options); + SelectOptimizationOptions select_options; + select_options.force_fallback_implementation = true; + + builder.AddAstTransform(std::make_unique()); + builder.AddProgramOptimizer( + CreateSelectOptimizationProgramOptimizer(select_options)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CompileForTestCase(test_case.expr)); + + ASSERT_OK_AND_ASSIGN(FlatExpression plan, + builder.CreateExpressionImpl(std::move(ast), nullptr)); + + cel::Activation act; + // activation only uses a ptr to the underlying message, persist them. + std::vector> vars; + for (const auto& entry : test_case.vars) { + vars.push_back(std::make_unique()); + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(entry.second, vars.back().get())); + ASSERT_THAT(TestBindLegacyMessage(entry.first, *vars.back(), &arena_, act), + IsOk()); + } + + if (test_case.setup_activation != nullptr) { + ASSERT_THAT(test_case.setup_activation(&arena_, act), IsOk()); + } + + auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), + env_->MutableMessageFactory(), &arena_); + absl::StatusOr result = plan.EvaluateWithCallback( + act, /*embedder_context=*/nullptr, + google::api::expr::runtime::EvaluationListener(), state); + + ASSERT_NO_FATAL_FAILURE(test_case.expectations(result)); +} + +INSTANTIATE_TEST_SUITE_P( + TestCases, SelectOptimizationProgramOptimizerTest, + testing::ValuesIn({ + { + "chained_select_success", + "nested_test_all_types.child.payload.standalone_message.bb", + {{"nested_test_all_types", + R"pb( + child { payload { standalone_message { bb: 42 } } } + )pb"}}, + ActivationSetupFn(), + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "chained_select_defaults_success", + "nested_test_all_types.child.payload.standalone_message.bb", + {{"nested_test_all_types", R"pb()pb"}}, + ActivationSetupFn(), + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 0); + }, + }, + { + "chained_select_partial_success", + "nested_test_all_types.child.payload.standalone_message.bb", + {}, + [](google::protobuf::Arena* arena, Activation& act) { + auto mock_pair = MakeMockLegacyMessage(arena); + MockAccessApis* mock = mock_pair.first; + CelValue mocked_value = mock_pair.second; + ON_CALL(*mock, Qualify(SizeIs(4), _, /*presence_test=*/false, _)) + .WillByDefault( + Return(LegacyTypeAccessApis::LegacyQualifyResult{ + mocked_value, 3})); + ON_CALL(*mock, GetField("bb", _, _, _)) + .WillByDefault(Return(CelValue::CreateInt64(42))); + + // Support the forced-fallback case. + ON_CALL(*mock, GetField(AnyOf(Eq("child"), Eq("payload"), + Eq("standalone_message")), + _, _, _)) + .WillByDefault(Return(mocked_value)); + + return TestBindLegacyValue("nested_test_all_types", mocked_value, + arena, act); + }, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "chained_select_presence_partial_present", + "has(nested_test_all_types.child.payload.standalone_message.bb)", + {}, + [](google::protobuf::Arena* arena, Activation& act) { + auto mock_pair = MakeMockLegacyMessage(arena); + MockAccessApis* mock = mock_pair.first; + CelValue mocked_value = mock_pair.second; + ON_CALL(*mock, Qualify(SizeIs(4), _, /*presence_test=*/true, _)) + .WillByDefault( + Return(LegacyTypeAccessApis::LegacyQualifyResult{ + mocked_value, 3})); + ON_CALL(*mock, HasField("bb", _)).WillByDefault(Return(true)); + ON_CALL(*mock, GetField("bb", _, _, _)) + .WillByDefault(Return(CelValue::CreateInt64(42))); + + // Support the forced-fallback case. + ON_CALL(*mock, GetField(AnyOf(Eq("child"), Eq("payload"), + Eq("standalone_message")), + _, _, _)) + .WillByDefault(Return(mocked_value)); + ON_CALL(*mock, HasField(AnyOf(Eq("child"), Eq("payload"), + Eq("standalone_message")), + _)) + .WillByDefault(Return(true)); + + return TestBindLegacyValue("nested_test_all_types", mocked_value, + arena, act); + }, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "chained_select_not_bound", + "nested_test_all_types.child.payload.standalone_message.bb", + {}, // not set + ActivationSetupFn(), + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("nested_test_all_types"))); + }, + }, + { + // Some clients will use maps to represent a protobuf message at + // runtime. This is not yet supported. + "chained_select_map_as_root_unsupported", + "nested_test_all_types.child.payload.standalone_message.bb", + {}, // not set + [](google::protobuf::Arena* arena, Activation& act) -> absl::Status { + auto builder = cel::NewMapValueBuilder(arena); + CEL_RETURN_IF_ERROR( + builder->Put(cel::StringValue("child"), cel::NullValue())); + + auto value = std::move(*builder).Build(); + + act.InsertOrAssignValue("nested_test_all_types", + std::move(value)); + + return absl::OkStatus(); + }, + + [](const absl::StatusOr& got) { + EXPECT_THAT(got.status(), + StatusIs(absl::StatusCode::kInvalidArgument)); + }, + }, + { + // Some clients will use maps to represent a protobuf at runtime, + // this is not yet supported. + "chained_select_noncontainer_as_root_unsupported", + "nested_test_all_types.child.payload.standalone_message.bb", + {}, // not set + [](google::protobuf::Arena* arena, Activation& act) { + act.InsertOrAssignValue("nested_test_all_types", + cel::DurationValue(absl::Seconds(1))); + return absl::OkStatus(); + }, + + [](const absl::StatusOr& got) { + EXPECT_THAT(got.status(), + StatusIs(absl::StatusCode::kInvalidArgument)); + }, + }, + { + "complex_select_success", + "((false)? a.child.child : b.child).child.payload.single_int64", + {{"a", ""}, + {"b", + R"pb( + child { child { payload { single_int64: -42 } } } + )pb"}}, + ActivationSetupFn(), + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), -42); + }, + }, + { + "chained_select_presence_present", + "has(nested_test_all_types.child.payload.standalone_message.bb)", + {{"nested_test_all_types", + R"pb( + child { payload { standalone_message { bb: 2 } } } + )pb"}}, + ActivationSetupFn(), + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "chained_select_presence_not_present", + "has(nested_test_all_types.child.payload.standalone_message.bb)", + {{"nested_test_all_types", + R"pb( + child { payload { standalone_message {} } } + )pb"}}, + ActivationSetupFn(), + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_FALSE(result.GetBool().NativeValue()); + }, + }, + { + "select_with_map_supported", + "nested_test_all_types.payload.map_string_message['$not_a_field']." + "bb", + {{"nested_test_all_types", + R"pb( + payload { + map_string_message { + key: "$not_a_field", + value { bb: 5 } + } + } + )pb"}}, + ActivationSetupFn(), + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 5); + }, + }, + { + "select_with_map_no_such_key", + "nested_test_all_types.payload.map_string_message['$not_a_field']." + "bb", + {{"nested_test_all_types", + R"pb( + payload { + map_string_message { + key: "a_different_field", + value { bb: 5 } + } + } + )pb"}}, + ActivationSetupFn(), + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kNotFound, + AllOf(HasSubstr("Key not found"), + HasSubstr("$not_a_field")))); + }, + }, + { + "select_with_repeated_supported", + "nested_test_all_types.payload.repeated_nested_message[1].bb", + {{"nested_test_all_types", + R"pb( + payload { + repeated_nested_message {} + repeated_nested_message { bb: 7 } + } + )pb"}}, + ActivationSetupFn(), + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 7); + }, + }, + { + "select_with_repeated_index_out_of_bounds", + "nested_test_all_types.payload.repeated_nested_message[1].bb", + {{"nested_test_all_types", + R"pb( + payload { repeated_nested_message {} } + )pb"}}, + ActivationSetupFn(), + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("index out of bounds"))); + }, + }, + { + "unknown_field", + "((false)? a.child.child : b.child).child.payload.single_int64", + {{"a", ""}, + {"b", + R"pb( + child { child { payload { single_int64: -42 } } } + )pb"}}, + [](google::protobuf::Arena*, Activation& act) { + act.SetUnknownPatterns({AttributePattern( + "b", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("child")})}); + return absl::OkStatus(); + }, + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT( + result.GetUnknown().attribute_set(), + ElementsAre(Eq(Attribute( + "b", { + AttributeQualifier::OfString("child"), + AttributeQualifier::OfString("child"), + AttributeQualifier::OfString("payload"), + AttributeQualifier::OfString("single_int64"), + })))); + }, + }, + { + "unknown_field_partial", + "((false)? a.child.child : b.child).child.payload.single_int64", + {{"a", ""}, + {"b", + R"pb( + child { child { payload { single_int64: -42 } } } + )pb"}}, + [](google::protobuf::Arena*, Activation& act) { + act.SetUnknownPatterns({AttributePattern( + "b", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("child")})}); + return absl::OkStatus(); + }, + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), -42); + }, + }, + { + "unknown_ident", + "((false)? a.child.child : b.child).child.payload.single_int64", + {{"a", ""}, + {"b", + R"pb( + child { child { payload { single_int64: -42 } } } + )pb"}}, + [](google::protobuf::Arena*, Activation& act) { + act.SetUnknownPatterns({ + AttributePattern("b", {}), + }); + return absl::OkStatus(); + }, + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT(result.GetUnknown().attribute_set(), + ElementsAre(Truly([](const Attribute& attr) { + return attr.variable_name() == "b"; + }))); + }, + }, + { + "unknown_pruned", + "((false)? a.child.child : b.child).child.payload.single_int64", + {{"a", ""}, + {"b", + R"pb( + child { child { payload { single_int64: -42 } } } + )pb"}}, + [](google::protobuf::Arena*, Activation& act) { + act.SetUnknownPatterns({ + AttributePattern("a", {}), + }); + return absl::OkStatus(); + }, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), -42); + }, + }, + { + "missing_field", + "custom_predicate(nested_test_all_types.child.payload.standalone_" + "message)", + {{"nested_test_all_types", + R"pb( + child { payload { standalone_message { bb: 20 } } } + )pb"}}, + [](google::protobuf::Arena*, Activation& act) { + act.SetMissingPatterns({AttributePattern( + "nested_test_all_types", + { + AttributeQualifierPattern::OfString("child"), + })}); + return absl::OkStatus(); + }, + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT(result.GetError().NativeValue().message(), + HasSubstr("nested_test_all_types.child.payload." + "standalone_message")); + }, + }, + { + "missing_field_partial", + "custom_predicate(nested_test_all_types.child.payload.standalone_" + "message)", + {{"nested_test_all_types", + R"pb( + child { payload { standalone_message { bb: 20 } } } + )pb"}}, + [](google::protobuf::Arena*, Activation& act) { + act.SetMissingPatterns({AttributePattern( + "b", {AttributeQualifierPattern::OfString("child"), + AttributeQualifierPattern::OfString("child")})}); + return absl::OkStatus(); + }, + + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "select_wrapper_int_leaf", + "nested_test_all_types.payload.single_int64_wrapper", + {{"nested_test_all_types", + R"pb( + payload { single_int64_wrapper { value: 10 } } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 10); + }, + }, + { + "select_repeated_leaf", + "nested_test_all_types.payload.repeated_int64", + {{"nested_test_all_types", + R"pb( + payload { repeated_int64: 10 } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + }, + }, + { + "select_map_leaf", + "nested_test_all_types.payload.map_string_int64", + {{"nested_test_all_types", + R"pb( + payload { map_string_int64 { key: "key", value: 12 } } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + }, + }, + { + "select_with_map_dot", + "nested_test_all_types.payload.map_string_message.field_like_key." + "bb", + {{"nested_test_all_types", + R"pb( + payload { + map_string_message { + key: "field_like_key", + value { bb: 42 } + } + } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "select_with_map_bool", + "nested_test_all_types.payload.map_bool_message[false].bb", + {{"nested_test_all_types", + R"pb( + payload { + map_bool_message { + key: false, + value { bb: 42 } + } + } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "select_with_map_int", + "nested_test_all_types.payload.map_int64_message[-1].bb", + {{"nested_test_all_types", + R"pb( + payload { + map_int64_message { + key: -1, + value { bb: 42 } + } + } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "select_with_map_uint", + "nested_test_all_types.payload.map_uint64_message[1u].bb", + {{"nested_test_all_types", + R"pb( + payload { + map_uint64_message { + key: 1, + value { bb: 42 } + } + } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "select_with_repeated", + "nested_test_all_types.payload.repeated_nested_message[1].bb", + {{"nested_test_all_types", + R"pb( + payload { + repeated_nested_message {} + repeated_nested_message { bb: 42 } + } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "select_with_any", + "nested_test_all_types.payload.single_any.single_int64", + {{"nested_test_all_types", + R"pb( + payload { + single_any { + [type.googleapis.com/cel.expr.conformance.proto2 + .TestAllTypes] { single_int64: 42 } + } + } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_EQ(result.GetInt().NativeValue(), 42); + }, + }, + { + "has_repeated_leaf_true", + "has(nested_test_all_types.payload.repeated_int64)", + {{"nested_test_all_types", + R"pb( + payload { repeated_int64: 42 } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "has_repeated_leaf_false", + "has(nested_test_all_types.payload.repeated_int64)", + {{"nested_test_all_types", + R"pb( + payload {} + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_FALSE(result.GetBool().NativeValue()); + }, + }, + { + "has_map_leaf_true", + "has(nested_test_all_types.payload.map_string_int64)", + {{"nested_test_all_types", + R"pb( + payload { map_string_int64 { key: "string" value: 12 } } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "has_map_leaf_false", + "has(nested_test_all_types.payload.map_string_int64)", + {{"nested_test_all_types", + R"pb( + payload {} + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_FALSE(result.GetBool().NativeValue()); + }, + }, + { + "has_map_field_like_key", + "has(nested_test_all_types.payload.map_string_int64.field_like_" + "key)", + {{"nested_test_all_types", + R"pb( + payload { map_string_int64 { key: "field_like_key" value: 12 } } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "has_map_field_like_key_false", + "has(nested_test_all_types.payload.map_string_int64.field_like_" + "key)", + {{"nested_test_all_types", + R"pb( + payload { map_string_int64 { key: "wrong_key" value: 12 } } + )pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_FALSE(result.GetBool().NativeValue()); + }, + }, + { + "select_wrong_runtime_type", + "test_all_types.single_int64", + {{}}, + [](google::protobuf::Arena* arena, Activation& activation) { + activation.InsertOrAssignValue("test_all_types", + cel::IntValue(42)); + return absl::OkStatus(); + }, + [](const absl::StatusOr& got) { + EXPECT_THAT(got, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Expected struct type"))); + }, + }, + { + "select_with_struct", + "nested_test_all_types.payload.single_struct['key']['subkey']", + {{"nested_test_all_types", + R"pb(payload { + single_struct { + fields { + key: "key" + value { + struct_value { + fields { + key: "subkey" + value { bool_value: true } + } + } + } + } + } + })pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "select_with_list_value", + "nested_test_all_types.payload.list_value[0]['subkey']", + {{"nested_test_all_types", + R"pb(payload { + list_value { + values { + struct_value { + fields { + key: "subkey" + value { bool_value: true } + } + } + } + } + })pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + { + "select_with_value", + "nested_test_all_types.payload.single_value['key']['subkey']", + {{"nested_test_all_types", + R"pb(payload { + single_value { + struct_value { + fields { + key: "key" + value { + struct_value { + fields { + key: "subkey" + value { bool_value: true } + } + } + } + } + } + } + })pb"}}, + nullptr, + [](const absl::StatusOr& got) { + ASSERT_OK_AND_ASSIGN(Value result, got); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_TRUE(result.GetBool().NativeValue()); + }, + }, + }), + + [](const testing::TestParamInfo& info) { + return info.param.case_name; + }); + +// Tests covering unexpected / malformed ASTs. +// +// These cases shouldn't be possible under normal usage, but are possible if +// there's a bug in the optimizer implementation or if a hand-rolled AST is +// used. +class SelectOptimizationUnexpectedAstTest : public SelectOptimizationTest { + public: + SelectOptimizationUnexpectedAstTest() + : SelectOptimizationTest(), next_id_(1) {} + + Expr NextExpr() { + Expr result; + result.set_id(next_id_++); + return result; + } + + cel::ListExprElement NextListExprElement() { + cel::ListExprElement element; + element.set_expr(NextExpr()); + return element; + } + + protected: + int64_t next_id_; +}; + +TEST_F(SelectOptimizationUnexpectedAstTest, WrongArgumentCount) { + std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); + + ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); + ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_ident_expr() + .set_name("ident"); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(SelectOptimizationUnexpectedAstTest, EmptySelectPath) { + std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); + + ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); + ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_ident_expr() + .set_name("ident"); + ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_list_expr(); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(SelectOptimizationUnexpectedAstTest, MalformedSelectPathNotPair) { + std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); + + ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); + ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_ident_expr() + .set_name("ident"); + auto& select_step_list = ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_list_expr(); + + auto& select_step_element = select_step_list.mutable_elements() + .emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_list_expr(); + + select_step_element.mutable_elements() + .emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_const_expr() + .set_string_value("field"); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(SelectOptimizationUnexpectedAstTest, MalformedSelectPathWrongPairTypes) { + std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); + + ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); + ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_ident_expr() + .set_name("ident"); + auto& select_step_list = ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_list_expr(); + + auto& select_step_element = select_step_list.mutable_elements() + .emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_list_expr(); + + select_step_element.mutable_elements() + .emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_const_expr() + .set_string_value("field"); + + select_step_element.mutable_elements() + .emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_const_expr() + .set_int64_value(1); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(SelectOptimizationUnexpectedAstTest, + MalformedSelectPathUnsupportedConstant) { + std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); + + ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); + ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_ident_expr() + .set_name("ident"); + auto& select_step_list = ast->mutable_root_expr() + .mutable_call_expr() + .mutable_args() + .emplace_back(NextExpr()) + .mutable_list_expr(); + + auto& select_step_element = select_step_list.mutable_elements() + .emplace_back(NextListExprElement()) + .mutable_expr(); + + select_step_element.mutable_const_expr().set_bytes_value("bytes_key"); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(SelectOptimizationUnexpectedAstTest, OptionalNotYetSupported) { + std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); + + ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); + auto& call_args = ast->mutable_root_expr().mutable_call_expr().mutable_args(); + call_args.emplace_back(NextExpr()).mutable_ident_expr().set_name("ident"); + + auto& list_expr = call_args.emplace_back(NextExpr()).mutable_list_expr(); + auto& fields = list_expr.mutable_elements() + .emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_list_expr() + .mutable_elements(); + + fields.emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_const_expr() + .set_int64_value(1); + fields.emplace_back(NextListExprElement()) + .mutable_expr() + .mutable_const_expr() + .set_string_value("field"); + + call_args.emplace_back(NextExpr()).mutable_const_expr().set_int64_value(0); + + FlatExprBuilder builder(env_, runtime_options_); + + builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); + + EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), + StatusIs(absl::StatusCode::kUnimplemented)); +} + +} // namespace +} // namespace cel::extensions diff --git a/extensions/sets_functions.cc b/extensions/sets_functions.cc new file mode 100644 index 000000000..ebe163550 --- /dev/null +++ b/extensions/sets_functions.cc @@ -0,0 +1,171 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/sets_functions.h" + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/function_adapter.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { +using google::api::expr::runtime::CelFunctionRegistry; +using google::api::expr::runtime::ConvertToRuntimeOptions; +using google::api::expr::runtime::InterpreterOptions; + +namespace { + +absl::StatusOr SetsContains( + const ListValue& list, const ListValue& sublist, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + bool any_missing = false; + CEL_RETURN_IF_ERROR(sublist.ForEach( + [&](const Value& sublist_element) -> absl::StatusOr { + CEL_ASSIGN_OR_RETURN(auto contains, + list.Contains(sublist_element, descriptor_pool, + message_factory, arena)); + + // Treat CEL error as missing + any_missing = + !contains->Is() || !contains.GetBool().NativeValue(); + // The first false result will terminate the loop. + return !any_missing; + }, + descriptor_pool, message_factory, arena)); + return BoolValue(!any_missing); +} + +absl::StatusOr SetsIntersects( + const ListValue& list, const ListValue& sublist, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + bool exists = false; + CEL_RETURN_IF_ERROR(list.ForEach( + [&](const Value& list_element) -> absl::StatusOr { + CEL_ASSIGN_OR_RETURN(auto contains, + sublist.Contains(list_element, descriptor_pool, + message_factory, arena)); + // Treat contains return CEL error as false for the sake of + // intersecting. + exists = contains->Is() && contains.GetBool().NativeValue(); + return !exists; + }, + descriptor_pool, message_factory, arena)); + + return BoolValue(exists); +} + +absl::StatusOr SetsEquivalent( + const ListValue& list, const ListValue& sublist, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN( + auto contains_sublist, + SetsContains(list, sublist, descriptor_pool, message_factory, arena)); + if (contains_sublist.Is() && + !contains_sublist.GetBool().NativeValue()) { + return contains_sublist; + } + return SetsContains(sublist, list, descriptor_pool, message_factory, arena); +} + +absl::Status RegisterSetsContainsFunction(FunctionRegistry& registry) { + return registry.Register( + BinaryFunctionAdapter< + absl::StatusOr, const ListValue&, + const ListValue&>::CreateDescriptor("sets.contains", + /*receiver_style=*/false), + BinaryFunctionAdapter, const ListValue&, + const ListValue&>::WrapFunction(SetsContains)); +} + +absl::Status RegisterSetsIntersectsFunction(FunctionRegistry& registry) { + return registry.Register( + BinaryFunctionAdapter< + absl::StatusOr, const ListValue&, + const ListValue&>::CreateDescriptor("sets.intersects", + /*receiver_style=*/false), + BinaryFunctionAdapter, const ListValue&, + const ListValue&>::WrapFunction(SetsIntersects)); +} + +absl::Status RegisterSetsEquivalentFunction(FunctionRegistry& registry) { + return registry.Register( + BinaryFunctionAdapter< + absl::StatusOr, const ListValue&, + const ListValue&>::CreateDescriptor("sets.equivalent", + /*receiver_style=*/false), + BinaryFunctionAdapter, const ListValue&, + const ListValue&>::WrapFunction(SetsEquivalent)); +} + +absl::Status RegisterSetsDecls(TypeCheckerBuilder& b) { + ListType list_t(b.arena(), TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl("sets.contains", + MakeOverloadDecl("list_sets_contains_list", BoolType(), + list_t, list_t))); + CEL_RETURN_IF_ERROR(b.AddFunction(decl)); + + CEL_ASSIGN_OR_RETURN( + decl, MakeFunctionDecl("sets.equivalent", + MakeOverloadDecl("list_sets_equivalent_list", + BoolType(), list_t, list_t))); + CEL_RETURN_IF_ERROR(b.AddFunction(decl)); + + CEL_ASSIGN_OR_RETURN( + decl, MakeFunctionDecl("sets.intersects", + MakeOverloadDecl("list_sets_intersects_list", + BoolType(), list_t, list_t))); + return b.AddFunction(decl); +} + +} // namespace + +CheckerLibrary SetsCheckerLibrary() { + return {.id = "cel.lib.ext.sets", .configure = RegisterSetsDecls}; +} + +absl::Status RegisterSetsFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(RegisterSetsContainsFunction(registry)); + CEL_RETURN_IF_ERROR(RegisterSetsIntersectsFunction(registry)); + CEL_RETURN_IF_ERROR(RegisterSetsEquivalentFunction(registry)); + return absl::OkStatus(); +} + +absl::Status RegisterSetsFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + return RegisterSetsFunctions(registry->InternalGetRegistry(), + ConvertToRuntimeOptions(options)); +} + +} // namespace cel::extensions diff --git a/extensions/sets_functions.h b/extensions/sets_functions.h new file mode 100644 index 000000000..a49e52174 --- /dev/null +++ b/extensions/sets_functions.h @@ -0,0 +1,45 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Declarations for the sets functions. +CheckerLibrary SetsCheckerLibrary(); + +inline CompilerLibrary SetsCompilerLibrary() { + return CompilerLibrary::FromCheckerLibrary(SetsCheckerLibrary()); +} + +// Register set functions. +absl::Status RegisterSetsFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +absl::Status RegisterSetsFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ diff --git a/extensions/sets_functions_benchmark_test.cc b/extensions/sets_functions_benchmark_test.cc new file mode 100644 index 000000000..0b51f1464 --- /dev/null +++ b/extensions/sets_functions_benchmark_test.cc @@ -0,0 +1,339 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "common/value.h" +#include "eval/internal/interop.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "extensions/sets_functions.h" +#include "internal/benchmark.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::cel::Value; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::ContainerBackedListImpl; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::InterpreterOptions; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; + +enum class ListImpl : int { kLegacy = 0, kWrappedModern = 1, kRhsConstant = 2 }; +int ToNumber(ListImpl impl) { return static_cast(impl); } +ListImpl FromNumber(int number) { + switch (number) { + case 0: + return ListImpl::kLegacy; + case 1: + return ListImpl::kWrappedModern; + case 2: + return ListImpl::kRhsConstant; + default: + return ListImpl::kLegacy; + } +} + +struct TestCase { + std::string test_name; + std::string expr; + ListImpl list_impl; + int size; + CelValue result; + + std::string MakeLabel(int len) const { + std::string list_impl; + switch (this->list_impl) { + case ListImpl::kRhsConstant: + list_impl = "rhs_constant"; + break; + case ListImpl::kWrappedModern: + list_impl = "wrapped_modern"; + break; + case ListImpl::kLegacy: + list_impl = "legacy"; + break; + } + + return absl::StrCat(test_name, "/", list_impl, "/", len); + } +}; + +class ListStorage { + public: + virtual ~ListStorage() = default; +}; + +class LegacyListStorage : public ListStorage { + public: + LegacyListStorage(ContainerBackedListImpl x, ContainerBackedListImpl y) + : x_(std::move(x)), y_(std::move(y)) {} + + CelValue x() { return CelValue::CreateList(&x_); } + CelValue y() { return CelValue::CreateList(&y_); } + + private: + ContainerBackedListImpl x_; + ContainerBackedListImpl y_; +}; + +class ModernListStorage : public ListStorage { + public: + ModernListStorage(Value x, Value y) : x_(std::move(x)), y_(std::move(y)) {} + + CelValue x() { + return interop_internal::ModernValueToLegacyValueOrDie(&arena_, x_); + } + CelValue y() { + return interop_internal::ModernValueToLegacyValueOrDie(&arena_, y_); + } + + private: + google::protobuf::Arena arena_; + Value x_; + Value y_; +}; + +absl::StatusOr> RegisterLegacyLists( + bool overlap, int len, Activation& activation) { + std::vector x; + std::vector y; + x.reserve(len + 1); + y.reserve(len + 1); + if (overlap) { + x.push_back(CelValue::CreateInt64(2)); + y.push_back(CelValue::CreateInt64(1)); + } + + for (int i = 0; i < len; i++) { + x.push_back(CelValue::CreateInt64(1)); + y.push_back(CelValue::CreateInt64(2)); + } + + auto result = std::make_unique( + ContainerBackedListImpl(std::move(x)), + ContainerBackedListImpl(std::move(y))); + + activation.InsertValue("x", result->x()); + activation.InsertValue("y", result->y()); + return result; +} + +// Constant list literal that has the same elements as the bound test cases. +std::string ConstantList(bool overlap, int len) { + std::string list_body; + for (int i = 0; i < len; i++) { + } + return absl::StrCat("[", overlap ? "1, " : "", + absl::StrJoin(std::vector(len, "2"), ", "), + "]"); +} + +absl::StatusOr> RegisterModernLists( + bool overlap, int len, google::protobuf::Arena* absl_nonnull arena, + Activation& activation) { + auto x_builder = cel::NewListValueBuilder(arena); + auto y_builder = cel::NewListValueBuilder(arena); + + x_builder->Reserve(len + 1); + y_builder->Reserve(len + 1); + + if (overlap) { + CEL_RETURN_IF_ERROR(x_builder->Add(cel::IntValue(2))); + CEL_RETURN_IF_ERROR(y_builder->Add(cel::IntValue(1))); + } + + for (int i = 0; i < len; i++) { + CEL_RETURN_IF_ERROR(x_builder->Add(cel::IntValue(1))); + CEL_RETURN_IF_ERROR(y_builder->Add(cel::IntValue(2))); + } + + auto x = std::move(*x_builder).Build(); + auto y = std::move(*y_builder).Build(); + auto result = std::make_unique(std::move(x), std::move(y)); + activation.InsertValue("x", result->x()); + activation.InsertValue("y", result->y()); + + return result; +} + +absl::StatusOr> RegisterLists( + bool overlap, int len, bool use_modern, google::protobuf::Arena* absl_nonnull arena, + Activation& activation) { + if (use_modern) { + return RegisterModernLists(overlap, len, arena, activation); + } else { + return RegisterLegacyLists(overlap, len, activation); + } +} + +void RunBenchmark(const TestCase& test_case, benchmark::State& state) { + bool lists_overlap = test_case.result.BoolOrDie(); + + std::string expr = test_case.expr; + if (test_case.list_impl == ListImpl::kRhsConstant) { + expr = absl::StrReplaceAll( + expr, {{"y", ConstantList(lists_overlap, test_case.size)}}); + } + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + + google::protobuf::Arena arena; + + InterpreterOptions options; + options.constant_folding = true; + options.constant_arena = &arena; + options.enable_qualified_identifier_rewrites = true; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK(RegisterSetsFunctions(builder->GetRegistry()->InternalGetRegistry(), + cel::RuntimeOptions{})); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN( + auto storage, + RegisterLists(test_case.result.BoolOrDie(), test_case.size, + test_case.list_impl == ListImpl::kWrappedModern, &arena, + activation)); + + state.SetLabel(test_case.MakeLabel(test_case.size)); + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_EQ(result.BoolOrDie(), test_case.result.BoolOrDie()) + << test_case.test_name; + } +} + +void BM_SetsIntersectsTrue(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark({"sets.intersects_true", "sets.intersects(x, y)", impl, size, + CelValue::CreateBool(true)}, + state); +} + +void BM_SetsIntersectsFalse(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark({"sets.intersects_false", "sets.intersects(x, y)", impl, size, + CelValue::CreateBool(false)}, + state); +} + +void BM_SetsIntersectsComprehensionTrue(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark({"comprehension_intersects_true", "x.exists(i, i in y)", impl, + size, CelValue::CreateBool(true)}, + state); +} + +void BM_SetsIntersectsComprehensionFalse(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark({"comprehension_intersects_false", "x.exists(i, i in y)", impl, + size, CelValue::CreateBool(false)}, + state); +} + +void BM_SetsEquivalentTrue(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark({"sets.equivalent_true", "sets.equivalent(x, y)", impl, size, + CelValue::CreateBool(true)}, + state); +} + +void BM_SetsEquivalentFalse(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark({"sets.equivalent_false", "sets.equivalent(x, y)", impl, size, + CelValue::CreateBool(false)}, + state); +} + +void BM_SetsEquivalentComprehensionTrue(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark( + {"comprehension_equivalent_true", "x.all(i, i in y) && y.all(j, j in x)", + impl, size, CelValue::CreateBool(true)}, + state); +} + +void BM_SetsEquivalentComprehensionFalse(benchmark::State& state) { + ListImpl impl = FromNumber(state.range(0)); + int size = state.range(1); + + RunBenchmark( + {"comprehension_equivalent_false", "x.all(i, i in y) && y.all(j, j in x)", + impl, size, CelValue::CreateBool(false)}, + state); +} + +template +void BenchArgs(Benchmark* bench) { + for (ListImpl impl : + {ListImpl::kLegacy, ListImpl::kWrappedModern, ListImpl::kRhsConstant}) { + for (int size : {1, 8, 32, 64, 256}) { + bench->ArgPair(ToNumber(impl), size); + } + } +} + +BENCHMARK(BM_SetsIntersectsComprehensionTrue)->Apply(BenchArgs); +BENCHMARK(BM_SetsIntersectsComprehensionFalse)->Apply(BenchArgs); +BENCHMARK(BM_SetsIntersectsTrue)->Apply(BenchArgs); +BENCHMARK(BM_SetsIntersectsFalse)->Apply(BenchArgs); + +BENCHMARK(BM_SetsEquivalentComprehensionTrue)->Apply(BenchArgs); +BENCHMARK(BM_SetsEquivalentComprehensionFalse)->Apply(BenchArgs); +BENCHMARK(BM_SetsEquivalentTrue)->Apply(BenchArgs); +BENCHMARK(BM_SetsEquivalentFalse)->Apply(BenchArgs); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/sets_functions_test.cc b/extensions/sets_functions_test.cc new file mode 100644 index 000000000..dc6768f34 --- /dev/null +++ b/extensions/sets_functions_test.cc @@ -0,0 +1,172 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/sets_functions.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status_matchers.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" +#include "common/ast_proto.h" +#include "common/minimal_descriptor_pool.h" +#include "compiler/compiler_factory.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function_adapter.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "internal/testing.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::FunctionAdapter; +using ::google::api::expr::runtime::InterpreterOptions; + +using ::absl_testing::IsOk; +using ::google::protobuf::Arena; + +struct TestInfo { + std::string expr; +}; + +class CelSetsFunctionsTest : public testing::TestWithParam {}; + +TEST_P(CelSetsFunctionsTest, EndToEnd) { + const TestInfo& test_info = GetParam(); + ASSERT_OK_AND_ASSIGN(auto compiler_builder, + NewCompilerBuilder(cel::GetMinimalDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(SetsCompilerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult compiled, + compiler->Compile(test_info.expr)); + + ASSERT_TRUE(compiled.IsValid()) << compiled.FormatError(); + + cel::expr::CheckedExpr checked_expr; + ASSERT_THAT(AstToCheckedExpr(*compiled.GetAst(), &checked_expr), IsOk()); + + // Obtain CEL Expression builder. + InterpreterOptions options; + options.enable_heterogeneous_equality = true; + options.enable_empty_wrapper_null_unboxing = true; + options.enable_qualified_identifier_rewrites = true; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_THAT(RegisterSetsFunctions(builder->GetRegistry(), options), IsOk()); + ASSERT_THAT(google::api::expr::runtime::RegisterBuiltinFunctions( + builder->GetRegistry(), options), + IsOk()); + + // Create CelExpression from AST (Expr object). + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&checked_expr)); + Arena arena; + Activation activation; + // Run evaluation. + ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(out.IsBool()) << test_info.expr << " -> " << out.DebugString(); + EXPECT_TRUE(out.BoolOrDie()) << test_info.expr << " -> " << out.DebugString(); +} + +INSTANTIATE_TEST_SUITE_P( + CelSetsFunctionsTest, CelSetsFunctionsTest, + testing::ValuesIn({ + {"sets.contains([], [])"}, + {"sets.contains([1], [])"}, + {"sets.contains([1], [1])"}, + {"sets.contains([1], [1, 1])"}, + {"sets.contains([1, 1], [1])"}, + {"sets.contains([2, 1], [1])"}, + {"sets.contains([1], [1.0, 1u])"}, + {"sets.contains([1, 2], [2u, 2.0])"}, + {"sets.contains([1, 2u], [2, 2.0])"}, + {"!sets.contains([1], [2])"}, + {"!sets.contains([1], [1, 2])"}, + {"!sets.contains([1], [\"1\", 1])"}, + {"!sets.contains([1], [1.1, 2])"}, + {"sets.intersects([1], [1])"}, + {"sets.intersects([1], [1, 1])"}, + {"sets.intersects([1, 1], [1])"}, + {"sets.intersects([2, 1], [1])"}, + {"sets.intersects([1], [1, 2])"}, + {"sets.intersects([1], [1.0, 2])"}, + {"sets.intersects([1, 2], [2u, 2, 2.0])"}, + {"sets.intersects([1, 2], [1u, 2, 2.3])"}, + {"!sets.intersects([], [])"}, + {"!sets.intersects([1], [])"}, + {"!sets.intersects([1], [2])"}, + {"!sets.intersects([1], [\"1\", 2])"}, + {"!sets.intersects([1], [1.1, 2u])"}, + {"sets.equivalent([], [])"}, + {"sets.equivalent([1], [1])"}, + {"sets.equivalent([1], [1, 1])"}, + {"sets.equivalent([1, 1, 2], [2, 2, 1])"}, + {"sets.equivalent([1, 1], [1])"}, + {"sets.equivalent([1], [1u, 1.0])"}, + {"sets.equivalent([1], [1u, 1.0])"}, + {"sets.equivalent([1, 2, 3], [3u, 2.0, 1])"}, + {"!sets.equivalent([2, 1], [1])"}, + {"!sets.equivalent([1], [1, 2])"}, + {"!sets.equivalent([1, 2], [2u, 2, 2.0])"}, + {"!sets.equivalent([1, 2], [1u, 2, 2.3])"}, + + {"sets.equivalent([false, true], [true, false])"}, + {"!sets.equivalent([true], [false])"}, + + {"sets.equivalent(['foo', 'bar'], ['bar', 'foo'])"}, + {"!sets.equivalent(['foo'], ['bar'])"}, + + {"sets.equivalent([b'foo', b'bar'], [b'bar', b'foo'])"}, + {"!sets.equivalent([b'foo'], [b'bar'])"}, + + {"sets.equivalent([null], [null])"}, + {"!sets.equivalent([null], [])"}, + + {"sets.equivalent([type(1), type(1u)], [type(1u), type(1)])"}, + {"!sets.equivalent([type(1)], [type(1u)])"}, + + {"sets.equivalent([duration('0s'), duration('1s')], [duration('1s'), " + "duration('0s')])"}, + {"!sets.equivalent([duration('0s')], [duration('1s')])"}, + + {"sets.equivalent([timestamp('1970-01-01T00:00:00Z'), " + "timestamp('1970-01-01T00:00:01Z')], " + "[timestamp('1970-01-01T00:00:01Z'), " + "timestamp('1970-01-01T00:00:00Z')])"}, + {"!sets.equivalent([timestamp('1970-01-01T00:00:00Z')], " + "[timestamp('1970-01-01T00:00:01Z')])"}, + + {"sets.equivalent([[false, true]], [[false, true]])"}, + {"!sets.equivalent([[false, true]], [[true, false]])"}, + + {"sets.equivalent([{'foo': true, 'bar': false}], [{'bar': false, " + "'foo': true}])"}, + })); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/strings.cc b/extensions/strings.cc new file mode 100644 index 000000000..54fda20d6 --- /dev/null +++ b/extensions/strings.cc @@ -0,0 +1,432 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/strings.h" + +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "extensions/formatting.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +namespace { + +using ::cel::checker_internal::BuiltinsArena; + +struct AppendToStringVisitor { + std::string& append_to; + + void operator()(absl::string_view string) const { append_to.append(string); } + + void operator()(const absl::Cord& cord) const { + append_to.append(static_cast(cord)); + } +}; + +absl::StatusOr Join2( + const ListValue& value, const StringValue& separator, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return separator.Join(value, descriptor_pool, message_factory, arena); +} + +absl::StatusOr Join1( + const ListValue& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return StringValue().Join(value, descriptor_pool, message_factory, arena); +} + +absl::StatusOr Split3( + const StringValue& string, const StringValue& delimiter, int64_t limit, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return string.Split(delimiter, limit, arena); +} + +absl::StatusOr Split2( + const StringValue& string, const StringValue& delimiter, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return string.Split(delimiter, arena); +} + +absl::StatusOr Replace2(const StringValue& string, + const StringValue& old_sub, + const StringValue& new_sub, int64_t limit, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.Replace(old_sub, new_sub, limit, arena); +} + +absl::StatusOr Replace1( + const StringValue& string, const StringValue& old_sub, + const StringValue& new_sub, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return string.Replace(old_sub, new_sub, -1, arena); +} + +Value CharAt(const StringValue& string, int64_t pos) { + return string.CharAt(pos); +} + +int64_t IndexOf2(const StringValue& haystack, const StringValue& needle) { + return haystack.IndexOf(needle).value_or(-1); +} + +Value IndexOf3(const StringValue& haystack, const StringValue& needle, + int64_t pos) { + if (pos > haystack.Size()) { + return ErrorValue{ + absl::InvalidArgumentError(absl::StrCat("index out of range: ", pos))}; + } + return IntValue(haystack.IndexOf(needle, pos).value_or(-1)); +} + +int64_t LastIndexOf2(const StringValue& haystack, const StringValue& needle) { + return haystack.LastIndexOf(needle).value_or(-1); +} + +Value LastIndexOf3(const StringValue& haystack, const StringValue& needle, + int64_t pos) { + if (pos < 0 || pos > haystack.Size()) { + return ErrorValue{ + absl::InvalidArgumentError(absl::StrCat("index out of range: ", pos))}; + } + return IntValue(haystack.LastIndexOf(needle, pos).value_or(-1)); +} + +Value Substring2(const StringValue& string, int64_t start) { + return string.Substring(start); +} + +Value Substring3(const StringValue& string, int64_t start, int64_t end) { + return string.Substring(start, end); +} + +StringValue Trim(const StringValue& string) { return string.Trim(); } + +StringValue LowerAscii(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.LowerAscii(arena); +} + +StringValue UpperAscii(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.UpperAscii(arena); +} + +StringValue Quote(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.Quote(arena); +} + +StringValue Reverse(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.Reverse(arena); +} + +const Type& ListStringType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), StringType())); + return *kInstance; +} + +absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { + // Runtime Supported functions. + CEL_ASSIGN_OR_RETURN( + auto join_decl, + MakeFunctionDecl( + "join", + MakeMemberOverloadDecl("list_join", StringType(), ListStringType()), + MakeMemberOverloadDecl("list_join_string", StringType(), + ListStringType(), StringType()))); + CEL_ASSIGN_OR_RETURN( + auto split_decl, + MakeFunctionDecl( + "split", + MakeMemberOverloadDecl("string_split_string", ListStringType(), + StringType(), StringType()), + MakeMemberOverloadDecl("string_split_string_int", ListStringType(), + StringType(), StringType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto lower_decl, + MakeFunctionDecl("lowerAscii", + MakeMemberOverloadDecl("string_lower_ascii", + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto replace_decl, + MakeFunctionDecl( + "replace", + MakeMemberOverloadDecl("string_replace_string_string", StringType(), + StringType(), StringType(), StringType()), + MakeMemberOverloadDecl("string_replace_string_string_int", + StringType(), StringType(), StringType(), + StringType(), IntType()))); + + // Additional functions described in the spec. + CEL_ASSIGN_OR_RETURN( + auto char_at_decl, + MakeFunctionDecl( + "charAt", MakeMemberOverloadDecl("string_char_at_int", StringType(), + StringType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto index_of_decl, + MakeFunctionDecl( + "indexOf", + MakeMemberOverloadDecl("string_index_of_string", IntType(), + StringType(), StringType()), + MakeMemberOverloadDecl("string_index_of_string_int", IntType(), + StringType(), StringType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto last_index_of_decl, + MakeFunctionDecl( + "lastIndexOf", + MakeMemberOverloadDecl("string_last_index_of_string", IntType(), + StringType(), StringType()), + MakeMemberOverloadDecl("string_last_index_of_string_int", IntType(), + StringType(), StringType(), IntType()))); + + CEL_ASSIGN_OR_RETURN( + auto substring_decl, + MakeFunctionDecl( + "substring", + MakeMemberOverloadDecl("string_substring_int", StringType(), + StringType(), IntType()), + MakeMemberOverloadDecl("string_substring_int_int", StringType(), + StringType(), IntType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto upper_ascii_decl, + MakeFunctionDecl("upperAscii", + MakeMemberOverloadDecl("string_upper_ascii", + StringType(), StringType()))); + CEL_ASSIGN_OR_RETURN( + auto format_decl, + MakeFunctionDecl("format", + MakeMemberOverloadDecl("string_format", StringType(), + StringType(), ListType()))); + CEL_ASSIGN_OR_RETURN( + auto quote_decl, + MakeFunctionDecl( + "strings.quote", + MakeOverloadDecl("strings_quote", StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto reverse_decl, + MakeFunctionDecl("reverse", + MakeMemberOverloadDecl("string_reverse", StringType(), + StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto trim_decl, + MakeFunctionDecl("trim", MakeMemberOverloadDecl( + "string_trim", StringType(), StringType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(split_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(lower_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(replace_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(char_at_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(index_of_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(last_index_of_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(substring_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(upper_ascii_decl))); + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(trim_decl))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(format_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(quote_decl))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(join_decl))); + if (version == 2) { + return absl::OkStatus(); + } + + // MergeFunction is used to combine with the reverse function + // defined in cel.lib.ext.lists extension. + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); + + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + const StringsExtensionOptions& extension_options) { + const int version = extension_options.version; + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, StringValue, StringValue>:: + CreateDescriptor("split", /*receiver_style=*/true), + BinaryFunctionAdapter, StringValue, + StringValue>::WrapFunction(Split2))); + CEL_RETURN_IF_ERROR(registry.Register( + TernaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, + int64_t>::CreateDescriptor("split", /*receiver_style=*/true), + TernaryFunctionAdapter, StringValue, StringValue, + int64_t>::WrapFunction(Split3))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, StringValue>:: + CreateDescriptor("lowerAscii", /*receiver_style=*/true), + UnaryFunctionAdapter, StringValue>::WrapFunction( + LowerAscii))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, StringValue>:: + CreateDescriptor("upperAscii", /*receiver_style=*/true), + UnaryFunctionAdapter, StringValue>::WrapFunction( + UpperAscii))); + CEL_RETURN_IF_ERROR(registry.Register( + TernaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, + StringValue>::CreateDescriptor("replace", /*receiver_style=*/true), + TernaryFunctionAdapter, StringValue, StringValue, + StringValue>::WrapFunction(Replace1))); + CEL_RETURN_IF_ERROR(registry.Register( + QuaternaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, StringValue, + int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), + QuaternaryFunctionAdapter, StringValue, StringValue, + StringValue, int64_t>::WrapFunction(Replace2))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("indexOf", + &IndexOf2, + registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter::RegisterMemberOverload("indexOf", + &IndexOf3, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", + &LastIndexOf2, + registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", + &LastIndexOf3, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("substring", + &Substring2, + registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter::RegisterMemberOverload("substring", + &Substring3, + registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterMemberOverload( + "trim", &Trim, registry))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions( + registry, options, {extension_options.max_precision})); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "strings.quote", &Quote, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "join", /*receiver_style=*/true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + Join1))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, ListValue, StringValue>:: + CreateDescriptor("join", /*receiver_style=*/true), + BinaryFunctionAdapter, ListValue, + StringValue>::WrapFunction(Join2))); + if (version == 2) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterMemberOverload( + "reverse", &Reverse, registry))); + return absl::OkStatus(); +} + +absl::Status RegisterStringsFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options) { + return RegisterStringsFunctions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options), + extension_options); +} + +CheckerLibrary StringsCheckerLibrary(const StringsExtensionOptions& options) { + const int version = options.version; + return {"strings", [version](TypeCheckerBuilder& builder) { + return RegisterStringsDecls(builder, version); + }}; +} + +} // namespace cel::extensions diff --git a/extensions/strings.h b/extensions/strings.h new file mode 100644 index 000000000..3ec92d603 --- /dev/null +++ b/extensions/strings.h @@ -0,0 +1,73 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ + +#include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +constexpr int kStringsExtensionLatestVersion = 4; + +struct StringsExtensionOptions { + int version = kStringsExtensionLatestVersion; + + // Maximum precision allowed for floating point format specifiers in + // format() function. This is used for both fixed and scientific notations. + // Value must be in the range [0, 1000], otherwise clamped. + // + // Does not affect default precisions for %e and %f format specifiers. + int max_precision = 1000; +}; + +// Register extension functions for strings. +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + const StringsExtensionOptions& extension_options = {}); + +absl::Status RegisterStringsFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options = {}); + +CheckerLibrary StringsCheckerLibrary( + const StringsExtensionOptions& extension_options = {}); + +inline CheckerLibrary StringsCheckerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCheckerLibrary(options); +} + +inline CompilerLibrary StringsCompilerLibrary( + const StringsExtensionOptions& options = {}) { + return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(options)); +} + +inline CompilerLibrary StringsCompilerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCompilerLibrary(options); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc new file mode 100644 index 000000000..c3059808f --- /dev/null +++ b/extensions/strings_test.cc @@ -0,0 +1,473 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/strings.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/standard_library.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "testutil/baseline_tests.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::parser::ParserOptions; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Values; +using ::testing::ValuesIn; + +TEST(StringsCheckerLibrary, SmokeTest) { + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StringsCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("foo", StringType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("foo.replace('he', 'we', 1) == 'wello hello'")); + ASSERT_TRUE(result.IsValid()); + + EXPECT_EQ(test::FormatBaselineAst(*result.GetAst()), + R"(_==_( + foo~string^foo.replace( + "he"~string, + "we"~string, + 1~int + )~string^string_replace_string_string_int, + "wello hello"~string +)~bool^equals)"); +} + +TEST(StringsExtTest, MaxPrecisionOption) { + StringsExtensionOptions extension_options; + extension_options.max_precision = 99; + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("'abc %.100f'.format([2.0])", "")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(RegisterStringsFunctions(runtime_builder.function_registry(), + opts, extension_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("precision specifier exceeds maximum of 99"))); +} + +using StringsExtFunctionsTest = testing::TestWithParam; + +TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) { + const std::string& expr = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); + + auto result = compiler->Compile(expr, ""); + + ASSERT_THAT(result, IsOk()); + ASSERT_TRUE(result->IsValid()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT( + RegisterStringsFunctions(runtime_builder.function_registry(), opts), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, + runtime->CreateProgram(*result->ReleaseAst())); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); +} + +INSTANTIATE_TEST_SUITE_P( + StringsExtMacrosParamsTest, StringsExtFunctionsTest, + testing::Values( + // Tests for split() + "'hello world!'.split('') == ['h', 'e', 'l', 'l', 'o', ' ', " + "'w', 'o', 'r', 'l', 'd', '!']", + // Tests for replace() + "'hello hello'.replace('he', 'we') == 'wello wello'", + "'hello hello'.replace('he', 'we', -1) == 'wello wello'", + "'hello hello'.replace('he', 'we', 1) == 'wello hello'", + "'hello hello'.replace('he', 'we', 0) == 'hello hello'", + // Tests for lowerAscii() + "'UPPER lower'.lowerAscii() == 'upper lower'", + // Tests for upperAscii() + "'UPPER lower'.upperAscii() == 'UPPER LOWER'", + // Tests for format() + "'abc %.3f'.format([2.0]) == 'abc 2.000'", + + // Tests for charAt() + "'tacocat'.charAt(3) == 'o'", "'tacocat'.charAt(7) == ''", + "'©αT'.charAt(0) == '©' && '©αT'.charAt(1) == 'α' && '©αT'.charAt(2) " + "== 'T'", + + // Tests for indexOf() + "'tacocat'.indexOf('') == 0", "'tacocat'.indexOf('ac') == 1", + "'tacocat'.indexOf('none') == -1", "'tacocat'.indexOf('', 3) == 3", + "'tacocat'.indexOf('a', 3) == 5", "'tacocat'.indexOf('at', 3) == 5", + "'ta©o©αT'.indexOf('©') == 2", "'ta©o©αT'.indexOf('©', 3) == 4", + "'ta©o©αT'.indexOf('©αT', 3) == 4", "'ta©o©αT'.indexOf('©α', 5) == -1", + "'ijk'.indexOf('k') == 2", "'hello wello'.indexOf('hello wello') == 0", + "'hello wello'.indexOf('ello', 6) == 7", + "'hello wello'.indexOf('elbo room!!') == -1", + "'hello wello'.indexOf('elbo room!!!') == -1", + "''.lastIndexOf('@@') == -1", "'tacocat'.lastIndexOf('') == 7", + "'tacocat'.lastIndexOf('at') == 5", + "'tacocat'.lastIndexOf('none') == -1", + "'tacocat'.lastIndexOf('', 3) == 3", + "'tacocat'.lastIndexOf('a', 3) == 1", "'ta©o©αT'.lastIndexOf('©') == 4", + "'ta©o©αT'.lastIndexOf('©', 3) == 2", + "'ta©o©αT'.lastIndexOf('©α', 4) == 4", + "'hello wello'.lastIndexOf('ello', 6) == 1", + "'hello wello'.lastIndexOf('low') == -1", + "'hello wello'.lastIndexOf('elbo room!!') == -1", + "'hello wello'.lastIndexOf('elbo room!!!') == -1", + "'hello wello'.lastIndexOf('hello wello') == 0", + "'bananananana'.lastIndexOf('nana', 7) == 6", + + // Tests for substring() + "'tacocat'.substring(4) == 'cat'", "'tacocat'.substring(7) == ''", + "'tacocat'.substring(0, 4) == 'taco'", + "'tacocat'.substring(4, 4) == ''", + "'ta©o©αT'.substring(2, 6) == '©o©α'", + "'ta©o©αT'.substring(7, 7) == ''", + + // Tests for reverse() + "''.reverse() == ''", "'hello'.reverse() == 'olleh'", + "'©αT'.reverse() == 'Tα©'", "'gums'.reverse() == 'smug'", + "'palindromes'.reverse() == 'semordnilap'", + "'John Smith'.reverse() == 'htimS nhoJ'", + "'u180etext'.reverse() == 'txete081u'", + "'2600+U'.reverse() == 'U+0062'", + "'\u180e\u200b\u200c\u200d\u2060\ufeff'.reverse() == " + "'\ufeff\u2060\u200d\u200c\u200b\u180e'", + + // Tests for strings.quote() + R"(strings.quote("first\nsecond") == "\"first\\nsecond\"")", + R"(strings.quote("bell\a") == "\"bell\\a\"")", + R"(strings.quote("\bbackspace") == "\"\\bbackspace\"")", + R"(strings.quote("\fform feed") == "\"\\fform feed\"")", + R"(strings.quote("carriage \r return") == "\"carriage \\r return\"")", + R"(strings.quote("vertical \v tab") == "\"vertical \\v tab\"")", + R"(strings.quote("verbatim") == "\"verbatim\"")", + R"(strings.quote("ends with \\") == "\"ends with \\\\\"")", + R"(strings.quote("\\ starts with") == "\"\\\\ starts with\"")", + + // Tests for trim() + R"(' \f\n\r\t\vtext '.trim() == 'text')", + R"('\u0085\u00a0\u1680text'.trim() == 'text')", + R"('text\u2000\u2001\u2002\u2003\u2004\u2004\u2006\u2007\u2008\u2009'.trim() == 'text')", + R"('\u200atext\u2028\u2029\u202F\u205F\u3000'.trim() == 'text')", + R"(' hello world '.trim() == 'hello world')")); + +// Basic test for the included declarations. +// Additional coverage for behavior in the spec tests. +class StringsCheckerLibraryTest : public ::testing::TestWithParam { +}; + +TEST_P(StringsCheckerLibraryTest, TypeChecks) { + const std::string& expr = GetParam(); + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(expr)); + EXPECT_TRUE(result.IsValid()) << "Failed to compile: " << expr; +} + +INSTANTIATE_TEST_SUITE_P( + Expressions, StringsCheckerLibraryTest, + Values("['a', 'b', 'c'].join() == 'abc'", + "['a', 'b', 'c'].join('|') == 'a|b|c'", + "'a|b|c'.split('|') == ['a', 'b', 'c']", + "'a|b|c'.split('|', 1) == ['a', 'b|c']", + "'a|b|c'.split('|') == ['a', 'b', 'c']", + "'AbC'.lowerAscii() == 'abc'", + "'tacocat'.replace('cat', 'dog') == 'tacodog'", + "'tacocat'.replace('aco', 'an', 2) == 'tacocat'", + "'tacocat'.charAt(2) == 'c'", "'tacocat'.indexOf('c') == 2", + "'tacocat'.indexOf('c', 3) == 4", "'tacocat'.lastIndexOf('c') == 4", + "'tacocat'.lastIndexOf('c', 5) == -1", + "'tacocat'.substring(1) == 'acocat'", + "'tacocat'.substring(1, 3) == 'aco'", "'aBc'.upperAscii() == 'ABC'", + "'abc %d'.format([2]) == 'abc 2'", + "strings.quote('abc') == \"'abc 2'\"", "'abc'.reverse() == 'cba'", + "'ta©o©αT'.substring(7, 7) == ''")); + +class StringsOverloadNotFoundTest + : public ::testing::TestWithParam {}; + +TEST_P(StringsOverloadNotFoundTest, PlannerTests) { + const std::string& expr_string = GetParam(); + const auto options = RuntimeOptions{}; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(expr_string, "", ParserOptions{})); + + EXPECT_THAT( + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + absl_testing::StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr("No overloads provided"))); +} + +INSTANTIATE_TEST_SUITE_P( + OverloadNotFound, StringsOverloadNotFoundTest, + Values( + // string_ext.type_errors/indexof_ternary_invalid_arguments + "'42'.indexOf('4', 0, 1) == 0", + // string_ext.type_errors/replace_quaternary_invalid_argument + "'42'.replace('2', '1', 1, false) == '41'", + // string_ext.type_errors/split_ternary_invalid_argument + "'42'.split('2', 1, 1) == ['4']", + // string_ext.type_errors/substring_ternary_invalid_argument + "'hello'.substring(1, 2, 3) == ''")); + +class StringsRuntimeErrorTest : public ::testing::TestWithParam {}; + +TEST_P(StringsRuntimeErrorTest, EvaluationErrors) { + const std::string& expr = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); + + auto result = compiler->Compile(expr, ""); + + ASSERT_THAT(result, IsOk()); + ASSERT_TRUE(result->IsValid()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT( + RegisterStringsFunctions(runtime_builder.function_registry(), opts), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, + runtime->CreateProgram(*result->ReleaseAst())); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.As()->NativeValue().code(), + absl::StatusCode::kInvalidArgument); +} + +INSTANTIATE_TEST_SUITE_P(EvaluationErrors, StringsRuntimeErrorTest, + Values("'a'.substring(-1)", "'a'.substring(2)", + "'a'.substring(0, -1)", "'a'.substring(0, 2)", + "'a'.substring(1, 0)")); + +struct StringsExtensionVersionTestCase { + std::string expr; + std::vector expected_supported_versions; +}; + +class StringsExtensionVersionTest + : public ::testing::TestWithParam {}; + +TEST_P(StringsExtensionVersionTest, StringsExtensionVersions) { + const StringsExtensionVersionTestCase& test_case = GetParam(); + for (int version = 0; + version <= cel::extensions::kStringsExtensionLatestVersion; ++version) { + CompilerLibrary compiler_library = StringsCompilerLibrary(version); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), + CompilerOptions())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(test_case.expr)); + if (absl::c_contains(test_case.expected_supported_versions, version)) { + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << "Expected no issues for expr: " << test_case.expr + << " at version: " << version << " but got: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference")))); + } + } +}; + +std::vector +CreateStringsExtensionVersionParams() { + return { + StringsExtensionVersionTestCase{ + .expr = "'foo'.charAt(0)", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.indexOf('f')", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.lastIndexOf('f')", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.lowerAscii()", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.replace('f', 'b')", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.split('o')", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.substring(0, 1)", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.trim()", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.upperAscii()", + .expected_supported_versions = {0, 1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'%d'.format([1])", + .expected_supported_versions = {1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "strings.quote('foo')", + .expected_supported_versions = {1, 2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "['a', 'b', 'c'].join(',')", + .expected_supported_versions = {2, 3, 4}, + }, + StringsExtensionVersionTestCase{ + .expr = "'foo'.reverse()", + .expected_supported_versions = {3, 4}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(StringsExtensionVersionTest, + StringsExtensionVersionTest, + ValuesIn(CreateStringsExtensionVersionParams())); + +} // namespace +} // namespace cel::extensions diff --git a/internal/BUILD b/internal/BUILD index 7819be972..6d0efab72 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -1,290 +1,861 @@ -# Description -# Internal implemenation details and libraries. +# Copyright 2021 Google LLC # -# Uses the namespace google::api::expr::internal +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") +load("//bazel:cel_cc_embed.bzl", "cel_cc_embed") +load("//bazel:cel_proto_transitive_descriptor_set.bzl", "cel_proto_transitive_descriptor_set") package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) cc_library( - name = "ref_countable", - srcs = ["ref_countable.cc"], - hdrs = [ - "ref_countable.h", + name = "align", + hdrs = ["align.h"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/numeric:bits", ], +) + +cc_test( + name = "align_test", + srcs = ["align_test.cc"], + tags = ["no_test_msvc"], deps = [ - ":holder", - ":specialize", + ":align", + ":testing", + ], +) + +cc_library( + name = "new", + srcs = ["new.cc"], + hdrs = ["new.h"], + deps = [ + ":align", + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/numeric:bits", ], ) cc_test( - name = "ref_countable_test", - srcs = ["ref_countable_test.cc"], + name = "new_test", + srcs = ["new_test.cc"], deps = [ - ":ref_countable", - "//testutil:util", - "@com_google_absl//absl/memory", - "@com_google_googletest//:gtest_main", + ":new", + ":testing", ], ) cc_library( - name = "handle", - hdrs = [ - "handle.h", + name = "benchmark", + testonly = True, + hdrs = ["benchmark.h"], + deps = ["@com_github_google_benchmark//:benchmark_main"], +) + +cc_library( + name = "casts", + hdrs = ["casts.h"], +) + +cc_library( + name = "re2_options", + hdrs = ["re2_options.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", ], +) + +cc_library( + name = "runfiles", + srcs = ["runfiles.cc"], + hdrs = ["runfiles.h"], deps = [ - ":hash_util", - ":specialize", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@rules_cc//cc/runfiles", ], ) -cc_test( - name = "handle_test", - srcs = ["handle_test.cc"], +cc_library( + name = "status_builder", + hdrs = ["status_builder.h"], deps = [ - ":handle", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", ], ) cc_library( - name = "holder", - hdrs = [ - "holder.h", + name = "overflow", + srcs = ["overflow.cc"], + hdrs = ["overflow.h"], + deps = [ + ":status_macros", + ":time", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", ], +) + +cc_test( + name = "overflow_test", + srcs = ["overflow_test.cc"], deps = [ - ":port", - ":specialize", - ":types", - ":visitor_util", + ":overflow", + ":testing", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", ], ) +cc_library( + name = "number", + hdrs = ["number.h"], + deps = ["@com_google_absl//absl/types:variant"], +) + cc_test( - name = "holder_test", - srcs = ["holder_test.cc"], + name = "number_test", + srcs = ["number_test.cc"], deps = [ - ":holder", - ":types", - "//testutil:util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + ":number", + ":testing", ], ) cc_library( - name = "hash_util", - srcs = [ - "hash_util.cc", + name = "exceptions", + hdrs = ["exceptions.h"], + deps = ["@com_google_absl//absl/base:config"], +) + +cc_library( + name = "status_macros", + hdrs = ["status_macros.h"], + deps = [ + ":status_builder", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", ], - hdrs = [ - "hash_util.h", +) + +cc_library( + name = "string_pool", + srcs = ["string_pool.cc"], + hdrs = ["string_pool.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], +) + +cc_test( + name = "string_pool_test", + srcs = ["string_pool_test.cc"], deps = [ - ":port", - ":specialize", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googleapis//google/rpc:status_cc_proto", + ":string_pool", + ":testing", + "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "status_util", - srcs = ["status_util.cc"], - hdrs = [ - "status_util.h", + name = "strings", + srcs = ["strings.cc"], + hdrs = ["strings.h"], + deps = [ + ":lexis", + ":unicode", + ":utf8", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", ], +) + +cc_test( + name = "strings_test", + srcs = ["strings_test.cc"], deps = [ + ":strings", + ":testing", + ":utf8", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/rpc:code_cc_proto", - "@com_google_googleapis//google/rpc:status_cc_proto", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:cord_test_helpers", + "@com_google_absl//absl/strings:str_format", ], ) cc_library( - name = "visitor_util", - hdrs = [ - "visitor_util.h", + name = "lexis", + srcs = ["lexis.cc"], + hdrs = ["lexis.h"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], +) + +cc_test( + name = "lexis_test", + srcs = ["lexis_test.cc"], deps = [ - ":specialize", - ":types", + ":lexis", + ":testing", + ], +) + +cc_library( + name = "proto_util", + hdrs = ["proto_util.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_util_test", + srcs = ["proto_util_test.cc"], + deps = [ + ":proto_util", + ":testing", + "//eval/public/structs:cel_proto_descriptor_pool_builder", + "@com_google_absl//absl/status", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "proto_time_encoding", + srcs = ["proto_time_encoding.cc"], + hdrs = ["proto_time_encoding.h"], + deps = [ + ":status_macros", + ":time", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:time_util", + "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_test( - name = "visitor_util_test", - srcs = ["visitor_util_test.cc"], + name = "proto_time_encoding_test", + srcs = ["proto_time_encoding_test.cc"], deps = [ - ":adapter_util", - ":visitor_util", + ":proto_time_encoding", + ":testing", "//testutil:util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/time", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_library( - name = "adapter_util", + name = "testing", + testonly = True, + srcs = [ + "testing.cc", + ], hdrs = [ - "adapter_util.h", + "testing.h", ], deps = [ - ":visitor_util", + ":status_macros", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", ], ) cc_library( - name = "cel_printer", - srcs = ["cel_printer.cc"], + name = "testing_no_main", + testonly = True, + srcs = [ + "testing.cc", + ], hdrs = [ - "cel_printer.h", + "testing.h", + ], + deps = [ + ":status_macros", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", ], +) + +cc_library( + name = "time", + srcs = ["time.cc"], + hdrs = ["time.h"], deps = [ - ":specialize", - ":types", - ":visitor_util", + ":status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", - "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:time_util", ], ) cc_test( - name = "cel_printer_test", - srcs = ["cel_printer_test.cc"], + name = "time_test", + srcs = ["time_test.cc"], deps = [ - ":cel_printer", - "@com_google_googletest//:gtest_main", + ":testing", + ":time", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_protobuf//:time_util", ], ) cc_library( - name = "proto_util", - srcs = ["proto_util.cc"], - hdrs = ["proto_util.h"], + name = "unicode", + hdrs = ["unicode.h"], +) + +cc_library( + name = "utf8", + srcs = ["utf8.cc"], + hdrs = ["utf8.h"], deps = [ - ":status_util", - "//common:macros", - "@com_google_absl//absl/memory", + ":unicode", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googleapis//google/rpc:status_cc_proto", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_test( + name = "utf8_test", + srcs = ["utf8_test.cc"], + deps = [ + ":benchmark", + ":testing", + ":utf8", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:cord_test_helpers", + ], +) + +cc_library( + name = "proto_matchers", + testonly = True, + hdrs = ["proto_matchers.h"], + deps = [ + ":casts", + ":testing", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "map_impl", - srcs = ["map_impl.cc"], - hdrs = ["map_impl.h"], + name = "proto_file_util", + testonly = True, + hdrs = ["proto_file_util.h"], deps = [ - ":status_util", - "//common:macros", - "//common:value", - "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//src/google/protobuf/io", ], ) cc_library( - name = "list_impl", - hdrs = ["list_impl.h"], + name = "names", + srcs = ["names.cc"], + hdrs = ["names.h"], + deps = [ + ":lexis", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "names_test", + srcs = ["names_test.cc"], deps = [ - ":holder", - "//common:macros", - "//common:value", + ":names", + ":testing", ], ) cc_library( - name = "value_internal", - hdrs = ["value_internal.h"], + name = "to_address", + hdrs = ["to_address.h"], deps = [ - ":adapter_util", - ":cast", - ":ref_countable", - "//common:enum", - "//common:error", - "//common:id", - "//common:parent_ref", - "//common:type", - "//common:unknown", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/meta:type_traits", ], ) cc_test( - name = "value_internal_test", - srcs = ["value_internal_test.cc"], + name = "to_address_test", + srcs = ["to_address_test.cc"], deps = [ - ":value_internal", - "//testutil:util", - "@com_google_absl//absl/memory", + ":testing", + ":to_address", + ], +) + +cel_proto_transitive_descriptor_set( + name = "empty_descriptor_set", + deps = [ + "@com_google_protobuf//:empty_proto", + ], +) + +cel_cc_embed( + name = "empty_descriptor_set_embed", + src = ":empty_descriptor_set", +) + +cc_library( + name = "empty_descriptors", + srcs = ["empty_descriptors.cc"], + hdrs = ["empty_descriptors.h"], + textual_hdrs = [":empty_descriptor_set_embed"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "empty_descriptors_test", + srcs = ["empty_descriptors_test.cc"], + deps = [ + ":empty_descriptors", + ":testing", + ], +) + +cel_proto_transitive_descriptor_set( + name = "minimal_descriptor_set", + deps = [ + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + "@com_google_protobuf//:wrappers_proto", + ], +) + +cel_cc_embed( + name = "minimal_descriptor_set_embed", + src = ":minimal_descriptor_set", +) + +alias( + name = "minimal_descriptor_pool", + actual = ":minimal_descriptors", +) + +cc_library( + name = "minimal_descriptors", + srcs = ["minimal_descriptors.cc"], + hdrs = [ + "minimal_descriptor_database.h", + "minimal_descriptor_pool.h", + ], + textual_hdrs = [":minimal_descriptor_set_embed"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cel_proto_transitive_descriptor_set( + name = "testing_descriptor_set", + testonly = True, + deps = [ + "//eval/testutil:test_extensions_proto", + "//eval/testutil:test_message_proto", + "//testutil:test_json_names_proto", + "@com_google_cel_spec//proto/cel/expr:checked_proto", + "@com_google_cel_spec//proto/cel/expr:expr_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_proto", + "@com_google_cel_spec//proto/cel/expr:value_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:field_mask_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + "@com_google_protobuf//:wrappers_proto", + ], +) + +cel_cc_embed( + name = "testing_descriptor_set_embed", + testonly = True, + src = ":testing_descriptor_set", +) + +cc_library( + name = "testing_descriptor_pool", + testonly = True, + srcs = ["testing_descriptor_pool.cc"], + hdrs = ["testing_descriptor_pool.h"], + textual_hdrs = [":testing_descriptor_set_embed"], + deps = [ + ":noop_delete", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "testing_descriptor_pool_test", + srcs = ["testing_descriptor_pool_test.cc"], + deps = [ + ":testing", + ":testing_descriptor_pool", + "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "port", - hdrs = ["port.h"], + name = "message_type_name", + hdrs = ["message_type_name.h"], + deps = [ + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "message_type_name_test", + srcs = ["message_type_name_test.cc"], + deps = [ + ":message_type_name", + ":testing", + "@com_google_protobuf//:any_cc_proto", + ], ) cc_library( - name = "specialize", - hdrs = ["specialize.h"], + name = "parse_text_proto", + testonly = True, + hdrs = ["parse_text_proto.h"], + deps = [ + ":message_type_name", + ":testing_descriptor_pool", + ":testing_message_factory", + "//common:memory", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], ) cc_library( - name = "cast", - hdrs = ["cast.h"], + name = "equals_text_proto", + testonly = True, + srcs = ["equals_text_proto.cc"], + hdrs = ["equals_text_proto.h"], deps = [ - ":port", - ":specialize", - ":types", + ":parse_text_proto", + ":testing", + ":testing_descriptor_pool", + ":testing_message_factory", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "testing_message_factory", + testonly = True, + srcs = ["testing_message_factory.cc"], + hdrs = ["testing_message_factory.h"], + deps = [ + ":testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "well_known_types", + srcs = ["well_known_types.cc"], + hdrs = ["well_known_types.h"], + deps = [ + ":protobuf_runtime_version", + ":status_macros", + "//common:any", + "//common:json", + "//common:memory", + "//extensions/protobuf/internal:map_reflection", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:field_mask_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:time_util", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( - name = "cast_test", - srcs = ["cast_test.cc"], + name = "well_known_types_test", + srcs = ["well_known_types_test.cc"], deps = [ - ":cast", - "@com_google_googletest//:gtest_main", + ":message_type_name", + ":minimal_descriptor_pool", + ":parse_text_proto", + ":testing", + ":testing_descriptor_pool", + ":testing_message_factory", + ":well_known_types", + "//common:memory", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:field_mask_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( - name = "types", - hdrs = [ - "types.h", + name = "json", + srcs = ["json.cc"], + hdrs = ["json.h"], + deps = [ + ":status_macros", + ":strings", + ":well_known_types", + "//extensions/protobuf/internal:map_reflection", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:time_util", + "@com_google_protobuf//:timestamp_cc_proto", ], +) + +cc_test( + name = "json_test", + srcs = ["json_test.cc"], deps = [ - ":port", - ":specialize", - "@com_google_absl//absl/memory", + ":equals_text_proto", + ":json", + ":message_type_name", + ":parse_text_proto", + ":testing", + ":testing_descriptor_pool", + ":testing_message_factory", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:field_mask_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_library( + name = "message_equality", + srcs = ["message_equality.cc"], + hdrs = ["message_equality.h"], + deps = [ + ":json", + ":number", + ":status_macros", + ":well_known_types", + "//common:memory", + "//extensions/protobuf/internal:map_reflection", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", ], ) cc_test( - name = "types_test", - srcs = ["types_test.cc"], + name = "message_equality_test", + srcs = ["message_equality_test.cc"], + tags = ["no_test_msvc"], deps = [ - ":types", - "@com_google_googletest//:gtest_main", + ":message_equality", + ":message_type_name", + ":parse_text_proto", + ":testing", + ":testing_descriptor_pool", + ":testing_message_factory", + ":well_known_types", + "//common:allocator", + "//common:memory", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:struct_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_google_protobuf//:wrappers_cc_proto", + ], +) + +cc_library( + name = "protobuf_runtime_version", + hdrs = ["protobuf_runtime_version.h"], + deps = ["@com_google_protobuf//:protobuf"], +) + +cc_library( + name = "noop_delete", + hdrs = ["noop_delete.h"], + deps = ["@com_google_absl//absl/base:nullability"], +) + +cc_library( + name = "manual", + hdrs = ["manual.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", ], ) diff --git a/internal/adapter_util.h b/internal/adapter_util.h deleted file mode 100644 index 28dea6e7e..000000000 --- a/internal/adapter_util.h +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Utilities for adapters. - * - * Adapters are visitors that accept a single argument. - * - * The primary utilities provided in this library include: - * - MaybeAdapt/MaybeAdaptResult: Tries to apply the given adapter - * if possible, otherwise returns the given value unchanged. - * - VisitorAdapter/AdaptVisitor: A visitor that tries to apply the - * given adapter to every argument before passing those arguments on to the - * wrapped visitor. - * - CompositeAdapter: An adapter that (mabye) applies the provided adapters in - * order. - */ - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_ADAPTER_UTIL_H_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_ADAPTER_UTIL_H_H_ - -#include - -#include "internal/visitor_util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -/** An adapter that passes through the give value unchanged. */ -struct IdentityAdapter { - template - T&& operator()(T&& value) { - return std::forward(value); - } -}; - -/** The result of a call to MaybeAdapt. */ -template -using MaybeAdaptResultType = MaybeVisitResultType; - -/** - * A helper function that applies the adapter if a suitable overload is found, - * otherwise the value is returned unchanged. - */ -template -MaybeAdaptResultType MaybeAdapt(Adapter&& adpt, T&& value) { - return MaybeVisit(std::forward(adpt), IdentityAdapter(), - std::forward(value)); -} - -/** - * (Maybe) applies `Adapter` to every argument before passing them to Visitor. - */ -template -class VisitorAdapter { - public: - VisitorAdapter() {} - explicit VisitorAdapter(Visitor&& vis) : vis_(std::forward(vis)) {} - VisitorAdapter(Visitor&& vis, Adapter&& adapter) - : vis_(std::forward(vis)), - adapter_(std::forward(adapter)) {} - - template - VisitResultType...> operator()( - Args&&... args) { - return vis_(MaybeAdapt(adapter_, args)...); - } - - private: - Visitor vis_; - Adapter adapter_; -}; - -template -VisitorAdapter AdaptVisitor(Visitor&& vis, - Adapter&& adapter) { - return VisitorAdapter(std::forward(vis), - std::forward(adapter)); -} - -/** An adapter that (maybe) applies the given adapters in order. */ -template -class CompositeAdapter; - -// Only a single adapter. -template -class CompositeAdapter { - public: - CompositeAdapter() = default; - CompositeAdapter(const CompositeAdapter&) = default; - CompositeAdapter(CompositeAdapter&&) = default; - explicit CompositeAdapter(Adapter&& adpt) - : adpt_(std::forward(adpt)) {} - - template - using ResultType = MaybeAdaptResultType; - - template - ResultType operator()(T&& value) { - return MaybeAdapt(adpt_, std::forward(value)); - } - - private: - Adapter adpt_; -}; - -// Multiple adapters, so pull of the head, and construct a new adapter from -// the tail. Then use MaybeVisit to try and visit Head first. -template -class CompositeAdapter { - private: - using Adapter = Head; - using NextAdapter = CompositeAdapter; - - public: - template - using ResultType = - VisitResultType>; - - CompositeAdapter() = default; - CompositeAdapter(const CompositeAdapter&) = default; - CompositeAdapter(CompositeAdapter&&) = default; - CompositeAdapter(Adapter&& adpt, Tail&&... next) - : adpt_(std::forward(adpt_)), - next_(std::forward(next)...) {} - - template - ResultType operator()(T&& value) { - return next_(MaybeAdapt(adpt_, std::forward(value))); - } - - private: - Adapter adpt_; - NextAdapter next_; -}; - -/** Helper function to construct a CompositeAdapter. */ -template -CompositeAdapter MakeCompositeAdapter(Adapters&&... vis) { - return CompositeAdapter(std::forward(vis)...); -} - -/** An adapter wrapper that restricts adaptation to the specified types. */ -template -struct StrictAdapter { - template - StrictAdapter(Args&&... args) : adpt(std::forward(args)...) {} - - template - specialize_ift, VisitResultType> operator()( - T&& value) { - return adpt(std::forward(value)); - } - - Adapter adpt; -}; - -/** An adapter wrapper that applies Adapter and converts the result to T. */ -template -struct ConvertAdapter { - template - ConvertAdapter(Args&&... args) : adpt(std::forward(args)...) {} - - template - // Only enable for types accepted by Adapter - specialize_ifd> operator()(U&& value) { - return T(adpt(std::forward(value))); - } - - Adapter adpt; -}; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_ADAPTER_UTIL_H_H_ diff --git a/internal/align.h b/internal/align.h new file mode 100644 index 000000000..244dcbf44 --- /dev/null +++ b/internal/align.h @@ -0,0 +1,95 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_ALIGN_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_ALIGN_H_ + +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/base/config.h" +#include "absl/base/macros.h" +#include "absl/numeric/bits.h" + +namespace cel::internal { + +template +constexpr std::enable_if_t< + std::conjunction_v, std::is_unsigned>, T> +AlignmentMask(T alignment) { + ABSL_ASSERT(absl::has_single_bit(alignment)); + return alignment - T{1}; +} + +template +std::enable_if_t, std::is_unsigned>, + T> +AlignDown(T x, size_t alignment) { + ABSL_ASSERT(absl::has_single_bit(alignment)); +#if ABSL_HAVE_BUILTIN(__builtin_align_up) + return __builtin_align_down(x, alignment); +#else + using C = std::common_type_t; + return static_cast(static_cast(x) & + ~AlignmentMask(static_cast(alignment))); +#endif +} + +template +std::enable_if_t, T> AlignDown(T x, size_t alignment) { + return absl::bit_cast(AlignDown(absl::bit_cast(x), alignment)); +} + +template +std::enable_if_t, std::is_unsigned>, + T> +AlignUp(T x, size_t alignment) { + ABSL_ASSERT(absl::has_single_bit(alignment)); +#if ABSL_HAVE_BUILTIN(__builtin_align_up) + return __builtin_align_up(x, alignment); +#else + using C = std::common_type_t; + return static_cast(AlignDown( + static_cast(x) + AlignmentMask(static_cast(alignment)), alignment)); +#endif +} + +template +std::enable_if_t, T> AlignUp(T x, size_t alignment) { + return absl::bit_cast(AlignUp(absl::bit_cast(x), alignment)); +} + +template +constexpr std::enable_if_t< + std::conjunction_v, std::is_unsigned>, bool> +IsAligned(T x, size_t alignment) { + ABSL_ASSERT(absl::has_single_bit(alignment)); +#if ABSL_HAVE_BUILTIN(__builtin_is_aligned) + return __builtin_is_aligned(x, alignment); +#else + using C = std::common_type_t; + return (static_cast(x) & AlignmentMask(static_cast(alignment))) == C{0}; +#endif +} + +template +std::enable_if_t, bool> IsAligned(T x, size_t alignment) { + return IsAligned(absl::bit_cast(x), alignment); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_ALIGN_H_ diff --git a/internal/align_test.cc b/internal/align_test.cc new file mode 100644 index 000000000..b1f31a9f6 --- /dev/null +++ b/internal/align_test.cc @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/align.h" + +#include +#include + +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +TEST(AlignmentMask, Masks) { + EXPECT_EQ(AlignmentMask(size_t{1}), size_t{0}); + EXPECT_EQ(AlignmentMask(size_t{2}), size_t{1}); + EXPECT_EQ(AlignmentMask(size_t{4}), size_t{3}); +} + +TEST(AlignDown, Aligns) { + EXPECT_EQ(AlignDown(uintptr_t{3}, 4), 0); + EXPECT_EQ(AlignDown(uintptr_t{0}, 4), 0); + EXPECT_EQ(AlignDown(uintptr_t{5}, 4), 4); + EXPECT_EQ(AlignDown(uintptr_t{4}, 4), 4); + + uint64_t val = 0; + EXPECT_EQ(AlignDown(&val, alignof(val)), &val); +} + +TEST(AlignUp, Aligns) { + EXPECT_EQ(AlignUp(uintptr_t{0}, 4), 0); + EXPECT_EQ(AlignUp(uintptr_t{3}, 4), 4); + EXPECT_EQ(AlignUp(uintptr_t{5}, 4), 8); + + uint64_t val = 0; + EXPECT_EQ(AlignUp(&val, alignof(val)), &val); +} + +TEST(IsAligned, Aligned) { + EXPECT_TRUE(IsAligned(uintptr_t{0}, 4)); + EXPECT_TRUE(IsAligned(uintptr_t{4}, 4)); + EXPECT_FALSE(IsAligned(uintptr_t{3}, 4)); + EXPECT_FALSE(IsAligned(uintptr_t{5}, 4)); + + uint64_t val = 0; + EXPECT_TRUE(IsAligned(&val, alignof(val))); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/benchmark.h b/internal/benchmark.h new file mode 100644 index 000000000..6a34fa0b0 --- /dev/null +++ b/internal/benchmark.h @@ -0,0 +1,20 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_BENCHMARK_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_BENCHMARK_H_ + +#include "benchmark/benchmark.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_BENCHMARK_H_ diff --git a/internal/cast.h b/internal/cast.h deleted file mode 100644 index f5c3349b4..000000000 --- a/internal/cast.h +++ /dev/null @@ -1,127 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_DOWN_CAST_H_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_DOWN_CAST_H_H_ - -#include -#include - -#include "absl/memory/memory.h" -#include "absl/types/optional.h" -#include "internal/port.h" -#include "internal/specialize.h" -#include "internal/types.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -template -struct StaticDownCastHelper; - -// A down casting function that only checks RTT if NDEBUG is not defined and -// works for std::unique_ptr. -template -inline T static_down_cast(F&& value) { - return std::forward(StaticDownCastHelper::cast(std::forward(value))); -} - -template -struct StaticDownCastHelper { - template - inline static T* cast(F* value) { - static_assert(std::is_base_of::value, "bad down cast."); - assert(dynamic_cast(value) != nullptr); - return static_cast(value); - } -}; - -template -struct StaticDownCastHelper { - template - inline static T& cast(F&& value) { - return *StaticDownCastHelper::cast(&value); - } -}; - -template -struct StaticDownCastHelper { - template - inline static T&& cast(F&& value) { - return std::forward(*StaticDownCastHelper::cast(&value)); - } -}; - -template -struct StaticDownCastHelper> { - template - inline static std::unique_ptr cast(std::unique_ptr&& value) { - return absl::WrapUnique(StaticDownCastHelper::cast(value.release())); - } -}; - -// Default impl returns true if the types are the same. -template -struct RepresentableAsHelper { - static constexpr bool check(const U&) { - return std::is_same, U>::value; - } -}; - -// Convertible pointers always return true. -template -struct RepresentableAsHelper>> { - static constexpr bool check(const U*) { return true; } -}; - -// Convertible references always return true. -template -struct RepresentableAsHelper>> { - static constexpr bool check(const U&) { return true; } -}; - -// Numeric types check boundaries. -template -struct RepresentableAsHelper>> { - static bool check(const U& value) { - // Handle infinity and nan. - if (std::numeric_limits::has_infinity && !std::isfinite(value)) { - return true; - } - // Explicitly handle signed to avoid implicit conversion issues. - if (!std::numeric_limits::is_signed && value < 0) { - return false; - } - return value >= std::numeric_limits::min() && - value <= std::numeric_limits::max(); - } -}; - -/** Returns if the value is representable as the given type T. */ -template -bool representable_as(const U& value) { - return RepresentableAsHelper::check(value); -} - -/** Converts a smart pointer to an absl::optional value. */ -template -absl::optional copy_if(const F& value) { - if (value) { - return absl::optional(absl::in_place, *value); - } - return absl::nullopt; -} - -/** Converts a pointer to an absl::optional value. */ -template -absl::optional copy_if(const T* value) { - return copy_if(value); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_DOWN_CAST_H_H_ diff --git a/internal/cast_test.cc b/internal/cast_test.cc deleted file mode 100644 index 75368fbca..000000000 --- a/internal/cast_test.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "internal/cast.h" - -#include - -#include "gtest/gtest.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { -namespace { - -struct A {}; -struct B : A {}; -struct C {}; - -TEST(CastTest, RepAs_Numeric) { - EXPECT_TRUE(representable_as(1)); - EXPECT_TRUE(representable_as(1u)); - EXPECT_TRUE(representable_as(1.5)); - EXPECT_FALSE(representable_as(std::numeric_limits::max())); - EXPECT_FALSE(representable_as(std::numeric_limits::max())); - EXPECT_FALSE( - representable_as(std::numeric_limits::infinity())); - - EXPECT_TRUE(representable_as(1)); - EXPECT_TRUE(representable_as(1u)); - EXPECT_TRUE(representable_as(1.5)); - EXPECT_FALSE(representable_as(-1)); - EXPECT_FALSE(representable_as(-1.0)); - EXPECT_FALSE( - representable_as(std::numeric_limits::infinity())); - - EXPECT_TRUE(representable_as(1)); - EXPECT_TRUE(representable_as(1u)); - EXPECT_TRUE(representable_as(1.5)); - EXPECT_FALSE(representable_as(std::numeric_limits::max())); - EXPECT_TRUE(representable_as(std::numeric_limits::infinity())); - EXPECT_TRUE( - representable_as(std::numeric_limits::quiet_NaN())); -} - -TEST(CastTest, RepAs_Value) { - A a; - B b; - C c; - - // Representable as self. - EXPECT_TRUE(representable_as(a)); - EXPECT_TRUE(representable_as(a)); - EXPECT_TRUE(representable_as(&a)); - - // Representable as ref or pointer to base class. - EXPECT_FALSE(representable_as(b)); - EXPECT_TRUE(representable_as(b)); - EXPECT_TRUE(representable_as(&b)); - - // Defaults to false when conversion would be required. - EXPECT_FALSE(representable_as(b)); - EXPECT_FALSE(representable_as(c)); - - // Bad down casting. - EXPECT_FALSE(representable_as(a)); - EXPECT_FALSE(representable_as(&a)); - - // Down casting not currently supported - A& b_as_a = b; - EXPECT_FALSE(representable_as(b_as_a)); - EXPECT_FALSE(representable_as(&b_as_a)); -} - -TEST(CastTest, CopyIf) { - A* null_ptr = nullptr; - std::unique_ptr null_uptr; - A a; - EXPECT_EQ(absl::nullopt, copy_if(null_ptr)); - EXPECT_TRUE(copy_if(&a).has_value()); - EXPECT_EQ(absl::nullopt, copy_if(null_uptr)); -} - -} // namespace -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/casts.h b/internal/casts.h new file mode 100644 index 000000000..495c2a017 --- /dev/null +++ b/internal/casts.h @@ -0,0 +1,50 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_CASTS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_CASTS_H_ + +#include +#include +#include + +namespace cel::internal { + +template +To down_cast(From* from) { + static_assert(std::is_pointer_v, "Target type not a pointer."); + static_assert((std::is_base_of_v>), + "Target type not derived from source type."); +#if !defined(__GNUC__) || defined(__GXX_RTTI) + assert(from == nullptr || dynamic_cast(from) != nullptr); +#endif + return static_cast(from); +} + +template +To down_cast(From& from) { + static_assert(std::is_lvalue_reference_v, + "Target type not a lvalue reference."); + static_assert((std::is_base_of_v>), + "Target type not derived from source type."); +#if !defined(__GNUC__) || defined(__GXX_RTTI) + assert(dynamic_cast>>( + std::addressof(from)) != nullptr); +#endif + return static_cast(from); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_CASTS_H_ diff --git a/internal/cel_printer.cc b/internal/cel_printer.cc deleted file mode 100644 index 3787a2812..000000000 --- a/internal/cel_printer.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "internal/cel_printer.h" - -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "absl/strings/str_cat.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -constexpr const absl::string_view BaseJoinPolicy::kValueDelim; - -constexpr const absl::string_view ListJoinPolicy::kStart; -constexpr const absl::string_view ListJoinPolicy::kEnd; -constexpr const absl::string_view SetJoinPolicy::kStart; -constexpr const absl::string_view SetJoinPolicy::kEnd; -constexpr const absl::string_view CallJoinPolicy::kStart; -constexpr const absl::string_view CallJoinPolicy::kEnd; - -constexpr const absl::string_view MapJoinPolicy::kKeyDelim; -constexpr const absl::string_view ObjectJoinPolicy::kKeyDelim; - -std::string ScalarPrinter::operator()(absl::Time value) { - return ToCallString( - google::protobuf::Timestamp::descriptor()->full_name(), - absl::FormatTime("%Y-%m-%dT%H:%M:%E*SZ", value, absl::UTCTimeZone())); -} - -std::string ScalarPrinter::operator()(absl::Duration value) { - return ToCallString(google::protobuf::Duration::descriptor()->full_name(), - absl::FormatDuration(value)); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/cel_printer.h b/internal/cel_printer.h deleted file mode 100644 index 330a37be7..000000000 --- a/internal/cel_printer.h +++ /dev/null @@ -1,358 +0,0 @@ -/** - * Helper classes to converts native value into CEL expressions. - */ - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_CEL_PRINTER_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_CEL_PRINTER_H_ - -#include - -#include "absl/strings/escaping.h" -#include "absl/time/time.h" -#include "internal/specialize.h" -#include "internal/types.h" -#include "internal/visitor_util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -// Print specific overload helpers. -template -using print_if = specialize_if; -template -using print_ift = specialize_ift; - -// Helper to print a string value without quotes/escaping. -struct RawString { - std::string name; - const std::string& ToString() const { return name; } -}; - -/** - * Printer for all scalar types (null, bool, int, unit, double, string). - */ -struct ScalarPrinter { - inline std::string operator()(std::nullptr_t) { return "null"; } - inline std::string operator()(bool value) { return value ? "true" : "false"; } - std::string operator()(absl::Time value); - std::string operator()(absl::Duration value); - - // Print int directly. - template - print_ift> operator()(T value) { - return absl::StrCat(value); - } - - // Print uint directly with a trailing 'u'. - template - print_ift> operator()(T value) { - return absl::StrCat(value, "u"); - } - - // Print float types directly and add a trailing '.0' if necessary. - template - print_ift> operator()(T value) { - T ipart; - if (std::isfinite(value) && std::modf(value, &ipart) == 0.0) { - return absl::StrCat(value, ".0"); - } - return absl::StrCat(value); - } - - // Quote and escape any 'string' type. - template - print_ift> operator()(T&& value) { - return absl::StrCat("\"", absl::CEscape(value), "\""); - } -}; - -/** - * A printer that forwards to functions on the value. - */ -struct ForwardingPrinter { - // If the value defines a ToString function, call it. - template - specialize_ifd().ToString())> operator()( - T&& value) { - return std::string(value.ToString()); - } - - // If the value defines a ToDebugString function, call it. - template - specialize_ifd().ToDebugString())> - operator()(T&& value) { - return std::string(value.ToDebugString()); - } -}; - -/** - * A printer that can print all CelValue values. - */ -struct CelPrinter : OrderedVisitor {}; - -/** - * The type of key used in a sequence. - */ -enum KeyType { - kNoKey, - kValueKey, - kIdentKey, -}; - -/** - * The base join policy for use with sequence printers. - */ -struct BaseJoinPolicy { - static constexpr const absl::string_view kValueDelim = ", "; - static constexpr const KeyType kKeyType = kNoKey; -}; - -/** - * Produces: [, , ...] - */ -struct ListJoinPolicy : BaseJoinPolicy { - static constexpr const absl::string_view kStart = "["; - static constexpr const absl::string_view kEnd = "]"; -}; - -/** - * Produces: {, , ...} - */ -struct SetJoinPolicy : BaseJoinPolicy { - static constexpr const absl::string_view kStart = "{"; - static constexpr const absl::string_view kEnd = "}"; -}; - -/** - * Produces: (, , ...) - */ -struct CallJoinPolicy : BaseJoinPolicy { - static constexpr const absl::string_view kStart = "("; - static constexpr const absl::string_view kEnd = ")"; -}; - -/** - * Produces: {: , : , ...} - */ -struct MapJoinPolicy : SetJoinPolicy { - static constexpr const absl::string_view kKeyDelim = ": "; - static constexpr KeyType kKeyType = kValueKey; -}; - -/** - * Produces: {=, =, ...} - */ -struct ObjectJoinPolicy : SetJoinPolicy { - static constexpr const absl::string_view kKeyDelim = "="; - static constexpr const KeyType kKeyType = kIdentKey; -}; - -// Join policy specific overload helpers. -template -using print_if_no_key = print_if; -template -using print_if_has_key = print_if; -template -using print_if_value_key = print_if; -template -using print_if_ident_key = print_if; - -/** - * A printer for entries in a sequence. - */ -template -struct EntryPrinter { - CelPrinter value_printer; - - // Print an entry with a value key. - template - print_if_value_key operator()(K&& key, V&& value) { - return absl::StrCat(value_printer(std::forward(key)), - JoinPolicy::kKeyDelim, - value_printer(std::forward(value))); - } - - // Print an entry with a ident key. - template - print_if_ident_key operator()(absl::string_view ident_key, - V&& value) { - if (ident_key.empty()) { - return value_printer(std::forward(value)); - } - return absl::StrCat(ident_key, JoinPolicy::kKeyDelim, - value_printer(std::forward(value))); - } - - template - print_if_no_key operator()(T&& entry) { - // Pass through. - return value_printer(std::forward(entry)); - } - - // Forward first and second to the proper overload. - template - print_if_has_key operator()(T&& entry) { - return (*this)(std::forward(entry).first, std::forward(entry).second); - } -}; - -/** - * A sequence printer for standard containers. - */ -template -struct SequencePrinter { - EntryPrinter entry_printer; - - template - std::string operator()(absl::string_view name, T&& value) { - std::string result; - absl::StrAppend(&result, name, JoinPolicy::kStart); - auto itr = value.begin(); - if (itr != value.end()) { - absl::StrAppend(&result, entry_printer(*itr)); - while (++itr != value.end()) { - absl::StrAppend(&result, JoinPolicy::kValueDelim, entry_printer(*itr)); - } - } - absl::StrAppend(&result, JoinPolicy::kEnd); - return result; - } -}; - -/** - * A sequence printer for sequences with variable types, known at compile time. - */ -template -struct VarSequencePrinter { - EntryPrinter entry_printer; - - template - std::string operator()(absl::string_view name, Args&&... args) { - std::string result; - absl::StrAppend(&result, name, JoinPolicy::kStart); - PrintArgs(&result, std::forward(args)...); - absl::StrAppend(&result, JoinPolicy::kEnd); - return result; - } - - private: - // No args. - void PrintArgs(std::string*) {} - - // Args for a non-keyed collection. - template - void PrintArgs(print_if_no_key* result, V&& value, - Args&&... args) { - absl::StrAppend(result, entry_printer(std::forward(value))); - if (!args_empty::value) { - absl::StrAppend(result, JoinPolicy::kValueDelim); - PrintArgs(result, std::forward(args)...); - } - } - - // Args for a keyed collection. - template - void PrintArgs(print_if_has_key* result, K&& key, V&& value, - Args&&... args) { - absl::StrAppend( - result, entry_printer(std::forward(key), std::forward(value))); - if (!args_empty::value) { - absl::StrAppend(result, JoinPolicy::kValueDelim); - PrintArgs(result, std::forward(args)...); - } - } -}; - -/** - * A sequence builder for all types of sequences. - */ -template -class SequenceBuilder { - public: - template - void Add(Args&&... args) { - absl::StrAppend(&result_, entry_printer_(std::forward(args)...), - JoinPolicy::kValueDelim); - } - - std::string Build(absl::string_view name = "") { - absl::string_view result(result_); - if (!result.empty()) { - result = - result.substr(0, result.size() - JoinPolicy::kValueDelim.length()); - } - return absl::StrCat(name, JoinPolicy::kStart, result, JoinPolicy::kEnd); - } - - private: - std::string result_; - EntryPrinter entry_printer_; -}; - -/** - * Helper function to print a single value. - */ -template -std::string ToString(T&& value) { - CelPrinter printer; - return printer(std::forward(value)); -} - -/** - * Helper function to print a list value. - */ -template -std::string ToListString(T&& value) { - SequencePrinter printer; - return printer("", std::forward(value)); -} - -// Helper overload to make initializer list literals work. -template -std::string ToListString(std::initializer_list&& value) { - SequencePrinter printer; - return printer("", std::forward>(value)); -} - -/** - * Helper function to print a map value. - */ -template -std::string ToMapString(T&& value) { - SequencePrinter printer; - return printer("", std::forward(value)); -} - -// Helper overload to make initializer list literals work. -template > -std::string ToMapString(std::initializer_list&& value) { - SequencePrinter printer; - return printer("", std::forward>(value)); -} - -/** - * Helper function to print a call to a function. - */ -template -std::string ToCallString(absl::string_view name, Args&&... args) { - VarSequencePrinter printer; - return printer(name, std::forward(args)...); -} - -/** - * Helper function to print object creation. - */ -template -std::string ToObjectString(absl::string_view name, Args&&... args) { - VarSequencePrinter printer; - return printer(name, std::forward(args)...); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_CEL_PRINTER_H_ diff --git a/internal/cel_printer_test.cc b/internal/cel_printer_test.cc deleted file mode 100644 index f6743fdd9..000000000 --- a/internal/cel_printer_test.cc +++ /dev/null @@ -1,94 +0,0 @@ -#include "internal/cel_printer.h" - -#include - -#include "gtest/gtest.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -TEST(CelPrinterTest, ToString_String) { - EXPECT_EQ("\"\"", ToString("")); - EXPECT_EQ("\"hi\"", ToString(std::string("hi"))); - EXPECT_EQ("\"h\\000i\"", ToString(absl::string_view("h\000i", 3))); -} - -TEST(CelPrinterTest, ToString_Number) { - EXPECT_EQ("1", ToString(1)); - EXPECT_EQ("1u", ToString(1u)); - EXPECT_EQ("1.0", ToString(1.0)); -} - -TEST(CelPrinterTest, ToListString) { - EXPECT_EQ("[]", ToListString({})); - EXPECT_EQ("[1]", ToListString({1})); - EXPECT_EQ("[1u, 2u]", ToListString({1u, 2u})); - EXPECT_EQ("[1.0, 2.0, 3.5]", ToListString({1.0, 2.0, 3.5})); - auto actual = ToListString( - {"one", "2", absl::string_view("h\000i", 3), std::string("4")}); - EXPECT_EQ("[\"one\", \"2\", \"h\\000i\", \"4\"]", actual); - - EXPECT_EQ("[1, 2, 3]", ToListString>({1, 2, 3})); -} - -TEST(CelPrinterTest, EntryPrinter) { - EntryPrinter printer; - EXPECT_EQ("1: 2u", printer(1, 2u)); - EXPECT_EQ("1: 2u", printer(std::make_pair(1, 2u))); -} - -TEST(CelPrinterTest, ToMapString) { - EXPECT_EQ("{}", ToMapString({})); - EXPECT_EQ("{1: 2u}", ToMapString({std::make_pair(1, 2u)})); - EXPECT_EQ("{1: 2u, 3: 4u}", ToMapString({std::make_pair(1, 2u), {3, 4u}})); - - auto actual = ToMapString>({{3, 4u}, {1, 2u}}); - EXPECT_EQ("{1: 2u, 3: 4u}", actual); - actual = ToMapString>({{3, 4u}, {1, 2u}}); - EXPECT_EQ("{3: 4u, 1: 2u}", actual); -} - -TEST(CelPrinterTest, SequenceBuilder) { - SequenceBuilder builder; - builder.Add(1, 2u); - builder.Add(3.0, "four"); - EXPECT_EQ("name{1: 2u, 3.0: \"four\"}", builder.Build("name")); -} - -TEST(CelPrinterTest, ToCallString) { - EXPECT_EQ("()", ToCallString("")); - EXPECT_EQ("name()", ToCallString("name")); - EXPECT_EQ("name(1)", ToCallString("name", 1)); - EXPECT_EQ("name(1, 2u)", ToCallString("name", 1, 2u)); - EXPECT_EQ("name(1, 2u, 3.0)", ToCallString("name", 1, 2u, 3.0)); - EXPECT_EQ("name(1, 2u, 3.0, \"4\")", ToCallString("name", 1, 2u, 3.0, "4")); -} - -TEST(CelPrinterTest, ToObjectString) { - EXPECT_EQ("{}", ToObjectString("")); - EXPECT_EQ("object_type{}", ToObjectString("object_type")); - EXPECT_EQ("object_type{1}", ToObjectString("object_type", "", 1)); - EXPECT_EQ("object_type{1, uint=2u}", - ToObjectString("object_type", "", 1, "uint", 2u)); - EXPECT_EQ("object_type{1, uint=2u, 3.0}", - ToObjectString("object_type", "", 1, "uint", 2u, "", 3.0)); - EXPECT_EQ( - "object_type{1, uint=2u, 3.0, string=\"4\"}", - ToObjectString("object_type", "", 1, "uint", 2u, "", 3.0, "string", "4")); -} - -TEST(CelPrinterTest, NaN) { - EXPECT_EQ("nan", ToString(std::numeric_limits::quiet_NaN())); -} - -TEST(CelPrinterTest, Inf) { - EXPECT_EQ("inf", ToString(std::numeric_limits::infinity())); - EXPECT_EQ("-inf", ToString(-std::numeric_limits::infinity())); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/empty_descriptors.cc b/internal/empty_descriptors.cc new file mode 100644 index 000000000..05e3843a5 --- /dev/null +++ b/internal/empty_descriptors.cc @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/empty_descriptors.h" + +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/log/die_if_null.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace cel::internal { + +namespace { + +ABSL_CONST_INIT const uint8_t kEmptyDescriptorSet[] = { +#include "internal/empty_descriptor_set_embed.inc" +}; + +const google::protobuf::DescriptorPool* absl_nonnull GetEmptyDescriptorPool() { + static const google::protobuf::DescriptorPool* absl_nonnull const pool = []() { + google::protobuf::FileDescriptorSet file_desc_set; + ABSL_CHECK(file_desc_set.ParseFromArray( // Crash OK + kEmptyDescriptorSet, ABSL_ARRAYSIZE(kEmptyDescriptorSet))); + auto* pool = new google::protobuf::DescriptorPool(); + for (const auto& file_desc : file_desc_set.file()) { + ABSL_CHECK(pool->BuildFile(file_desc) != nullptr); // Crash OK + } + return pool; + }(); + return pool; +} + +google::protobuf::MessageFactory* absl_nonnull GetEmptyMessageFactory() { + static absl::NoDestructor factory; + return &*factory; +} + +} // namespace + +const google::protobuf::Message* absl_nonnull GetEmptyDefaultInstance() { + static const google::protobuf::Message* absl_nonnull const instance = []() { + return ABSL_DIE_IF_NULL( // Crash OK + ABSL_DIE_IF_NULL( // Crash OK + GetEmptyMessageFactory()->GetPrototype( + ABSL_DIE_IF_NULL( // Crash OK + GetEmptyDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Empty"))))) + ->New(); + }(); + return instance; +} + +} // namespace cel::internal diff --git a/internal/empty_descriptors.h b/internal/empty_descriptors.h new file mode 100644 index 000000000..dfe6f2e3b --- /dev/null +++ b/internal/empty_descriptors.h @@ -0,0 +1,31 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_EMPTY_DESCRIPTORS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_EMPTY_DESCRIPTORS_H_ + +#include "absl/base/nullability.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::internal { + +// GetEmptyDefaultInstance returns a pointer to a `google::protobuf::Message` which is an +// instance of `google.protobuf.Empty`. The returned `google::protobuf::Message` is valid +// for the lifetime of the process. +const google::protobuf::Message* absl_nonnull GetEmptyDefaultInstance(); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_EMPTY_DESCRIPTORS_H_ diff --git a/internal/empty_descriptors_test.cc b/internal/empty_descriptors_test.cc new file mode 100644 index 000000000..c14bd1bc9 --- /dev/null +++ b/internal/empty_descriptors_test.cc @@ -0,0 +1,32 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/empty_descriptors.h" + +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +using ::testing::NotNull; + +TEST(GetEmptyDefaultInstance, Empty) { + const auto* empty = GetEmptyDefaultInstance(); + ASSERT_THAT(empty, NotNull()); + EXPECT_EQ(empty->GetDescriptor()->full_name(), "google.protobuf.Empty"); + EXPECT_EQ(empty, GetEmptyDefaultInstance()); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/equals_text_proto.cc b/internal/equals_text_proto.cc new file mode 100644 index 000000000..c9a6f517d --- /dev/null +++ b/internal/equals_text_proto.cc @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/equals_text_proto.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/strings/cord.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel::internal { + +void TextProtoMatcher::DescribeTo(std::ostream* os) const { + std::string text; + ABSL_CHECK( // Crash OK + google::protobuf::TextFormat::PrintToString(*message_, &text)); + *os << "is equal to <" << text << ">"; +} + +void TextProtoMatcher::DescribeNegationTo(std::ostream* os) const { + std::string text; + ABSL_CHECK( // Crash OK + google::protobuf::TextFormat::PrintToString(*message_, &text)); + *os << "is not equal to <" << text << ">"; +} + +bool TextProtoMatcher::MatchAndExplain( + const google::protobuf::MessageLite& other, + ::testing::MatchResultListener* listener) const { + if (other.GetTypeName() != message_->GetTypeName()) { + if (listener->IsInterested()) { + *listener << "whose type should be " << message_->GetTypeName() + << " but actually is " << other.GetTypeName(); + } + return false; + } + google::protobuf::util::MessageDifferencer differencer; + std::string diff; + if (listener->IsInterested()) { + differencer.ReportDifferencesToString(&diff); + } + bool match; + if (const auto* other_full_message = + google::protobuf::DynamicCastMessage(&other); + other_full_message != nullptr && + other_full_message->GetDescriptor() == message_->GetDescriptor()) { + match = differencer.Compare(*other_full_message, *message_); + } else { + auto other_message = absl::WrapUnique(message_->New()); + absl::Cord serialized; + ABSL_CHECK(other.SerializeToString(&serialized)); // Crash OK + ABSL_CHECK(other_message->ParseFromString(serialized)); // Crash OK + match = differencer.Compare(*other_message, *message_); + } + if (!match && listener->IsInterested()) { + if (!diff.empty() && diff.back() == '\n') { + diff.erase(diff.end() - 1); + } + *listener << "with the difference:\n" << diff; + } + return match; +} + +} // namespace cel::internal diff --git a/internal/equals_text_proto.h b/internal/equals_text_proto.h new file mode 100644 index 000000000..ac27a6d85 --- /dev/null +++ b/internal/equals_text_proto.h @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_EQUALS_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_EQUALS_PROTO_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace cel::internal { + +class TextProtoMatcher { + public: + TextProtoMatcher(const google::protobuf::Message* absl_nonnull message, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory) + : message_(message), pool_(pool), factory_(factory) {} + + void DescribeTo(std::ostream* os) const; + + void DescribeNegationTo(std::ostream* os) const; + + bool MatchAndExplain(const google::protobuf::MessageLite& other, + ::testing::MatchResultListener* listener) const; + + private: + const google::protobuf::Message* absl_nonnull message_; + const google::protobuf::DescriptorPool* absl_nonnull pool_; + google::protobuf::MessageFactory* absl_nonnull factory_; +}; + +template +::testing::PolymorphicMatcher EqualsTextProto( + google::protobuf::Arena* absl_nonnull arena, absl::string_view text, + const google::protobuf::DescriptorPool* absl_nonnull pool = + GetTestingDescriptorPool(), + google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { + return ::testing::MakePolymorphicMatcher(TextProtoMatcher( + DynamicParseTextProto(arena, text, pool, factory), pool, factory)); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_EQUALS_PROTO_H_ diff --git a/internal/exceptions.h b/internal/exceptions.h new file mode 100644 index 000000000..2b53f25c5 --- /dev/null +++ b/internal/exceptions.h @@ -0,0 +1,35 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_EXCEPTIONS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_EXCEPTIONS_H_ + +#include "absl/base/config.h" // IWYU pragma: keep + +#ifdef ABSL_HAVE_EXCEPTIONS +#define CEL_INTERNAL_TRY try +#define CEL_INTERNAL_CATCH_ANY catch (...) +#define CEL_INTERNAL_RETHROW \ + do { \ + throw; \ + } while (false) +#else +#define CEL_INTERNAL_TRY if (true) +#define CEL_INTERNAL_CATCH_ANY else if (false) +#define CEL_INTERNAL_RETHROW \ + do { \ + } while (false) +#endif + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_EXCEPTIONS_H_ diff --git a/internal/handle.h b/internal/handle.h deleted file mode 100644 index ff4a1d914..000000000 --- a/internal/handle.h +++ /dev/null @@ -1,100 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_HANDLE_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_HANDLE_H_ - -#include "internal/hash_util.h" -#include "internal/specialize.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -class HandleBase { - protected: - ~HandleBase() = default; -}; - -/** - * Handle provides a strongly typed wrapper around an integral type suitable - * for use as an ID. - * - * Handle disallows mutation with the exception of assignment from - * another Handle of the same template instantiation type. Handle defines - * comparison operators (==,!=,<,<=,>=,>). For example: - * - * class DatabaseId : public Handle {...}; - * class TableId : public Handle {...}; - * - * DatabaseId(1) == 1; // compile time error - * DatabaseId(1) + 1; // compile time error - * DatabaseId(1) == TableId(1); // compile time error - * DatabaseId(1) == DatabaseId(1); // true - * DatabaseId(1) == DatabaseId(2); // false - * DatabaseId(1) <= DatabaseId(2); // true - */ -template -class Handle : public HandleBase { - public: - struct Hasher { - std::size_t operator()(const DerivedType& handle) const { - return internal::Hash(handle.value()); - } - }; - - constexpr explicit Handle(T value) : value_(value) {} - - constexpr T value() const { return value_; } - - protected: - static_assert(std::is_enum::value || std::is_integral::value || - std::is_pointer::value, - "ValueType must be an integer or pointer type"); - T value_; - - // Can't be destroyed directly. - ~Handle() = default; -}; - -template -constexpr specialize_ift, bool> operator==( - const T& lhs, const T& rhs) { - return lhs.value() == rhs.value(); -} - -template -constexpr specialize_ift, bool> operator!=( - const T& lhs, const T& rhs) { - return lhs.value() != rhs.value(); -} - -// Comparison operator useful for data structures that require ordering among -// elements like set<>, map<>, etc... -template -constexpr specialize_ift, bool> operator<( - const T& lhs, const T& rhs) { - return lhs.value() < rhs.value(); -} - -template -constexpr specialize_ift, bool> operator<=( - const T& lhs, const T& rhs) { - return lhs.value() <= rhs.value(); -} - -template -constexpr specialize_ift, bool> operator>( - const T& lhs, const T& rhs) { - return lhs.value() > rhs.value(); -} -template -constexpr specialize_ift, bool> operator>=( - const T& lhs, const T& rhs) { - return lhs.value() >= rhs.value(); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_HANDLE_H_ diff --git a/internal/handle_test.cc b/internal/handle_test.cc deleted file mode 100644 index b0120f16a..000000000 --- a/internal/handle_test.cc +++ /dev/null @@ -1,66 +0,0 @@ -#include "internal/handle.h" - -#include "gtest/gtest.h" -#include "absl/container/node_hash_set.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { -namespace { - -class IntHandle : public Handle { - public: - constexpr explicit IntHandle(int value) : Handle(value) {} -}; - -class OtherIntHandle : public Handle { - public: - constexpr explicit OtherIntHandle(int value) : Handle(value) {} -}; - -// Should be usable in a constexpr. -constexpr IntHandle kOne = IntHandle(1); -constexpr bool kOneVsTwo = kOne == IntHandle(2) && kOne != IntHandle(2) && - kOne < IntHandle(2) && kOne <= IntHandle(2) && - kOne > IntHandle(2) && kOne >= IntHandle(2); - -TEST(Handle, TypeSafty) { - auto convertible = std::is_convertible::value; - EXPECT_FALSE(convertible); - convertible = std::is_convertible::value; - EXPECT_FALSE(convertible); -} - -TEST(Handle, Operators) { - EXPECT_TRUE(IntHandle(1) == IntHandle(1)); - EXPECT_TRUE(IntHandle(1) != IntHandle(2)); - EXPECT_TRUE(IntHandle(1) < IntHandle(2)); - EXPECT_TRUE(IntHandle(1) <= IntHandle(2)); - EXPECT_TRUE(IntHandle(2) > IntHandle(1)); - EXPECT_TRUE(IntHandle(2) >= IntHandle(1)); -} - -TEST(Handle, Hash) { - absl::node_hash_set handles; - handles.emplace(1); - handles.insert(IntHandle(1)); - handles.emplace(2); - - EXPECT_EQ(handles.size(), 2); -} - -TEST(Handle, Order) { - std::set handles; - handles.emplace(1); - handles.insert(IntHandle(1)); - handles.emplace(2); - - EXPECT_EQ(handles.size(), 2); -} - -} // namespace -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/hash_util.cc b/internal/hash_util.cc deleted file mode 100644 index c44bd347e..000000000 --- a/internal/hash_util.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "internal/hash_util.h" - -#include - -namespace google { -namespace api { -namespace expr { -namespace internal { - -std::size_t HashImpl(const std::string& value, specialize) { - return StdHash(value); -} - -std::size_t HashImpl(const google::rpc::Status& value, specialize) { - std::size_t hash = Hash(value.code()); - AccumulateHash(value.message(), &hash); - return hash; -} - -std::size_t HashImpl(absl::string_view value, specialize) { - return StdHash(std::string(value)); -} - -std::size_t HashImpl(absl::Duration value, specialize) { - return StdHash(absl::ToInt64Nanoseconds(value)); -} - -std::size_t HashImpl(absl::Time value, specialize) { - return StdHash(absl::ToUnixNanos(value)); -} - -std::size_t HashImpl(std::nullptr_t, specialize) { return 0; } - -std::size_t HashImpl(const google::protobuf::Any& value, specialize) { - return Hash(value.type_url(), value.value()); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/hash_util.h b/internal/hash_util.h deleted file mode 100644 index 655c75c1c..000000000 --- a/internal/hash_util.h +++ /dev/null @@ -1,163 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_HASH_UTIL_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_HASH_UTIL_H_ - -#include -#include - -#include "google/protobuf/any.pb.h" -#include "google/rpc/status.pb.h" -#include "absl/strings/string_view.h" -#include "absl/time/time.h" -#include "internal/port.h" -#include "internal/specialize.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -/** Offset to use so that Hash() is not always 0. */ -constexpr const std::size_t kIntegralTypeOffset = 109; - -/** Hash value reserved to represent 'not computed'. */ -constexpr const std::size_t kNoHash = std::numeric_limits::max(); - -// Lightly hash two hash codes together. When used repetitively to mix more -// than two values, the new values should be in the first argument. -ABSL_MUST_USE_RESULT inline size_t MixHash(size_t new_hash, size_t accu) { - static const size_t kMul = static_cast(0xc6a4a7935bd1e995ULL); - // Multiplicative hashing will mix bits better in the msb end ... - accu *= kMul; - // ... and rotating will move the better mixed msb-bits to lsb-bits. - return ((accu << 21) | (accu >> (std::numeric_limits::digits - 21))) + - new_hash; -} - -// Lightly hash two hash codes together, such that the result is order -// independent. -ABSL_MUST_USE_RESULT inline std::size_t MixHashNoOrder(size_t new_hash, - size_t accu) { - return new_hash + accu; -} - -template -ABSL_MUST_USE_RESULT std::size_t Hash(T&& value); - -template -ABSL_MUST_USE_RESULT std::size_t Hash(T&& head, Args&&... rest); - -template -void AccumulateHash(T&& new_value, std::size_t* accu) { - *accu = MixHash(Hash(std::forward(new_value)), *accu); -} - -template -void AccumulateHashNoOrder(T&& new_value, std::size_t* accu) { - *accu = MixHashNoOrder(Hash(std::forward(new_value)), *accu); -} - -// Helper to call std::hash function. -template -std::size_t StdHash(const T& value) { - return std::hash{}(value); -} - -template -std::size_t HashImpl(const T& value, general) { - return StdHash(value); -} - -std::size_t HashImpl(const std::string& value, specialize); -std::size_t HashImpl(absl::string_view value, specialize); -std::size_t HashImpl(absl::Duration value, specialize); -std::size_t HashImpl(absl::Time value, specialize); -std::size_t HashImpl(std::nullptr_t, specialize); -std::size_t HashImpl(const google::rpc::Status& value, specialize); -std::size_t HashImpl(const google::protobuf::Any& value, specialize); - -// Hack for supporting 'cords'. -template -std::size_t HashImpl(const T& cord, specialize_for) { - return StdHash(cord.ToString()); -} - -// Specialization for any type that defines a `hash_code` function. -template -std::size_t HashImpl(const T& value, specialize_for) { - return value.hash_code(); -} - -// Specialization for classes that define a zero arg constructable hasher. -template -std::size_t HashImpl(const T& value, specialize_for) { - return typename T::Hasher()(value); -} - -// Specialization for enums. -template -std::size_t HashImpl(T&& v, specialize_ift>) { - return Hash(static_cast::type>(v)); -} - -template -ABSL_MUST_USE_RESULT std::size_t Hash(T&& value) { - return HashImpl(std::forward(value), specialize()); -} - -template -ABSL_MUST_USE_RESULT std::size_t Hash(T&& head, Args&&... rest) { - return MixHash(Hash(std::forward(head)), - Hash(std::forward(rest)...)); -} - -template -ABSL_MUST_USE_RESULT std::size_t HashNoOrder(T&& value) { - return Hash(std::forward(value)); -} - -template -ABSL_MUST_USE_RESULT std::size_t HashNoOrder(T&& head, Args&&... rest) { - return MixHashNoOrder(Hash(std::forward(head)), - Hash(std::forward(rest)...)); -} - -template -ABSL_MUST_USE_RESULT std::size_t HashPair(F&& first, S&& second) { - size_t h1 = Hash(first); - size_t h2 = Hash(second); - if (std::is_integral::value) { - // We want to avoid absl::Hash({x, y}) == 0 for common values of {x, y}. - // hash is the identity function for integral types X, so without this, - // absl::Hash({0, 0}) would be 0. - h1 += kIntegralTypeOffset; - } - return MixHash(h1, h2); -} - -/** A visitor for use with CelValue's absl::variant. */ -struct Hasher { - template - std::size_t operator()(const T& value) const { - return Hash(value); - } -}; - -template -std::size_t LazyComputeHash(T&& hash_fn, std::atomic* hash_code) { - std::size_t code = hash_code->load(std::memory_order_relaxed); - if (code == internal::kNoHash) { - code = hash_fn(); - if (code == internal::kNoHash) { - --code; - } - hash_code->store(code, std::memory_order_relaxed); - } - return code; -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_HASH_UTIL_H_ diff --git a/internal/holder.h b/internal/holder.h deleted file mode 100644 index 76833aef1..000000000 --- a/internal/holder.h +++ /dev/null @@ -1,239 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_HOLDERS_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_HOLDERS_H_ - -#include -#include - -#include "internal/port.h" -#include "internal/specialize.h" -#include "internal/types.h" -#include "internal/visitor_util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -/** - * A class that holds a value using the given holder policy. - * - * This class is largely used to smooth over the awkward syntax for - * referencing dependent template type names (see the type of value_). - * - * @tparam T the type of value to hold. - * @tparam HolderPolicy the policy by which the value should be held. - */ -template -class Holder { - using ValueType = typename HolderPolicy::template ValueType; - - public: - Holder(Holder&) = default; - Holder(const Holder&) = default; - Holder(Holder&&) = default; - Holder& operator=(const Holder&) = default; - Holder& operator=(Holder&&) = default; - - template - explicit Holder(Args&&... args) - : value_(HolderPolicy::template Create( - std::forward(args)...)) {} - - T& value() { return HolderPolicy::template get(value_); } - const T& value() const { return HolderPolicy::template get(value_); } - - T& operator*() { return value(); } - const T& operator*() const { return value(); } - - T* operator->() { return &value(); } - const T* operator->() const { return &value(); } - - private: - ValueType value_; -}; - -struct BaseHolderPolicy { - /** - * In-place create by default. - * @tparam V the hold's value type. - * @tparam T the type being held. - */ - template - static V Create(Args&&... args) { - return V(std::forward(args)...); - } -}; - -/** A holder policy that keeps a copy of the value. */ -struct Copy : BaseHolderPolicy { - constexpr static const bool kOwnsValue = true; - - template - using ValueType = remove_const_t; - - template - static T& get(T& value) { - return value; - } - - template - static const T& get(const T& value) { - return value; - } -}; - -template -using CopyHolder = Holder; - -/** A holder policy that keeps a unique_ptr of the given type. */ -struct OwnedPtr : BaseHolderPolicy { - constexpr static const bool kOwnsValue = true; - - template - using ValueType = std::unique_ptr; - - template - static T& get(std::unique_ptr& value) { - assert(value != nullptr); - return *value; - } - - template - static const T& get(const std::unique_ptr& value) { - assert(value != nullptr); - return *value; - } -}; - -template -using OwnedPtrHolder = Holder; - -/** A holder policy that keeps a raw pointer of the given type. */ -struct UnownedPtr : BaseHolderPolicy { - constexpr static const bool kOwnsValue = false; - - template - using ValueType = T*; - - template - static T& get(T* value) { - assert(value != nullptr); - return *value; - } - - template - static const T& get(const T* value) { - assert(value != nullptr); - return *value; - } -}; - -template -using UnownedPtrHolder = Holder; - -/** A holder policy the keeps a reference on a parent container. */ -template -struct ParentOwned : BaseHolderPolicy { - // It owns the value since it owns a ref on the parent. - constexpr static const bool kOwnsValue = true; - - template - using InternalValueType = typename HolderPolicy::template ValueType; - - template - static V Create(const P& parent, Args&&... args) { - return V(parent, HolderPolicy::template Create, T>( - std::forward(args)...)); - } - - /** - * The holder stores a pair of a reference to the parent and the real value. - */ - template - using ValueType = std::pair>; - - /** Returns the real value from the pair. */ - template - static T& get(std::pair>& value) { - return HolderPolicy::get(value.second); - } - - /** Returns the real const value from the pair. */ - template - static const T& get(const std::pair>& value) { - return HolderPolicy::get(value.second); - } -}; - -template -using ParentOwnedPtr = ParentOwned; - -template -using ParentOwnedCopy = ParentOwned; - -/** An adapter that returns a held value regardless of the holder policy. */ -struct HolderAdapter { - template - const T& operator()(const Holder& value) { - return value.value(); - } - - template - T& operator()(Holder& value) { - return value.value(); - } - - template - T operator()(Holder&& value) { - return std::move(value.value()); - } -}; - -/** - * An adapter that returns a pointer to the held value regardless of the holder - * policy. - */ -struct HolderPtrAdapter { - template - const T* operator()(const Holder& value) { - return &value.value(); - } - - template - T* operator()(Holder& value) { - return &value.value(); - } -}; - -/** A wrapper that only passes through holders of a specific set of types. */ -template -struct StrictHolderAdapter : Adapter { - template - StrictHolderAdapter(Args&&... args) : Adapter(std::forward(args)...) {} - - template - specialize_ift, - VisitResultType&>> - operator()(const Holder& value) { - return Adapter::operator()(value); - } - - template - specialize_ift, VisitResultType&>> - operator()(Holder& value) { - return Adapter::operator()(value); - } - - template - specialize_ift, VisitResultType&&>> - operator()(Holder&& value) { - return Adapter::operator()(std::move(value)); - } -}; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_HOLDERS_H_ diff --git a/internal/holder_test.cc b/internal/holder_test.cc deleted file mode 100644 index ac06d043c..000000000 --- a/internal/holder_test.cc +++ /dev/null @@ -1,233 +0,0 @@ -#include "internal/holder.h" - -#include "gtest/gtest.h" -#include "absl/strings/str_cat.h" -#include "internal/types.h" -#include "testutil/util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { -namespace { - -TEST(Holder, Access) { - Holder holder; - - // Can be accessed like a smart pointer (similar to std::optional) - EXPECT_EQ("", *holder); - EXPECT_TRUE(holder->empty()); - *holder = "hi"; - EXPECT_EQ("hi", holder.value()); - EXPECT_FALSE(holder->empty()); - holder.value() = "bye"; - EXPECT_EQ("bye", holder.value()); - EXPECT_FALSE(holder->empty()); -} - -TEST(Holder, Copy) { - // Copy supports all modes - using HolderType = Holder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Value can be mutated. - HolderType holder(1); - testutil::ExpectSameType(); - EXPECT_EQ(1, holder.value()); - holder.value() = 2; - EXPECT_EQ(2, holder.value()); - - // Const holder cannot be assigned or have its value changed. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(2); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value()); -} - -TEST(Holder, Copy_const) { - // Const Copy supports all modes. - using HolderType = Holder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Value cannot be changed. - HolderType holder(1); - testutil::ExpectSameType(); - EXPECT_EQ(1, holder.value()); - - // Const holder has the same properties. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(2); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value()); -} - -TEST(Holder, OwnedPtr) { - // OwnedPtr can only be moved. - using HolderType = Holder; - EXPECT_FALSE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Null cannot be accessed. - HolderType holder; - testutil::ExpectSameType(); -#ifndef NDEBUG // Assert only throws when debugging. - EXPECT_DEATH(holder.value(), "null"); - holder = HolderType(nullptr); - EXPECT_DEATH(holder.value(), "null"); -#endif - - // Value can be mutated. - holder = HolderType(absl::make_unique(1)); - EXPECT_EQ(1, holder.value()); - holder.value() = 2; - EXPECT_EQ(2, holder.value()); - - // Const holder is not assignable, and cannot have its value changed. - EXPECT_FALSE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(absl::make_unique(2)); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value()); -} - -TEST(Holder, OwnedPtr_const) { - // OwnedPtr of a const value cannot be copied, but can be assigned. - using HolderType = Holder; - EXPECT_FALSE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Null cannot be accessed. - HolderType holder; -#ifndef NDEBUG // Assert only throws when debugging. - EXPECT_DEATH(holder.value(), "null"); - holder = HolderType(nullptr); - EXPECT_DEATH(holder.value(), "null"); -#endif - - // Value cannot be changed. - testutil::ExpectSameType(); - // Holder can be assigned. - holder = HolderType(absl::make_unique(1)); - EXPECT_EQ(1, holder.value()); - holder = HolderType(absl::make_unique(2)); - EXPECT_EQ(2, holder.value()); - - // Const version const only be moved. - EXPECT_FALSE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(absl::make_unique(3)); - testutil::ExpectSameType(); - EXPECT_EQ(3, const_holder.value()); -} - -TEST(Holder, UnownedPtr) { - // UnownedPtr supports all modes. - using HolderType = Holder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Null cannot be accessed. - HolderType holder(static_cast(0)); -#ifndef NDEBUG // Assert only throws when debugging. - EXPECT_DEATH(holder.value(), "null"); - holder = HolderType(nullptr); - EXPECT_DEATH(holder.value(), "null"); -#endif - - // Value can be mutated - testutil::ExpectSameType(); - int i = 1; - holder = HolderType(&i); - EXPECT_EQ(1, holder.value()); - holder.value() = 2; - EXPECT_EQ(2, holder.value()); - EXPECT_EQ(2, i); - - // Const holder cannot be assigned, and value cannot be changed. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(&i); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value()); - EXPECT_EQ(&holder.value(), &const_holder.value()); -} - -TEST(Holder, UnownedPtr_const) { - // UnownedPtr to a const value supports all modes. - using HolderType = Holder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Null cannot be accessed. - HolderType holder(static_cast(0)); -#ifndef NDEBUG // Assert only throws when debugging. - EXPECT_DEATH(holder.value(), "null"); - holder = HolderType(nullptr); - EXPECT_DEATH(holder.value(), "null"); -#endif - - // Value cannot be changed, but holder can be assigned. - testutil::ExpectSameType(); - int i = 1; - holder = HolderType(&i); - EXPECT_EQ(1, holder.value()); - - // Const holder cannot be assigned, and value cannot be changed. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(&i); - testutil::ExpectSameType(); - EXPECT_EQ(1, const_holder.value()); - EXPECT_EQ(&holder.value(), &const_holder.value()); -} - -// Crazy policy that inits everything with the string "cat#", where # is the -// number of args. -struct CatHolderPolicy : Copy { - template - static V Create(Args&&... args) { - return absl::StrCat("cat", args_size::value); - } -}; - -TEST(Holder, Create) { - using HolderType = Holder; - HolderType holder; - EXPECT_EQ("cat0", *holder); - holder = HolderType(1, "foo"); - EXPECT_EQ("cat2", *holder); -} - -} // namespace -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/json.cc b/internal/json.cc new file mode 100644 index 000000000..cdd4c1a5d --- /dev/null +++ b/internal/json.cc @@ -0,0 +1,2041 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/json.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/cord.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "internal/well_known_types.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" +#include "google/protobuf/util/time_util.h" + +#undef GetMessage + +namespace cel::internal { + +namespace { + +using ::cel::well_known_types::AsVariant; +using ::cel::well_known_types::GetListValueReflection; +using ::cel::well_known_types::GetRepeatedBytesField; +using ::cel::well_known_types::GetRepeatedStringField; +using ::cel::well_known_types::GetStructReflection; +using ::cel::well_known_types::GetValueReflection; +using ::cel::well_known_types::JsonReflection; +using ::cel::well_known_types::ListValueReflection; +using ::cel::well_known_types::Reflection; +using ::cel::well_known_types::StructReflection; +using ::cel::well_known_types::ValueReflection; +using ::google::protobuf::Descriptor; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::util::TimeUtil; + +// Yanked from the implementation `google::protobuf::util::TimeUtil`. +template +absl::Status SnakeCaseToCamelCaseImpl(Chars input, + std::string* absl_nonnull output) { + output->clear(); + bool after_underscore = false; + for (char input_char : input) { + if (absl::ascii_isupper(input_char)) { + // The field name must not contain uppercase letters. + return absl::InvalidArgumentError( + "field mask path name contains uppercase letters"); + } + if (after_underscore) { + if (absl::ascii_islower(input_char)) { + output->push_back(absl::ascii_toupper(input_char)); + after_underscore = false; + } else { + // The character after a "_" must be a lowercase letter. + return absl::InvalidArgumentError( + "field mask path contains '_' not followed by a lowercase letter"); + } + } else if (input_char == '_') { + after_underscore = true; + } else { + output->push_back(input_char); + } + } + if (after_underscore) { + // Trailing "_". + return absl::InvalidArgumentError("field mask path contains trailing '_'"); + } + return absl::OkStatus(); +} + +absl::Status SnakeCaseToCamelCase(const well_known_types::StringValue& input, + std::string* absl_nonnull output) { + return absl::visit(absl::Overload( + [&](absl::string_view string) -> absl::Status { + return SnakeCaseToCamelCaseImpl(string, output); + }, + [&](const absl::Cord& cord) -> absl::Status { + return SnakeCaseToCamelCaseImpl(cord.Chars(), + output); + }), + AsVariant(input)); +} + +class MessageToJsonState; + +using MapFieldKeyToString = std::string (*)(const google::protobuf::MapKey&); + +std::string BoolMapFieldKeyToString(const google::protobuf::MapKey& key) { + return key.GetBoolValue() ? "true" : "false"; +} + +std::string Int32MapFieldKeyToString(const google::protobuf::MapKey& key) { + return absl::StrCat(key.GetInt32Value()); +} + +std::string Int64MapFieldKeyToString(const google::protobuf::MapKey& key) { + return absl::StrCat(key.GetInt64Value()); +} + +std::string UInt32MapFieldKeyToString(const google::protobuf::MapKey& key) { + return absl::StrCat(key.GetUInt32Value()); +} + +std::string UInt64MapFieldKeyToString(const google::protobuf::MapKey& key) { + return absl::StrCat(key.GetUInt64Value()); +} + +std::string StringMapFieldKeyToString(const google::protobuf::MapKey& key) { + return std::string(key.GetStringValue()); +} + +MapFieldKeyToString GetMapFieldKeyToString( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_BOOL: + return &BoolMapFieldKeyToString; + case FieldDescriptor::CPPTYPE_INT32: + return &Int32MapFieldKeyToString; + case FieldDescriptor::CPPTYPE_INT64: + return &Int64MapFieldKeyToString; + case FieldDescriptor::CPPTYPE_UINT32: + return &UInt32MapFieldKeyToString; + case FieldDescriptor::CPPTYPE_UINT64: + return &UInt64MapFieldKeyToString; + case FieldDescriptor::CPPTYPE_STRING: + return &StringMapFieldKeyToString; + default: + ABSL_UNREACHABLE(); + } +} + +using MapFieldValueToValue = absl::Status (MessageToJsonState::*)( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result); + +using RepeatedFieldToValue = absl::Status (MessageToJsonState::*)( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result); + +class MessageToJsonState { + public: + MessageToJsonState(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory) + : descriptor_pool_(descriptor_pool), message_factory_(message_factory) {} + + virtual ~MessageToJsonState() = default; + + absl::Status ToJson(const google::protobuf::Message& message, + google::protobuf::MessageLite* absl_nonnull result) { + const auto* descriptor = message.GetDescriptor(); + switch (descriptor->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { + CEL_RETURN_IF_ERROR(reflection_.DoubleValue().Initialize(descriptor)); + SetNumberValue(result, reflection_.DoubleValue().GetValue(message)); + } break; + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: { + CEL_RETURN_IF_ERROR(reflection_.FloatValue().Initialize(descriptor)); + SetNumberValue(result, reflection_.FloatValue().GetValue(message)); + } break; + case Descriptor::WELLKNOWNTYPE_INT64VALUE: { + CEL_RETURN_IF_ERROR(reflection_.Int64Value().Initialize(descriptor)); + SetNumberValue(result, reflection_.Int64Value().GetValue(message)); + } break; + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: { + CEL_RETURN_IF_ERROR(reflection_.UInt64Value().Initialize(descriptor)); + SetNumberValue(result, reflection_.UInt64Value().GetValue(message)); + } break; + case Descriptor::WELLKNOWNTYPE_INT32VALUE: { + CEL_RETURN_IF_ERROR(reflection_.Int32Value().Initialize(descriptor)); + SetNumberValue(result, reflection_.Int32Value().GetValue(message)); + } break; + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: { + CEL_RETURN_IF_ERROR(reflection_.UInt32Value().Initialize(descriptor)); + SetNumberValue(result, reflection_.UInt32Value().GetValue(message)); + } break; + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: { + CEL_RETURN_IF_ERROR(reflection_.StringValue().Initialize(descriptor)); + StringValueToJson(reflection_.StringValue().GetValue(message, scratch_), + result); + } break; + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: { + CEL_RETURN_IF_ERROR(reflection_.BytesValue().Initialize(descriptor)); + BytesValueToJson(reflection_.BytesValue().GetValue(message, scratch_), + result); + } break; + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: { + CEL_RETURN_IF_ERROR(reflection_.BoolValue().Initialize(descriptor)); + SetBoolValue(result, reflection_.BoolValue().GetValue(message)); + } break; + case Descriptor::WELLKNOWNTYPE_ANY: { + CEL_ASSIGN_OR_RETURN(auto unpacked, + well_known_types::UnpackAnyFrom( + result->GetArena(), reflection_.Any(), message, + descriptor_pool_, message_factory_)); + auto* struct_result = MutableStructValue(result); + const auto* unpacked_descriptor = unpacked->GetDescriptor(); + SetStringValue(InsertField(struct_result, "@type"), + absl::StrCat("type.googleapis.com/", + unpacked_descriptor->full_name())); + switch (unpacked_descriptor->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_FIELDMASK: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_DURATION: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_STRUCT: + return ToJson(*unpacked, InsertField(struct_result, "value")); + default: + if (unpacked_descriptor->full_name() == "google.protobuf.Empty") { + MutableStructValue(InsertField(struct_result, "value")); + return absl::OkStatus(); + } else { + return MessageToJson(*unpacked, struct_result); + } + } + } + case Descriptor::WELLKNOWNTYPE_FIELDMASK: { + CEL_RETURN_IF_ERROR(reflection_.FieldMask().Initialize(descriptor)); + std::vector paths; + const int paths_size = reflection_.FieldMask().PathsSize(message); + for (int i = 0; i < paths_size; ++i) { + CEL_RETURN_IF_ERROR(SnakeCaseToCamelCase( + reflection_.FieldMask().Paths(message, i, scratch_), + &paths.emplace_back())); + } + SetStringValue(result, absl::StrJoin(paths, ",")); + } break; + case Descriptor::WELLKNOWNTYPE_DURATION: { + CEL_RETURN_IF_ERROR(reflection_.Duration().Initialize(descriptor)); + google::protobuf::Duration duration; + duration.set_seconds(reflection_.Duration().GetSeconds(message)); + duration.set_nanos(reflection_.Duration().GetNanos(message)); + SetStringValue(result, TimeUtil::ToString(duration)); + } break; + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + CEL_RETURN_IF_ERROR(reflection_.Timestamp().Initialize(descriptor)); + google::protobuf::Timestamp timestamp; + timestamp.set_seconds(reflection_.Timestamp().GetSeconds(message)); + timestamp.set_nanos(reflection_.Timestamp().GetNanos(message)); + SetStringValue(result, TimeUtil::ToString(timestamp)); + } break; + case Descriptor::WELLKNOWNTYPE_VALUE: { + absl::Cord serialized; + if (!message.SerializePartialToString(&serialized)) { + return absl::UnknownError( + "failed to serialize message google.protobuf.Value"); + } + if (!result->ParsePartialFromString(serialized)) { + return absl::UnknownError( + "failed to parsed message: google.protobuf.Value"); + } + } break; + case Descriptor::WELLKNOWNTYPE_LISTVALUE: { + absl::Cord serialized; + if (!message.SerializePartialToString(&serialized)) { + return absl::UnknownError( + "failed to serialize message google.protobuf.ListValue"); + } + if (!MutableListValue(result)->ParsePartialFromString(serialized)) { + return absl::UnknownError( + "failed to parsed message: google.protobuf.ListValue"); + } + } break; + case Descriptor::WELLKNOWNTYPE_STRUCT: { + absl::Cord serialized; + if (!message.SerializePartialToString(&serialized)) { + return absl::UnknownError( + "failed to serialize message google.protobuf.Struct"); + } + if (!MutableStructValue(result)->ParsePartialFromString(serialized)) { + return absl::UnknownError( + "failed to parsed message: google.protobuf.Struct"); + } + } break; + default: + return MessageToJson(message, MutableStructValue(result)); + } + return absl::OkStatus(); + } + + absl::Status ToJsonObject(const google::protobuf::Message& message, + google::protobuf::MessageLite* absl_nonnull result) { + return MessageToJson(message, result); + } + + absl::Status FieldToJson(const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + return MessageFieldToJson(message, field, result); + } + + absl::Status FieldToJsonArray( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + return MessageRepeatedFieldToJson(message, field, result); + } + + absl::Status FieldToJsonObject( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + return MessageMapFieldToJson(message, field, result); + } + + virtual absl::Status Initialize( + google::protobuf::MessageLite* absl_nonnull message) = 0; + + private: + absl::StatusOr GetMapFieldValueToValue( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + switch (field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + return &MessageToJsonState::MapDoubleFieldToValue; + case FieldDescriptor::TYPE_FLOAT: + return &MessageToJsonState::MapFloatFieldToValue; + case FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT64: + return &MessageToJsonState::MapUInt64FieldToValue; + case FieldDescriptor::TYPE_BOOL: + return &MessageToJsonState::MapBoolFieldToValue; + case FieldDescriptor::TYPE_STRING: + return &MessageToJsonState::MapStringFieldToValue; + case FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_MESSAGE: + return &MessageToJsonState::MapMessageFieldToValue; + case FieldDescriptor::TYPE_BYTES: + return &MessageToJsonState::MapBytesFieldToValue; + case FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT32: + return &MessageToJsonState::MapUInt32FieldToValue; + case FieldDescriptor::TYPE_ENUM: { + const auto* enum_descriptor = field->enum_type(); + if (enum_descriptor->full_name() == "google.protobuf.NullValue") { + return &MessageToJsonState::MapNullFieldToValue; + } else { + return &MessageToJsonState::MapEnumFieldToValue; + } + } + case FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT32: + return &MessageToJsonState::MapInt32FieldToValue; + case FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT64: + return &MessageToJsonState::MapInt64FieldToValue; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected message field type: ", field->type_name())); + } + } + + absl::Status MapBoolFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_BOOL); + SetBoolValue(result, value.GetBoolValue()); + return absl::OkStatus(); + } + + absl::Status MapInt32FieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT32); + SetNumberValue(result, value.GetInt32Value()); + return absl::OkStatus(); + } + + absl::Status MapInt64FieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT64); + SetNumberValue(result, value.GetInt64Value()); + return absl::OkStatus(); + } + + absl::Status MapUInt32FieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT32); + SetNumberValue(result, value.GetUInt32Value()); + return absl::OkStatus(); + } + + absl::Status MapUInt64FieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT64); + SetNumberValue(result, value.GetUInt64Value()); + return absl::OkStatus(); + } + + absl::Status MapFloatFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_FLOAT); + SetNumberValue(result, value.GetFloatValue()); + return absl::OkStatus(); + } + + absl::Status MapDoubleFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_DOUBLE); + SetNumberValue(result, value.GetDoubleValue()); + return absl::OkStatus(); + } + + absl::Status MapBytesFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + SetStringValueFromBytes(result, value.GetStringValue()); + return absl::OkStatus(); + } + + absl::Status MapStringFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + SetStringValue(result, value.GetStringValue()); + return absl::OkStatus(); + } + + absl::Status MapMessageFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_MESSAGE); + return ToJson(value.GetMessageValue(), result); + } + + absl::Status MapEnumFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); + ABSL_DCHECK_NE(field->enum_type()->full_name(), + "google.protobuf.NullValue"); + if (const auto* value_descriptor = + field->enum_type()->FindValueByNumber(value.GetEnumValue()); + value_descriptor != nullptr) { + SetStringValue(result, value_descriptor->name()); + } else { + SetNumberValue(result, value.GetEnumValue()); + } + return absl::OkStatus(); + } + + absl::Status MapNullFieldToValue( + const google::protobuf::MapValueConstRef& value, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(value.type(), field->cpp_type()); + ABSL_DCHECK(!field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); + ABSL_DCHECK_EQ(field->enum_type()->full_name(), + "google.protobuf.NullValue"); + SetNullValue(result); + return absl::OkStatus(); + } + + absl::StatusOr GetRepeatedFieldToValue( + const google::protobuf::FieldDescriptor* absl_nonnull field) { + switch (field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + return &MessageToJsonState::RepeatedDoubleFieldToValue; + case FieldDescriptor::TYPE_FLOAT: + return &MessageToJsonState::RepeatedFloatFieldToValue; + case FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT64: + return &MessageToJsonState::RepeatedUInt64FieldToValue; + case FieldDescriptor::TYPE_BOOL: + return &MessageToJsonState::RepeatedBoolFieldToValue; + case FieldDescriptor::TYPE_STRING: + return &MessageToJsonState::RepeatedStringFieldToValue; + case FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_MESSAGE: + return &MessageToJsonState::RepeatedMessageFieldToValue; + case FieldDescriptor::TYPE_BYTES: + return &MessageToJsonState::RepeatedBytesFieldToValue; + case FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT32: + return &MessageToJsonState::RepeatedUInt32FieldToValue; + case FieldDescriptor::TYPE_ENUM: { + const auto* enum_descriptor = field->enum_type(); + if (enum_descriptor->full_name() == "google.protobuf.NullValue") { + return &MessageToJsonState::RepeatedNullFieldToValue; + } else { + return &MessageToJsonState::RepeatedEnumFieldToValue; + } + } + case FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT32: + return &MessageToJsonState::RepeatedInt32FieldToValue; + case FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT64: + return &MessageToJsonState::RepeatedInt64FieldToValue; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected message field type: ", field->type_name())); + } + } + + absl::Status RepeatedBoolFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_BOOL); + SetBoolValue(result, reflection->GetRepeatedBool(message, field, index)); + return absl::OkStatus(); + } + + absl::Status RepeatedInt32FieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT32); + SetNumberValue(result, reflection->GetRepeatedInt32(message, field, index)); + return absl::OkStatus(); + } + + absl::Status RepeatedInt64FieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT64); + SetNumberValue(result, reflection->GetRepeatedInt64(message, field, index)); + return absl::OkStatus(); + } + + absl::Status RepeatedUInt32FieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT32); + SetNumberValue(result, + reflection->GetRepeatedUInt32(message, field, index)); + return absl::OkStatus(); + } + + absl::Status RepeatedUInt64FieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT64); + SetNumberValue(result, + reflection->GetRepeatedUInt64(message, field, index)); + return absl::OkStatus(); + } + + absl::Status RepeatedFloatFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_FLOAT); + SetNumberValue(result, reflection->GetRepeatedFloat(message, field, index)); + return absl::OkStatus(); + } + + absl::Status RepeatedDoubleFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_DOUBLE); + SetNumberValue(result, + reflection->GetRepeatedDouble(message, field, index)); + return absl::OkStatus(); + } + + absl::Status RepeatedBytesFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + absl::visit(absl::Overload( + [&](absl::string_view string) -> void { + SetStringValueFromBytes(result, string); + }, + [&](absl::Cord&& cord) -> void { + SetStringValueFromBytes(result, cord); + }), + AsVariant(GetRepeatedBytesField(reflection, message, field, + index, scratch_))); + return absl::OkStatus(); + } + + absl::Status RepeatedStringFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + absl::visit( + absl::Overload( + [&](absl::string_view string) -> void { + SetStringValue(result, string); + }, + [&](absl::Cord&& cord) -> void { SetStringValue(result, cord); }), + AsVariant(GetRepeatedStringField(reflection, message, field, index, + scratch_))); + return absl::OkStatus(); + } + + absl::Status RepeatedMessageFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_MESSAGE); + return ToJson(reflection->GetRepeatedMessage(message, field, index), + result); + } + + absl::Status RepeatedEnumFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); + ABSL_DCHECK_NE(field->enum_type()->full_name(), + "google.protobuf.NullValue"); + if (const auto* value = reflection->GetRepeatedEnum(message, field, index); + value != nullptr) { + SetStringValue(result, value->name()); + } else { + SetNumberValue(result, + reflection->GetRepeatedEnumValue(message, field, index)); + } + return absl::OkStatus(); + } + + absl::Status RepeatedNullFieldToValue( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + google::protobuf::MessageLite* absl_nonnull result) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); + ABSL_DCHECK_EQ(field->enum_type()->full_name(), + "google.protobuf.NullValue"); + SetNullValue(result); + return absl::OkStatus(); + } + + absl::Status MessageMapFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + const auto* reflection = message.GetReflection(); + if (reflection->FieldSize(message, field) == 0) { + return absl::OkStatus(); + } + const auto key_to_string = + GetMapFieldKeyToString(field->message_type()->map_key()); + const auto* value_descriptor = field->message_type()->map_value(); + CEL_ASSIGN_OR_RETURN(const auto value_to_value, + GetMapFieldValueToValue(value_descriptor)); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + message, *field); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, message, *field); + for (; begin != end; ++begin) { + auto key = (*key_to_string)(begin.GetKey()); + CEL_RETURN_IF_ERROR((this->*value_to_value)( + begin.GetValueRef(), value_descriptor, InsertField(result, key))); + } + return absl::OkStatus(); + } + + absl::Status MessageRepeatedFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + const auto* reflection = message.GetReflection(); + const int size = reflection->FieldSize(message, field); + if (size == 0) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN(const auto to_value, GetRepeatedFieldToValue(field)); + for (int index = 0; index < size; ++index) { + CEL_RETURN_IF_ERROR((this->*to_value)(reflection, message, field, index, + AddValues(result))); + } + return absl::OkStatus(); + } + + absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + google::protobuf::MessageLite* absl_nonnull result) { + if (field->is_map()) { + return MessageMapFieldToJson(message, field, MutableStructValue(result)); + } + if (field->is_repeated()) { + return MessageRepeatedFieldToJson(message, field, + MutableListValue(result)); + } + const auto* reflection = message.GetReflection(); + switch (field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + SetNumberValue(result, reflection->GetDouble(message, field)); + break; + case FieldDescriptor::TYPE_FLOAT: + SetNumberValue(result, reflection->GetFloat(message, field)); + break; + case FieldDescriptor::TYPE_FIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT64: + SetNumberValue(result, reflection->GetUInt64(message, field)); + break; + case FieldDescriptor::TYPE_BOOL: + SetBoolValue(result, reflection->GetBool(message, field)); + break; + case FieldDescriptor::TYPE_STRING: + StringValueToJson( + well_known_types::GetStringField(message, field, scratch_), result); + break; + case FieldDescriptor::TYPE_GROUP: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_MESSAGE: + return ToJson((reflection->GetMessage)(message, field), result); + case FieldDescriptor::TYPE_BYTES: + BytesValueToJson( + well_known_types::GetBytesField(message, field, scratch_), result); + break; + case FieldDescriptor::TYPE_FIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_UINT32: + SetNumberValue(result, reflection->GetUInt32(message, field)); + break; + case FieldDescriptor::TYPE_ENUM: { + const auto* enum_descriptor = field->enum_type(); + if (enum_descriptor->full_name() == "google.protobuf.NullValue") { + SetNullValue(result); + } else { + const auto* enum_value_descriptor = + reflection->GetEnum(message, field); + if (enum_value_descriptor != nullptr) { + SetStringValue(result, enum_value_descriptor->name()); + } else { + SetNumberValue(result, reflection->GetEnumValue(message, field)); + } + } + } break; + case FieldDescriptor::TYPE_SFIXED32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT32: + SetNumberValue(result, reflection->GetInt32(message, field)); + break; + case FieldDescriptor::TYPE_SFIXED64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_SINT64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::TYPE_INT64: + SetNumberValue(result, reflection->GetInt64(message, field)); + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected message field type: ", field->type_name())); + } + return absl::OkStatus(); + } + + absl::Status MessageToJson(const google::protobuf::Message& message, + google::protobuf::MessageLite* absl_nonnull result) { + std::vector fields; + const auto* reflection = message.GetReflection(); + reflection->ListFields(message, &fields); + if (!fields.empty()) { + for (const auto* field : fields) { + CEL_RETURN_IF_ERROR(MessageFieldToJson( + message, field, InsertField(result, field->json_name()))); + } + } + return absl::OkStatus(); + } + + void StringValueToJson(const well_known_types::StringValue& value, + google::protobuf::MessageLite* absl_nonnull result) const { + absl::visit(absl::Overload([&](absl::string_view string) + -> void { SetStringValue(result, string); }, + [&](const absl::Cord& cord) -> void { + SetStringValue(result, cord); + }), + AsVariant(value)); + } + + void BytesValueToJson(const well_known_types::BytesValue& value, + google::protobuf::MessageLite* absl_nonnull result) const { + absl::visit(absl::Overload( + [&](absl::string_view string) -> void { + SetStringValueFromBytes(result, string); + }, + [&](const absl::Cord& cord) -> void { + SetStringValueFromBytes(result, cord); + }), + AsVariant(value)); + } + + virtual void SetNullValue( + google::protobuf::MessageLite* absl_nonnull message) const = 0; + + virtual void SetBoolValue(google::protobuf::MessageLite* absl_nonnull message, + bool value) const = 0; + + virtual void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + double value) const = 0; + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + float value) const { + SetNumberValue(message, static_cast(value)); + } + + virtual void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + int64_t value) const = 0; + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + int32_t value) const { + SetNumberValue(message, static_cast(value)); + } + + virtual void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + uint64_t value) const = 0; + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + uint32_t value) const { + SetNumberValue(message, static_cast(value)); + } + + virtual void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, + absl::string_view value) const = 0; + + virtual void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, + const absl::Cord& value) const = 0; + + void SetStringValueFromBytes(google::protobuf::MessageLite* absl_nonnull message, + absl::string_view value) const { + if (value.empty()) { + SetStringValue(message, value); + return; + } + SetStringValue(message, absl::Base64Escape(value)); + } + + void SetStringValueFromBytes(google::protobuf::MessageLite* absl_nonnull message, + const absl::Cord& value) const { + if (value.empty()) { + SetStringValue(message, value); + return; + } + if (auto flat = value.TryFlat(); flat) { + SetStringValue(message, absl::Base64Escape(*flat)); + return; + } + SetStringValue(message, + absl::Base64Escape(static_cast(value))); + } + + virtual google::protobuf::MessageLite* absl_nonnull MutableListValue( + google::protobuf::MessageLite* absl_nonnull message) const = 0; + + virtual google::protobuf::MessageLite* absl_nonnull MutableStructValue( + google::protobuf::MessageLite* absl_nonnull message) const = 0; + + virtual google::protobuf::MessageLite* absl_nonnull AddValues( + google::protobuf::MessageLite* absl_nonnull message) const = 0; + + virtual google::protobuf::MessageLite* absl_nonnull InsertField( + google::protobuf::MessageLite* absl_nonnull message, + absl::string_view name) const = 0; + + const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull const message_factory_; + std::string scratch_; + Reflection reflection_; +}; + +class GeneratedMessageToJsonState final : public MessageToJsonState { + public: + using MessageToJsonState::MessageToJsonState; + + absl::Status Initialize(google::protobuf::MessageLite* absl_nonnull message) override { + // Nothing to do. + return absl::OkStatus(); + } + + private: + void SetNullValue(google::protobuf::MessageLite* absl_nonnull message) const override { + ValueReflection::SetNullValue( + google::protobuf::DownCastMessage(message)); + } + + void SetBoolValue(google::protobuf::MessageLite* absl_nonnull message, + bool value) const override { + ValueReflection::SetBoolValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + double value) const override { + ValueReflection::SetNumberValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + int64_t value) const override { + ValueReflection::SetNumberValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + uint64_t value) const override { + ValueReflection::SetNumberValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, + absl::string_view value) const override { + ValueReflection::SetStringValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, + const absl::Cord& value) const override { + ValueReflection::SetStringValue( + google::protobuf::DownCastMessage(message), value); + } + + google::protobuf::MessageLite* absl_nonnull MutableListValue( + google::protobuf::MessageLite* absl_nonnull message) const override { + return ValueReflection::MutableListValue( + google::protobuf::DownCastMessage(message)); + } + + google::protobuf::MessageLite* absl_nonnull MutableStructValue( + google::protobuf::MessageLite* absl_nonnull message) const override { + return ValueReflection::MutableStructValue( + google::protobuf::DownCastMessage(message)); + } + + google::protobuf::MessageLite* absl_nonnull AddValues( + google::protobuf::MessageLite* absl_nonnull message) const override { + return ListValueReflection::AddValues( + google::protobuf::DownCastMessage(message)); + } + + google::protobuf::MessageLite* absl_nonnull InsertField( + google::protobuf::MessageLite* absl_nonnull message, + absl::string_view name) const override { + return StructReflection::InsertField( + google::protobuf::DownCastMessage(message), name); + } +}; + +class DynamicMessageToJsonState final : public MessageToJsonState { + public: + using MessageToJsonState::MessageToJsonState; + + absl::Status Initialize(google::protobuf::MessageLite* absl_nonnull message) override { + CEL_RETURN_IF_ERROR(reflection_.Initialize( + google::protobuf::DownCastMessage(message)->GetDescriptor())); + return absl::OkStatus(); + } + + private: + void SetNullValue(google::protobuf::MessageLite* absl_nonnull message) const override { + reflection_.Value().SetNullValue( + google::protobuf::DownCastMessage(message)); + } + + void SetBoolValue(google::protobuf::MessageLite* absl_nonnull message, + bool value) const override { + reflection_.Value().SetBoolValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + double value) const override { + reflection_.Value().SetNumberValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + int64_t value) const override { + reflection_.Value().SetNumberValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, + uint64_t value) const override { + reflection_.Value().SetNumberValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, + absl::string_view value) const override { + reflection_.Value().SetStringValue( + google::protobuf::DownCastMessage(message), value); + } + + void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, + const absl::Cord& value) const override { + reflection_.Value().SetStringValue( + google::protobuf::DownCastMessage(message), value); + } + + google::protobuf::MessageLite* absl_nonnull MutableListValue( + google::protobuf::MessageLite* absl_nonnull message) const override { + return reflection_.Value().MutableListValue( + google::protobuf::DownCastMessage(message)); + } + + google::protobuf::MessageLite* absl_nonnull MutableStructValue( + google::protobuf::MessageLite* absl_nonnull message) const override { + return reflection_.Value().MutableStructValue( + google::protobuf::DownCastMessage(message)); + } + + google::protobuf::MessageLite* absl_nonnull AddValues( + google::protobuf::MessageLite* absl_nonnull message) const override { + return reflection_.ListValue().AddValues( + google::protobuf::DownCastMessage(message)); + } + + google::protobuf::MessageLite* absl_nonnull InsertField( + google::protobuf::MessageLite* absl_nonnull message, + absl::string_view name) const override { + return reflection_.Struct().InsertField( + google::protobuf::DownCastMessage(message), name); + } + + JsonReflection reflection_; +}; + +} // namespace + +absl::Status MessageToJson( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Value* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + auto state = std::make_unique(descriptor_pool, + message_factory); + CEL_RETURN_IF_ERROR(state->Initialize(result)); + return state->ToJson(message, result); +} + +absl::Status MessageToJson( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Struct* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + auto state = std::make_unique(descriptor_pool, + message_factory); + CEL_RETURN_IF_ERROR(state->Initialize(result)); + return state->ToJsonObject(message, result); +} + +absl::Status MessageToJson( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull result) { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + auto state = std::make_unique(descriptor_pool, + message_factory); + CEL_RETURN_IF_ERROR(state->Initialize(result)); + switch (result->GetDescriptor()->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + return state->ToJson(message, result); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + return state->ToJsonObject(message, result); + default: + return absl::InvalidArgumentError("cannot convert message to JSON array"); + } +} + +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Value* absl_nonnull result) { + ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + auto state = std::make_unique(descriptor_pool, + message_factory); + CEL_RETURN_IF_ERROR(state->Initialize(result)); + return state->FieldToJson(message, field, result); +} + +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::ListValue* absl_nonnull result) { + ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + auto state = std::make_unique(descriptor_pool, + message_factory); + CEL_RETURN_IF_ERROR(state->Initialize(result)); + return state->FieldToJsonArray(message, field, result); +} + +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Struct* absl_nonnull result) { + ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + auto state = std::make_unique(descriptor_pool, + message_factory); + CEL_RETURN_IF_ERROR(state->Initialize(result)); + return state->FieldToJsonObject(message, field, result); +} + +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull result) { + ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + auto state = std::make_unique(descriptor_pool, + message_factory); + CEL_RETURN_IF_ERROR(state->Initialize(result)); + switch (result->GetDescriptor()->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + return state->FieldToJson(message, field, result); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + return state->FieldToJsonArray(message, field, result); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + return state->FieldToJsonObject(message, field, result); + default: + return absl::InternalError("unreachable"); + } +} + +absl::Status CheckJson(const google::protobuf::MessageLite& message) { + if (const auto* generated_message = + google::protobuf::DynamicCastMessage(&message); + generated_message) { + return absl::OkStatus(); + } + if (const auto* dynamic_message = + google::protobuf::DynamicCastMessage(&message); + dynamic_message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetValueReflection(dynamic_message->GetDescriptor())); + CEL_RETURN_IF_ERROR( + GetListValueReflection(reflection.GetListValueDescriptor()).status()); + CEL_RETURN_IF_ERROR( + GetStructReflection(reflection.GetStructDescriptor()).status()); + return absl::OkStatus(); + } + return absl::InvalidArgumentError( + absl::StrCat("message must be an instance of `google.protobuf.Value`: ", + message.GetTypeName())); +} + +absl::Status CheckJsonList(const google::protobuf::MessageLite& message) { + if (const auto* generated_message = + google::protobuf::DynamicCastMessage(&message); + generated_message) { + return absl::OkStatus(); + } + if (const auto* dynamic_message = + google::protobuf::DynamicCastMessage(&message); + dynamic_message) { + CEL_ASSIGN_OR_RETURN( + auto reflection, + GetListValueReflection(dynamic_message->GetDescriptor())); + CEL_ASSIGN_OR_RETURN(auto value_reflection, + GetValueReflection(reflection.GetValueDescriptor())); + CEL_RETURN_IF_ERROR( + GetStructReflection(value_reflection.GetStructDescriptor()).status()); + return absl::OkStatus(); + } + return absl::InvalidArgumentError(absl::StrCat( + "message must be an instance of `google.protobuf.ListValue`: ", + message.GetTypeName())); +} + +absl::Status CheckJsonMap(const google::protobuf::MessageLite& message) { + if (const auto* generated_message = + google::protobuf::DynamicCastMessage(&message); + generated_message) { + return absl::OkStatus(); + } + if (const auto* dynamic_message = + google::protobuf::DynamicCastMessage(&message); + dynamic_message) { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetStructReflection(dynamic_message->GetDescriptor())); + CEL_ASSIGN_OR_RETURN(auto value_reflection, + GetValueReflection(reflection.GetValueDescriptor())); + CEL_RETURN_IF_ERROR( + GetListValueReflection(value_reflection.GetListValueDescriptor()) + .status()); + return absl::OkStatus(); + } + return absl::InvalidArgumentError( + absl::StrCat("message must be an instance of `google.protobuf.Struct`: ", + message.GetTypeName())); +} + +namespace { + +class JsonMapIterator final { + public: + using Generated = + typename google::protobuf::Map::const_iterator; + using Dynamic = google::protobuf::ConstMapIterator; + using Value = std::pair; + + // NOLINTNEXTLINE(google-explicit-constructor) + JsonMapIterator(Generated generated) : variant_(std::move(generated)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + JsonMapIterator(Dynamic dynamic) : variant_(std::move(dynamic)) {} + + JsonMapIterator(const JsonMapIterator&) = default; + JsonMapIterator(JsonMapIterator&&) = default; + JsonMapIterator& operator=(const JsonMapIterator&) = default; + JsonMapIterator& operator=(JsonMapIterator&&) = default; + + Value Next(std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + Value result; + absl::visit(absl::Overload( + [&](Generated& generated) -> void { + result = std::pair{absl::string_view(generated->first), + &generated->second}; + ++generated; + }, + [&](Dynamic& dynamic) -> void { + const auto& key = dynamic.GetKey().GetStringValue(); + scratch.assign(key.data(), key.size()); + result = + std::pair{absl::string_view(scratch), + &dynamic.GetValueRef().GetMessageValue()}; + ++dynamic; + }), + variant_); + return result; + } + + private: + std::variant variant_; +}; + +class JsonAccessor { + public: + virtual ~JsonAccessor() = default; + + virtual google::protobuf::Value::KindCase GetKindCase( + const google::protobuf::MessageLite& message) const = 0; + + virtual bool GetBoolValue(const google::protobuf::MessageLite& message) const = 0; + + virtual double GetNumberValue(const google::protobuf::MessageLite& message) const = 0; + + virtual well_known_types::StringValue GetStringValue( + const google::protobuf::MessageLite& message, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const = 0; + + virtual const google::protobuf::MessageLite& GetListValue( + const google::protobuf::MessageLite& message) const = 0; + + virtual int ValuesSize(const google::protobuf::MessageLite& message) const = 0; + + virtual const google::protobuf::MessageLite& Values(const google::protobuf::MessageLite& message, + int index) const = 0; + + virtual const google::protobuf::MessageLite& GetStructValue( + const google::protobuf::MessageLite& message) const = 0; + + virtual int FieldsSize(const google::protobuf::MessageLite& message) const = 0; + + virtual const google::protobuf::MessageLite* absl_nullable FindField( + const google::protobuf::MessageLite& message, absl::string_view name) const = 0; + + virtual JsonMapIterator IterateFields( + const google::protobuf::MessageLite& message) const = 0; +}; + +class GeneratedJsonAccessor final : public JsonAccessor { + public: + static const GeneratedJsonAccessor* absl_nonnull Singleton() { + static const absl::NoDestructor singleton; + return &*singleton; + } + + google::protobuf::Value::KindCase GetKindCase( + const google::protobuf::MessageLite& message) const override { + return ValueReflection::GetKindCase( + google::protobuf::DownCastMessage(message)); + } + + bool GetBoolValue(const google::protobuf::MessageLite& message) const override { + return ValueReflection::GetBoolValue( + google::protobuf::DownCastMessage(message)); + } + + double GetNumberValue(const google::protobuf::MessageLite& message) const override { + return ValueReflection::GetNumberValue( + google::protobuf::DownCastMessage(message)); + } + + well_known_types::StringValue GetStringValue( + const google::protobuf::MessageLite& message, std::string&) const override { + return ValueReflection::GetStringValue( + google::protobuf::DownCastMessage(message)); + } + + const google::protobuf::MessageLite& GetListValue( + const google::protobuf::MessageLite& message) const override { + return ValueReflection::GetListValue( + google::protobuf::DownCastMessage(message)); + } + + int ValuesSize(const google::protobuf::MessageLite& message) const override { + return ListValueReflection::ValuesSize( + google::protobuf::DownCastMessage(message)); + } + + const google::protobuf::MessageLite& Values(const google::protobuf::MessageLite& message, + int index) const override { + return ListValueReflection::Values( + google::protobuf::DownCastMessage(message), index); + } + + const google::protobuf::MessageLite& GetStructValue( + const google::protobuf::MessageLite& message) const override { + return ValueReflection::GetStructValue( + google::protobuf::DownCastMessage(message)); + } + + int FieldsSize(const google::protobuf::MessageLite& message) const override { + return StructReflection::FieldsSize( + google::protobuf::DownCastMessage(message)); + } + + const google::protobuf::MessageLite* absl_nullable FindField( + const google::protobuf::MessageLite& message, + absl::string_view name) const override { + return StructReflection::FindField( + google::protobuf::DownCastMessage(message), name); + } + + JsonMapIterator IterateFields( + const google::protobuf::MessageLite& message) const override { + return StructReflection::BeginFields( + google::protobuf::DownCastMessage(message)); + } +}; + +class DynamicJsonAccessor final : public JsonAccessor { + public: + void InitializeValue(const google::protobuf::Message& message) { + ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK + } + + void InitializeListValue(const google::protobuf::Message& message) { + ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK + } + + void InitializeStruct(const google::protobuf::Message& message) { + ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK + } + + google::protobuf::Value::KindCase GetKindCase( + const google::protobuf::MessageLite& message) const override { + return reflection_.Value().GetKindCase( + google::protobuf::DownCastMessage(message)); + } + + bool GetBoolValue(const google::protobuf::MessageLite& message) const override { + return reflection_.Value().GetBoolValue( + google::protobuf::DownCastMessage(message)); + } + + double GetNumberValue(const google::protobuf::MessageLite& message) const override { + return reflection_.Value().GetNumberValue( + google::protobuf::DownCastMessage(message)); + } + + well_known_types::StringValue GetStringValue( + const google::protobuf::MessageLite& message, std::string& scratch) const override { + return reflection_.Value().GetStringValue( + google::protobuf::DownCastMessage(message), scratch); + } + + const google::protobuf::MessageLite& GetListValue( + const google::protobuf::MessageLite& message) const override { + return reflection_.Value().GetListValue( + google::protobuf::DownCastMessage(message)); + } + + int ValuesSize(const google::protobuf::MessageLite& message) const override { + return reflection_.ListValue().ValuesSize( + google::protobuf::DownCastMessage(message)); + } + + const google::protobuf::MessageLite& Values(const google::protobuf::MessageLite& message, + int index) const override { + return reflection_.ListValue().Values( + google::protobuf::DownCastMessage(message), index); + } + + const google::protobuf::MessageLite& GetStructValue( + const google::protobuf::MessageLite& message) const override { + return reflection_.Value().GetStructValue( + google::protobuf::DownCastMessage(message)); + } + + int FieldsSize(const google::protobuf::MessageLite& message) const override { + return reflection_.Struct().FieldsSize( + google::protobuf::DownCastMessage(message)); + } + + const google::protobuf::MessageLite* absl_nullable FindField( + const google::protobuf::MessageLite& message, + absl::string_view name) const override { + return reflection_.Struct().FindField( + google::protobuf::DownCastMessage(message), name); + } + + JsonMapIterator IterateFields( + const google::protobuf::MessageLite& message) const override { + return reflection_.Struct().BeginFields( + google::protobuf::DownCastMessage(message)); + } + + private: + JsonReflection reflection_; +}; + +std::string JsonStringDebugString(const well_known_types::StringValue& value) { + return absl::visit(absl::Overload( + [&](absl::string_view string) -> std::string { + return FormatStringLiteral(string); + }, + [&](const absl::Cord& cord) -> std::string { + return FormatStringLiteral(cord); + }), + well_known_types::AsVariant(value)); +} + +std::string JsonNumberDebugString(double value) { + if (std::isfinite(value)) { + if (std::floor(value) != value) { + // The double is not representable as a whole number, so use + // absl::StrCat which will add decimal places. + return absl::StrCat(value); + } + // absl::StrCat historically would represent 0.0 as 0, and we want the + // decimal places so ZetaSQL correctly assumes the type as double + // instead of int64. + std::string stringified = absl::StrCat(value); + if (!absl::StrContains(stringified, '.')) { + absl::StrAppend(&stringified, ".0"); + } else { + // absl::StrCat has a decimal now? Use it directly. + } + return stringified; + } + if (std::isnan(value)) { + return "nan"; + } + if (std::signbit(value)) { + return "-infinity"; + } + return "+infinity"; +} + +class JsonDebugStringState final { + public: + JsonDebugStringState(const JsonAccessor* absl_nonnull accessor, + std::string* absl_nonnull output) + : accessor_(accessor), output_(output) {} + + void ValueDebugString(const google::protobuf::MessageLite& message) { + const auto kind_case = accessor_->GetKindCase(message); + switch (kind_case) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + output_->append("null"); + break; + case google::protobuf::Value::kBoolValue: + if (accessor_->GetBoolValue(message)) { + output_->append("true"); + } else { + output_->append("false"); + } + break; + case google::protobuf::Value::kNumberValue: + output_->append( + JsonNumberDebugString(accessor_->GetNumberValue(message))); + break; + case google::protobuf::Value::kStringValue: + output_->append(JsonStringDebugString( + accessor_->GetStringValue(message, scratch_))); + break; + case google::protobuf::Value::kListValue: + ListValueDebugString(accessor_->GetListValue(message)); + break; + case google::protobuf::Value::kStructValue: + StructDebugString(accessor_->GetStructValue(message)); + break; + default: + // Should not get here, but if for some terrible reason + // `google.protobuf.Value` is expanded, just skip. + break; + } + } + + void ListValueDebugString(const google::protobuf::MessageLite& message) { + const int size = accessor_->ValuesSize(message); + output_->push_back('['); + for (int i = 0; i < size; ++i) { + if (i > 0) { + output_->append(", "); + } + ValueDebugString(accessor_->Values(message, i)); + } + output_->push_back(']'); + } + + void StructDebugString(const google::protobuf::MessageLite& message) { + const int size = accessor_->FieldsSize(message); + std::string key_scratch; + well_known_types::StringValue key; + const google::protobuf::MessageLite* absl_nonnull value; + auto iterator = accessor_->IterateFields(message); + output_->push_back('{'); + for (int i = 0; i < size; ++i) { + if (i > 0) { + output_->append(", "); + } + std::tie(key, value) = iterator.Next(key_scratch); + output_->append(JsonStringDebugString(key)); + output_->append(": "); + ValueDebugString(*value); + } + output_->push_back('}'); + } + + private: + const JsonAccessor* absl_nonnull const accessor_; + std::string* absl_nonnull const output_; + std::string scratch_; +}; + +} // namespace + +std::string JsonDebugString(const google::protobuf::Value& message) { + std::string output; + JsonDebugStringState(GeneratedJsonAccessor::Singleton(), &output) + .ValueDebugString(message); + return output; +} + +std::string JsonDebugString(const google::protobuf::Message& message) { + DynamicJsonAccessor accessor; + accessor.InitializeValue(message); + std::string output; + JsonDebugStringState(&accessor, &output).ValueDebugString(message); + return output; +} + +std::string JsonListDebugString(const google::protobuf::ListValue& message) { + std::string output; + JsonDebugStringState(GeneratedJsonAccessor::Singleton(), &output) + .ListValueDebugString(message); + return output; +} + +std::string JsonListDebugString(const google::protobuf::Message& message) { + DynamicJsonAccessor accessor; + accessor.InitializeListValue(message); + std::string output; + JsonDebugStringState(&accessor, &output).ListValueDebugString(message); + return output; +} + +std::string JsonMapDebugString(const google::protobuf::Struct& message) { + std::string output; + JsonDebugStringState(GeneratedJsonAccessor::Singleton(), &output) + .StructDebugString(message); + return output; +} + +std::string JsonMapDebugString(const google::protobuf::Message& message) { + DynamicJsonAccessor accessor; + accessor.InitializeStruct(message); + std::string output; + JsonDebugStringState(&accessor, &output).StructDebugString(message); + return output; +} + +namespace { + +class JsonEqualsState final { + public: + explicit JsonEqualsState(const JsonAccessor* absl_nonnull lhs_accessor, + const JsonAccessor* absl_nonnull rhs_accessor) + : lhs_accessor_(lhs_accessor), rhs_accessor_(rhs_accessor) {} + + bool ValueEqual(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs) { + auto lhs_kind_case = lhs_accessor_->GetKindCase(lhs); + if (lhs_kind_case == google::protobuf::Value::KIND_NOT_SET) { + lhs_kind_case = google::protobuf::Value::kNullValue; + } + auto rhs_kind_case = rhs_accessor_->GetKindCase(rhs); + if (rhs_kind_case == google::protobuf::Value::KIND_NOT_SET) { + rhs_kind_case = google::protobuf::Value::kNullValue; + } + if (lhs_kind_case != rhs_kind_case) { + return false; + } + switch (lhs_kind_case) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_UNREACHABLE(); + case google::protobuf::Value::kNullValue: + return true; + case google::protobuf::Value::kBoolValue: + return lhs_accessor_->GetBoolValue(lhs) == + rhs_accessor_->GetBoolValue(rhs); + case google::protobuf::Value::kNumberValue: + return lhs_accessor_->GetNumberValue(lhs) == + rhs_accessor_->GetNumberValue(rhs); + case google::protobuf::Value::kStringValue: + return lhs_accessor_->GetStringValue(lhs, lhs_scratch_) == + rhs_accessor_->GetStringValue(rhs, rhs_scratch_); + case google::protobuf::Value::kListValue: + return ListValueEqual(lhs_accessor_->GetListValue(lhs), + rhs_accessor_->GetListValue(rhs)); + case google::protobuf::Value::kStructValue: + return StructEqual(lhs_accessor_->GetStructValue(lhs), + rhs_accessor_->GetStructValue(rhs)); + default: + // Should not get here, but if for some terrible reason + // `google.protobuf.Value` is expanded, default to false. + return false; + } + } + + bool ListValueEqual(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs) { + const int lhs_size = lhs_accessor_->ValuesSize(lhs); + const int rhs_size = rhs_accessor_->ValuesSize(rhs); + if (lhs_size != rhs_size) { + return false; + } + for (int i = 0; i < lhs_size; ++i) { + if (!ValueEqual(lhs_accessor_->Values(lhs, i), + rhs_accessor_->Values(rhs, i))) { + return false; + } + } + return true; + } + + bool StructEqual(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs) { + const int lhs_size = lhs_accessor_->FieldsSize(lhs); + const int rhs_size = rhs_accessor_->FieldsSize(rhs); + if (lhs_size != rhs_size) { + return false; + } + if (lhs_size == 0) { + return true; + } + std::string lhs_key_scratch; + well_known_types::StringValue lhs_key; + const google::protobuf::MessageLite* absl_nonnull lhs_value; + auto lhs_iterator = lhs_accessor_->IterateFields(lhs); + for (int i = 0; i < lhs_size; ++i) { + std::tie(lhs_key, lhs_value) = lhs_iterator.Next(lhs_key_scratch); + if (const auto* rhs_value = rhs_accessor_->FindField( + rhs, absl::visit( + absl::Overload( + [](absl::string_view string) -> absl::string_view { + return string; + }, + [&lhs_key_scratch]( + const absl::Cord& cord) -> absl::string_view { + if (auto flat = cord.TryFlat(); flat) { + return *flat; + } + absl::CopyCordToString(cord, &lhs_key_scratch); + return absl::string_view(lhs_key_scratch); + }), + AsVariant(lhs_key))); + rhs_value == nullptr || !ValueEqual(*lhs_value, *rhs_value)) { + return false; + } + } + return true; + } + + private: + const JsonAccessor* absl_nonnull const lhs_accessor_; + const JsonAccessor* absl_nonnull const rhs_accessor_; + std::string lhs_scratch_; + std::string rhs_scratch_; +}; + +} // namespace + +bool JsonEquals(const google::protobuf::Value& lhs, + const google::protobuf::Value& rhs) { + return JsonEqualsState(GeneratedJsonAccessor::Singleton(), + GeneratedJsonAccessor::Singleton()) + .ValueEqual(lhs, rhs); +} + +bool JsonEquals(const google::protobuf::Value& lhs, + const google::protobuf::Message& rhs) { + DynamicJsonAccessor rhs_accessor; + rhs_accessor.InitializeValue(rhs); + return JsonEqualsState(GeneratedJsonAccessor::Singleton(), &rhs_accessor) + .ValueEqual(lhs, rhs); +} + +bool JsonEquals(const google::protobuf::Message& lhs, + const google::protobuf::Value& rhs) { + DynamicJsonAccessor lhs_accessor; + lhs_accessor.InitializeValue(lhs); + return JsonEqualsState(&lhs_accessor, GeneratedJsonAccessor::Singleton()) + .ValueEqual(lhs, rhs); +} + +bool JsonEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs) { + DynamicJsonAccessor lhs_accessor; + lhs_accessor.InitializeValue(lhs); + DynamicJsonAccessor rhs_accessor; + rhs_accessor.InitializeValue(rhs); + return JsonEqualsState(&lhs_accessor, &rhs_accessor).ValueEqual(lhs, rhs); +} + +bool JsonEquals(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs) { + const auto* lhs_generated = + google::protobuf::DynamicCastMessage(&lhs); + const auto* rhs_generated = + google::protobuf::DynamicCastMessage(&rhs); + if (lhs_generated && rhs_generated) { + return JsonEquals(*lhs_generated, *rhs_generated); + } + if (lhs_generated) { + return JsonEquals(*lhs_generated, + google::protobuf::DownCastMessage(rhs)); + } + if (rhs_generated) { + return JsonEquals(google::protobuf::DownCastMessage(lhs), + *rhs_generated); + } + return JsonEquals(google::protobuf::DownCastMessage(lhs), + google::protobuf::DownCastMessage(rhs)); +} + +bool JsonListEquals(const google::protobuf::ListValue& lhs, + const google::protobuf::ListValue& rhs) { + return JsonEqualsState(GeneratedJsonAccessor::Singleton(), + GeneratedJsonAccessor::Singleton()) + .ListValueEqual(lhs, rhs); +} + +bool JsonListEquals(const google::protobuf::ListValue& lhs, + const google::protobuf::Message& rhs) { + DynamicJsonAccessor rhs_accessor; + rhs_accessor.InitializeListValue(rhs); + return JsonEqualsState(GeneratedJsonAccessor::Singleton(), &rhs_accessor) + .ListValueEqual(lhs, rhs); +} + +bool JsonListEquals(const google::protobuf::Message& lhs, + const google::protobuf::ListValue& rhs) { + DynamicJsonAccessor lhs_accessor; + lhs_accessor.InitializeListValue(lhs); + return JsonEqualsState(&lhs_accessor, GeneratedJsonAccessor::Singleton()) + .ListValueEqual(lhs, rhs); +} + +bool JsonListEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs) { + DynamicJsonAccessor lhs_accessor; + lhs_accessor.InitializeListValue(lhs); + DynamicJsonAccessor rhs_accessor; + rhs_accessor.InitializeListValue(rhs); + return JsonEqualsState(&lhs_accessor, &rhs_accessor).ListValueEqual(lhs, rhs); +} + +bool JsonListEquals(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs) { + const auto* lhs_generated = + google::protobuf::DynamicCastMessage(&lhs); + const auto* rhs_generated = + google::protobuf::DynamicCastMessage(&rhs); + if (lhs_generated && rhs_generated) { + return JsonListEquals(*lhs_generated, *rhs_generated); + } + if (lhs_generated) { + return JsonListEquals(*lhs_generated, + google::protobuf::DownCastMessage(rhs)); + } + if (rhs_generated) { + return JsonListEquals(google::protobuf::DownCastMessage(lhs), + *rhs_generated); + } + return JsonListEquals(google::protobuf::DownCastMessage(lhs), + google::protobuf::DownCastMessage(rhs)); +} + +bool JsonMapEquals(const google::protobuf::Struct& lhs, + const google::protobuf::Struct& rhs) { + return JsonEqualsState(GeneratedJsonAccessor::Singleton(), + GeneratedJsonAccessor::Singleton()) + .StructEqual(lhs, rhs); +} + +bool JsonMapEquals(const google::protobuf::Struct& lhs, + const google::protobuf::Message& rhs) { + DynamicJsonAccessor rhs_accessor; + rhs_accessor.InitializeStruct(rhs); + return JsonEqualsState(GeneratedJsonAccessor::Singleton(), &rhs_accessor) + .StructEqual(lhs, rhs); +} + +bool JsonMapEquals(const google::protobuf::Message& lhs, + const google::protobuf::Struct& rhs) { + DynamicJsonAccessor lhs_accessor; + lhs_accessor.InitializeStruct(lhs); + return JsonEqualsState(&lhs_accessor, GeneratedJsonAccessor::Singleton()) + .StructEqual(lhs, rhs); +} + +bool JsonMapEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs) { + DynamicJsonAccessor lhs_accessor; + lhs_accessor.InitializeStruct(lhs); + DynamicJsonAccessor rhs_accessor; + rhs_accessor.InitializeStruct(rhs); + return JsonEqualsState(&lhs_accessor, &rhs_accessor).StructEqual(lhs, rhs); +} + +bool JsonMapEquals(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs) { + const auto* lhs_generated = + google::protobuf::DynamicCastMessage(&lhs); + const auto* rhs_generated = + google::protobuf::DynamicCastMessage(&rhs); + if (lhs_generated && rhs_generated) { + return JsonMapEquals(*lhs_generated, *rhs_generated); + } + if (lhs_generated) { + return JsonMapEquals(*lhs_generated, + google::protobuf::DownCastMessage(rhs)); + } + if (rhs_generated) { + return JsonMapEquals(google::protobuf::DownCastMessage(lhs), + *rhs_generated); + } + return JsonMapEquals(google::protobuf::DownCastMessage(lhs), + google::protobuf::DownCastMessage(rhs)); +} + +} // namespace cel::internal diff --git a/internal/json.h b/internal/json.h new file mode 100644 index 000000000..e35909d0e --- /dev/null +++ b/internal/json.h @@ -0,0 +1,141 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_JSON_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_JSON_H_ + +#include + +#include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::internal { + +// Converts the given message to its `google.protobuf.Value` equivalent +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, +// except that this results in structured serialization. +absl::Status MessageToJson( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Value* absl_nonnull result); +absl::Status MessageToJson( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Struct* absl_nonnull result); +absl::Status MessageToJson( + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull result); + +// Converts the given message field to its `google.protobuf.Value` equivalent +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, +// except that this results in structured serialization. +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Value* absl_nonnull result); +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::ListValue* absl_nonnull result); +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Struct* absl_nonnull result); +absl::Status MessageFieldToJson( + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Message* absl_nonnull result); + +// Checks that the instance of `google.protobuf.Value` has a descriptor which is +// well formed. +inline absl::Status CheckJson(const google::protobuf::Value&) { + return absl::OkStatus(); +} +absl::Status CheckJson(const google::protobuf::MessageLite& message); + +// Checks that the instance of `google.protobuf.ListValue` has a descriptor +// which is well formed. +inline absl::Status CheckJsonList(const google::protobuf::ListValue&) { + return absl::OkStatus(); +} +absl::Status CheckJsonList(const google::protobuf::MessageLite& message); + +// Checks that the instance of `google.protobuf.Struct` has a descriptor which +// is well formed. +inline absl::Status CheckJsonMap(const google::protobuf::Struct&) { + return absl::OkStatus(); +} +absl::Status CheckJsonMap(const google::protobuf::MessageLite& message); + +// Produces a debug string for the given instance of `google.protobuf.Value`. +std::string JsonDebugString(const google::protobuf::Value& message); +std::string JsonDebugString(const google::protobuf::Message& message); + +// Produces a debug string for the given instance of +// `google.protobuf.ListValue`. +std::string JsonListDebugString(const google::protobuf::ListValue& message); +std::string JsonListDebugString(const google::protobuf::Message& message); + +// Produces a debug string for the given instance of `google.protobuf.Struct`. +std::string JsonMapDebugString(const google::protobuf::Struct& message); +std::string JsonMapDebugString(const google::protobuf::Message& message); + +// Compares the given instances of `google.protobuf.Value` for equality. +bool JsonEquals(const google::protobuf::Value& lhs, + const google::protobuf::Value& rhs); +bool JsonEquals(const google::protobuf::Value& lhs, const google::protobuf::Message& rhs); +bool JsonEquals(const google::protobuf::Message& lhs, const google::protobuf::Value& rhs); +bool JsonEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs); +bool JsonEquals(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs); + +// Compares the given instances of `google.protobuf.ListValue` for equality. +bool JsonListEquals(const google::protobuf::ListValue& lhs, + const google::protobuf::ListValue& rhs); +bool JsonListEquals(const google::protobuf::ListValue& lhs, + const google::protobuf::Message& rhs); +bool JsonListEquals(const google::protobuf::Message& lhs, + const google::protobuf::ListValue& rhs); +bool JsonListEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs); +bool JsonListEquals(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs); + +// Compares the given instances of `google.protobuf.Struct` for equality. +bool JsonMapEquals(const google::protobuf::Struct& lhs, + const google::protobuf::Struct& rhs); +bool JsonMapEquals(const google::protobuf::Struct& lhs, + const google::protobuf::Message& rhs); +bool JsonMapEquals(const google::protobuf::Message& lhs, + const google::protobuf::Struct& rhs); +bool JsonMapEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs); +bool JsonMapEquals(const google::protobuf::MessageLite& lhs, + const google::protobuf::MessageLite& rhs); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_JSON_H_ diff --git a/internal/json_test.cc b/internal/json_test.cc new file mode 100644 index 000000000..5f88b117a --- /dev/null +++ b/internal/json_test.cc @@ -0,0 +1,2990 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/json.h" + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "internal/equals_text_proto.h" +#include "internal/message_type_name.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::internal { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::AnyOf; +using ::testing::HasSubstr; +using ::testing::Test; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +class CheckJsonTest : public Test { + public: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return GetTestingMessageFactory(); + } + + template + T* MakeGenerated() { + return google::protobuf::Arena::Create(arena()); + } + + template + google::protobuf::Message* MakeDynamic() { + const auto* descriptor = ABSL_DIE_IF_NULL( + descriptor_pool()->FindMessageTypeByName(MessageTypeNameFor())); + const auto* prototype = + ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); + return ABSL_DIE_IF_NULL(prototype->New(arena())); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(CheckJsonTest, Value_Generated) { + EXPECT_THAT(CheckJson(*MakeGenerated()), IsOk()); +} + +TEST_F(CheckJsonTest, Value_Dynamic) { + EXPECT_THAT(CheckJson(*MakeDynamic()), IsOk()); +} + +TEST_F(CheckJsonTest, ListValue_Generated) { + EXPECT_THAT(CheckJsonList(*MakeGenerated()), + IsOk()); +} + +TEST_F(CheckJsonTest, ListValue_Dynamic) { + EXPECT_THAT(CheckJsonList(*MakeDynamic()), + IsOk()); +} + +TEST_F(CheckJsonTest, Struct_Generated) { + EXPECT_THAT(CheckJsonMap(*MakeGenerated()), IsOk()); +} + +TEST_F(CheckJsonTest, Struct_Dynamic) { + EXPECT_THAT(CheckJsonMap(*MakeDynamic()), IsOk()); +} + +class MessageToJsonTest : public Test { + public: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return GetTestingMessageFactory(); + } + + template + T* MakeGenerated() { + return google::protobuf::Arena::Create(arena()); + } + + template + google::protobuf::Message* MakeDynamic() { + const auto* descriptor = ABSL_DIE_IF_NULL( + descriptor_pool()->FindMessageTypeByName(MessageTypeNameFor())); + const auto* prototype = + ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); + return ABSL_DIE_IF_NULL(prototype->New(arena())); + } + + template + auto DynamicParseTextProto(absl::string_view text) { + return ::cel::internal::DynamicParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + template + auto EqualsTextProto(absl::string_view text) { + return ::cel::internal::EqualsTextProto(arena(), text, descriptor_pool(), + message_factory()); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(MessageToJsonTest, BoolValue_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(value: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(bool_value: true)pb")); +} + +TEST_F(MessageToJsonTest, BoolValue_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(value: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(bool_value: true)pb")); +} + +TEST_F(MessageToJsonTest, Int32Value_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, Int32Value_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, Int64Value_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, Int64Value_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, UInt32Value_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, UInt32Value_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, UInt64Value_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, UInt64Value_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, FloatValue_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, FloatValue_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, DoubleValue_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, DoubleValue_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(number_value: 1.0)pb")); +} + +TEST_F(MessageToJsonTest, BytesValue_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "Zm9v")pb")); +} + +TEST_F(MessageToJsonTest, BytesValue_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "Zm9v")pb")); +} + +TEST_F(MessageToJsonTest, StringValue_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "foo")pb")); +} + +TEST_F(MessageToJsonTest, StringValue_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(value: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "foo")pb")); +} + +TEST_F(MessageToJsonTest, Duration_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "1.000000001s")pb")); +} + +TEST_F(MessageToJsonTest, Duration_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "1.000000001s")pb")); +} + +TEST_F(MessageToJsonTest, Timestamp_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, + EqualsTextProto( + R"pb(string_value: "1970-01-01T00:00:01.000000001Z")pb")); +} + +TEST_F(MessageToJsonTest, Timestamp_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, + EqualsTextProto( + R"pb(string_value: "1970-01-01T00:00:01.000000001Z")pb")); +} + +TEST_F(MessageToJsonTest, Value_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(bool_value: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(bool_value: true)pb")); +} + +TEST_F(MessageToJsonTest, Value_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(bool_value: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(bool_value: true)pb")); +} + +TEST_F(MessageToJsonTest, ListValue_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(values { bool_value: true })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, + EqualsTextProto( + R"pb(list_value: { values { bool_value: true } })pb")); +} + +TEST_F(MessageToJsonTest, ListValue_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(values { bool_value: true })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, + EqualsTextProto( + R"pb(list_value: { values { bool_value: true } })pb")); +} + +TEST_F(MessageToJsonTest, Struct_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: { bool_value: true } + })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, Struct_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(fields { + key: "foo" + value: { bool_value: true } + })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, FieldMask_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(paths: "foo" paths: "bar")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "foo,bar")pb")); +} + +TEST_F(MessageToJsonTest, FieldMask_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(paths: "foo" paths: "bar")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(string_value: "foo,bar")pb")); +} + +TEST_F(MessageToJsonTest, FieldMask_BadUpperCase) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(paths: "Foo")pb"), + descriptor_pool(), message_factory(), result), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field mask path name contains uppercase letters"))); +} + +TEST_F(MessageToJsonTest, FieldMask_BadUnderscoreUpperCase) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(paths: "foo_?")pb"), + descriptor_pool(), message_factory(), result), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field mask path contains '_' not followed by " + "a lowercase letter"))); +} + +TEST_F(MessageToJsonTest, FieldMask_BadTrailingUnderscore) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(paths: "foo_")pb"), + descriptor_pool(), message_factory(), result), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field mask path contains trailing '_'"))); +} + +TEST_F(MessageToJsonTest, Any_WellKnownType_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.BoolValue" + value: "\x08\x01")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "@type" + value: { + string_value: "type.googleapis.com/google.protobuf.BoolValue" + } + } + fields { + key: "value" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, Any_WellKnownType_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.BoolValue" + value: "\x08\x01")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "@type" + value: { + string_value: "type.googleapis.com/google.protobuf.BoolValue" + } + } + fields { + key: "value" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, Any_Empty_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Empty")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "@type" + value: { + string_value: "type.googleapis.com/google.protobuf.Empty" + } + } + fields { + key: "value" + value: { struct_value: {} } + } + })pb")); +} + +TEST_F(MessageToJsonTest, Any_Empty_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Empty")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "@type" + value: { + string_value: "type.googleapis.com/google.protobuf.Empty" + } + } + fields { + key: "value" + value: { struct_value: {} } + } + })pb")); +} + +TEST_F(MessageToJsonTest, Any_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" + value: "\x68\x01")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "@type" + value: { + string_value: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" + } + } + fields { + key: "singleBool" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, Any_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" + value: "\x68\x01")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "@type" + value: { + string_value: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" + } + } + fields { + key: "singleBool" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Bool_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_bool: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleBool" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Bool_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_bool: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleBool" + value: { bool_value: true } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Int32_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_int32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleInt32" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Int32_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_int32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleInt32" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Int64_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_int64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleInt64" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Int64_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_int64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleInt64" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt32_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_uint32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleUint32" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt32_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_uint32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleUint32" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt64_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_uint64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleUint64" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt64_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_uint64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleUint64" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Float_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_float: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleFloat" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Float_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_float: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleFloat" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Double_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_double: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleDouble" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Double_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_double: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleDouble" + value: { number_value: 1.0 } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Bytes_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_bytes: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleBytes" + value: { string_value: "Zm9v" } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Bytes_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_bytes: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleBytes" + value: { string_value: "Zm9v" } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_String_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_string: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleString" + value: { string_value: "foo" } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_String_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(single_string: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "singleString" + value: { string_value: "foo" } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Message_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(standalone_message: { bb: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "standaloneMessage" + value: { + struct_value: { + fields { + key: "bb" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Message_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(standalone_message: { bb: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "standaloneMessage" + value: { + struct_value: { + fields { + key: "bb" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Enum_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(standalone_enum: BAR)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "standaloneEnum" + value: { string_value: "BAR" } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_Enum_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(standalone_enum: BAR)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "standaloneEnum" + value: { string_value: "BAR" } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBool_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_bool: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedBool" + value: { list_value: { values: { bool_value: true } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBool_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_bool: true)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedBool" + value: { list_value: { values: { bool_value: true } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt32_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_int32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedInt32" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt32_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_int32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedInt32" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt64_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_int64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedInt64" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt64_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_int64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedInt64" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt32_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_uint32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedUint32" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt32_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_uint32: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedUint32" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt64_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_uint64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedUint64" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt64_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_uint64: 1)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedUint64" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedFloat_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_float: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedFloat" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedFloat_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_float: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedFloat" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedDouble_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_double: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedDouble" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedDouble_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_double: 1.0)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedDouble" + value: { list_value: { values: { number_value: 1.0 } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBytes_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_bytes: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedBytes" + value: { list_value: { values: { string_value: "Zm9v" } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBytes_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_bytes: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedBytes" + value: { list_value: { values: { string_value: "Zm9v" } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedString_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_string: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedString" + value: { list_value: { values: { string_value: "foo" } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedString_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_string: "foo")pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedString" + value: { list_value: { values: { string_value: "foo" } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedMessage_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_nested_message: { bb: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedNestedMessage" + value: { + list_value: { + values: { + struct_value: { + fields { + key: "bb" + value: { number_value: 1.0 } + } + } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedMessage_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_nested_message: { bb: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedNestedMessage" + value: { + list_value: { + values: { + struct_value: { + fields { + key: "bb" + value: { number_value: 1.0 } + } + } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedEnum_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_nested_enum: BAR)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedNestedEnum" + value: { list_value: { values: { string_value: "BAR" } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedEnum_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_nested_enum: BAR)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedNestedEnum" + value: { list_value: { values: { string_value: "BAR" } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedNull_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_null_value: NULL_VALUE)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedNullValue" + value: { list_value: { values: { null_value: NULL_VALUE } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedNull_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(repeated_null_value: NULL_VALUE)pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT( + *result, + EqualsTextProto( + R"pb(struct_value: { + fields { + key: "repeatedNullValue" + value: { list_value: { values: { null_value: NULL_VALUE } } } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapBoolBool_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_bool_bool: { key: true value: true })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapBoolBool" + value: { + struct_value: { + fields { + key: "true" + value: { bool_value: true } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapBoolBool_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_bool_bool: { key: true value: true })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapBoolBool" + value: { + struct_value: { + fields { + key: "true" + value: { bool_value: true } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt32Int32_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(map_int32_int32: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapInt32Int32" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt32Int32_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(map_int32_int32: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapInt32Int32" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt64Int64_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(map_int64_int64: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapInt64Int64" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt64Int64_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(map_int64_int64: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapInt64Int64" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt32UInt32_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_uint32_uint32: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapUint32Uint32" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt32UInt32_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_uint32_uint32: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapUint32Uint32" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt64UInt64_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_uint64_uint64: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapUint64Uint64" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt64UInt64_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_uint64_uint64: { key: 1 value: 1 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapUint64Uint64" + value: { + struct_value: { + fields { + key: "1" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringString_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson( + *DynamicParseTextProto( + R"pb(map_string_string: { key: "foo" value: "bar" })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringString" + value: { + struct_value: { + fields { + key: "foo" + value: { string_value: "bar" } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringString_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson( + *DynamicParseTextProto( + R"pb(map_string_string: { key: "foo" value: "bar" })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringString" + value: { + struct_value: { + fields { + key: "foo" + value: { string_value: "bar" } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringFloat_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_float: { key: "foo" value: 1.0 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringFloat" + value: { + struct_value: { + fields { + key: "foo" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringFloat_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_float: { key: "foo" value: 1.0 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringFloat" + value: { + struct_value: { + fields { + key: "foo" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringDouble_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_double: { key: "foo" value: 1.0 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringDouble" + value: { + struct_value: { + fields { + key: "foo" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringDouble_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_double: { key: "foo" value: 1.0 })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringDouble" + value: { + struct_value: { + fields { + key: "foo" + value: { number_value: 1.0 } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringBytes_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_bytes: { key: "foo" value: "bar" })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringBytes" + value: { + struct_value: { + fields { + key: "foo" + value: { string_value: "YmFy" } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringBytes_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_bytes: { key: "foo" value: "bar" })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringBytes" + value: { + struct_value: { + fields { + key: "foo" + value: { string_value: "YmFy" } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringMessage_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(map_string_message: { + key: "foo" + value: { bb: 1 } + })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringMessage" + value: { + struct_value: { + fields { + key: "foo" + value: { + struct_value: { + fields { + key: "bb" + value: { number_value: 1.0 } + } + } + } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringMessage_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT(MessageToJson(*DynamicParseTextProto( + R"pb(map_string_message: { + key: "foo" + value: { bb: 1 } + })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringMessage" + value: { + struct_value: { + fields { + key: "foo" + value: { + struct_value: { + fields { + key: "bb" + value: { number_value: 1.0 } + } + } + } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringEnum_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_enum: { key: "foo" value: BAR })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringEnum" + value: { + struct_value: { + fields { + key: "foo" + value: { string_value: "BAR" } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringEnum_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson(*DynamicParseTextProto( + R"pb(map_string_enum: { key: "foo" value: BAR })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringEnum" + value: { + struct_value: { + fields { + key: "foo" + value: { string_value: "BAR" } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringNull_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(map_string_null_value: { key: "foo" value: NULL_VALUE })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringNullValue" + value: { + struct_value: { + fields { + key: "foo" + value: { null_value: NULL_VALUE } + } + } + } + } + })pb")); +} + +TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringNull_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageToJson( + *DynamicParseTextProto( + R"pb(map_string_null_value: { key: "foo" value: NULL_VALUE })pb"), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(struct_value: { + fields { + key: "mapStringNullValue" + value: { + struct_value: { + fields { + key: "foo" + value: { null_value: NULL_VALUE } + } + } + } + } + })pb")); +} + +class MessageFieldToJsonTest : public Test { + public: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return GetTestingMessageFactory(); + } + + template + T* MakeGenerated() { + return google::protobuf::Arena::Create(arena()); + } + + template + google::protobuf::Message* MakeDynamic() { + const auto* descriptor = ABSL_DIE_IF_NULL( + descriptor_pool()->FindMessageTypeByName(MessageTypeNameFor())); + const auto* prototype = + ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); + return ABSL_DIE_IF_NULL(prototype->New(arena())); + } + + template + auto DynamicParseTextProto(absl::string_view text) { + return ::cel::internal::DynamicParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + template + auto EqualsTextProto(absl::string_view text) { + return ::cel::internal::EqualsTextProto(arena(), text, descriptor_pool(), + message_factory()); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(MessageFieldToJsonTest, TestAllTypesProto3_Generated) { + auto* result = MakeGenerated(); + EXPECT_THAT( + MessageFieldToJson( + *DynamicParseTextProto( + R"pb(single_bool: true)pb"), + ABSL_DIE_IF_NULL( + ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("single_bool")), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(bool_value: true)pb")); +} + +TEST_F(MessageFieldToJsonTest, TestAllTypesProto3_Dynamic) { + auto* result = MakeDynamic(); + EXPECT_THAT( + MessageFieldToJson( + *DynamicParseTextProto( + R"pb(single_bool: true)pb"), + ABSL_DIE_IF_NULL( + ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("single_bool")), + descriptor_pool(), message_factory(), result), + IsOk()); + EXPECT_THAT(*result, EqualsTextProto( + R"pb(bool_value: true)pb")); +} + +class JsonDebugStringTest : public Test { + public: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return GetTestingMessageFactory(); + } + + template + auto GeneratedParseTextProto(absl::string_view text) { + return ::cel::internal::GeneratedParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + template + auto DynamicParseTextProto(absl::string_view text) { + return ::cel::internal::DynamicParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(JsonDebugStringTest, Null_Generated) { + EXPECT_EQ(JsonDebugString( + *GeneratedParseTextProto(R"pb()pb")), + "null"); +} + +TEST_F(JsonDebugStringTest, Null_Dynamic) { + EXPECT_EQ(JsonDebugString( + *DynamicParseTextProto(R"pb()pb")), + "null"); +} + +TEST_F(JsonDebugStringTest, Bool_Generated) { + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(bool_value: false)pb")), + "false"); + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(bool_value: true)pb")), + "true"); +} + +TEST_F(JsonDebugStringTest, Bool_Dynamic) { + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(bool_value: false)pb")), + "false"); + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(bool_value: true)pb")), + "true"); +} + +TEST_F(JsonDebugStringTest, Number_Generated) { + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(number_value: 1.0)pb")), + "1.0"); + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(number_value: 1.1)pb")), + "1.1"); + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(number_value: infinity)pb")), + "+infinity"); + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(number_value: -infinity)pb")), + "-infinity"); + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(number_value: nan)pb")), + "nan"); +} + +TEST_F(JsonDebugStringTest, Number_Dynamic) { + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(number_value: 1.0)pb")), + "1.0"); + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(number_value: 1.1)pb")), + "1.1"); + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(number_value: infinity)pb")), + "+infinity"); + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(number_value: -infinity)pb")), + "-infinity"); + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(number_value: nan)pb")), + "nan"); +} + +TEST_F(JsonDebugStringTest, String_Generated) { + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(string_value: "foo")pb")), + "\"foo\""); +} + +TEST_F(JsonDebugStringTest, String_Dynamic) { + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(string_value: "foo")pb")), + "\"foo\""); +} + +TEST_F(JsonDebugStringTest, List_Generated) { + EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")), + "[null, true]"); + EXPECT_EQ( + JsonListDebugString(*GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true })pb")), + "[null, true]"); +} + +TEST_F(JsonDebugStringTest, List_Dynamic) { + EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")), + "[null, true]"); + EXPECT_EQ( + JsonListDebugString(*DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true })pb")), + "[null, true]"); +} + +TEST_F(JsonDebugStringTest, Struct_Generated) { + EXPECT_THAT(JsonDebugString(*GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")), + AnyOf("{\"foo\": null, \"bar\": true}", + "{\"bar\": true, \"foo\": null}")); + EXPECT_THAT( + JsonMapDebugString(*GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb")), + AnyOf("{\"foo\": null, \"bar\": true}", + "{\"bar\": true, \"foo\": null}")); +} + +TEST_F(JsonDebugStringTest, Struct_Dynamic) { + EXPECT_THAT(JsonDebugString(*DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")), + AnyOf("{\"foo\": null, \"bar\": true}", + "{\"bar\": true, \"foo\": null}")); + EXPECT_THAT( + JsonMapDebugString(*DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + })pb")), + AnyOf("{\"foo\": null, \"bar\": true}", + "{\"bar\": true, \"foo\": null}")); +} + +class JsonEqualsTest : public Test { + public: + google::protobuf::Arena* absl_nonnull arena() { return &arena_; } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return GetTestingMessageFactory(); + } + + template + auto GeneratedParseTextProto(absl::string_view text) { + return ::cel::internal::GeneratedParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + template + auto DynamicParseTextProto(absl::string_view text) { + return ::cel::internal::DynamicParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(JsonEqualsTest, Null_Null_Generated_Generated) { + EXPECT_TRUE( + JsonEquals(*GeneratedParseTextProto(R"pb()pb"), + *GeneratedParseTextProto(R"pb()pb"))); +} + +TEST_F(JsonEqualsTest, Null_Null_Generated_Dynamic) { + EXPECT_TRUE( + JsonEquals(*GeneratedParseTextProto(R"pb()pb"), + *DynamicParseTextProto(R"pb()pb"))); +} + +TEST_F(JsonEqualsTest, Null_Null_Dynamic_Generated) { + EXPECT_TRUE( + JsonEquals(*DynamicParseTextProto(R"pb()pb"), + *GeneratedParseTextProto(R"pb()pb"))); +} + +TEST_F(JsonEqualsTest, Null_Null_Dynamic_Dynamic) { + EXPECT_TRUE( + JsonEquals(*DynamicParseTextProto(R"pb()pb"), + *DynamicParseTextProto(R"pb()pb"))); +} + +TEST_F(JsonEqualsTest, Bool_Bool_Generated_Generated) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(bool_value: true)pb"), + *GeneratedParseTextProto( + R"pb(bool_value: true)pb"))); +} + +TEST_F(JsonEqualsTest, Bool_Bool_Generated_Dynamic) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(bool_value: true)pb"), + *DynamicParseTextProto( + R"pb(bool_value: true)pb"))); +} + +TEST_F(JsonEqualsTest, Bool_Bool_Dynamic_Generated) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(bool_value: true)pb"), + *GeneratedParseTextProto( + R"pb(bool_value: true)pb"))); +} + +TEST_F(JsonEqualsTest, Bool_Bool_Dynamic_Dynamic) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(bool_value: true)pb"), + *DynamicParseTextProto( + R"pb(bool_value: true)pb"))); +} + +TEST_F(JsonEqualsTest, Number_Number_Generated_Generated) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(number_value: 1.0)pb"), + *GeneratedParseTextProto( + R"pb(number_value: 1.0)pb"))); +} + +TEST_F(JsonEqualsTest, Number_Number_Generated_Dynamic) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(number_value: 1.0)pb"), + *DynamicParseTextProto( + R"pb(number_value: 1.0)pb"))); +} + +TEST_F(JsonEqualsTest, Number_Number_Dynamic_Generated) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(number_value: 1.0)pb"), + *GeneratedParseTextProto( + R"pb(number_value: 1.0)pb"))); +} + +TEST_F(JsonEqualsTest, Number_Number_Dynamic_Dynamic) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(number_value: 1.0)pb"), + *DynamicParseTextProto( + R"pb(number_value: 1.0)pb"))); +} + +TEST_F(JsonEqualsTest, String_String_Generated_Generated) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(string_value: "foo")pb"), + *GeneratedParseTextProto( + R"pb(string_value: "foo")pb"))); +} + +TEST_F(JsonEqualsTest, String_String_Generated_Dynamic) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(string_value: "foo")pb"), + *DynamicParseTextProto( + R"pb(string_value: "foo")pb"))); +} + +TEST_F(JsonEqualsTest, String_String_Dynamic_Generated) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(string_value: "foo")pb"), + *GeneratedParseTextProto( + R"pb(string_value: "foo")pb"))); +} + +TEST_F(JsonEqualsTest, String_String_Dynamic_Dynamic) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(string_value: "foo")pb"), + *DynamicParseTextProto( + R"pb(string_value: "foo")pb"))); +} + +TEST_F(JsonEqualsTest, List_List_Generated_Generated) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"), + *GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")), + static_cast( + *GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")))); + EXPECT_TRUE( + JsonListEquals(*GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"), + *GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"))); + EXPECT_TRUE( + JsonListEquals(static_cast( + *GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")), + static_cast( + *GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")))); +} + +TEST_F(JsonEqualsTest, List_List_Generated_Dynamic) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"), + *DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")), + static_cast( + *DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")))); + EXPECT_TRUE( + JsonListEquals(*GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"), + *DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"))); + EXPECT_TRUE( + JsonListEquals(static_cast( + *GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")), + static_cast( + *DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")))); +} + +TEST_F(JsonEqualsTest, List_List_Dynamic_Generated) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"), + *GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")), + static_cast( + *GeneratedParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")))); + EXPECT_TRUE( + JsonListEquals(*DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"), + *GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"))); + EXPECT_TRUE( + JsonListEquals(static_cast( + *DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")), + static_cast( + *GeneratedParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")))); +} + +TEST_F(JsonEqualsTest, List_List_Dynamic_Dynamic) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"), + *DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")), + static_cast( + *DynamicParseTextProto( + R"pb(list_value: { + values {} + values { bool_value: true } + })pb")))); + EXPECT_TRUE( + JsonListEquals(*DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"), + *DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb"))); + EXPECT_TRUE( + JsonListEquals(static_cast( + *DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")), + static_cast( + *DynamicParseTextProto( + R"pb( + values {} + values { bool_value: true } + )pb")))); +} + +TEST_F(JsonEqualsTest, Map_Map_Generated_Generated) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"), + *GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")), + static_cast( + *GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")))); + EXPECT_TRUE(JsonMapEquals(*GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"), + *GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"))); + EXPECT_TRUE( + JsonMapEquals(static_cast( + *GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")), + static_cast( + *GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")))); +} + +TEST_F(JsonEqualsTest, Map_Map_Generated_Dynamic) { + EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"), + *DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")), + static_cast( + *DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")))); + EXPECT_TRUE(JsonMapEquals(*GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"), + *DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"))); + EXPECT_TRUE( + JsonMapEquals(static_cast( + *GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")), + static_cast( + *DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")))); +} + +TEST_F(JsonEqualsTest, Map_Map_Dynamic_Generated) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"), + *GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")), + static_cast( + *GeneratedParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")))); + EXPECT_TRUE(JsonMapEquals(*DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"), + *GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"))); + EXPECT_TRUE( + JsonMapEquals(static_cast( + *DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")), + static_cast( + *GeneratedParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")))); +} + +TEST_F(JsonEqualsTest, Map_Map_Dynamic_Dynamic) { + EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"), + *DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb"))); + EXPECT_TRUE(JsonEquals(static_cast( + *DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")), + static_cast( + *DynamicParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + })pb")))); + EXPECT_TRUE(JsonMapEquals(*DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"), + *DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb"))); + EXPECT_TRUE( + JsonMapEquals(static_cast( + *DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")), + static_cast( + *DynamicParseTextProto( + R"pb( + fields { + key: "foo" + value: {} + } + fields { + key: "bar" + value: { bool_value: true } + } + )pb")))); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/lexis.cc b/internal/lexis.cc new file mode 100644 index 000000000..e81fb8e39 --- /dev/null +++ b/internal/lexis.cc @@ -0,0 +1,79 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/lexis.h" + +#include "absl/base/call_once.h" +#include "absl/base/macros.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/ascii.h" + +namespace cel::internal { + +namespace { + +ABSL_CONST_INIT absl::once_flag reserved_keywords_once_flag = {}; +ABSL_CONST_INIT absl::flat_hash_set* reserved_keywords = + nullptr; + +void InitializeReservedKeywords() { + ABSL_ASSERT(reserved_keywords == nullptr); + reserved_keywords = new absl::flat_hash_set(); + reserved_keywords->insert("false"); + reserved_keywords->insert("true"); + reserved_keywords->insert("null"); + reserved_keywords->insert("in"); + reserved_keywords->insert("as"); + reserved_keywords->insert("break"); + reserved_keywords->insert("const"); + reserved_keywords->insert("continue"); + reserved_keywords->insert("else"); + reserved_keywords->insert("for"); + reserved_keywords->insert("function"); + reserved_keywords->insert("if"); + reserved_keywords->insert("import"); + reserved_keywords->insert("let"); + reserved_keywords->insert("loop"); + reserved_keywords->insert("package"); + reserved_keywords->insert("namespace"); + reserved_keywords->insert("return"); + reserved_keywords->insert("var"); + reserved_keywords->insert("void"); + reserved_keywords->insert("while"); +} + +} // namespace + +bool LexisIsReserved(absl::string_view text) { + absl::call_once(reserved_keywords_once_flag, InitializeReservedKeywords); + return reserved_keywords->find(text) != reserved_keywords->end(); +} + +bool LexisIsIdentifier(absl::string_view text) { + if (text.empty()) { + return false; + } + char first = text.front(); + if (!absl::ascii_isalpha(first) && first != '_') { + return false; + } + for (size_t index = 1; index < text.size(); index++) { + if (!absl::ascii_isalnum(text[index]) && text[index] != '_') { + return false; + } + } + return !LexisIsReserved(text); +} + +} // namespace cel::internal diff --git a/internal/lexis.h b/internal/lexis.h new file mode 100644 index 000000000..e3697a639 --- /dev/null +++ b/internal/lexis.h @@ -0,0 +1,32 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_LEXIS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_LEXIS_H_ + +#include "absl/strings/string_view.h" + +namespace cel::internal { + +// Returns true if the given text matches RESERVED per the lexis of the CEL +// specification. +bool LexisIsReserved(absl::string_view text); + +// Returns true if the given text matches IDENT per the lexis of the CEL +// specification, fales otherwise. +bool LexisIsIdentifier(absl::string_view text); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_LEXIS_H_ diff --git a/internal/lexis_test.cc b/internal/lexis_test.cc new file mode 100644 index 000000000..fdd3ae19d --- /dev/null +++ b/internal/lexis_test.cc @@ -0,0 +1,65 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/lexis.h" + +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +struct LexisTestCase final { + absl::string_view text; + bool ok; +}; + +using LexisIsReservedTest = testing::TestWithParam; + +TEST_P(LexisIsReservedTest, Compliance) { + const LexisTestCase& test_case = GetParam(); + if (test_case.ok) { + EXPECT_TRUE(LexisIsReserved(test_case.text)); + } else { + EXPECT_FALSE(LexisIsReserved(test_case.text)); + } +} + +INSTANTIATE_TEST_SUITE_P(LexisIsReservedTest, LexisIsReservedTest, + testing::ValuesIn({{"true", true}, + {"cel", false}})); + +using LexisIsIdentifierTest = testing::TestWithParam; + +TEST_P(LexisIsIdentifierTest, Compliance) { + const LexisTestCase& test_case = GetParam(); + if (test_case.ok) { + EXPECT_TRUE(LexisIsIdentifier(test_case.text)); + } else { + EXPECT_FALSE(LexisIsIdentifier(test_case.text)); + } +} + +INSTANTIATE_TEST_SUITE_P( + LexisIsIdentifierTest, LexisIsIdentifierTest, + testing::ValuesIn( + {{"true", false}, {"0abc", false}, {"-abc", false}, + {".abc", false}, {"~abc", false}, {"!abc", false}, + {"abc-", false}, {"abc.", false}, {"abc~", false}, + {"abc!", false}, {"cel", true}, {"cel0", true}, + {"_cel", true}, {"_cel0", true}, {"cel_", true}, + {"cel0_", true}, {"cel_cel", true}, {"cel0_cel", true}, + {"cel_cel0", true}, {"cel0_cel0", true}})); + +} // namespace +} // namespace cel::internal diff --git a/internal/list_impl.h b/internal/list_impl.h deleted file mode 100644 index 34bb913e2..000000000 --- a/internal/list_impl.h +++ /dev/null @@ -1,55 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_LIST_IMPL_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_LIST_IMPL_H_ - -#include "common/macros.h" -#include "common/value.h" -#include "internal/holder.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -// A wrapper for a native c++ list container. -template -class ListWrapper : public common::List { - public: - template - explicit ListWrapper(Args&&... args) : value_(std::forward(args)...) {} - - inline std::size_t size() const override { return value_->size(); } - inline bool owns_value() const override { return HolderPolicy::kOwnsValue; } - - common::Value Get(std::size_t index) const override; - - google::rpc::Status ForEach( - const std::function& call) - const override; - - private: - Holder value_; -}; - -template -common::Value ListWrapper::Get( - std::size_t index) const { - return GetValue((*value_)[index]); -} - -template -google::rpc::Status ListWrapper::ForEach( - const std::function& call) - const { - for (const auto& elem : *value_) { - RETURN_IF_STATUS_ERROR(call(GetValue(elem))); - } - return OkStatus(); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_LIST_IMPL_H_ diff --git a/internal/manual.h b/internal/manual.h new file mode 100644 index 000000000..fb81a9b13 --- /dev/null +++ b/internal/manual.h @@ -0,0 +1,91 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MANUAL_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_MANUAL_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" + +namespace cel::internal { + +template +class Manual final { + public: + static_assert(!std::is_reference_v, "T must not be a reference"); + static_assert(!std::is_array_v, "T must not be an array"); + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + + using element_type = T; + + Manual() = default; + + Manual(const Manual&) = delete; + Manual(Manual&&) = delete; + + ~Manual() = default; + + Manual& operator=(const Manual&) = delete; + Manual& operator=(Manual&&) = delete; + + constexpr T* absl_nonnull get() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::launder(reinterpret_cast(&storage_[0])); + } + + constexpr const T* absl_nonnull get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return std::launder(reinterpret_cast(&storage_[0])); + } + + constexpr T& operator*() ABSL_ATTRIBUTE_LIFETIME_BOUND { return *get(); } + + constexpr const T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *get(); + } + + constexpr T* absl_nonnull operator->() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get(); + } + + constexpr const T* absl_nonnull operator->() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get(); + } + + template + T* absl_nonnull Construct(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return ::new (static_cast(&storage_[0])) + T(std::forward(args)...); + } + + T* absl_nonnull DefaultConstruct() { + return ::new (static_cast(&storage_[0])) T; + } + + T* absl_nonnull ValueConstruct() { + return ::new (static_cast(&storage_[0])) T(); + } + + void Destruct() { get()->~T(); } + + private: + alignas(T) char storage_[sizeof(T)]; +}; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_MANUAL_H_ diff --git a/internal/map_impl.cc b/internal/map_impl.cc deleted file mode 100644 index 8f4aef3a6..000000000 --- a/internal/map_impl.cc +++ /dev/null @@ -1,30 +0,0 @@ -#include "internal/map_impl.h" -#include "common/macros.h" -#include "internal/status_util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -google::rpc::Status MapImpl::ForEach( - const std::function& call) const { - for (const auto& entry : value_) { - RETURN_IF_STATUS_ERROR(call(entry.first, entry.second)); - } - return OkStatus(); -} - -common::Value MapImpl::GetImpl(const common::Value& key) const { - auto itr = value_.find(key); - if (itr == value_.end()) { - return common::Value::FromError(NoSuchKey(key.ToString())); - } - return itr->second; -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/map_impl.h b/internal/map_impl.h deleted file mode 100644 index bcb461f2a..000000000 --- a/internal/map_impl.h +++ /dev/null @@ -1,86 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MAP_IMPL_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_MAP_IMPL_H_ - -#include "absl/container/node_hash_map.h" -#include "common/macros.h" -#include "common/value.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -/** A simple Value -> Value map implementation. */ -class MapImpl final : public common::Map { - public: - explicit MapImpl(absl::node_hash_map&& value) - : value_(std::move(value)) {} - - inline std::size_t size() const override { return value_.size(); } - - google::rpc::Status ForEach( - const std::function& call) const override; - - inline bool owns_value() const override { return true; } - - protected: - common::Value GetImpl(const common::Value& key) const override; - - private: - absl::node_hash_map value_; -}; - -template -class MapWrapper : public common::Map { - public: - template - explicit MapWrapper(Args&&... args) : value_(std::forward(args)...) {} - - inline std::size_t size() const override { return value_->size(); } - inline bool owns_value() const override { return HolderPolicy::kOwnsValue; } - - google::rpc::Status ForEach( - const std::function& call) const override; - - protected: - common::Value GetImpl(const common::Value& key) const override; - - private: - Holder value_; -}; - -template -common::Value MapWrapper::GetImpl( - const common::Value& key) const { - auto* key_value = key.get_if(); - if (key_value && representable_as(*key_value)) { - auto itr = value_->find(*key_value); - if (itr != value_->end()) { - return GetValue(itr->second); - } - } - return common::Value::FromError(NoSuchKey(key.ToString())); -} - -template -google::rpc::Status MapWrapper::ForEach( - const std::function& call) const { - for (const auto& entry : *value_) { - RETURN_IF_STATUS_ERROR(call(GetValue(entry.first), - GetValue(entry.second))); - } - return OkStatus(); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_MAP_IMPL_H_ diff --git a/internal/message_equality.cc b/internal/message_equality.cc new file mode 100644 index 000000000..33ef78089 --- /dev/null +++ b/internal/message_equality.cc @@ -0,0 +1,1490 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/message_equality.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "common/memory.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/json.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/message_differencer.h" + +#undef GetMessage + +namespace cel::internal { + +namespace { + +using ::cel::extensions::protobuf_internal::ConstMapBegin; +using ::cel::extensions::protobuf_internal::ConstMapEnd; +using ::cel::extensions::protobuf_internal::LookupMapValue; +using ::cel::extensions::protobuf_internal::MapSize; +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::MessageFactory; +using ::google::protobuf::util::MessageDifferencer; + +class EquatableListValue final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +class EquatableStruct final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +class EquatableAny final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +class EquatableMessage final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +using EquatableValue = + std::variant; + +struct NullValueEqualer { + bool operator()(std::nullptr_t, std::nullptr_t) const { return true; } + + template + std::enable_if_t>, bool> + operator()(std::nullptr_t, const T&) const { + return false; + } +}; + +struct BoolValueEqualer { + bool operator()(bool lhs, bool rhs) const { return lhs == rhs; } + + template + std::enable_if_t>, bool> operator()( + bool, const T&) const { + return false; + } +}; + +struct BytesValueEqualer { + bool operator()(const well_known_types::BytesValue& lhs, + const well_known_types::BytesValue& rhs) const { + return lhs == rhs; + } + + template + std::enable_if_t< + std::negation_v>, bool> + operator()(const well_known_types::BytesValue&, const T&) const { + return false; + } +}; + +struct IntValueEqualer { + bool operator()(int64_t lhs, int64_t rhs) const { return lhs == rhs; } + + bool operator()(int64_t lhs, uint64_t rhs) const { + return Number::FromInt64(lhs) == Number::FromUint64(rhs); + } + + bool operator()(int64_t lhs, double rhs) const { + return Number::FromInt64(lhs) == Number::FromDouble(rhs); + } + + template + std::enable_if_t>, + std::negation>, + std::negation>>, + bool> + operator()(int64_t, const T&) const { + return false; + } +}; + +struct UintValueEqualer { + bool operator()(uint64_t lhs, int64_t rhs) const { + return Number::FromUint64(lhs) == Number::FromInt64(rhs); + } + + bool operator()(uint64_t lhs, uint64_t rhs) const { return lhs == rhs; } + + bool operator()(uint64_t lhs, double rhs) const { + return Number::FromUint64(lhs) == Number::FromDouble(rhs); + } + + template + std::enable_if_t>, + std::negation>, + std::negation>>, + bool> + operator()(uint64_t, const T&) const { + return false; + } +}; + +struct DoubleValueEqualer { + bool operator()(double lhs, int64_t rhs) const { + return Number::FromDouble(lhs) == Number::FromInt64(rhs); + } + + bool operator()(double lhs, uint64_t rhs) const { + return Number::FromDouble(lhs) == Number::FromUint64(rhs); + } + + bool operator()(double lhs, double rhs) const { return lhs == rhs; } + + template + std::enable_if_t>, + std::negation>, + std::negation>>, + bool> + operator()(double, const T&) const { + return false; + } +}; + +struct StringValueEqualer { + bool operator()(const well_known_types::StringValue& lhs, + const well_known_types::StringValue& rhs) const { + return lhs == rhs; + } + + template + std::enable_if_t< + std::negation_v>, bool> + operator()(const well_known_types::StringValue&, const T&) const { + return false; + } +}; + +struct DurationEqualer { + bool operator()(absl::Duration lhs, absl::Duration rhs) const { + return lhs == rhs; + } + + template + std::enable_if_t>, bool> + operator()(absl::Duration, const T&) const { + return false; + } +}; + +struct TimestampEqualer { + bool operator()(absl::Time lhs, absl::Time rhs) const { return lhs == rhs; } + + template + std::enable_if_t>, bool> + operator()(absl::Time, const T&) const { + return false; + } +}; + +struct ListValueEqualer { + bool operator()(EquatableListValue lhs, EquatableListValue rhs) const { + return JsonListEquals(lhs, rhs); + } + + template + std::enable_if_t>, bool> + operator()(EquatableListValue, const T&) const { + return false; + } +}; + +struct StructEqualer { + bool operator()(EquatableStruct lhs, EquatableStruct rhs) const { + return JsonMapEquals(lhs, rhs); + } + + template + std::enable_if_t>, bool> + operator()(EquatableStruct, const T&) const { + return false; + } +}; + +struct AnyEqualer { + bool operator()(EquatableAny lhs, EquatableAny rhs) const { + auto lhs_reflection = + well_known_types::GetAnyReflectionOrDie(lhs.get().GetDescriptor()); + std::string lhs_type_url_scratch; + std::string lhs_value_scratch; + auto rhs_reflection = + well_known_types::GetAnyReflectionOrDie(rhs.get().GetDescriptor()); + std::string rhs_type_url_scratch; + std::string rhs_value_scratch; + return lhs_reflection.GetTypeUrl(lhs.get(), lhs_type_url_scratch) == + rhs_reflection.GetTypeUrl(rhs.get(), rhs_type_url_scratch) && + lhs_reflection.GetValue(lhs.get(), lhs_value_scratch) == + rhs_reflection.GetValue(rhs.get(), rhs_value_scratch); + } + + template + std::enable_if_t>, bool> + operator()(EquatableAny, const T&) const { + return false; + } +}; + +struct MessageEqualer { + bool operator()(EquatableMessage lhs, EquatableMessage rhs) const { + return lhs.get().GetDescriptor() == rhs.get().GetDescriptor() && + MessageDifferencer::Equals(lhs.get(), rhs.get()); + } + + template + std::enable_if_t>, bool> + operator()(EquatableMessage, const T&) const { + return false; + } +}; + +struct EquatableValueReflection final { + well_known_types::DoubleValueReflection double_value_reflection; + well_known_types::FloatValueReflection float_value_reflection; + well_known_types::Int64ValueReflection int64_value_reflection; + well_known_types::UInt64ValueReflection uint64_value_reflection; + well_known_types::Int32ValueReflection int32_value_reflection; + well_known_types::UInt32ValueReflection uint32_value_reflection; + well_known_types::StringValueReflection string_value_reflection; + well_known_types::BytesValueReflection bytes_value_reflection; + well_known_types::BoolValueReflection bool_value_reflection; + well_known_types::AnyReflection any_reflection; + well_known_types::DurationReflection duration_reflection; + well_known_types::TimestampReflection timestamp_reflection; + well_known_types::ValueReflection value_reflection; + well_known_types::ListValueReflection list_value_reflection; + well_known_types::StructReflection struct_reflection; +}; + +absl::StatusOr AsEquatableValue( + EquatableValueReflection& reflection, + const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const Descriptor* absl_nonnull descriptor, + Descriptor::WellKnownType well_known_type, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (well_known_type) { + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + CEL_RETURN_IF_ERROR( + reflection.double_value_reflection.Initialize(descriptor)); + return reflection.double_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + CEL_RETURN_IF_ERROR( + reflection.float_value_reflection.Initialize(descriptor)); + return static_cast( + reflection.float_value_reflection.GetValue(message)); + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + CEL_RETURN_IF_ERROR( + reflection.int64_value_reflection.Initialize(descriptor)); + return reflection.int64_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + CEL_RETURN_IF_ERROR( + reflection.uint64_value_reflection.Initialize(descriptor)); + return reflection.uint64_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + CEL_RETURN_IF_ERROR( + reflection.int32_value_reflection.Initialize(descriptor)); + return static_cast( + reflection.int32_value_reflection.GetValue(message)); + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + CEL_RETURN_IF_ERROR( + reflection.uint32_value_reflection.Initialize(descriptor)); + return static_cast( + reflection.uint32_value_reflection.GetValue(message)); + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + CEL_RETURN_IF_ERROR( + reflection.string_value_reflection.Initialize(descriptor)); + return reflection.string_value_reflection.GetValue(message, scratch); + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + CEL_RETURN_IF_ERROR( + reflection.bytes_value_reflection.Initialize(descriptor)); + return reflection.bytes_value_reflection.GetValue(message, scratch); + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + CEL_RETURN_IF_ERROR( + reflection.bool_value_reflection.Initialize(descriptor)); + return reflection.bool_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_VALUE: { + CEL_RETURN_IF_ERROR(reflection.value_reflection.Initialize(descriptor)); + const auto kind_case = reflection.value_reflection.GetKindCase(message); + switch (kind_case) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + return nullptr; + case google::protobuf::Value::kBoolValue: + return reflection.value_reflection.GetBoolValue(message); + case google::protobuf::Value::kNumberValue: + return reflection.value_reflection.GetNumberValue(message); + case google::protobuf::Value::kStringValue: + return reflection.value_reflection.GetStringValue(message, scratch); + case google::protobuf::Value::kListValue: + return EquatableListValue( + reflection.value_reflection.GetListValue(message)); + case google::protobuf::Value::kStructValue: + return EquatableStruct( + reflection.value_reflection.GetStructValue(message)); + default: + return absl::InternalError( + absl::StrCat("unexpected value kind case: ", kind_case)); + } + } + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + return EquatableListValue(message); + case Descriptor::WELLKNOWNTYPE_STRUCT: + return EquatableStruct(message); + case Descriptor::WELLKNOWNTYPE_DURATION: + CEL_RETURN_IF_ERROR( + reflection.duration_reflection.Initialize(descriptor)); + return reflection.duration_reflection.ToAbslDuration(message); + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + CEL_RETURN_IF_ERROR( + reflection.timestamp_reflection.Initialize(descriptor)); + return reflection.timestamp_reflection.ToAbslTime(message); + case Descriptor::WELLKNOWNTYPE_ANY: + return EquatableAny(message); + default: + return EquatableMessage(message); + } +} + +absl::StatusOr AsEquatableValue( + EquatableValueReflection& reflection, + const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const Descriptor* absl_nonnull descriptor, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return AsEquatableValue(reflection, message, descriptor, + descriptor->well_known_type(), scratch); +} + +absl::StatusOr AsEquatableValue( + EquatableValueReflection& reflection, + const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const FieldDescriptor* absl_nonnull field, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(!field->is_repeated() && !field->is_map()); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return static_cast( + message.GetReflection()->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return message.GetReflection()->GetInt64(message, field); + case FieldDescriptor::CPPTYPE_UINT32: + return static_cast( + message.GetReflection()->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return message.GetReflection()->GetUInt64(message, field); + case FieldDescriptor::CPPTYPE_DOUBLE: + return message.GetReflection()->GetDouble(message, field); + case FieldDescriptor::CPPTYPE_FLOAT: + return static_cast( + message.GetReflection()->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return message.GetReflection()->GetBool(message, field); + case FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return nullptr; + } + return static_cast( + message.GetReflection()->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return well_known_types::GetBytesField(message, field, scratch); + } + return well_known_types::GetStringField(message, field, scratch); + case FieldDescriptor::CPPTYPE_MESSAGE: + return AsEquatableValue( + reflection, message.GetReflection()->GetMessage(message, field), + field->message_type(), scratch); + default: + return absl::InternalError( + absl::StrCat("unexpected field type: ", field->cpp_type_name())); + } +} + +bool IsAny(const Message& message) { + return message.GetDescriptor()->well_known_type() == + Descriptor::WELLKNOWNTYPE_ANY; +} + +bool IsAnyField(const FieldDescriptor* absl_nonnull field) { + return field->type() == FieldDescriptor::TYPE_MESSAGE && + field->message_type()->well_known_type() == + Descriptor::WELLKNOWNTYPE_ANY; +} + +absl::StatusOr MapValueAsEquatableValue( + google::protobuf::Arena* absl_nonnull arena, const DescriptorPool* absl_nonnull pool, + MessageFactory* absl_nonnull factory, EquatableValueReflection& reflection, + const google::protobuf::MapValueConstRef& value, + const FieldDescriptor* absl_nonnull field, std::string& scratch, + Unique& unpacked) { + if (IsAnyField(field)) { + CEL_ASSIGN_OR_RETURN(unpacked, well_known_types::UnpackAnyIfResolveable( + arena, reflection.any_reflection, + value.GetMessageValue(), pool, factory)); + if (unpacked) { + return AsEquatableValue(reflection, *unpacked, unpacked->GetDescriptor(), + scratch); + } + return AsEquatableValue(reflection, value.GetMessageValue(), + value.GetMessageValue().GetDescriptor(), scratch); + } + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return static_cast(value.GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return value.GetInt64Value(); + case FieldDescriptor::CPPTYPE_UINT32: + return static_cast(value.GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return value.GetUInt64Value(); + case FieldDescriptor::CPPTYPE_DOUBLE: + return value.GetDoubleValue(); + case FieldDescriptor::CPPTYPE_FLOAT: + return static_cast(value.GetFloatValue()); + case FieldDescriptor::CPPTYPE_BOOL: + return value.GetBoolValue(); + case FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return nullptr; + } + return static_cast(value.GetEnumValue()); + case FieldDescriptor::CPPTYPE_STRING: + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return well_known_types::BytesValue( + absl::string_view(value.GetStringValue())); + } + return well_known_types::StringValue( + absl::string_view(value.GetStringValue())); + case FieldDescriptor::CPPTYPE_MESSAGE: { + const auto& message = value.GetMessageValue(); + return AsEquatableValue(reflection, message, message.GetDescriptor(), + scratch); + } + default: + return absl::InternalError( + absl::StrCat("unexpected field type: ", field->cpp_type_name())); + } +} + +absl::StatusOr RepeatedFieldAsEquatableValue( + google::protobuf::Arena* absl_nonnull arena, const DescriptorPool* absl_nonnull pool, + MessageFactory* absl_nonnull factory, EquatableValueReflection& reflection, + const Message& message, const FieldDescriptor* absl_nonnull field, + int index, std::string& scratch, Unique& unpacked) { + if (IsAnyField(field)) { + const auto& field_value = + message.GetReflection()->GetRepeatedMessage(message, field, index); + CEL_ASSIGN_OR_RETURN(unpacked, well_known_types::UnpackAnyIfResolveable( + arena, reflection.any_reflection, + field_value, pool, factory)); + if (unpacked) { + return AsEquatableValue(reflection, *unpacked, unpacked->GetDescriptor(), + scratch); + } + return AsEquatableValue(reflection, field_value, + field_value.GetDescriptor(), scratch); + } + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return static_cast( + message.GetReflection()->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return message.GetReflection()->GetRepeatedInt64(message, field, index); + case FieldDescriptor::CPPTYPE_UINT32: + return static_cast( + message.GetReflection()->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return message.GetReflection()->GetRepeatedUInt64(message, field, index); + case FieldDescriptor::CPPTYPE_DOUBLE: + return message.GetReflection()->GetRepeatedDouble(message, field, index); + case FieldDescriptor::CPPTYPE_FLOAT: + return static_cast( + message.GetReflection()->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return message.GetReflection()->GetRepeatedBool(message, field, index); + case FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return nullptr; + } + return static_cast( + message.GetReflection()->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return well_known_types::GetRepeatedBytesField(message, field, index, + scratch); + } + return well_known_types::GetRepeatedStringField(message, field, index, + scratch); + case FieldDescriptor::CPPTYPE_MESSAGE: { + const auto& submessage = + message.GetReflection()->GetRepeatedMessage(message, field, index); + return AsEquatableValue(reflection, submessage, + submessage.GetDescriptor(), scratch); + } + default: + return absl::InternalError( + absl::StrCat("unexpected field type: ", field->cpp_type_name())); + } +} + +// Compare two `EquatableValue` for equality. +bool EquatableValueEquals(const EquatableValue& lhs, + const EquatableValue& rhs) { + return absl::visit( + absl::Overload(NullValueEqualer{}, BoolValueEqualer{}, + BytesValueEqualer{}, IntValueEqualer{}, UintValueEqualer{}, + DoubleValueEqualer{}, StringValueEqualer{}, + DurationEqualer{}, TimestampEqualer{}, ListValueEqualer{}, + StructEqualer{}, AnyEqualer{}, MessageEqualer{}), + lhs, rhs); +} + +// Attempts to coalesce one map key to another. Returns true if it was possible, +// false otherwise. +bool CoalesceMapKey(const google::protobuf::MapKey& src, + FieldDescriptor::CppType dest_type, + google::protobuf::MapKey* absl_nonnull dest) { + switch (src.type()) { + case FieldDescriptor::CPPTYPE_BOOL: + if (dest_type != FieldDescriptor::CPPTYPE_BOOL) { + return false; + } + dest->SetBoolValue(src.GetBoolValue()); + return true; + case FieldDescriptor::CPPTYPE_INT32: { + const auto src_value = src.GetInt32Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + dest->SetInt32Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_INT64: + dest->SetInt64Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + if (src_value < 0) { + return false; + } + dest->SetUInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + if (src_value < 0) { + return false; + } + dest->SetUInt64Value(static_cast(src_value)); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_INT64: { + const auto src_value = src.GetInt64Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + if (src_value < std::numeric_limits::min() || + src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_INT64: + dest->SetInt64Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + if (src_value < 0 || + src_value > std::numeric_limits::max()) { + return false; + } + dest->SetUInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + if (src_value < 0) { + return false; + } + dest->SetUInt64Value(static_cast(src_value)); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_UINT32: { + const auto src_value = src.GetUInt32Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_INT64: + dest->SetInt64Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + dest->SetUInt32Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + dest->SetUInt64Value(static_cast(src_value)); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_UINT64: { + const auto src_value = src.GetUInt64Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_INT64: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt64Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetUInt32Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + dest->SetUInt64Value(src_value); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_STRING: + if (dest_type != FieldDescriptor::CPPTYPE_STRING) { + return false; + } + dest->SetStringValue(src.GetStringValue()); + return true; + default: + // Only bool, integrals, and string may be map keys. + ABSL_UNREACHABLE(); + } +} + +// Bits used for categorizing equality. Can be used to cheaply check whether two +// categories are comparable for equality by performing an AND and checking if +// the result against `kNone`. +enum class EquatableCategory { + kNone = 0, + + kNullLike = 1 << 0, + kBoolLike = 1 << 1, + kNumericLike = 1 << 2, + kBytesLike = 1 << 3, + kStringLike = 1 << 4, + kList = 1 << 5, + kMap = 1 << 6, + kMessage = 1 << 7, + kDuration = 1 << 8, + kTimestamp = 1 << 9, + + kAny = kNullLike | kBoolLike | kNumericLike | kBytesLike | kStringLike | + kList | kMap | kMessage | kDuration | kTimestamp, + kValue = kNullLike | kBoolLike | kNumericLike | kStringLike | kList | kMap, +}; + +constexpr EquatableCategory operator&(EquatableCategory lhs, + EquatableCategory rhs) { + return static_cast( + static_cast>(lhs) & + static_cast>(rhs)); +} + +constexpr bool operator==(EquatableCategory lhs, EquatableCategory rhs) { + return static_cast>(lhs) == + static_cast>(rhs); +} + +EquatableCategory GetEquatableCategory( + const Descriptor* absl_nonnull descriptor) { + switch (descriptor->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return EquatableCategory::kBoolLike; + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return EquatableCategory::kNumericLike; + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return EquatableCategory::kBytesLike; + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return EquatableCategory::kStringLike; + case Descriptor::WELLKNOWNTYPE_VALUE: + return EquatableCategory::kValue; + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + return EquatableCategory::kList; + case Descriptor::WELLKNOWNTYPE_STRUCT: + return EquatableCategory::kMap; + case Descriptor::WELLKNOWNTYPE_ANY: + return EquatableCategory::kAny; + case Descriptor::WELLKNOWNTYPE_DURATION: + return EquatableCategory::kDuration; + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return EquatableCategory::kTimestamp; + default: + return EquatableCategory::kAny; + } +} + +EquatableCategory GetEquatableFieldCategory( + const FieldDescriptor* absl_nonnull field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_ENUM: + return field->enum_type()->full_name() == "google.protobuf.NullValue" + ? EquatableCategory::kNullLike + : EquatableCategory::kNumericLike; + case FieldDescriptor::CPPTYPE_BOOL: + return EquatableCategory::kBoolLike; + case FieldDescriptor::CPPTYPE_FLOAT: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_DOUBLE: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_INT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_UINT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_INT64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_UINT64: + return EquatableCategory::kNumericLike; + case FieldDescriptor::CPPTYPE_STRING: + return field->type() == FieldDescriptor::TYPE_BYTES + ? EquatableCategory::kBytesLike + : EquatableCategory::kStringLike; + case FieldDescriptor::CPPTYPE_MESSAGE: + return GetEquatableCategory(field->message_type()); + default: + // Ugh. Force any future additions to compare instead of short circuiting. + return EquatableCategory::kAny; + } +} + +class MessageEqualsState final { + public: + MessageEqualsState(const DescriptorPool* absl_nonnull pool, + MessageFactory* absl_nonnull factory) + : pool_(pool), factory_(factory) {} + + // Equality between messages. + absl::StatusOr Equals(const Message& lhs, const Message& rhs) { + const auto* lhs_descriptor = lhs.GetDescriptor(); + const auto* rhs_descriptor = rhs.GetDescriptor(); + // Deal with well known types, starting with any. + auto lhs_well_known_type = lhs_descriptor->well_known_type(); + auto rhs_well_known_type = rhs_descriptor->well_known_type(); + const Message* absl_nonnull lhs_ptr = &lhs; + const Message* absl_nonnull rhs_ptr = &rhs; + Unique lhs_unpacked; + Unique rhs_unpacked; + // Deal with any first. We could in theory check if we should bother + // unpacking, but that is more complicated. We can always implement it + // later. + if (lhs_well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { + CEL_ASSIGN_OR_RETURN( + lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, lhs, pool_, factory_)); + if (lhs_unpacked) { + lhs_ptr = cel::to_address(lhs_unpacked); + lhs_descriptor = lhs_ptr->GetDescriptor(); + lhs_well_known_type = lhs_descriptor->well_known_type(); + } + } + if (rhs_well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { + CEL_ASSIGN_OR_RETURN( + rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, rhs, pool_, factory_)); + if (rhs_unpacked) { + rhs_ptr = cel::to_address(rhs_unpacked); + rhs_descriptor = rhs_ptr->GetDescriptor(); + rhs_well_known_type = rhs_descriptor->well_known_type(); + } + } + CEL_ASSIGN_OR_RETURN( + auto lhs_value, + AsEquatableValue(lhs_reflection_, *lhs_ptr, lhs_descriptor, + lhs_well_known_type, lhs_scratch_)); + CEL_ASSIGN_OR_RETURN( + auto rhs_value, + AsEquatableValue(rhs_reflection_, *rhs_ptr, rhs_descriptor, + rhs_well_known_type, rhs_scratch_)); + return EquatableValueEquals(lhs_value, rhs_value); + } + + // Equality between map message fields. + absl::StatusOr MapFieldEquals( + const Message& lhs, const FieldDescriptor* absl_nonnull lhs_field, + const Message& rhs, const FieldDescriptor* absl_nonnull rhs_field) { + ABSL_DCHECK(lhs_field->is_map()); + ABSL_DCHECK_EQ(lhs_field->containing_type(), lhs.GetDescriptor()); + ABSL_DCHECK(rhs_field->is_map()); + ABSL_DCHECK_EQ(rhs_field->containing_type(), rhs.GetDescriptor()); + const auto* lhs_entry = lhs_field->message_type(); + const auto* lhs_entry_key_field = lhs_entry->map_key(); + const auto* lhs_entry_value_field = lhs_entry->map_value(); + const auto* rhs_entry = rhs_field->message_type(); + const auto* rhs_entry_key_field = rhs_entry->map_key(); + const auto* rhs_entry_value_field = rhs_entry->map_value(); + // Perform cheap test which checks whether the left and right can even be + // compared for equality. + if (lhs_field != rhs_field && + ((GetEquatableFieldCategory(lhs_entry_key_field) & + GetEquatableFieldCategory(rhs_entry_key_field)) == + EquatableCategory::kNone || + (GetEquatableFieldCategory(lhs_entry_value_field) & + GetEquatableFieldCategory(rhs_entry_value_field)) == + EquatableCategory::kNone)) { + // Short-circuit. + return false; + } + const auto* lhs_reflection = lhs.GetReflection(); + const auto* rhs_reflection = rhs.GetReflection(); + if (MapSize(*lhs_reflection, lhs, *lhs_field) != + MapSize(*rhs_reflection, rhs, *rhs_field)) { + return false; + } + auto lhs_begin = ConstMapBegin(*lhs_reflection, lhs, *lhs_field); + const auto lhs_end = ConstMapEnd(*lhs_reflection, lhs, *lhs_field); + Unique lhs_unpacked; + EquatableValue lhs_value; + Unique rhs_unpacked; + EquatableValue rhs_value; + google::protobuf::MapKey rhs_map_key; + google::protobuf::MapValueConstRef rhs_map_value; + for (; lhs_begin != lhs_end; ++lhs_begin) { + if (!CoalesceMapKey(lhs_begin.GetKey(), rhs_entry_key_field->cpp_type(), + &rhs_map_key)) { + return false; + } + if (!LookupMapValue(*rhs_reflection, rhs, *rhs_field, rhs_map_key, + &rhs_map_value)) { + return false; + } + CEL_ASSIGN_OR_RETURN(lhs_value, + MapValueAsEquatableValue( + &arena_, pool_, factory_, lhs_reflection_, + lhs_begin.GetValueRef(), lhs_entry_value_field, + lhs_scratch_, lhs_unpacked)); + CEL_ASSIGN_OR_RETURN( + rhs_value, + MapValueAsEquatableValue(&arena_, pool_, factory_, rhs_reflection_, + rhs_map_value, rhs_entry_value_field, + rhs_scratch_, rhs_unpacked)); + if (!EquatableValueEquals(lhs_value, rhs_value)) { + return false; + } + } + return true; + } + + // Equality between repeated message fields. + absl::StatusOr RepeatedFieldEquals( + const Message& lhs, const FieldDescriptor* absl_nonnull lhs_field, + const Message& rhs, const FieldDescriptor* absl_nonnull rhs_field) { + ABSL_DCHECK(lhs_field->is_repeated() && !lhs_field->is_map()); + ABSL_DCHECK_EQ(lhs_field->containing_type(), lhs.GetDescriptor()); + ABSL_DCHECK(rhs_field->is_repeated() && !rhs_field->is_map()); + ABSL_DCHECK_EQ(rhs_field->containing_type(), rhs.GetDescriptor()); + // Perform cheap test which checks whether the left and right can even be + // compared for equality. + if (lhs_field != rhs_field && + (GetEquatableFieldCategory(lhs_field) & + GetEquatableFieldCategory(rhs_field)) == EquatableCategory::kNone) { + // Short-circuit. + return false; + } + const auto* lhs_reflection = lhs.GetReflection(); + const auto* rhs_reflection = rhs.GetReflection(); + const auto size = lhs_reflection->FieldSize(lhs, lhs_field); + if (size != rhs_reflection->FieldSize(rhs, rhs_field)) { + return false; + } + Unique lhs_unpacked; + EquatableValue lhs_value; + Unique rhs_unpacked; + EquatableValue rhs_value; + for (int i = 0; i < size; ++i) { + CEL_ASSIGN_OR_RETURN(lhs_value, + RepeatedFieldAsEquatableValue( + &arena_, pool_, factory_, lhs_reflection_, lhs, + lhs_field, i, lhs_scratch_, lhs_unpacked)); + CEL_ASSIGN_OR_RETURN(rhs_value, + RepeatedFieldAsEquatableValue( + &arena_, pool_, factory_, rhs_reflection_, rhs, + rhs_field, i, rhs_scratch_, rhs_unpacked)); + if (!EquatableValueEquals(lhs_value, rhs_value)) { + return false; + } + } + return true; + } + + // Equality between singular message fields and/or messages. If the field is + // `nullptr`, we are performing equality on the message itself rather than the + // corresponding field. + absl::StatusOr SingularFieldEquals( + const Message& lhs, const FieldDescriptor* absl_nullable lhs_field, + const Message& rhs, const FieldDescriptor* absl_nullable rhs_field) { + ABSL_DCHECK(lhs_field == nullptr || + (!lhs_field->is_repeated() && !lhs_field->is_map())); + ABSL_DCHECK(lhs_field == nullptr || + lhs_field->containing_type() == lhs.GetDescriptor()); + ABSL_DCHECK(rhs_field == nullptr || + (!rhs_field->is_repeated() && !rhs_field->is_map())); + ABSL_DCHECK(rhs_field == nullptr || + rhs_field->containing_type() == rhs.GetDescriptor()); + // Perform cheap test which checks whether the left and right can even be + // compared for equality. + if (lhs_field != rhs_field && + ((lhs_field != nullptr ? GetEquatableFieldCategory(lhs_field) + : GetEquatableCategory(lhs.GetDescriptor())) & + (rhs_field != nullptr ? GetEquatableFieldCategory(rhs_field) + : GetEquatableCategory(rhs.GetDescriptor()))) == + EquatableCategory::kNone) { + // Short-circuit. + return false; + } + const Message* absl_nonnull lhs_ptr = &lhs; + const Message* absl_nonnull rhs_ptr = &rhs; + Unique lhs_unpacked; + Unique rhs_unpacked; + if (lhs_field != nullptr && IsAnyField(lhs_field)) { + CEL_ASSIGN_OR_RETURN(lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, + lhs.GetReflection()->GetMessage(lhs, lhs_field), + pool_, factory_)); + if (lhs_unpacked) { + lhs_ptr = cel::to_address(lhs_unpacked); + lhs_field = nullptr; + } + } else if (lhs_field == nullptr && IsAny(lhs)) { + CEL_ASSIGN_OR_RETURN( + lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, lhs, pool_, factory_)); + if (lhs_unpacked) { + lhs_ptr = cel::to_address(lhs_unpacked); + } + } + if (rhs_field != nullptr && IsAnyField(rhs_field)) { + CEL_ASSIGN_OR_RETURN(rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, + rhs.GetReflection()->GetMessage(rhs, rhs_field), + pool_, factory_)); + if (rhs_unpacked) { + rhs_ptr = cel::to_address(rhs_unpacked); + rhs_field = nullptr; + } + } else if (rhs_field == nullptr && IsAny(rhs)) { + CEL_ASSIGN_OR_RETURN( + rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, rhs, pool_, factory_)); + if (rhs_unpacked) { + rhs_ptr = cel::to_address(rhs_unpacked); + } + } + EquatableValue lhs_value; + if (lhs_field != nullptr) { + CEL_ASSIGN_OR_RETURN( + lhs_value, + AsEquatableValue(lhs_reflection_, *lhs_ptr, lhs_field, lhs_scratch_)); + } else { + CEL_ASSIGN_OR_RETURN( + lhs_value, AsEquatableValue(lhs_reflection_, *lhs_ptr, + lhs_ptr->GetDescriptor(), lhs_scratch_)); + } + EquatableValue rhs_value; + if (rhs_field != nullptr) { + CEL_ASSIGN_OR_RETURN( + rhs_value, + AsEquatableValue(rhs_reflection_, *rhs_ptr, rhs_field, rhs_scratch_)); + } else { + CEL_ASSIGN_OR_RETURN( + rhs_value, AsEquatableValue(rhs_reflection_, *rhs_ptr, + rhs_ptr->GetDescriptor(), rhs_scratch_)); + } + return EquatableValueEquals(lhs_value, rhs_value); + } + + absl::StatusOr FieldEquals( + const Message& lhs, const FieldDescriptor* absl_nullable lhs_field, + const Message& rhs, const FieldDescriptor* absl_nullable rhs_field) { + ABSL_DCHECK(lhs_field != nullptr || + rhs_field != nullptr); // Both cannot be null. + if (lhs_field != nullptr && lhs_field->is_map()) { + // map == map + // map == google.protobuf.Value + // map == google.protobuf.Struct + // map == google.protobuf.Any + + // Right hand side should be a map, `google.protobuf.Value`, + // `google.protobuf.Struct`, or `google.protobuf.Any`. + if (rhs_field != nullptr && rhs_field->is_map()) { + // map == map + return MapFieldEquals(lhs, lhs_field, rhs, rhs_field); + } + if (rhs_field != nullptr && + (rhs_field->is_repeated() || + rhs_field->type() != FieldDescriptor::TYPE_MESSAGE)) { + return false; + } + const Message* absl_nullable rhs_packed = nullptr; + Unique rhs_unpacked; + if (rhs_field != nullptr && IsAnyField(rhs_field)) { + rhs_packed = &rhs.GetReflection()->GetMessage(rhs, rhs_field); + } else if (rhs_field == nullptr && IsAny(rhs)) { + rhs_packed = &rhs; + } + if (rhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(rhs_reflection_.any_reflection.Initialize( + rhs_packed->GetDescriptor())); + auto rhs_type_url = rhs_reflection_.any_reflection.GetTypeUrl( + *rhs_packed, rhs_scratch_); + if (!rhs_type_url.ConsumePrefix("type.googleapis.com/") && + !rhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (rhs_type_url != "google.protobuf.Value" && + rhs_type_url != "google.protobuf.Struct" && + rhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, + *rhs_packed, pool_, factory_)); + if (rhs_unpacked) { + rhs_field = nullptr; + } + } + const Message* absl_nonnull rhs_message = + rhs_field != nullptr + ? &rhs.GetReflection()->GetMessage(rhs, rhs_field) + : rhs_unpacked != nullptr ? cel::to_address(rhs_unpacked) + : &rhs; + const auto* rhs_descriptor = rhs_message->GetDescriptor(); + const auto rhs_well_known_type = rhs_descriptor->well_known_type(); + switch (rhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + rhs_reflection_.value_reflection.Initialize(rhs_descriptor)); + if (rhs_reflection_.value_reflection.GetKindCase(*rhs_message) != + google::protobuf::Value::kStructValue) { + return false; + } + CEL_RETURN_IF_ERROR(rhs_reflection_.struct_reflection.Initialize( + rhs_reflection_.value_reflection.GetStructDescriptor())); + return MapFieldEquals( + lhs, lhs_field, + rhs_reflection_.value_reflection.GetStructValue(*rhs_message), + rhs_reflection_.struct_reflection.GetFieldsDescriptor()); + } + case Descriptor::WELLKNOWNTYPE_STRUCT: { + // map == google.protobuf.Struct + CEL_RETURN_IF_ERROR( + rhs_reflection_.struct_reflection.Initialize(rhs_descriptor)); + return MapFieldEquals( + lhs, lhs_field, *rhs_message, + rhs_reflection_.struct_reflection.GetFieldsDescriptor()); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + if (rhs_field != nullptr && rhs_field->is_map()) { + // google.protobuf.Value == map + // google.protobuf.Struct == map + // google.protobuf.Any == map + + // Left hand side should be singular `google.protobuf.Value` + // `google.protobuf.Struct`, or `google.protobuf.Any`. + ABSL_DCHECK(lhs_field == nullptr || + !lhs_field->is_map()); // Handled above. + if (lhs_field != nullptr && + (lhs_field->is_repeated() || + lhs_field->type() != FieldDescriptor::TYPE_MESSAGE)) { + return false; + } + const Message* absl_nullable lhs_packed = nullptr; + Unique lhs_unpacked; + if (lhs_field != nullptr && IsAnyField(lhs_field)) { + lhs_packed = &lhs.GetReflection()->GetMessage(lhs, lhs_field); + } else if (lhs_field == nullptr && IsAny(lhs)) { + lhs_packed = &lhs; + } + if (lhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(lhs_reflection_.any_reflection.Initialize( + lhs_packed->GetDescriptor())); + auto lhs_type_url = lhs_reflection_.any_reflection.GetTypeUrl( + *lhs_packed, lhs_scratch_); + if (!lhs_type_url.ConsumePrefix("type.googleapis.com/") && + !lhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (lhs_type_url != "google.protobuf.Value" && + lhs_type_url != "google.protobuf.Struct" && + lhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, + *lhs_packed, pool_, factory_)); + if (lhs_unpacked) { + lhs_field = nullptr; + } + } + const Message* absl_nonnull lhs_message = + lhs_field != nullptr + ? &lhs.GetReflection()->GetMessage(lhs, lhs_field) + : lhs_unpacked != nullptr ? cel::to_address(lhs_unpacked) + : &lhs; + const auto* lhs_descriptor = lhs_message->GetDescriptor(); + const auto lhs_well_known_type = lhs_descriptor->well_known_type(); + switch (lhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + lhs_reflection_.value_reflection.Initialize(lhs_descriptor)); + if (lhs_reflection_.value_reflection.GetKindCase(*lhs_message) != + google::protobuf::Value::kStructValue) { + return false; + } + CEL_RETURN_IF_ERROR(lhs_reflection_.struct_reflection.Initialize( + lhs_reflection_.value_reflection.GetStructDescriptor())); + return MapFieldEquals( + lhs_reflection_.value_reflection.GetStructValue(*lhs_message), + lhs_reflection_.struct_reflection.GetFieldsDescriptor(), rhs, + rhs_field); + } + case Descriptor::WELLKNOWNTYPE_STRUCT: { + // map == google.protobuf.Struct + CEL_RETURN_IF_ERROR( + lhs_reflection_.struct_reflection.Initialize(lhs_descriptor)); + return MapFieldEquals( + *lhs_message, + lhs_reflection_.struct_reflection.GetFieldsDescriptor(), rhs, + rhs_field); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + ABSL_DCHECK(lhs_field == nullptr || + !lhs_field->is_map()); // Handled above. + ABSL_DCHECK(rhs_field == nullptr || + !rhs_field->is_map()); // Handled above. + if (lhs_field != nullptr && lhs_field->is_repeated()) { + // repeated == repeated + // repeated == google.protobuf.Value + // repeated == google.protobuf.ListValue + // repeated == google.protobuf.Any + + // Right hand side should be a repeated, `google.protobuf.Value`, + // `google.protobuf.ListValue`, or `google.protobuf.Any`. + if (rhs_field != nullptr && rhs_field->is_repeated()) { + // map == map + return RepeatedFieldEquals(lhs, lhs_field, rhs, rhs_field); + } + if (rhs_field != nullptr && + rhs_field->type() != FieldDescriptor::TYPE_MESSAGE) { + return false; + } + const Message* absl_nullable rhs_packed = nullptr; + Unique rhs_unpacked; + if (rhs_field != nullptr && IsAnyField(rhs_field)) { + rhs_packed = &rhs.GetReflection()->GetMessage(rhs, rhs_field); + } else if (rhs_field == nullptr && IsAny(rhs)) { + rhs_packed = &rhs; + } + if (rhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(rhs_reflection_.any_reflection.Initialize( + rhs_packed->GetDescriptor())); + auto rhs_type_url = rhs_reflection_.any_reflection.GetTypeUrl( + *rhs_packed, rhs_scratch_); + if (!rhs_type_url.ConsumePrefix("type.googleapis.com/") && + !rhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (rhs_type_url != "google.protobuf.Value" && + rhs_type_url != "google.protobuf.ListValue" && + rhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, + *rhs_packed, pool_, factory_)); + if (rhs_unpacked) { + rhs_field = nullptr; + } + } + const Message* absl_nonnull rhs_message = + rhs_field != nullptr + ? &rhs.GetReflection()->GetMessage(rhs, rhs_field) + : rhs_unpacked != nullptr ? cel::to_address(rhs_unpacked) + : &rhs; + const auto* rhs_descriptor = rhs_message->GetDescriptor(); + const auto rhs_well_known_type = rhs_descriptor->well_known_type(); + switch (rhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + rhs_reflection_.value_reflection.Initialize(rhs_descriptor)); + if (rhs_reflection_.value_reflection.GetKindCase(*rhs_message) != + google::protobuf::Value::kListValue) { + return false; + } + CEL_RETURN_IF_ERROR(rhs_reflection_.list_value_reflection.Initialize( + rhs_reflection_.value_reflection.GetListValueDescriptor())); + return RepeatedFieldEquals( + lhs, lhs_field, + rhs_reflection_.value_reflection.GetListValue(*rhs_message), + rhs_reflection_.list_value_reflection.GetValuesDescriptor()); + } + case Descriptor::WELLKNOWNTYPE_LISTVALUE: { + // map == google.protobuf.ListValue + CEL_RETURN_IF_ERROR( + rhs_reflection_.list_value_reflection.Initialize(rhs_descriptor)); + return RepeatedFieldEquals( + lhs, lhs_field, *rhs_message, + rhs_reflection_.list_value_reflection.GetValuesDescriptor()); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + if (rhs_field != nullptr && rhs_field->is_repeated()) { + // google.protobuf.Value == repeated + // google.protobuf.ListValue == repeated + // google.protobuf.Any == repeated + + // Left hand side should be singular `google.protobuf.Value` + // `google.protobuf.ListValue`, or `google.protobuf.Any`. + ABSL_DCHECK(lhs_field == nullptr || + !lhs_field->is_repeated()); // Handled above. + if (lhs_field != nullptr && + lhs_field->type() != FieldDescriptor::TYPE_MESSAGE) { + return false; + } + const Message* absl_nullable lhs_packed = nullptr; + Unique lhs_unpacked; + if (lhs_field != nullptr && IsAnyField(lhs_field)) { + lhs_packed = &lhs.GetReflection()->GetMessage(lhs, lhs_field); + } else if (lhs_field == nullptr && IsAny(lhs)) { + lhs_packed = &lhs; + } + if (lhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(lhs_reflection_.any_reflection.Initialize( + lhs_packed->GetDescriptor())); + auto lhs_type_url = lhs_reflection_.any_reflection.GetTypeUrl( + *lhs_packed, lhs_scratch_); + if (!lhs_type_url.ConsumePrefix("type.googleapis.com/") && + !lhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (lhs_type_url != "google.protobuf.Value" && + lhs_type_url != "google.protobuf.ListValue" && + lhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, + *lhs_packed, pool_, factory_)); + if (lhs_unpacked) { + lhs_field = nullptr; + } + } + const Message* absl_nonnull lhs_message = + lhs_field != nullptr + ? &lhs.GetReflection()->GetMessage(lhs, lhs_field) + : lhs_unpacked != nullptr ? cel::to_address(lhs_unpacked) + : &lhs; + const auto* lhs_descriptor = lhs_message->GetDescriptor(); + const auto lhs_well_known_type = lhs_descriptor->well_known_type(); + switch (lhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + lhs_reflection_.value_reflection.Initialize(lhs_descriptor)); + if (lhs_reflection_.value_reflection.GetKindCase(*lhs_message) != + google::protobuf::Value::kListValue) { + return false; + } + CEL_RETURN_IF_ERROR(lhs_reflection_.list_value_reflection.Initialize( + lhs_reflection_.value_reflection.GetListValueDescriptor())); + return RepeatedFieldEquals( + lhs_reflection_.value_reflection.GetListValue(*lhs_message), + lhs_reflection_.list_value_reflection.GetValuesDescriptor(), rhs, + rhs_field); + } + case Descriptor::WELLKNOWNTYPE_LISTVALUE: { + // map == google.protobuf.ListValue + CEL_RETURN_IF_ERROR( + lhs_reflection_.list_value_reflection.Initialize(lhs_descriptor)); + return RepeatedFieldEquals( + *lhs_message, + lhs_reflection_.list_value_reflection.GetValuesDescriptor(), rhs, + rhs_field); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + return SingularFieldEquals(lhs, lhs_field, rhs, rhs_field); + } + + private: + const DescriptorPool* absl_nonnull const pool_; + MessageFactory* absl_nonnull const factory_; + google::protobuf::Arena arena_; + EquatableValueReflection lhs_reflection_; + EquatableValueReflection rhs_reflection_; + std::string lhs_scratch_; + std::string rhs_scratch_; +}; + +} // namespace + +absl::StatusOr MessageEquals(const Message& lhs, const Message& rhs, + const DescriptorPool* absl_nonnull pool, + MessageFactory* absl_nonnull factory) { + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + if (&lhs == &rhs) { + return true; + } + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory)->Equals(lhs, rhs); +} + +absl::StatusOr MessageFieldEquals( + const Message& lhs, const FieldDescriptor* absl_nonnull lhs_field, + const Message& rhs, const FieldDescriptor* absl_nonnull rhs_field, + const DescriptorPool* absl_nonnull pool, + MessageFactory* absl_nonnull factory) { + ABSL_DCHECK(lhs_field != nullptr); + ABSL_DCHECK(rhs_field != nullptr); + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + if (&lhs == &rhs && lhs_field == rhs_field) { + return true; + } + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory) + ->FieldEquals(lhs, lhs_field, rhs, rhs_field); +} + +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, + const google::protobuf::FieldDescriptor* absl_nonnull rhs_field, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory) { + ABSL_DCHECK(rhs_field != nullptr); + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory) + ->FieldEquals(lhs, nullptr, rhs, rhs_field); +} + +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, + const google::protobuf::FieldDescriptor* absl_nonnull lhs_field, + const google::protobuf::Message& rhs, const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory) { + ABSL_DCHECK(lhs_field != nullptr); + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory) + ->FieldEquals(lhs, lhs_field, rhs, nullptr); +} + +} // namespace cel::internal diff --git a/internal/message_equality.h b/internal/message_equality.h new file mode 100644 index 000000000..3f7fabd2c --- /dev/null +++ b/internal/message_equality.h @@ -0,0 +1,54 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::internal { + +// Tests whether one message is equal to another following CEL equality +// semantics. +absl::StatusOr MessageEquals( + const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory); + +// Tests whether one message field is equal to another following CEL equality +// semantics. +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, + const google::protobuf::FieldDescriptor* absl_nonnull lhs_field, + const google::protobuf::Message& rhs, + const google::protobuf::FieldDescriptor* absl_nonnull rhs_field, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory); +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, + const google::protobuf::FieldDescriptor* absl_nonnull rhs_field, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory); +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, + const google::protobuf::FieldDescriptor* absl_nonnull lhs_field, + const google::protobuf::Message& rhs, const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc new file mode 100644 index 000000000..092edd71b --- /dev/null +++ b/internal/message_equality_test.cc @@ -0,0 +1,1055 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/message_equality.h" + +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/memory.h" +#include "internal/message_type_name.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "internal/well_known_types.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::testing::IsFalse; +using ::testing::IsTrue; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +template +google::protobuf::Message* ParseTextProto(absl::string_view text) { + return DynamicParseTextProto(GetTestArena(), text, + GetTestingDescriptorPool(), + GetTestingMessageFactory()); +} + +struct UnaryMessageEqualsTestParam { + std::string name; + std::vector ops; + bool equal; +}; + +std::string UnaryMessageEqualsTestParamName( + const TestParamInfo& param_info) { + return param_info.param.name; +} + +using UnaryMessageEqualsTest = TestWithParam; + +google::protobuf::Message* PackMessage(const google::protobuf::Message& message) { + const auto* descriptor = + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( + MessageTypeNameFor())); + const auto* prototype = + ABSL_DIE_IF_NULL(GetTestingMessageFactory()->GetPrototype(descriptor)); + auto instance = prototype->New(GetTestArena()); + auto reflection = well_known_types::GetAnyReflectionOrDie(descriptor); + reflection.SetTypeUrl( + cel::to_address(instance), + absl::StrCat("type.googleapis.com/", message.GetTypeName())); + absl::Cord value; + ABSL_CHECK(message.SerializeToString(&value)); + reflection.SetValue(cel::to_address(instance), value); + return instance; +} + +TEST_P(UnaryMessageEqualsTest, Equals) { + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + const auto& test_case = GetParam(); + for (const auto& lhs : test_case.ops) { + for (const auto& rhs : test_case.ops) { + if (!test_case.equal && &lhs == &rhs) { + continue; + } + EXPECT_THAT(MessageEquals(*lhs, *rhs, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); + EXPECT_THAT(MessageEquals(*rhs, *lhs, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); + // Test any. + auto lhs_any = PackMessage(*lhs); + auto rhs_any = PackMessage(*rhs); + EXPECT_THAT(MessageEquals(*lhs_any, *rhs, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_any->ShortDebugString() << " " << rhs->ShortDebugString(); + EXPECT_THAT(MessageEquals(*lhs, *rhs_any, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs->ShortDebugString() << " " << rhs_any->ShortDebugString(); + EXPECT_THAT(MessageEquals(*lhs_any, *rhs_any, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_any->ShortDebugString() << " " << rhs_any->ShortDebugString(); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + UnaryMessageEqualsTest, UnaryMessageEqualsTest, + ValuesIn({ + { + .name = "NullValue_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(null_value: NULL_VALUE)pb"), + }, + .equal = true, + }, + { + .name = "BoolValue_False_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(value: false)pb"), + ParseTextProto( + R"pb(bool_value: false)pb"), + }, + .equal = true, + }, + { + .name = "BoolValue_True_Equal", + .ops = + { + ParseTextProto( + R"pb(value: true)pb"), + ParseTextProto(R"pb(bool_value: + true)pb"), + }, + .equal = true, + }, + { + .name = "StringValue_Empty_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(value: "")pb"), + ParseTextProto( + R"pb(string_value: "")pb"), + }, + .equal = true, + }, + { + .name = "StringValue_Equal", + .ops = + { + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(string_value: "foo")pb"), + }, + .equal = true, + }, + { + .name = "BytesValue_Empty_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(value: "")pb"), + }, + .equal = true, + }, + { + .name = "BytesValue_Equal", + .ops = + { + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(value: "foo")pb"), + }, + .equal = true, + }, + { + .name = "ListValue_Equal", + .ops = + { + ParseTextProto( + R"pb(list_value: { values { bool_value: true } })pb"), + ParseTextProto( + R"pb(values { bool_value: true })pb"), + }, + .equal = true, + }, + { + .name = "ListValue_NotEqual", + .ops = + { + ParseTextProto( + R"pb(list_value: { values { number_value: 0.0 } })pb"), + ParseTextProto( + R"pb(values { number_value: 1.0 })pb"), + ParseTextProto( + R"pb(list_value: { values { number_value: 2.0 } })pb"), + ParseTextProto( + R"pb(values { number_value: 3.0 })pb"), + }, + .equal = false, + }, + { + .name = "StructValue_Equal", + .ops = + { + ParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { bool_value: true } + } + })pb"), + ParseTextProto( + R"pb(fields { + key: "foo" + value: { bool_value: true } + })pb"), + }, + .equal = true, + }, + { + .name = "StructValue_NotEqual", + .ops = + { + ParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { number_value: 0.0 } + } + })pb"), + ParseTextProto( + R"pb( + fields { + key: "bar" + value: { number_value: 0.0 } + })pb"), + ParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { number_value: 1.0 } + } + })pb"), + ParseTextProto( + R"pb( + fields { + key: "bar" + value: { number_value: 1.0 } + })pb"), + }, + .equal = false, + }, + { + .name = "Heterogeneous_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb(number_value: + 0.0)pb"), + }, + .equal = true, + }, + { + .name = "Message_Equals", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + }, + .equal = true, + }, + { + .name = "Heterogeneous_NotEqual", + .ops = + { + ParseTextProto( + R"pb(value: false)pb"), + ParseTextProto( + R"pb(value: 0)pb"), + ParseTextProto( + R"pb(value: 1)pb"), + ParseTextProto( + R"pb(value: 2)pb"), + ParseTextProto( + R"pb(value: 3)pb"), + ParseTextProto( + R"pb(value: 4.0)pb"), + ParseTextProto( + R"pb(value: 5.0)pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb(bool_value: + true)pb"), + ParseTextProto(R"pb(number_value: + 6.0)pb"), + ParseTextProto( + R"pb(string_value: "bar")pb"), + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(value: "")pb"), + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(list_value: {})pb"), + ParseTextProto( + R"pb(values { bool_value: true })pb"), + ParseTextProto(R"pb(struct_value: + {})pb"), + ParseTextProto( + R"pb(fields { + key: "foo" + value: { bool_value: false } + })pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(single_bool: true)pb"), + }, + .equal = false, + }, + }), + UnaryMessageEqualsTestParamName); + +struct UnaryMessageFieldEqualsTestParam { + std::string name; + std::string message; + std::vector fields; + bool equal; +}; + +std::string UnaryMessageFieldEqualsTestParamName( + const TestParamInfo& param_info) { + return param_info.param.name; +} + +using UnaryMessageFieldEqualsTest = + TestWithParam; + +void PackMessageTo(const google::protobuf::Message& message, google::protobuf::Message* instance) { + auto reflection = + *well_known_types::GetAnyReflection(instance->GetDescriptor()); + reflection.SetTypeUrl( + instance, absl::StrCat("type.googleapis.com/", message.GetTypeName())); + absl::Cord value; + ABSL_CHECK(message.SerializeToString(&value)); + reflection.SetValue(instance, value); +} + +absl::optional, + const google::protobuf::FieldDescriptor* absl_nonnull>> +PackTestAllTypesProto3Field(const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor* absl_nonnull field) { + if (field->is_map()) { + return std::nullopt; + } + if (field->is_repeated() && + field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + const auto* descriptor = message.GetDescriptor(); + const auto* any_field = descriptor->FindFieldByName("repeated_any"); + auto packed = WrapShared(message.New(), NewDeleteAllocator<>{}); + const int size = message.GetReflection()->FieldSize(message, field); + for (int i = 0; i < size; ++i) { + PackMessageTo( + message.GetReflection()->GetRepeatedMessage(message, field, i), + packed->GetReflection()->AddMessage(cel::to_address(packed), + any_field)); + } + return std::pair{packed, any_field}; + } + if (!field->is_repeated() && + field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + const auto* descriptor = message.GetDescriptor(); + const auto* any_field = descriptor->FindFieldByName("single_any"); + auto packed = WrapShared(message.New(), NewDeleteAllocator<>{}); + PackMessageTo(message.GetReflection()->GetMessage(message, field), + packed->GetReflection()->MutableMessage( + cel::to_address(packed), any_field)); + return std::pair{packed, any_field}; + } + return std::nullopt; +} + +TEST_P(UnaryMessageFieldEqualsTest, Equals) { + // We perform exhaustive comparison by testing for equality (or inequality) + // against all combinations of fields. Additionally we convert to + // `google.protobuf.Any` where applicable. This is all done for coverage and + // to ensure different combinations, regardless of argument order, produce the + // same result. + + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + const auto& test_case = GetParam(); + auto lhs_message = ParseTextProto(test_case.message); + auto rhs_message = ParseTextProto(test_case.message); + const auto* descriptor = ABSL_DIE_IF_NULL( + pool->FindMessageTypeByName(MessageTypeNameFor())); + for (const auto& lhs : test_case.fields) { + for (const auto& rhs : test_case.fields) { + if (!test_case.equal && lhs == rhs) { + // When testing for inequality, do not compare the same field to itself. + continue; + } + const auto* lhs_field = + ABSL_DIE_IF_NULL(descriptor->FindFieldByName(lhs)); + const auto* rhs_field = + ABSL_DIE_IF_NULL(descriptor->FindFieldByName(rhs)); + EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_message, + rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); + EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, *lhs_message, + lhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); + if (!lhs_field->is_repeated() && + lhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + EXPECT_THAT(MessageFieldEquals(lhs_message->GetReflection()->GetMessage( + *lhs_message, lhs_field), + *rhs_message, rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); + EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, + lhs_message->GetReflection()->GetMessage( + *lhs_message, lhs_field), + pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); + } + if (!rhs_field->is_repeated() && + rhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, + rhs_message->GetReflection()->GetMessage( + *rhs_message, rhs_field), + pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); + EXPECT_THAT(MessageFieldEquals(rhs_message->GetReflection()->GetMessage( + *rhs_message, rhs_field), + *lhs_message, lhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); + } + // Test `google.protobuf.Any`. + absl::optional, + const google::protobuf::FieldDescriptor* absl_nonnull>> + lhs_any = PackTestAllTypesProto3Field(*lhs_message, lhs_field); + absl::optional, + const google::protobuf::FieldDescriptor* absl_nonnull>> + rhs_any = PackTestAllTypesProto3Field(*rhs_message, rhs_field); + if (lhs_any) { + EXPECT_THAT(MessageFieldEquals(*lhs_any->first, lhs_any->second, + *rhs_message, rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); + if (!lhs_any->second->is_repeated()) { + EXPECT_THAT( + MessageFieldEquals(lhs_any->first->GetReflection()->GetMessage( + *lhs_any->first, lhs_any->second), + *rhs_message, rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); + } + } + if (rhs_any) { + EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_any->first, + rhs_any->second, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); + if (!rhs_any->second->is_repeated()) { + EXPECT_THAT( + MessageFieldEquals(*lhs_message, lhs_field, + rhs_any->first->GetReflection()->GetMessage( + *rhs_any->first, rhs_any->second), + pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); + } + } + if (lhs_any && rhs_any) { + EXPECT_THAT( + MessageFieldEquals(*lhs_any->first, lhs_any->second, + *rhs_any->first, rhs_any->second, pool, factory), + IsOkAndHolds(test_case.equal)) + << lhs_any->first->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); + } + } + } +} + +INSTANTIATE_TEST_SUITE_P( + UnaryMessageFieldEqualsTest, UnaryMessageFieldEqualsTest, + ValuesIn({ + { + .name = "Heterogeneous_Single_Equal", + .message = R"pb( + single_int32: 1 + single_int64: 1 + single_uint32: 1 + single_uint64: 1 + single_float: 1 + single_double: 1 + single_value: { number_value: 1 } + single_int32_wrapper: { value: 1 } + single_int64_wrapper: { value: 1 } + single_uint32_wrapper: { value: 1 } + single_uint64_wrapper: { value: 1 } + single_float_wrapper: { value: 1 } + single_double_wrapper: { value: 1 } + standalone_enum: BAR + )pb", + .fields = + { + "single_int32", + "single_int64", + "single_uint32", + "single_uint64", + "single_float", + "single_double", + "single_value", + "single_int32_wrapper", + "single_int64_wrapper", + "single_uint32_wrapper", + "single_uint64_wrapper", + "single_float_wrapper", + "single_double_wrapper", + "standalone_enum", + }, + .equal = true, + }, + { + .name = "Heterogeneous_Single_NotEqual", + .message = R"pb( + null_value: NULL_VALUE + single_bool: false + single_int32: 2 + single_int64: 3 + single_uint32: 4 + single_uint64: 5 + single_float: NaN + single_double: NaN + single_string: "foo" + single_bytes: "foo" + single_value: { number_value: 8 } + single_int32_wrapper: { value: 9 } + single_int64_wrapper: { value: 10 } + single_uint32_wrapper: { value: 11 } + single_uint64_wrapper: { value: 12 } + single_float_wrapper: { value: 13 } + single_double_wrapper: { value: 14 } + single_string_wrapper: { value: "bar" } + single_bytes_wrapper: { value: "bar" } + standalone_enum: BAR + )pb", + .fields = + { + "null_value", + "single_bool", + "single_int32", + "single_int64", + "single_uint32", + "single_uint64", + "single_float", + "single_double", + "single_string", + "single_bytes", + "single_value", + "single_int32_wrapper", + "single_int64_wrapper", + "single_uint32_wrapper", + "single_uint64_wrapper", + "single_float_wrapper", + "single_double_wrapper", + "standalone_enum", + }, + .equal = false, + }, + { + .name = "Heterogeneous_Repeated_Equal", + .message = R"pb( + repeated_int32: 1 + repeated_int64: 1 + repeated_uint32: 1 + repeated_uint64: 1 + repeated_float: 1 + repeated_double: 1 + repeated_value: { number_value: 1 } + repeated_int32_wrapper: { value: 1 } + repeated_int64_wrapper: { value: 1 } + repeated_uint32_wrapper: { value: 1 } + repeated_uint64_wrapper: { value: 1 } + repeated_float_wrapper: { value: 1 } + repeated_double_wrapper: { value: 1 } + repeated_nested_enum: BAR + single_value: { list_value: { values { number_value: 1 } } } + list_value: { values { number_value: 1 } } + )pb", + .fields = + { + "repeated_int32", + "repeated_int64", + "repeated_uint32", + "repeated_uint64", + "repeated_float", + "repeated_double", + "repeated_value", + "repeated_int32_wrapper", + "repeated_int64_wrapper", + "repeated_uint32_wrapper", + "repeated_uint64_wrapper", + "repeated_float_wrapper", + "repeated_double_wrapper", + "repeated_nested_enum", + "single_value", + "list_value", + }, + .equal = true, + }, + { + .name = "Heterogeneous_Repeated_NotEqual", + .message = R"pb( + repeated_null_value: NULL_VALUE + repeated_bool: false + repeated_int32: 2 + repeated_int64: 3 + repeated_uint32: 4 + repeated_uint64: 5 + repeated_float: 6 + repeated_double: 7 + repeated_string: "foo" + repeated_bytes: "foo" + repeated_value: { number_value: 8 } + repeated_int32_wrapper: { value: 9 } + repeated_int64_wrapper: { value: 10 } + repeated_uint32_wrapper: { value: 11 } + repeated_uint64_wrapper: { value: 12 } + repeated_float_wrapper: { value: 13 } + repeated_double_wrapper: { value: 14 } + repeated_string_wrapper: { value: "bar" } + repeated_bytes_wrapper: { value: "bar" } + repeated_nested_enum: BAR + )pb", + .fields = + { + "repeated_null_value", + "repeated_bool", + "repeated_int32", + "repeated_int64", + "repeated_uint32", + "repeated_uint64", + "repeated_float", + "repeated_double", + "repeated_string", + "repeated_bytes", + "repeated_value", + "repeated_int32_wrapper", + "repeated_int64_wrapper", + "repeated_uint32_wrapper", + "repeated_uint64_wrapper", + "repeated_float_wrapper", + "repeated_double_wrapper", + "repeated_nested_enum", + }, + .equal = false, + }, + { + .name = "Heterogeneous_Map_Equal", + .message = R"pb( + map_int32_int32 { key: 1 value: 1 } + map_int32_uint32 { key: 1 value: 1 } + map_int32_int64 { key: 1 value: 1 } + map_int32_uint64 { key: 1 value: 1 } + map_int32_float { key: 1 value: 1 } + map_int32_double { key: 1 value: 1 } + map_int32_enum { key: 1 value: BAR } + map_int32_value { + key: 1 + value: { number_value: 1 } + } + map_int32_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_float_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_double_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_int32 { key: 1 value: 1 } + map_int64_uint32 { key: 1 value: 1 } + map_int64_int64 { key: 1 value: 1 } + map_int64_uint64 { key: 1 value: 1 } + map_int64_float { key: 1 value: 1 } + map_int64_double { key: 1 value: 1 } + map_int64_enum { key: 1 value: BAR } + map_int64_value { + key: 1 + value: { number_value: 1 } + } + map_int64_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_float_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_double_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_int32 { key: 1 value: 1 } + map_uint32_uint32 { key: 1 value: 1 } + map_uint32_int64 { key: 1 value: 1 } + map_uint32_uint64 { key: 1 value: 1 } + map_uint32_float { key: 1 value: 1 } + map_uint32_double { key: 1 value: 1 } + map_uint32_enum { key: 1 value: BAR } + map_uint32_value { + key: 1 + value: { number_value: 1 } + } + map_uint32_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_float_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_double_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_int32 { key: 1 value: 1 } + map_uint64_uint32 { key: 1 value: 1 } + map_uint64_int64 { key: 1 value: 1 } + map_uint64_uint64 { key: 1 value: 1 } + map_uint64_float { key: 1 value: 1 } + map_uint64_double { key: 1 value: 1 } + map_uint64_enum { key: 1 value: BAR } + map_uint64_value { + key: 1 + value: { number_value: 1 } + } + map_uint64_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_float_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_double_wrapper { + key: 1 + value: { value: 1 } + } + )pb", + .fields = + { + "map_int32_int32", "map_int32_uint32", + "map_int32_int64", "map_int32_uint64", + "map_int32_float", "map_int32_double", + "map_int32_enum", "map_int32_value", + "map_int32_int32_wrapper", "map_int32_uint32_wrapper", + "map_int32_int64_wrapper", "map_int32_uint64_wrapper", + "map_int32_float_wrapper", "map_int32_double_wrapper", + "map_int64_int32", "map_int64_uint32", + "map_int64_int64", "map_int64_uint64", + "map_int64_float", "map_int64_double", + "map_int64_enum", "map_int64_value", + "map_int64_int32_wrapper", "map_int64_uint32_wrapper", + "map_int64_int64_wrapper", "map_int64_uint64_wrapper", + "map_int64_float_wrapper", "map_int64_double_wrapper", + "map_uint32_int32", "map_uint32_uint32", + "map_uint32_int64", "map_uint32_uint64", + "map_uint32_float", "map_uint32_double", + "map_uint32_enum", "map_uint32_value", + "map_uint32_int32_wrapper", "map_uint32_uint32_wrapper", + "map_uint32_int64_wrapper", "map_uint32_uint64_wrapper", + "map_uint32_float_wrapper", "map_uint32_double_wrapper", + "map_uint64_int32", "map_uint64_uint32", + "map_uint64_int64", "map_uint64_uint64", + "map_uint64_float", "map_uint64_double", + "map_uint64_enum", "map_uint64_value", + "map_uint64_int32_wrapper", "map_uint64_uint32_wrapper", + "map_uint64_int64_wrapper", "map_uint64_uint64_wrapper", + "map_uint64_float_wrapper", "map_uint64_double_wrapper", + }, + .equal = true, + }, + { + .name = "Heterogeneous_Map_NotEqual", + .message = R"pb( + map_bool_bool { key: false value: false } + map_bool_int32 { key: false value: 1 } + map_bool_uint32 { key: false value: 0 } + map_int32_int32 { key: 0x7FFFFFFF value: 1 } + map_int64_int64 { key: 0x7FFFFFFFFFFFFFFF value: 1 } + map_uint32_uint32 { key: 0xFFFFFFFF value: 1 } + map_uint64_uint64 { key: 0xFFFFFFFFFFFFFFFF value: 1 } + map_string_string { key: "foo" value: "bar" } + map_string_bytes { key: "foo" value: "bar" } + map_int32_bytes { key: -2147483648 value: "bar" } + map_int64_bytes { key: -9223372036854775808 value: "bar" } + map_int32_float { key: -2147483648 value: 1 } + map_int64_double { key: -9223372036854775808 value: 1 } + map_uint32_string { key: 0xFFFFFFFF value: "bar" } + map_uint64_string { key: 0xFFFFFFFF value: "foo" } + map_uint32_bytes { key: 0xFFFFFFFF value: "bar" } + map_uint64_bytes { key: 0xFFFFFFFF value: "foo" } + map_uint32_bool { key: 0xFFFFFFFF value: false } + map_uint64_bool { key: 0xFFFFFFFF value: true } + single_value: { + struct_value: { + fields { + key: "bar" + value: { string_value: "foo" } + } + } + } + single_struct: { + fields { + key: "baz" + value: { string_value: "foo" } + } + } + standalone_message: {} + )pb", + .fields = + { + "map_bool_bool", "map_bool_int32", + "map_bool_uint32", "map_int32_int32", + "map_int64_int64", "map_uint32_uint32", + "map_uint64_uint64", "map_string_string", + "map_string_bytes", "map_int32_bytes", + "map_int64_bytes", "map_int32_float", + "map_int64_double", "map_uint32_string", + "map_uint64_string", "map_uint32_bytes", + "map_uint64_bytes", "map_uint32_bool", + "map_uint64_bool", "single_value", + "single_struct", "standalone_message", + }, + .equal = false, + }, + }), + UnaryMessageFieldEqualsTestParamName); + +TEST(MessageEquals, AnyFallback) { + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + google::protobuf::Arena arena; + auto message1 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message2 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message3 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "bar" + })pb", + pool, factory); + EXPECT_THAT(MessageEquals(*message1, *message2, pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageEquals(*message2, *message1, pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageEquals(*message1, *message3, pool, factory), + IsOkAndHolds(IsFalse())); + EXPECT_THAT(MessageEquals(*message3, *message1, pool, factory), + IsOkAndHolds(IsFalse())); +} + +TEST(MessageFieldEquals, AnyFallback) { + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + google::protobuf::Arena arena; + auto message1 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message2 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message3 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "bar" + })pb", + pool, factory); + EXPECT_THAT(MessageFieldEquals( + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + *message2, + ABSL_DIE_IF_NULL( + message2->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageFieldEquals( + *message2, + ABSL_DIE_IF_NULL( + message2->GetDescriptor()->FindFieldByName("single_any")), + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageFieldEquals( + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + *message3, + ABSL_DIE_IF_NULL( + message3->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsFalse())); + EXPECT_THAT(MessageFieldEquals( + *message3, + ABSL_DIE_IF_NULL( + message3->GetDescriptor()->FindFieldByName("single_any")), + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsFalse())); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/message_type_name.h b/internal/message_type_name.h new file mode 100644 index 000000000..c496f3b22 --- /dev/null +++ b/internal/message_type_name.h @@ -0,0 +1,56 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_TYPE_NAME_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_TYPE_NAME_H_ + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace cel::internal { + +// MessageTypeNameFor returns the fully qualified message type name of a +// generated message. This is a portable version which works with the lite +// runtime as well. + +template +std::enable_if_t< + std::conjunction_v, + std::negation>>, + absl::string_view> +MessageTypeNameFor() { + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(!std::is_reference_v, "T must not be a reference"); + static const absl::NoDestructor kTypeName(T().GetTypeName()); + return *kTypeName; +} + +template +std::enable_if_t, absl::string_view> +MessageTypeNameFor() { + static_assert(!std::is_const_v, "T must not be const qualified"); + static_assert(!std::is_volatile_v, "T must not be volatile qualified"); + static_assert(!std::is_reference_v, "T must not be a reference"); + return T::descriptor()->full_name(); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_TYPE_NAME_H_ diff --git a/internal/message_type_name_test.cc b/internal/message_type_name_test.cc new file mode 100644 index 000000000..2abc7eed9 --- /dev/null +++ b/internal/message_type_name_test.cc @@ -0,0 +1,28 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/message_type_name.h" + +#include "google/protobuf/any.pb.h" +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +TEST(MessageTypeNameFor, Generated) { + EXPECT_EQ(MessageTypeNameFor(), "google.protobuf.Any"); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/minimal_descriptor_database.h b/internal/minimal_descriptor_database.h new file mode 100644 index 000000000..03e94b168 --- /dev/null +++ b/internal/minimal_descriptor_database.h @@ -0,0 +1,32 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_DATABASE_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_DATABASE_H_ + +#include "absl/base/nullability.h" +#include "google/protobuf/descriptor_database.h" + +namespace cel::internal { + +// GetMinimalDescriptorDatabase returns a pointer to a +// `google::protobuf::DescriptorDatabase` which includes has the minimally necessary +// descriptors required by the Common Expression Language. The returning +// `proto2::DescripDescriptorDatabasetorPool` is valid for the lifetime of the +// process. +google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase(); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_DATABASE_H_ diff --git a/internal/minimal_descriptor_pool.h b/internal/minimal_descriptor_pool.h new file mode 100644 index 000000000..c7cb6946d --- /dev/null +++ b/internal/minimal_descriptor_pool.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_POOL_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_POOL_H_ + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "google/protobuf/descriptor.h" + +namespace cel::internal { + +// GetMinimalDescriptorPool returns a pointer to a `google::protobuf::DescriptorPool` +// which includes has the minimally necessary descriptors required by the Common +// Expression Language. The returning `google::protobuf::DescriptorPool` is valid for the +// lifetime of the process. +// +// This descriptor pool can be used as an underlay for another descriptor pool: +// +// google::protobuf::DescriptorPool my_descriptor_pool(GetMinimalDescriptorPool()); +const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool(); + +// If required, adds the minimally required descriptors to the pool. +absl::Status AddMinimumRequiredDescriptorsToPool( + google::protobuf::DescriptorPool* absl_nonnull pool); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_POOL_H_ diff --git a/internal/minimal_descriptors.cc b/internal/minimal_descriptors.cc new file mode 100644 index 000000000..f0b96e838 --- /dev/null +++ b/internal/minimal_descriptors.cc @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "internal/minimal_descriptor_database.h" +#include "internal/minimal_descriptor_pool.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" + +namespace cel::internal { + +namespace { + +ABSL_CONST_INIT const uint8_t kMinimalDescriptorSet[] = { +#include "internal/minimal_descriptor_set_embed.inc" +}; + +const google::protobuf::FileDescriptorSet* GetMinimumFileDescriptorSet() { + static google::protobuf::FileDescriptorSet* const file_desc_set = []() { + google::protobuf::FileDescriptorSet* file_desc_set = new google::protobuf::FileDescriptorSet(); + ABSL_CHECK(file_desc_set->ParseFromArray( // Crash OK + kMinimalDescriptorSet, ABSL_ARRAYSIZE(kMinimalDescriptorSet))); + return file_desc_set; + }(); + return file_desc_set; +} + +} // namespace + +const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool() { + static const google::protobuf::DescriptorPool* absl_nonnull const pool = []() { + const google::protobuf::FileDescriptorSet* file_desc_set = + GetMinimumFileDescriptorSet(); + auto* pool = new google::protobuf::DescriptorPool(); + for (const auto& file_desc : file_desc_set->file()) { + ABSL_CHECK(pool->BuildFile(file_desc) != nullptr); // Crash OK + } + return pool; + }(); + return pool; +} + +google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase() { + static absl::NoDestructor database( + *GetMinimalDescriptorPool()); + return &*database; +} + +namespace { + +class DescriptorErrorCollector final + : public google::protobuf::DescriptorPool::ErrorCollector { + public: + void RecordError(absl::string_view, absl::string_view element_name, + const google::protobuf::Message*, ErrorLocation, + absl::string_view message) override { + errors_.push_back(absl::StrCat(element_name, ": ", message)); + } + + bool FoundErrors() const { return !errors_.empty(); } + + std::string FormatErrors() const { return absl::StrJoin(errors_, "\n\t"); } + + private: + std::vector errors_; +}; + +} // namespace + +absl::Status AddMinimumRequiredDescriptorsToPool( + google::protobuf::DescriptorPool* absl_nonnull pool) { + const google::protobuf::FileDescriptorSet* file_desc_set = + GetMinimumFileDescriptorSet(); + for (const auto& file_desc : file_desc_set->file()) { + if (pool->FindFileByName(file_desc.name()) != nullptr) { + continue; + } + DescriptorErrorCollector error_collector; + if (pool->BuildFileCollectingErrors(file_desc, &error_collector) == + nullptr) { + ABSL_DCHECK(error_collector.FoundErrors()); + return absl::UnknownError( + absl::StrCat("Failed to build file descriptor for ", file_desc.name(), + ":\n\t", error_collector.FormatErrors())); + } + } + return absl::OkStatus(); +} + +} // namespace cel::internal diff --git a/internal/names.cc b/internal/names.cc new file mode 100644 index 000000000..c1e32fad7 --- /dev/null +++ b/internal/names.cc @@ -0,0 +1,35 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/names.h" + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "internal/lexis.h" + +namespace cel::internal { + +bool IsValidRelativeName(absl::string_view name) { + if (name.empty()) { + return false; + } + for (const auto& id : absl::StrSplit(name, '.')) { + if (!LexisIsIdentifier(id)) { + return false; + } + } + return true; +} + +} // namespace cel::internal diff --git a/internal/names.h b/internal/names.h new file mode 100644 index 000000000..e9e7879d7 --- /dev/null +++ b/internal/names.h @@ -0,0 +1,26 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NAMES_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_NAMES_H_ + +#include "absl/strings/string_view.h" + +namespace cel::internal { + +bool IsValidRelativeName(absl::string_view name); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_NAMES_H_ diff --git a/internal/names_test.cc b/internal/names_test.cc new file mode 100644 index 000000000..45315cf26 --- /dev/null +++ b/internal/names_test.cc @@ -0,0 +1,50 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/names.h" + +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +struct NamesTestCase final { + absl::string_view text; + bool ok; +}; + +using IsValidRelativeNameTest = testing::TestWithParam; + +TEST_P(IsValidRelativeNameTest, Compliance) { + const NamesTestCase& test_case = GetParam(); + if (test_case.ok) { + EXPECT_TRUE(IsValidRelativeName(test_case.text)); + } else { + EXPECT_FALSE(IsValidRelativeName(test_case.text)); + } +} + +INSTANTIATE_TEST_SUITE_P(IsValidRelativeNameTest, IsValidRelativeNameTest, + testing::ValuesIn({{"foo", true}, + {"foo.Bar", true}, + {"", false}, + {".", false}, + {".foo", false}, + {".foo.Bar", false}, + {"foo..Bar", false}, + {"foo.Bar.", + false}})); + +} // namespace +} // namespace cel::internal diff --git a/internal/new.cc b/internal/new.cc new file mode 100644 index 000000000..31ec82a08 --- /dev/null +++ b/internal/new.cc @@ -0,0 +1,142 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/new.h" + +#include +#include +#include +#include + +#ifdef _MSC_VER +#include +#endif + +#include "absl/base/config.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "internal/align.h" + +#if defined(__cpp_aligned_new) && __cpp_aligned_new >= 201606L +#define CEL_INTERNAL_HAVE_ALIGNED_NEW 1 +#endif + +#if defined(__cpp_sized_deallocation) && __cpp_sized_deallocation >= 201309L +#define CEL_INTERNAL_HAVE_SIZED_DELETE 1 +#endif + +namespace cel::internal { + +namespace { + +[[noreturn, maybe_unused]] void ThrowStdBadAlloc() { +#ifdef ABSL_HAVE_EXCEPTIONS + throw std::bad_alloc(); +#else + std::abort(); +#endif +} + +} // namespace + +void* New(size_t size) { return ::operator new(size); } + +void* AlignedNew(size_t size, std::align_val_t alignment) { + ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); +#ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW + return ::operator new(size, alignment); +#else + if (static_cast(alignment) <= kDefaultNewAlignment) { + return New(size); + } +#if defined(_MSC_VER) + void* ptr = _aligned_malloc(size, static_cast(alignment)); + if (ABSL_PREDICT_FALSE(size != 0 && ptr == nullptr)) { + ThrowStdBadAlloc(); + } + return ptr; +#elif defined(__APPLE__) + void* ptr; + if (ABSL_PREDICT_FALSE( + posix_memalign(&ptr, static_cast(alignment), size) != 0)) { + ThrowStdBadAlloc(); + } + return ptr; +#else + void* ptr = std::aligned_alloc(static_cast(alignment), size); + if (ABSL_PREDICT_FALSE(size != 0 && ptr == nullptr)) { + ThrowStdBadAlloc(); + } + return ptr; +#endif +#endif +} + +std::pair SizeReturningNew(size_t size) { + return std::pair{::operator new(size), size}; +} + +std::pair SizeReturningAlignedNew(size_t size, + std::align_val_t alignment) { + ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); +#ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW + return std::pair{::operator new(size, alignment), size}; +#else + return std::pair{AlignedNew(size, alignment), size}; +#endif +} + +void Delete(void* ptr) noexcept { ::operator delete(ptr); } + +void SizedDelete(void* ptr, size_t size) noexcept { +#ifdef CEL_INTERNAL_HAVE_SIZED_DELETE + ::operator delete(ptr, size); +#else + ::operator delete(ptr); +#endif +} + +void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept { + ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); +#ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW + ::operator delete(ptr, alignment); +#else + if (static_cast(alignment) <= kDefaultNewAlignment) { + ::operator delete(ptr); + } else { +#if defined(_MSC_VER) + _aligned_free(ptr); +#else + std::free(ptr); +#endif + } +#endif +} + +void SizedAlignedDelete(void* ptr, size_t size, + std::align_val_t alignment) noexcept { + ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); +#ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW +#ifdef CEL_INTERNAL_HAVE_SIZED_DELETE + ::operator delete(ptr, size, alignment); +#else + ::operator delete(ptr, alignment); +#endif +#else + AlignedDelete(ptr, alignment); +#endif +} + +} // namespace cel::internal diff --git a/internal/new.h b/internal/new.h new file mode 100644 index 000000000..a4a2ea676 --- /dev/null +++ b/internal/new.h @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NEW_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_NEW_H_ + +#include +#include +#include + +namespace cel::internal { + +inline constexpr size_t kDefaultNewAlignment = +#ifdef __STDCPP_DEFAULT_NEW_ALIGNMENT__ + __STDCPP_DEFAULT_NEW_ALIGNMENT__ +#else + alignof(std::max_align_t) +#endif + ; // NOLINT(whitespace/semicolon) + +// Allocates memory which has a size of at least `size` and a minimum alignment +// of `kDefaultNewAlignment`. +void* New(size_t size); + +// Allocates memory which has a size of at least `size` and a minimum alignment +// of `alignment`. To deallocate, the caller must use `AlignedDelete` or +// `SizedAlignedDelete`. +void* AlignedNew(size_t size, std::align_val_t alignment); + +std::pair SizeReturningNew(size_t size); + +// Allocates memory which has a size of at least `size` and a minimum alignment +// of `alignment`, returns a pointer to the allocated memory and the actual +// usable allocation size. To deallocate, the caller must use `AlignedDelete` or +// `SizedAlignedDelete`. +std::pair SizeReturningAlignedNew(size_t size, + std::align_val_t alignment); + +void Delete(void* ptr) noexcept; + +void SizedDelete(void* ptr, size_t size) noexcept; + +void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept; + +void SizedAlignedDelete(void* ptr, size_t size, + std::align_val_t alignment) noexcept; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_NEW_H_ diff --git a/internal/new_test.cc b/internal/new_test.cc new file mode 100644 index 000000000..7a4d1dca0 --- /dev/null +++ b/internal/new_test.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/new.h" + +#include +#include +#include +#include + +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +using ::testing::Ge; +using ::testing::NotNull; + +TEST(New, Basic) { + void* p = New(sizeof(uint64_t)); + EXPECT_THAT(p, NotNull()); + Delete(p); +} + +TEST(AlignedNew, Basic) { + void* p = + AlignedNew(alignof(std::max_align_t) * 2, + static_cast(alignof(std::max_align_t) * 2)); + EXPECT_THAT(p, NotNull()); + AlignedDelete(p, + static_cast(alignof(std::max_align_t) * 2)); +} + +TEST(SizeReturningNew, Basic) { + void* p; + size_t n; + std::tie(p, n) = SizeReturningNew(sizeof(uint64_t)); + EXPECT_THAT(p, NotNull()); + EXPECT_THAT(n, Ge(sizeof(uint64_t))); + SizedDelete(p, n); +} + +TEST(SizeReturningAlignedNew, Basic) { + void* p; + size_t n; + std::tie(p, n) = SizeReturningAlignedNew( + alignof(std::max_align_t) * 2, + static_cast(alignof(std::max_align_t) * 2)); + EXPECT_THAT(p, NotNull()); + EXPECT_THAT(n, Ge(alignof(std::max_align_t) * 2)); + SizedAlignedDelete( + p, n, static_cast(alignof(std::max_align_t) * 2)); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/noop_delete.h b/internal/noop_delete.h new file mode 100644 index 000000000..7b362d98d --- /dev/null +++ b/internal/noop_delete.h @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ + +#include + +#include "absl/base/nullability.h" + +namespace cel::internal { + +// Like `std::default_delete`, except it does nothing. +template +struct NoopDelete { + static_assert(!std::is_function::value, + "NoopDelete cannot be instantiated for function types"); + + constexpr NoopDelete() noexcept = default; + constexpr NoopDelete(const NoopDelete&) noexcept = default; + + template < + typename U, + typename = std::enable_if_t>, std::is_convertible>>> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr NoopDelete(const NoopDelete&) noexcept {} + + constexpr void operator()(T* absl_nullable) const noexcept { + static_assert(sizeof(T) >= 0, "cannot delete an incomplete type"); + static_assert(!std::is_void::value, "cannot delete an incomplete type"); + } +}; + +template +inline constexpr NoopDelete NoopDeleteFor() noexcept { + return NoopDelete{}; +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ diff --git a/internal/number.h b/internal/number.h new file mode 100644 index 000000000..c1c1d14e8 --- /dev/null +++ b/internal/number.h @@ -0,0 +1,299 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NUMBER_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_NUMBER_H_ + +#include +#include + +#include "absl/types/variant.h" + +namespace cel::internal { + +constexpr int64_t kInt64Max = std::numeric_limits::max(); +constexpr int64_t kInt64Min = std::numeric_limits::lowest(); +constexpr uint64_t kUint64Max = std::numeric_limits::max(); +constexpr uint64_t kUintToIntMax = static_cast(kInt64Max); +constexpr double kDoubleToIntMax = static_cast(kInt64Max); +constexpr double kDoubleToIntMin = static_cast(kInt64Min); +constexpr double kDoubleToUintMax = static_cast(kUint64Max); + +// The highest integer values that are round-trippable after rounding and +// casting to double. +template +constexpr int RoundingError() { + return 1 << (std::numeric_limits::digits - + std::numeric_limits::digits - 1); +} + +constexpr double kMaxDoubleRepresentableAsInt = + static_cast(kInt64Max - RoundingError()); +constexpr double kMaxDoubleRepresentableAsUint = + static_cast(kUint64Max - RoundingError()); + +#define CEL_ABSL_VISIT_CONSTEXPR + +using NumberVariant = absl::variant; + +enum class ComparisonResult { + kLesser, + kEqual, + kGreater, + // Special case for nan. + kNanInequal +}; + +// Return the inverse relation (i.e. Invert(cmp(b, a)) is the same as cmp(a, b). +constexpr ComparisonResult Invert(ComparisonResult result) { + switch (result) { + case ComparisonResult::kLesser: + return ComparisonResult::kGreater; + case ComparisonResult::kGreater: + return ComparisonResult::kLesser; + case ComparisonResult::kEqual: + return ComparisonResult::kEqual; + case ComparisonResult::kNanInequal: + return ComparisonResult::kNanInequal; + } +} + +template +struct ConversionVisitor { + template + constexpr OutType operator()(InType v) { + return static_cast(v); + } +}; + +template +constexpr ComparisonResult Compare(T a, T b) { + return (a > b) ? ComparisonResult::kGreater + : (a == b) ? ComparisonResult::kEqual + : ComparisonResult::kLesser; +} + +constexpr ComparisonResult DoubleCompare(double a, double b) { + // constexpr friendly isnan check. + if (!(a == a) || !(b == b)) { + return ComparisonResult::kNanInequal; + } + return Compare(a, b); +} + +// Implement generic numeric comparison against double value. +struct DoubleCompareVisitor { + constexpr explicit DoubleCompareVisitor(double v) : v(v) {} + + constexpr ComparisonResult operator()(double other) const { + return DoubleCompare(v, other); + } + + constexpr ComparisonResult operator()(uint64_t other) const { + if (v > kDoubleToUintMax) { + return ComparisonResult::kGreater; + } else if (v < 0) { + return ComparisonResult::kLesser; + } else { + return DoubleCompare(v, static_cast(other)); + } + } + + constexpr ComparisonResult operator()(int64_t other) const { + if (v > kDoubleToIntMax) { + return ComparisonResult::kGreater; + } else if (v < kDoubleToIntMin) { + return ComparisonResult::kLesser; + } else { + return DoubleCompare(v, static_cast(other)); + } + } + double v; +}; + +// Implement generic numeric comparison against uint value. +// Delegates to double comparison if either variable is double. +struct UintCompareVisitor { + constexpr explicit UintCompareVisitor(uint64_t v) : v(v) {} + + constexpr ComparisonResult operator()(double other) const { + return Invert(DoubleCompareVisitor(other)(v)); + } + + constexpr ComparisonResult operator()(uint64_t other) const { + return Compare(v, other); + } + + constexpr ComparisonResult operator()(int64_t other) const { + if (v > kUintToIntMax || other < 0) { + return ComparisonResult::kGreater; + } else { + return Compare(v, static_cast(other)); + } + } + uint64_t v; +}; + +// Implement generic numeric comparison against int value. +// Delegates to uint / double if either value is uint / double. +struct IntCompareVisitor { + constexpr explicit IntCompareVisitor(int64_t v) : v(v) {} + + constexpr ComparisonResult operator()(double other) { + return Invert(DoubleCompareVisitor(other)(v)); + } + + constexpr ComparisonResult operator()(uint64_t other) { + return Invert(UintCompareVisitor(other)(v)); + } + + constexpr ComparisonResult operator()(int64_t other) { + return Compare(v, other); + } + int64_t v; +}; + +struct CompareVisitor { + explicit constexpr CompareVisitor(NumberVariant rhs) : rhs(rhs) {} + + CEL_ABSL_VISIT_CONSTEXPR ComparisonResult operator()(double v) { + return absl::visit(DoubleCompareVisitor(v), rhs); + } + + CEL_ABSL_VISIT_CONSTEXPR ComparisonResult operator()(uint64_t v) { + return absl::visit(UintCompareVisitor(v), rhs); + } + + CEL_ABSL_VISIT_CONSTEXPR ComparisonResult operator()(int64_t v) { + return absl::visit(IntCompareVisitor(v), rhs); + } + NumberVariant rhs; +}; + +struct LosslessConvertibleToIntVisitor { + constexpr bool operator()(double value) const { + return value >= kDoubleToIntMin && value <= kMaxDoubleRepresentableAsInt && + value == static_cast(static_cast(value)); + } + constexpr bool operator()(uint64_t value) const { + return value <= kUintToIntMax; + } + constexpr bool operator()(int64_t value) const { return true; } +}; + +struct LosslessConvertibleToUintVisitor { + constexpr bool operator()(double value) const { + return value >= 0 && value <= kMaxDoubleRepresentableAsUint && + value == static_cast(static_cast(value)); + } + constexpr bool operator()(uint64_t value) const { return true; } + constexpr bool operator()(int64_t value) const { return value >= 0; } +}; + +// Utility class for CEL number operations. +// +// In CEL expressions, comparisons between different numeric types are treated +// as all happening on the same continuous number line. This generally means +// that integers and doubles in convertible range are compared after converting +// to doubles (tolerating some loss of precision). +// +// This extends to key lookups -- {1: 'abc'}[1.0f] is expected to work since +// 1.0 == 1 in CEL. +class Number { + public: + // Factories to resolve ambiguous overload resolution against literals. + static constexpr Number FromInt64(int64_t value) { return Number(value); } + static constexpr Number FromUint64(uint64_t value) { return Number(value); } + static constexpr Number FromDouble(double value) { return Number(value); } + + constexpr explicit Number(double double_value) : value_(double_value) {} + constexpr explicit Number(int64_t int_value) : value_(int_value) {} + constexpr explicit Number(uint64_t uint_value) : value_(uint_value) {} + + // Return a double representation of the value. + CEL_ABSL_VISIT_CONSTEXPR double AsDouble() const { + return absl::visit(internal::ConversionVisitor(), value_); + } + + // Return signed int64 representation for the value. + // Caller must guarantee the underlying value is representatble as an + // int. + CEL_ABSL_VISIT_CONSTEXPR int64_t AsInt() const { + return absl::visit(internal::ConversionVisitor(), value_); + } + + // Return unsigned int64 representation for the value. + // Caller must guarantee the underlying value is representable as an + // uint. + CEL_ABSL_VISIT_CONSTEXPR uint64_t AsUint() const { + return absl::visit(internal::ConversionVisitor(), value_); + } + + // For key lookups, check if the conversion to signed int is lossless. + CEL_ABSL_VISIT_CONSTEXPR bool LosslessConvertibleToInt() const { + return absl::visit(internal::LosslessConvertibleToIntVisitor(), value_); + } + + // For key lookups, check if the conversion to unsigned int is lossless. + CEL_ABSL_VISIT_CONSTEXPR bool LosslessConvertibleToUint() const { + return absl::visit(internal::LosslessConvertibleToUintVisitor(), value_); + } + + CEL_ABSL_VISIT_CONSTEXPR bool operator<(Number other) const { + return Compare(other) == internal::ComparisonResult::kLesser; + } + + CEL_ABSL_VISIT_CONSTEXPR bool operator<=(Number other) const { + internal::ComparisonResult cmp = Compare(other); + return cmp != internal::ComparisonResult::kGreater && + cmp != internal::ComparisonResult::kNanInequal; + } + + CEL_ABSL_VISIT_CONSTEXPR bool operator>(Number other) const { + return Compare(other) == internal::ComparisonResult::kGreater; + } + + CEL_ABSL_VISIT_CONSTEXPR bool operator>=(Number other) const { + internal::ComparisonResult cmp = Compare(other); + return cmp != internal::ComparisonResult::kLesser && + cmp != internal::ComparisonResult::kNanInequal; + } + + CEL_ABSL_VISIT_CONSTEXPR bool operator==(Number other) const { + return Compare(other) == internal::ComparisonResult::kEqual; + } + + CEL_ABSL_VISIT_CONSTEXPR bool operator!=(Number other) const { + return Compare(other) != internal::ComparisonResult::kEqual; + } + + // Visit the underlying number representation, a variant of double, uint64_t, + // or int64_t. + template + T visit(Op&& op) const { + return absl::visit(std::forward(op), value_); + } + + private: + internal::NumberVariant value_; + + CEL_ABSL_VISIT_CONSTEXPR internal::ComparisonResult Compare( + Number other) const { + return absl::visit(internal::CompareVisitor(other.value_), value_); + } +}; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_NUMBER_H_ diff --git a/internal/number_test.cc b/internal/number_test.cc new file mode 100644 index 000000000..3cdcf2b2d --- /dev/null +++ b/internal/number_test.cc @@ -0,0 +1,64 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/number.h" + +#include +#include + +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +TEST(Number, Basic) { + EXPECT_GT(Number(1.1), Number::FromInt64(1)); + EXPECT_LT(Number::FromUint64(1), Number(1.1)); + EXPECT_EQ(Number(1.1), Number(1.1)); + + EXPECT_EQ(Number::FromUint64(1), Number::FromUint64(1)); + EXPECT_EQ(Number::FromInt64(1), Number::FromUint64(1)); + EXPECT_GT(Number::FromUint64(1), Number::FromInt64(-1)); + + EXPECT_EQ(Number::FromInt64(-1), Number::FromInt64(-1)); +} + +TEST(Number, Conversions) { + EXPECT_TRUE(Number::FromDouble(1.0).LosslessConvertibleToInt()); + EXPECT_TRUE(Number::FromDouble(1.0).LosslessConvertibleToUint()); + EXPECT_FALSE(Number::FromDouble(1.1).LosslessConvertibleToInt()); + EXPECT_FALSE(Number::FromDouble(1.1).LosslessConvertibleToUint()); + EXPECT_TRUE(Number::FromDouble(-1.0).LosslessConvertibleToInt()); + EXPECT_FALSE(Number::FromDouble(-1.0).LosslessConvertibleToUint()); + EXPECT_TRUE(Number::FromDouble(kDoubleToIntMin).LosslessConvertibleToInt()); + + // Need to add/substract a large number since double resolution is low at this + // range. + EXPECT_FALSE(Number::FromDouble(kMaxDoubleRepresentableAsUint + + RoundingError()) + .LosslessConvertibleToUint()); + EXPECT_FALSE(Number::FromDouble(kMaxDoubleRepresentableAsInt + + RoundingError()) + .LosslessConvertibleToInt()); + EXPECT_FALSE( + Number::FromDouble(kDoubleToIntMin - 1025).LosslessConvertibleToInt()); + + EXPECT_EQ(Number::FromInt64(1).AsUint(), 1u); + EXPECT_EQ(Number::FromUint64(1).AsInt(), 1); + EXPECT_EQ(Number::FromDouble(1.0).AsUint(), 1); + EXPECT_EQ(Number::FromDouble(1.0).AsInt(), 1); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/overflow.cc b/internal/overflow.cc new file mode 100644 index 000000000..8cc209384 --- /dev/null +++ b/internal/overflow.cc @@ -0,0 +1,339 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/overflow.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "internal/status_macros.h" +#include "internal/time.h" + +namespace cel::internal { +namespace { + +constexpr int64_t kInt32Max = std::numeric_limits::max(); +constexpr int64_t kInt32Min = std::numeric_limits::lowest(); +constexpr int64_t kInt64Max = std::numeric_limits::max(); +constexpr int64_t kInt64Min = std::numeric_limits::lowest(); +constexpr uint64_t kUint32Max = std::numeric_limits::max(); +ABSL_ATTRIBUTE_UNUSED constexpr uint64_t kUint64Max = + std::numeric_limits::max(); +constexpr uint64_t kUintToIntMax = static_cast(kInt64Max); +constexpr double kDoubleToIntMax = static_cast(kInt64Max); +constexpr double kDoubleToIntMin = static_cast(kInt64Min); +const double kDoubleTwoTo64 = std::ldexp(1.0, 64); // 1.0 * 2^64 + +const absl::Duration kOneSecondDuration = absl::Seconds(1); +const int64_t kOneSecondNanos = absl::ToInt64Nanoseconds(kOneSecondDuration); +// Number of seconds between `0001-01-01T00:00:00Z` and Unix epoch. +const int64_t kMinUnixTime = + absl::ToInt64Seconds(MinTimestamp() - absl::UnixEpoch()); + +// Number of seconds between `9999-12-31T23:59:59.999999999Z` and Unix epoch. +const int64_t kMaxUnixTime = + absl::ToInt64Seconds(MaxTimestamp() - absl::UnixEpoch()); + +absl::Status CheckRange(bool valid_expression, + absl::string_view error_message) { + return valid_expression ? absl::OkStatus() + : absl::OutOfRangeError(error_message); +} + +absl::Status CheckArgument(bool valid_expression, + absl::string_view error_message) { + return valid_expression ? absl::OkStatus() + : absl::InvalidArgumentError(error_message); +} + +// Determine whether the duration is finite. +bool IsFinite(absl::Duration d) { + return d != absl::InfiniteDuration() && d != -absl::InfiniteDuration(); +} + +// Determine whether the time is finite. +bool IsFinite(absl::Time t) { + return t != absl::InfiniteFuture() && t != absl::InfinitePast(); +} + +} // namespace + +absl::StatusOr CheckedAdd(int64_t x, int64_t y) { +#if ABSL_HAVE_BUILTIN(__builtin_add_overflow) + int64_t sum; + if (!__builtin_add_overflow(x, y, &sum)) { + return sum; + } + return absl::OutOfRangeError("integer overflow"); +#else + CEL_RETURN_IF_ERROR(CheckRange( + y > 0 ? x <= kInt64Max - y : x >= kInt64Min - y, "integer overflow")); + return x + y; +#endif +} + +absl::StatusOr CheckedSub(int64_t x, int64_t y) { +#if ABSL_HAVE_BUILTIN(__builtin_sub_overflow) + int64_t diff; + if (!__builtin_sub_overflow(x, y, &diff)) { + return diff; + } + return absl::OutOfRangeError("integer overflow"); +#else + CEL_RETURN_IF_ERROR(CheckRange( + y < 0 ? x <= kInt64Max + y : x >= kInt64Min + y, "integer overflow")); + return x - y; +#endif +} + +absl::StatusOr CheckedNegation(int64_t v) { +#if ABSL_HAVE_BUILTIN(__builtin_mul_overflow) + int64_t prod; + if (!__builtin_mul_overflow(v, -1, &prod)) { + return prod; + } + return absl::OutOfRangeError("integer overflow"); +#else + CEL_RETURN_IF_ERROR(CheckRange(v != kInt64Min, "integer overflow")); + return -v; +#endif +} + +absl::StatusOr CheckedMul(int64_t x, int64_t y) { +#if ABSL_HAVE_BUILTIN(__builtin_mul_overflow) + int64_t prod; + if (!__builtin_mul_overflow(x, y, &prod)) { + return prod; + } + return absl::OutOfRangeError("integer overflow"); +#else + CEL_RETURN_IF_ERROR( + CheckRange(!((x == -1 && y == kInt64Min) || (y == -1 && x == kInt64Min) || + (x > 0 && y > 0 && x > kInt64Max / y) || + (x < 0 && y < 0 && x < kInt64Max / y) || + // Avoid dividing kInt64Min by -1, use whichever value of x + // or y is positive as the divisor. + (x > 0 && y < 0 && y < kInt64Min / x) || + (x < 0 && y > 0 && x < kInt64Min / y)), + "integer overflow")); + return x * y; +#endif +} + +absl::StatusOr CheckedDiv(int64_t x, int64_t y) { + CEL_RETURN_IF_ERROR( + CheckRange(x != kInt64Min || y != -1, "integer overflow")); + CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "divide by zero")); + return x / y; +} + +absl::StatusOr CheckedMod(int64_t x, int64_t y) { + CEL_RETURN_IF_ERROR( + CheckRange(x != kInt64Min || y != -1, "integer overflow")); + CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "modulus by zero")); + return x % y; +} + +absl::StatusOr CheckedAdd(uint64_t x, uint64_t y) { +#if ABSL_HAVE_BUILTIN(__builtin_add_overflow) + uint64_t sum; + if (!__builtin_add_overflow(x, y, &sum)) { + return sum; + } + return absl::OutOfRangeError("unsigned integer overflow"); +#else + CEL_RETURN_IF_ERROR( + CheckRange(x <= kUint64Max - y, "unsigned integer overflow")); + return x + y; +#endif +} + +absl::StatusOr CheckedSub(uint64_t x, uint64_t y) { +#if ABSL_HAVE_BUILTIN(__builtin_sub_overflow) + uint64_t diff; + if (!__builtin_sub_overflow(x, y, &diff)) { + return diff; + } + return absl::OutOfRangeError("unsigned integer overflow"); +#else + CEL_RETURN_IF_ERROR(CheckRange(y <= x, "unsigned integer overflow")); + return x - y; +#endif +} + +absl::StatusOr CheckedMul(uint64_t x, uint64_t y) { +#if ABSL_HAVE_BUILTIN(__builtin_mul_overflow) + uint64_t prod; + if (!__builtin_mul_overflow(x, y, &prod)) { + return prod; + } + return absl::OutOfRangeError("unsigned integer overflow"); +#else + CEL_RETURN_IF_ERROR( + CheckRange(y == 0 || x <= kUint64Max / y, "unsigned integer overflow")); + return x * y; +#endif +} + +absl::StatusOr CheckedDiv(uint64_t x, uint64_t y) { + CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "divide by zero")); + return x / y; +} + +absl::StatusOr CheckedMod(uint64_t x, uint64_t y) { + CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "modulus by zero")); + return x % y; +} + +absl::StatusOr CheckedAdd(absl::Duration x, absl::Duration y) { + CEL_RETURN_IF_ERROR( + CheckRange(IsFinite(x) && IsFinite(y), "integer overflow")); + // absl::Duration can handle +- infinite durations, but the Go time.Duration + // implementation caps the durations to those expressible within a single + // int64 rather than (seconds int64, nanos int32). + // + // The absl implementation mirrors the protobuf implementation which supports + // durations on the order of +- 10,000 years, but Go only supports +- 290 year + // durations. + // + // Since Go is the more conservative of the implementations and 290 year + // durations seem quite reasonable, this code mirrors the conservative + // overflow behavior which would be observed in Go. + CEL_ASSIGN_OR_RETURN(int64_t nanos, CheckedAdd(absl::ToInt64Nanoseconds(x), + absl::ToInt64Nanoseconds(y))); + return absl::Nanoseconds(nanos); +} + +absl::StatusOr CheckedSub(absl::Duration x, absl::Duration y) { + CEL_RETURN_IF_ERROR( + CheckRange(IsFinite(x) && IsFinite(y), "integer overflow")); + CEL_ASSIGN_OR_RETURN(int64_t nanos, CheckedSub(absl::ToInt64Nanoseconds(x), + absl::ToInt64Nanoseconds(y))); + return absl::Nanoseconds(nanos); +} + +absl::StatusOr CheckedNegation(absl::Duration v) { + CEL_RETURN_IF_ERROR(CheckRange(IsFinite(v), "integer overflow")); + CEL_ASSIGN_OR_RETURN(int64_t nanos, + CheckedNegation(absl::ToInt64Nanoseconds(v))); + return absl::Nanoseconds(nanos); +} + +absl::StatusOr CheckedAdd(absl::Time t, absl::Duration d) { + CEL_RETURN_IF_ERROR( + CheckRange(IsFinite(t) && IsFinite(d), "timestamp overflow")); + // First we break time into its components by truncating and subtracting. + const int64_t s1 = absl::ToUnixSeconds(t); + const int64_t ns1 = (t - absl::FromUnixSeconds(s1)) / absl::Nanoseconds(1); + + // Second we break duration into its components by dividing and modulo. + // Truncate to seconds. + const int64_t s2 = d / kOneSecondDuration; + // Get remainder. + const int64_t ns2 = absl::ToInt64Nanoseconds(d % kOneSecondDuration); + + // Add seconds first, detecting any overflow. + CEL_ASSIGN_OR_RETURN(int64_t s, CheckedAdd(s1, s2)); + // Nanoseconds cannot overflow as nanos are normalized to [0, 999999999]. + absl::Duration ns = absl::Nanoseconds(ns2 + ns1); + + // Normalize nanoseconds to be positive and carry extra nanos to seconds. + if (ns < absl::ZeroDuration() || ns >= kOneSecondDuration) { + // Add seconds, or no-op if nanseconds negative (ns never < -999_999_999ns) + CEL_ASSIGN_OR_RETURN(s, CheckedAdd(s, ns / kOneSecondDuration)); + ns -= (ns / kOneSecondDuration) * kOneSecondDuration; + // Subtract a second to make the nanos positive. + if (ns < absl::ZeroDuration()) { + CEL_ASSIGN_OR_RETURN(s, CheckedAdd(s, -1)); + ns += kOneSecondDuration; + } + } + // Check if the the number of seconds from Unix epoch is within our acceptable + // range. + CEL_RETURN_IF_ERROR( + CheckRange(s >= kMinUnixTime && s <= kMaxUnixTime, "timestamp overflow")); + + // Return resulting time. + return absl::FromUnixSeconds(s) + ns; +} + +absl::StatusOr CheckedSub(absl::Time t, absl::Duration d) { + CEL_ASSIGN_OR_RETURN(auto neg_duration, CheckedNegation(d)); + return CheckedAdd(t, neg_duration); +} + +absl::StatusOr CheckedSub(absl::Time t1, absl::Time t2) { + CEL_RETURN_IF_ERROR( + CheckRange(IsFinite(t1) && IsFinite(t2), "integer overflow")); + // First we break time into its components by truncating and subtracting. + const int64_t s1 = absl::ToUnixSeconds(t1); + const int64_t ns1 = (t1 - absl::FromUnixSeconds(s1)) / absl::Nanoseconds(1); + const int64_t s2 = absl::ToUnixSeconds(t2); + const int64_t ns2 = (t2 - absl::FromUnixSeconds(s2)) / absl::Nanoseconds(1); + + // Subtract seconds first, detecting any overflow. + CEL_ASSIGN_OR_RETURN(int64_t s, CheckedSub(s1, s2)); + // Nanoseconds cannot overflow as nanos are normalized to [0, 999999999]. + absl::Duration ns = absl::Nanoseconds(ns1 - ns2); + + // Scale the seconds result to nanos. + CEL_ASSIGN_OR_RETURN(const int64_t t, CheckedMul(s, kOneSecondNanos)); + // Add the seconds (scaled to nanos) to the nanosecond value. + CEL_ASSIGN_OR_RETURN(const int64_t v, + CheckedAdd(t, absl::ToInt64Nanoseconds(ns))); + return absl::Nanoseconds(v); +} + +absl::StatusOr CheckedDoubleToInt64(double v) { + CEL_RETURN_IF_ERROR( + CheckRange(std::isfinite(v) && v < kDoubleToIntMax && v > kDoubleToIntMin, + "double out of int64 range")); + return static_cast(v); +} + +absl::StatusOr CheckedDoubleToUint64(double v) { + CEL_RETURN_IF_ERROR( + CheckRange(std::isfinite(v) && v >= 0 && v < kDoubleTwoTo64, + "double out of uint64 range")); + return static_cast(v); +} + +absl::StatusOr CheckedInt64ToUint64(int64_t v) { + CEL_RETURN_IF_ERROR(CheckRange(v >= 0, "int64 out of uint64 range")); + return static_cast(v); +} + +absl::StatusOr CheckedInt64ToInt32(int64_t v) { + CEL_RETURN_IF_ERROR( + CheckRange(v >= kInt32Min && v <= kInt32Max, "int64 out of int32 range")); + return static_cast(v); +} + +absl::StatusOr CheckedUint64ToInt64(uint64_t v) { + CEL_RETURN_IF_ERROR( + CheckRange(v <= kUintToIntMax, "uint64 out of int64 range")); + return static_cast(v); +} + +absl::StatusOr CheckedUint64ToUint32(uint64_t v) { + CEL_RETURN_IF_ERROR( + CheckRange(v <= kUint32Max, "uint64 out of uint32 range")); + return static_cast(v); +} + +} // namespace cel::internal diff --git a/internal/overflow.h b/internal/overflow.h new file mode 100644 index 000000000..15a60eaf1 --- /dev/null +++ b/internal/overflow.h @@ -0,0 +1,178 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_OVERFLOW_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_OVERFLOW_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/time/time.h" + +namespace cel::internal { + +// Add two int64_t values together. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// int64_t_max + 1 +absl::StatusOr CheckedAdd(int64_t x, int64_t y); + +// Subtract two int64_t values from each other. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError. e.g. +// int64_t_min - 1 +absl::StatusOr CheckedSub(int64_t x, int64_t y); + +// Negate an int64_t value. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// negate(int64_t_min) +absl::StatusOr CheckedNegation(int64_t v); + +// Multiply two int64_t values together. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError. e.g. +// 2 * int64_t_max +absl::StatusOr CheckedMul(int64_t x, int64_t y); + +// Divide one int64_t value into another. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// int64_t_min / -1 +absl::StatusOr CheckedDiv(int64_t x, int64_t y); + +// Compute the modulus of x into y. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// int64_t_min % -1 +absl::StatusOr CheckedMod(int64_t x, int64_t y); + +// Add two uint64_t values together. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// uint64_t_max + 1 +absl::StatusOr CheckedAdd(uint64_t x, uint64_t y); + +// Subtract two uint64_t values from each other. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// 1 - uint64_t_max +absl::StatusOr CheckedSub(uint64_t x, uint64_t y); + +// Multiply two uint64_t values together. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// 2 * uint64_t_max +absl::StatusOr CheckedMul(uint64_t x, uint64_t y); + +// Divide one uint64_t value into another. +absl::StatusOr CheckedDiv(uint64_t x, uint64_t y); + +// Compute the modulus of x into y. +// If 'y' is zero, the function will return an +// absl::StatusCode::kInvalidArgumentError, e.g. 1 / 0. +absl::StatusOr CheckedMod(uint64_t x, uint64_t y); + +// Add two durations together. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// duration(int64_t_max, "ns") + duration(int64_t_max, "ns") +// +// Note, absl::Duration is effectively an int64_t under the covers, which means +// the same cases that would result in overflow for int64_t values would hold +// true for absl::Duration values. +absl::StatusOr CheckedAdd(absl::Duration x, absl::Duration y); + +// Subtract two durations from each other. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// duration(int64_t_min, "ns") - duration(1, "ns") +// +// Note, absl::Duration is effectively an int64_t under the covers, which means +// the same cases that would result in overflow for int64_t values would hold +// true for absl::Duration values. +absl::StatusOr CheckedSub(absl::Duration x, absl::Duration y); + +// Negate a duration. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// negate(duration(int64_t_min, "ns")). +absl::StatusOr CheckedNegation(absl::Duration v); + +// Add an absl::Time and absl::Duration value together. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// timestamp(unix_epoch_max) + duration(1, "ns") +// +// Valid time values must be between `0001-01-01T00:00:00Z` (-62135596800s) and +// `9999-12-31T23:59:59.999999999Z` (253402300799s). +absl::StatusOr CheckedAdd(absl::Time t, absl::Duration d); + +// Subtract an absl::Time and absl::Duration value together. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// timestamp(unix_epoch_min) - duration(1, "ns") +// +// Valid time values must be between `0001-01-01T00:00:00Z` (-62135596800s) and +// `9999-12-31T23:59:59.999999999Z` (253402300799s). +absl::StatusOr CheckedSub(absl::Time t, absl::Duration d); + +// Subtract two absl::Time values from each other to produce an absl::Duration. +// If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. +// timestamp(unix_epoch_min) - timestamp(unix_epoch_max) +absl::StatusOr CheckedSub(absl::Time t1, absl::Time t2); + +// Convert a double value to an int64_t if possible. +// If the double exceeds the values representable in an int64_t the function +// will return an absl::StatusCode::kOutOfRangeError. +// +// Only finite double values may be converted to an int64_t. CEL may also reject +// some conversions if the value falls into a range where overflow would be +// ambiguous. +// +// The behavior of the static_cast(double) assembly instruction on +// x86 (cvttsd2si) can be manipulated by the header: +// https://en.cppreference.com/w/cpp/numeric/fenv/feround. This means that the +// set of values which will result in a valid or invalid conversion are +// environment dependent and the implementation must err on the side of caution +// and reject possibly valid values which might be invalid based on environment +// settings. +absl::StatusOr CheckedDoubleToInt64(double v); + +// Convert a double value to a uint64_t if possible. +// If the double exceeds the values representable in a uint64_t the function +// will return an absl::StatusCode::kOutOfRangeError. +// +// Only finite double values may be converted to a uint64_t. CEL may also reject +// some conversions if the value falls into a range where overflow would be +// ambiguous. +// +// The behavior of the static_cast(double) assembly instruction on +// x86 (cvttsd2si) can be manipulated by the header: +// https://en.cppreference.com/w/cpp/numeric/fenv/feround. This means that the +// set of values which will result in a valid or invalid conversion are +// environment dependent and the implementation must err on the side of caution +// and reject possibly valid values which might be invalid based on environment +// settings. +absl::StatusOr CheckedDoubleToUint64(double v); + +// Convert an int64_t value to a uint64_t value if possible. +// If the int64_t exceeds the values representable in a uint64_t the function +// will return an absl::StatusCode::kOutOfRangeError. +absl::StatusOr CheckedInt64ToUint64(int64_t v); + +// Convert an int64_t value to an int32_t value if possible. +// If the int64_t exceeds the values representable in an int32_t the function +// will return an absl::StatusCode::kOutOfRangeError. +absl::StatusOr CheckedInt64ToInt32(int64_t v); + +// Convert a uint64_t value to an int64_t value if possible. +// If the uint64_t exceeds the values representable in an int64_t the function +// will return an absl::StatusCode::kOutOfRangeError. +absl::StatusOr CheckedUint64ToInt64(uint64_t v); + +// Convert a uint64_t value to a uint32_t value if possible. +// If the uint64_t exceeds the values representable in a uint32_t the function +// will return an absl::StatusCode::kOutOfRangeError. +absl::StatusOr CheckedUint64ToUint32(uint64_t v); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_OVERFLOW_H_ diff --git a/internal/overflow_test.cc b/internal/overflow_test.cc new file mode 100644 index 000000000..213e7a79d --- /dev/null +++ b/internal/overflow_test.cc @@ -0,0 +1,682 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/overflow.h" + +#include +#include +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +template +struct TestCase { + std::string test_name; + absl::FunctionRef()> op; + absl::StatusOr result; +}; + +template +void ExpectResult(const T& test_case) { + auto result = test_case.op(); + ASSERT_EQ(result.status().code(), test_case.result.status().code()); + if (result.ok()) { + EXPECT_EQ(*result, *test_case.result); + } else { + EXPECT_THAT(result.status().message(), + HasSubstr(test_case.result.status().message())); + } +} + +using IntTestCase = TestCase; +using CheckedIntResultTest = testing::TestWithParam; +TEST_P(CheckedIntResultTest, IntOperations) { ExpectResult(GetParam()); } + +INSTANTIATE_TEST_SUITE_P( + CheckedIntMathTest, CheckedIntResultTest, + ValuesIn(std::vector{ + // Addition tests. + {"OneAddOne", [] { return CheckedAdd(int64_t{1L}, 1L); }, 2L}, + {"ZeroAddOne", [] { return CheckedAdd(int64_t{0}, 1L); }, 1L}, + {"ZeroAddMinusOne", [] { return CheckedAdd(int64_t{0}, -1L); }, -1L}, + {"OneAddZero", [] { return CheckedAdd(int64_t{1L}, 0); }, 1L}, + {"MinusOneAddZero", [] { return CheckedAdd(int64_t{-1L}, 0); }, -1L}, + {"OneAddIntMax", + [] { + return CheckedAdd(int64_t{1L}, std::numeric_limits::max()); + }, + absl::OutOfRangeError("integer overflow")}, + {"MinusOneAddIntMin", + [] { + return CheckedAdd(int64_t{-1L}, + std::numeric_limits::lowest()); + }, + absl::OutOfRangeError("integer overflow")}, + + // Subtraction tests. + {"TwoSubThree", [] { return CheckedSub(int64_t{2L}, 3L); }, -1L}, + {"TwoSubZero", [] { return CheckedSub(int64_t{2L}, 0); }, 2L}, + {"ZeroSubTwo", [] { return CheckedSub(int64_t{0}, 2L); }, -2L}, + {"MinusTwoSubThree", [] { return CheckedSub(int64_t{-2L}, 3L); }, -5L}, + {"MinusTwoSubZero", [] { return CheckedSub(int64_t{-2L}, 0); }, -2L}, + {"ZeroSubMinusTwo", [] { return CheckedSub(int64_t{0}, -2L); }, 2L}, + {"IntMinSubIntMax", + [] { + return CheckedSub(std::numeric_limits::max(), + std::numeric_limits::lowest()); + }, + absl::OutOfRangeError("integer overflow")}, + + // Multiplication tests. + {"TwoMulThree", [] { return CheckedMul(int64_t{2L}, 3L); }, 6L}, + {"MinusTwoMulThree", [] { return CheckedMul(int64_t{-2L}, 3L); }, -6L}, + {"MinusTwoMulMinusThree", [] { return CheckedMul(int64_t{-2L}, -3L); }, + 6L}, + {"TwoMulMinusThree", [] { return CheckedMul(int64_t{2L}, -3L); }, -6L}, + {"TwoMulIntMax", + [] { + return CheckedMul(int64_t{2L}, std::numeric_limits::max()); + }, + absl::OutOfRangeError("integer overflow")}, + {"MinusOneMulIntMin", + [] { + return CheckedMul(int64_t{-1L}, + std::numeric_limits::lowest()); + }, + absl::OutOfRangeError("integer overflow")}, + {"IntMinMulMinusOne", + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{-1L}); + }, + absl::OutOfRangeError("integer overflow")}, + {"IntMinMulZero", + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{0}); + }, + 0}, + {"ZeroMulIntMin", + [] { + return CheckedMul(int64_t{0}, + std::numeric_limits::lowest()); + }, + 0}, + {"IntMaxMulZero", + [] { + return CheckedMul(std::numeric_limits::max(), int64_t{0}); + }, + 0}, + {"ZeroMulIntMax", + [] { + return CheckedMul(int64_t{0}, std::numeric_limits::max()); + }, + 0}, + + // Division cases. + {"ZeroDivOne", [] { return CheckedDiv(int64_t{0}, 1L); }, 0}, + {"TenDivTwo", [] { return CheckedDiv(int64_t{10L}, 2L); }, 5}, + {"TenDivMinusOne", [] { return CheckedDiv(int64_t{10L}, -1L); }, -10}, + {"MinusTenDivMinusOne", [] { return CheckedDiv(int64_t{-10L}, -1L); }, + 10}, + {"MinusTenDivTwo", [] { return CheckedDiv(int64_t{-10L}, 2L); }, -5}, + {"OneDivZero", [] { return CheckedDiv(int64_t{1L}, 0L); }, + absl::InvalidArgumentError("divide by zero")}, + {"IntMinDivMinusOne", + [] { + return CheckedDiv(std::numeric_limits::lowest(), + int64_t{-1L}); + }, + absl::OutOfRangeError("integer overflow")}, + + // Modulus cases. + {"ZeroModTwo", [] { return CheckedMod(int64_t{0}, 2L); }, 0}, + {"TwoModTwo", [] { return CheckedMod(int64_t{2L}, 2L); }, 0}, + {"ThreeModTwo", [] { return CheckedMod(int64_t{3L}, 2L); }, 1L}, + {"TwoModZero", [] { return CheckedMod(int64_t{2L}, 0); }, + absl::InvalidArgumentError("modulus by zero")}, + {"IntMinModTwo", + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{2L}); + }, + 0}, + {"IntMaxModMinusOne", + [] { + return CheckedMod(std::numeric_limits::max(), int64_t{-1L}); + }, + 0}, + {"IntMinModMinusOne", + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{-1L}); + }, + absl::OutOfRangeError("integer overflow")}, + + // Negation cases. + {"NegateOne", [] { return CheckedNegation(int64_t{1L}); }, -1L}, + {"NegateMinInt64", + [] { return CheckedNegation(std::numeric_limits::lowest()); }, + absl::OutOfRangeError("integer overflow")}, + + // Numeric conversion cases for uint -> int, double -> int + {"Uint64Conversion", [] { return CheckedUint64ToInt64(uint64_t{1UL}); }, + 1L}, + {"Uint32MaxConversion", + [] { + return CheckedUint64ToInt64( + static_cast(std::numeric_limits::max())); + }, + std::numeric_limits::max()}, + {"Uint32MaxConversionError", + [] { + return CheckedUint64ToInt64( + static_cast(std::numeric_limits::max())); + }, + absl::OutOfRangeError("out of int64 range")}, + {"DoubleConversion", [] { return CheckedDoubleToInt64(double{100.1}); }, + 100L}, + {"DoubleInt64MaxConversionError", + [] { + return CheckedDoubleToInt64( + static_cast(std::numeric_limits::max())); + }, + absl::OutOfRangeError("out of int64 range")}, + {"DoubleInt64MaxMinus512Conversion", + [] { + return CheckedDoubleToInt64( + static_cast(std::numeric_limits::max() - 512)); + }, + std::numeric_limits::max() - 1023}, + {"DoubleInt64MaxMinus1024Conversion", + [] { + return CheckedDoubleToInt64( + static_cast(std::numeric_limits::max() - 1024)); + }, + std::numeric_limits::max() - 1023}, + {"DoubleInt64MinConversionError", + [] { + return CheckedDoubleToInt64( + static_cast(std::numeric_limits::lowest())); + }, + absl::OutOfRangeError("out of int64 range")}, + {"DoubleInt64MinMinusOneConversionError", + [] { + return CheckedDoubleToInt64( + static_cast(std::numeric_limits::lowest()) - + 1.0); + }, + absl::OutOfRangeError("out of int64 range")}, + {"DoubleInt64MinMinus511ConversionError", + [] { + return CheckedDoubleToInt64( + static_cast(std::numeric_limits::lowest()) - + 511.0); + }, + absl::OutOfRangeError("out of int64 range")}, + {"InfiniteConversionError", + [] { + return CheckedDoubleToInt64(std::numeric_limits::infinity()); + }, + absl::OutOfRangeError("out of int64 range")}, + {"NegRangeConversionError", + [] { return CheckedDoubleToInt64(double{-1.0e99}); }, + absl::OutOfRangeError("out of int64 range")}, + {"PosRangeConversionError", + [] { return CheckedDoubleToInt64(double{1.0e99}); }, + absl::OutOfRangeError("out of int64 range")}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +using UintTestCase = TestCase; +using CheckedUintResultTest = testing::TestWithParam; +TEST_P(CheckedUintResultTest, UnsignedOperations) { ExpectResult(GetParam()); } + +INSTANTIATE_TEST_SUITE_P( + CheckedUintMathTest, CheckedUintResultTest, + ValuesIn(std::vector{ + // Addition tests. + {"OneAddOne", [] { return CheckedAdd(uint64_t{1UL}, 1UL); }, 2UL}, + {"ZeroAddOne", [] { return CheckedAdd(uint64_t{0}, 1UL); }, 1UL}, + {"OneAddZero", [] { return CheckedAdd(uint64_t{1UL}, 0); }, 1UL}, + {"OneAddIntMax", + [] { + return CheckedAdd(uint64_t{1UL}, + std::numeric_limits::max()); + }, + absl::OutOfRangeError("unsigned integer overflow")}, + + // Subtraction tests. + {"OneSubOne", [] { return CheckedSub(uint64_t{1UL}, 1UL); }, 0}, + {"ZeroSubOne", [] { return CheckedSub(uint64_t{0}, 1UL); }, + absl::OutOfRangeError("unsigned integer overflow")}, + {"OneSubZero", [] { return CheckedSub(uint64_t{1UL}, 0); }, 1UL}, + + // Multiplication tests. + {"OneMulOne", [] { return CheckedMul(uint64_t{1UL}, 1UL); }, 1UL}, + {"ZeroMulOne", [] { return CheckedMul(uint64_t{0}, 1UL); }, 0}, + {"OneMulZero", [] { return CheckedMul(uint64_t{1UL}, 0); }, 0}, + {"TwoMulUintMax", + [] { + return CheckedMul(uint64_t{2UL}, + std::numeric_limits::max()); + }, + absl::OutOfRangeError("unsigned integer overflow")}, + + // Division tests. + {"TwoDivTwo", [] { return CheckedDiv(uint64_t{2UL}, 2UL); }, 1UL}, + {"TwoDivFour", [] { return CheckedDiv(uint64_t{2UL}, 4UL); }, 0}, + {"OneDivZero", [] { return CheckedDiv(uint64_t{1UL}, 0); }, + absl::InvalidArgumentError("divide by zero")}, + + // Modulus tests. + {"TwoModTwo", [] { return CheckedMod(uint64_t{2UL}, 2UL); }, 0}, + {"TwoModFour", [] { return CheckedMod(uint64_t{2UL}, 4UL); }, 2UL}, + {"OneModZero", [] { return CheckedMod(uint64_t{1UL}, 0); }, + absl::InvalidArgumentError("modulus by zero")}, + + // Conversion test cases for int -> uint, double -> uint. + {"Int64Conversion", [] { return CheckedInt64ToUint64(int64_t{1L}); }, + 1UL}, + {"Int64MaxConversion", + [] { + return CheckedInt64ToUint64(std::numeric_limits::max()); + }, + static_cast(std::numeric_limits::max())}, + {"NegativeInt64ConversionError", + [] { return CheckedInt64ToUint64(int64_t{-1L}); }, + absl::OutOfRangeError("out of uint64 range")}, + {"DoubleConversion", + [] { return CheckedDoubleToUint64(double{100.1}); }, 100UL}, + {"DoubleUint64MaxConversionError", + [] { + return CheckedDoubleToUint64( + static_cast(std::numeric_limits::max())); + }, + absl::OutOfRangeError("out of uint64 range")}, + {"DoubleUint64MaxMinus512Conversion", + [] { + return CheckedDoubleToUint64( + static_cast(std::numeric_limits::max() - 512)); + }, + absl::OutOfRangeError("out of uint64 range")}, + {"DoubleUint64MaxMinus1024Conversion", + [] { + return CheckedDoubleToUint64(static_cast( + std::numeric_limits::max() - 1024)); + }, + std::numeric_limits::max() - 2047}, + {"InfiniteConversionError", + [] { + return CheckedDoubleToUint64( + std::numeric_limits::infinity()); + }, + absl::OutOfRangeError("out of uint64 range")}, + {"NegConversionError", + [] { return CheckedDoubleToUint64(double{-1.1}); }, + absl::OutOfRangeError("out of uint64 range")}, + {"NegRangeConversionError", + [] { return CheckedDoubleToUint64(double{-1.0e99}); }, + absl::OutOfRangeError("out of uint64 range")}, + {"PosRangeConversionError", + [] { return CheckedDoubleToUint64(double{1.0e99}); }, + absl::OutOfRangeError("out of uint64 range")}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +using DurationTestCase = TestCase; +using CheckedDurationResultTest = testing::TestWithParam; +TEST_P(CheckedDurationResultTest, DurationOperations) { + ExpectResult(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + CheckedDurationMathTest, CheckedDurationResultTest, + ValuesIn(std::vector{ + // Addition tests. + {"OneSecondAddOneSecond", + [] { return CheckedAdd(absl::Seconds(1), absl::Seconds(1)); }, + absl::Seconds(2)}, + {"MaxDurationAddOneNano", + [] { + return CheckedAdd( + absl::Nanoseconds(std::numeric_limits::max()), + absl::Nanoseconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"MinDurationAddMinusOneNano", + [] { + return CheckedAdd( + absl::Nanoseconds(std::numeric_limits::lowest()), + absl::Nanoseconds(-1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"InfinityAddOneNano", + [] { + return CheckedAdd(absl::InfiniteDuration(), absl::Nanoseconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"NegInfinityAddOneNano", + [] { + return CheckedAdd(-absl::InfiniteDuration(), absl::Nanoseconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"OneSecondAddInfinity", + [] { + return CheckedAdd(absl::Nanoseconds(1), absl::InfiniteDuration()); + }, + absl::OutOfRangeError("integer overflow")}, + {"OneSecondAddNegInfinity", + [] { + return CheckedAdd(absl::Nanoseconds(1), -absl::InfiniteDuration()); + }, + absl::OutOfRangeError("integer overflow")}, + + // Subtraction tests for duration - duration. + {"OneSecondSubOneSecond", + [] { return CheckedSub(absl::Seconds(1), absl::Seconds(1)); }, + absl::ZeroDuration()}, + {"MinDurationSubOneSecond", + [] { + return CheckedSub( + absl::Nanoseconds(std::numeric_limits::lowest()), + absl::Nanoseconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"InfinitySubOneNano", + [] { + return CheckedSub(absl::InfiniteDuration(), absl::Nanoseconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"NegInfinitySubOneNano", + [] { + return CheckedSub(-absl::InfiniteDuration(), absl::Nanoseconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"OneNanoSubInfinity", + [] { + return CheckedSub(absl::Nanoseconds(1), absl::InfiniteDuration()); + }, + absl::OutOfRangeError("integer overflow")}, + {"OneNanoSubNegInfinity", + [] { + return CheckedSub(absl::Nanoseconds(1), -absl::InfiniteDuration()); + }, + absl::OutOfRangeError("integer overflow")}, + + // Subtraction tests for time - time. + {"TimeSubOneSecond", + [] { + return CheckedSub(absl::FromUnixSeconds(100), + absl::FromUnixSeconds(1)); + }, + absl::Seconds(99)}, + {"TimeWithNanosPositive", + [] { + return CheckedSub(absl::FromUnixSeconds(2) + absl::Nanoseconds(1), + absl::FromUnixSeconds(1) - absl::Nanoseconds(1)); + }, + absl::Seconds(1) + absl::Nanoseconds(2)}, + {"TimeWithNanosNegative", + [] { + return CheckedSub(absl::FromUnixSeconds(1) + absl::Nanoseconds(1), + absl::FromUnixSeconds(2) + absl::Seconds(1) - + absl::Nanoseconds(1)); + }, + absl::Seconds(-2) + absl::Nanoseconds(2)}, + {"MinTimestampMinusOne", + [] { + return CheckedSub( + absl::FromUnixSeconds(std::numeric_limits::lowest()), + absl::FromUnixSeconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"InfinitePastSubOneSecond", + [] { + return CheckedSub(absl::InfinitePast(), absl::FromUnixSeconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"InfiniteFutureSubOneMinusSecond", + [] { + return CheckedSub(absl::InfiniteFuture(), absl::FromUnixSeconds(-1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"InfiniteFutureSubInfinitePast", + [] { + return CheckedSub(absl::InfiniteFuture(), absl::InfinitePast()); + }, + absl::OutOfRangeError("integer overflow")}, + {"InfinitePastSubInfiniteFuture", + [] { + return CheckedSub(absl::InfinitePast(), absl::InfiniteFuture()); + }, + absl::OutOfRangeError("integer overflow")}, + + // Negation cases. + {"NegateOneSecond", [] { return CheckedNegation(absl::Seconds(1)); }, + absl::Seconds(-1)}, + {"NegateMinDuration", + [] { + return CheckedNegation( + absl::Nanoseconds(std::numeric_limits::lowest())); + }, + absl::OutOfRangeError("integer overflow")}, + {"NegateInfiniteDuration", + [] { return CheckedNegation(absl::InfiniteDuration()); }, + absl::OutOfRangeError("integer overflow")}, + {"NegateNegInfiniteDuration", + [] { return CheckedNegation(-absl::InfiniteDuration()); }, + absl::OutOfRangeError("integer overflow")}, + }), + [](const testing::TestParamInfo& + info) { return info.param.test_name; }); + +using TimeTestCase = TestCase; +using CheckedTimeResultTest = testing::TestWithParam; +TEST_P(CheckedTimeResultTest, TimeDurationOperations) { + ExpectResult(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + CheckedTimeDurationMathTest, CheckedTimeResultTest, + ValuesIn(std::vector{ + // Addition tests. + {"DateAddOneHourMinusOneMilli", + [] { + return CheckedAdd(absl::FromUnixSeconds(3506), + absl::Hours(1) + absl::Milliseconds(-1)); + }, + absl::FromUnixSeconds(7106) + absl::Milliseconds(-1)}, + {"DateAddOneHourOneNano", + [] { + return CheckedAdd(absl::FromUnixSeconds(3506), + absl::Hours(1) + absl::Nanoseconds(1)); + }, + absl::FromUnixSeconds(7106) + absl::Nanoseconds(1)}, + {"MaxIntAddOneSecond", + [] { + return CheckedAdd( + absl::FromUnixSeconds(std::numeric_limits::max()), + absl::Seconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"MaxTimestampAddOneSecond", + [] { + return CheckedAdd(absl::FromUnixSeconds(253402300799), + absl::Seconds(1)); + }, + absl::OutOfRangeError("timestamp overflow")}, + {"TimeWithNanosNegative", + [] { + return CheckedAdd(absl::FromUnixSeconds(1) + absl::Nanoseconds(1), + absl::Nanoseconds(-999999999)); + }, + absl::FromUnixNanos(2)}, + {"TimeWithNanosPositive", + [] { + return CheckedAdd( + absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), + absl::Nanoseconds(999999999)); + }, + absl::FromUnixSeconds(2) + absl::Nanoseconds(999999998)}, + {"SecondsAddInfinity", + [] { + return CheckedAdd( + absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), + absl::InfiniteDuration()); + }, + absl::OutOfRangeError("timestamp overflow")}, + {"SecondsAddNegativeInfinity", + [] { + return CheckedAdd( + absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), + -absl::InfiniteDuration()); + }, + absl::OutOfRangeError("timestamp overflow")}, + {"InfiniteFutureAddNegativeInfinity", + [] { + return CheckedAdd(absl::InfiniteFuture(), -absl::InfiniteDuration()); + }, + absl::OutOfRangeError("timestamp overflow")}, + {"InfinitePastAddInfinity", + [] { + return CheckedAdd(absl::InfinitePast(), absl::InfiniteDuration()); + }, + absl::OutOfRangeError("timestamp overflow")}, + + // Subtraction tests. + {"DateSubOneHour", + [] { return CheckedSub(absl::FromUnixSeconds(3506), absl::Hours(1)); }, + absl::FromUnixSeconds(-94)}, + {"MinTimestampSubOneSecond", + [] { + return CheckedSub(absl::FromUnixSeconds(-62135596800), + absl::Seconds(1)); + }, + absl::OutOfRangeError("timestamp overflow")}, + {"MinIntSubOneViaNanos", + [] { + return CheckedSub( + absl::FromUnixSeconds(std::numeric_limits::min()), + absl::Nanoseconds(1)); + }, + absl::OutOfRangeError("integer overflow")}, + {"MinTimestampSubOneViaNanosScaleOverflow", + [] { + return CheckedSub( + absl::FromUnixSeconds(-62135596800) + absl::Nanoseconds(1), + absl::Nanoseconds(999999999)); + }, + absl::OutOfRangeError("timestamp overflow")}, + {"SecondsSubInfinity", + [] { + return CheckedSub( + absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), + absl::InfiniteDuration()); + }, + absl::OutOfRangeError("integer overflow")}, + {"SecondsSubNegInfinity", + [] { + return CheckedSub( + absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), + -absl::InfiniteDuration()); + }, + absl::OutOfRangeError("integer overflow")}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +using ConvertInt64Int32TestCase = TestCase; +using CheckedConvertInt64Int32Test = + testing::TestWithParam; +TEST_P(CheckedConvertInt64Int32Test, Conversions) { ExpectResult(GetParam()); } + +INSTANTIATE_TEST_SUITE_P( + CheckedConvertInt64Int32Test, CheckedConvertInt64Int32Test, + ValuesIn(std::vector{ + {"SimpleConversion", [] { return CheckedInt64ToInt32(int64_t{1L}); }, + 1}, + {"Int32MaxConversion", + [] { + return CheckedInt64ToInt32( + static_cast(std::numeric_limits::max())); + }, + std::numeric_limits::max()}, + {"Int32MaxConversionError", + [] { + return CheckedInt64ToInt32( + static_cast(std::numeric_limits::max())); + }, + absl::OutOfRangeError("out of int32 range")}, + {"Int32MinConversion", + [] { + return CheckedInt64ToInt32( + static_cast(std::numeric_limits::lowest())); + }, + std::numeric_limits::lowest()}, + {"Int32MinConversionError", + [] { + return CheckedInt64ToInt32( + static_cast(std::numeric_limits::lowest())); + }, + absl::OutOfRangeError("out of int32 range")}, + }), + [](const testing::TestParamInfo& + info) { return info.param.test_name; }); + +using ConvertUint64Uint32TestCase = TestCase; +using CheckedConvertUint64Uint32Test = + testing::TestWithParam; +TEST_P(CheckedConvertUint64Uint32Test, Conversions) { + ExpectResult(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + CheckedConvertUint64Uint32Test, CheckedConvertUint64Uint32Test, + ValuesIn(std::vector{ + {"SimpleConversion", + [] { return CheckedUint64ToUint32(uint64_t{1UL}); }, 1U}, + {"Uint32MaxConversion", + [] { + return CheckedUint64ToUint32( + static_cast(std::numeric_limits::max())); + }, + std::numeric_limits::max()}, + {"Uint32MaxConversionError", + [] { + return CheckedUint64ToUint32( + static_cast(std::numeric_limits::max())); + }, + absl::OutOfRangeError("out of uint32 range")}, + }), + [](const testing::TestParamInfo& + info) { return info.param.test_name; }); + +} // namespace +} // namespace cel::internal diff --git a/internal/parse_text_proto.h b/internal/parse_text_proto.h new file mode 100644 index 000000000..772c24382 --- /dev/null +++ b/internal/parse_text_proto.h @@ -0,0 +1,121 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PARSE_TEXT_PROTO_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_PARSE_TEXT_PROTO_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/log/die_if_null.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/memory.h" +#include "internal/message_type_name.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" +#include "google/protobuf/text_format.h" + +namespace cel::internal { + +// `GeneratedParseTextProto` parses the text format protocol buffer message as +// the message with the same name as `T`, looked up in the provided descriptor +// pool, returning as the generated message. This works regardless of whether +// all messages are built with the lite runtime or not. +template +std::enable_if_t, T* absl_nonnull> +GeneratedParseTextProto( + google::protobuf::Arena* absl_nonnull arena, absl::string_view text, + const google::protobuf::DescriptorPool* absl_nonnull pool = + GetTestingDescriptorPool(), + google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { + // Full runtime. + const auto* descriptor = ABSL_DIE_IF_NULL( // Crash OK + pool->FindMessageTypeByName(MessageTypeNameFor())); + const auto* dynamic_message_prototype = + ABSL_DIE_IF_NULL(factory->GetPrototype(descriptor)); // Crash OK + auto* dynamic_message = dynamic_message_prototype->New(arena); + ABSL_CHECK( // Crash OK + google::protobuf::TextFormat::ParseFromString(text, dynamic_message)); + if (auto* generated_message = google::protobuf::DynamicCastMessage(dynamic_message); + generated_message != nullptr) { + // Same thing, no need to serialize and parse. + return generated_message; + } + auto* message = google::protobuf::Arena::Create(arena); + absl::Cord serialized_message; + ABSL_CHECK( // Crash OK + dynamic_message->SerializeToCord(&serialized_message)); + ABSL_CHECK(message->ParseFromCord(serialized_message)); // Crash OK + return message; +} + +// `GeneratedParseTextProto` parses the text format protocol buffer message as +// the message with the same name as `T`, looked up in the provided descriptor +// pool, returning as the generated message. This works regardless of whether +// all messages are built with the lite runtime or not. +template +std::enable_if_t< + std::conjunction_v, + std::negation>>, + T* absl_nonnull> +GeneratedParseTextProto( + google::protobuf::Arena* absl_nonnull arena, absl::string_view text, + const google::protobuf::DescriptorPool* absl_nonnull pool = + GetTestingDescriptorPool(), + google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { + // Lite runtime. + const auto* descriptor = ABSL_DIE_IF_NULL( // Crash OK + pool->FindMessageTypeByName(MessageTypeNameFor())); + const auto* dynamic_message_prototype = + ABSL_DIE_IF_NULL(factory->GetPrototype(descriptor)); // Crash OK + auto* dynamic_message = dynamic_message_prototype->New(arena); + ABSL_CHECK( // Crash OK + google::protobuf::TextFormat::ParseFromString(text, dynamic_message)); + auto* message = google::protobuf::Arena::Create(arena); + absl::Cord serialized_message; + ABSL_CHECK( // Crash OK + dynamic_message->SerializeToCord(&serialized_message)); + ABSL_CHECK(message->ParseFromCord(serialized_message)); // Crash OK + return message; +} + +// `DynamicParseTextProto` parses the text format protocol buffer message as the +// dynamic message with the same name as `T`, looked up in the provided +// descriptor pool, returning the dynamic message. +template +google::protobuf::Message* absl_nonnull DynamicParseTextProto( + google::protobuf::Arena* absl_nonnull arena, absl::string_view text, + const google::protobuf::DescriptorPool* absl_nonnull pool = + GetTestingDescriptorPool(), + google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { + static_assert(std::is_base_of_v); + const auto* descriptor = ABSL_DIE_IF_NULL( // Crash OK + pool->FindMessageTypeByName(MessageTypeNameFor())); + const auto* dynamic_message_prototype = + ABSL_DIE_IF_NULL(factory->GetPrototype(descriptor)); // Crash OK + auto* dynamic_message = dynamic_message_prototype->New(arena); + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString( // Crash OK + text, cel::to_address(dynamic_message))); + return dynamic_message; +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PARSE_TEXT_PROTO_H_ diff --git a/internal/port.h b/internal/port.h deleted file mode 100644 index 07473fed7..000000000 --- a/internal/port.h +++ /dev/null @@ -1,66 +0,0 @@ -// This files is a forwarding header for other headers containing various -// portability macros and functions. - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ - -#include - -namespace google { -namespace api { -namespace expr { -namespace internal { - -// Back port some helpers. -// Defined in std as of c++14. -template -using decay_t = typename std::decay::type; -template -using enable_if_t = typename std::enable_if::type; -template -using conditional_t = typename std::conditional::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_reference_t = typename std::remove_reference::type; -template -using remove_cv_t = typename std::remove_cv::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_volatile_t = typename std::remove_volatile::type; - -// Defined in std as of c++17 -template -struct conjunction : std::true_type {}; -template -struct conjunction : T {}; -template -struct conjunction - : std::conditional, T>::type {}; -template -struct disjunction : std::false_type {}; -template -struct disjunction : B1 {}; -template -struct disjunction - : conditional_t> {}; -template -using bool_constant = std::integral_constant; -template -struct negation : bool_constant(B::value)> {}; - -// Defined in std as of c++20 -template -struct remove_cvref { - typedef remove_cv_t> type; -}; -template -using remove_cvref_t = typename remove_cvref::type; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PORT_H_ diff --git a/internal/proto_file_util.h b/internal/proto_file_util.h new file mode 100644 index 000000000..7a17fe04c --- /dev/null +++ b/internal/proto_file_util.h @@ -0,0 +1,73 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_FILE_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_FILE_UTIL_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" + +namespace cel::internal::test { + +// Reads a binary protobuf message of MessageType from the given path. +template +absl::Status ReadBinaryProtoFromFile(absl::string_view file_name, + MessageType& message) { + std::ifstream file; + file.open(std::string(file_name), std::fstream::in | std::fstream::binary); + if (!file.is_open()) { + return absl::NotFoundError(absl::StrFormat("Failed to open file '%s': %s", + file_name, strerror(errno))); + } + + if (!message.ParseFromIstream(&file)) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to parse proto of type '%s' from file '%s'", + message.GetTypeName(), file_name)); + } + + return absl::OkStatus(); +} + +// Reads a text protobuf message of MessageType from the given path. +template +absl::Status ReadTextProtoFromFile(absl::string_view file_name, + MessageType& message) { + std::ifstream file; + file.open(std::string(file_name), std::fstream::in | std::fstream::binary); + if (!file.is_open()) { + return absl::NotFoundError(absl::StrFormat("Failed to open file '%s': %s", + file_name, strerror(errno))); + } + + google::protobuf::io::IstreamInputStream stream(&file); + if (!google::protobuf::TextFormat::Parse(&stream, &message)) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to parse proto of type '%s' from file '%s'", + message.GetTypeName(), file_name)); + } + return absl::OkStatus(); +} + +} // namespace cel::internal::test + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_FILE_UTIL_H_ diff --git a/internal/proto_matchers.h b/internal/proto_matchers.h new file mode 100644 index 000000000..02250634b --- /dev/null +++ b/internal/proto_matchers.h @@ -0,0 +1,140 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_MATCHERS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_MATCHERS_H_ + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel::internal::test { + +/** + * Simple implementation of a proto matcher comparing string representations. + * + * IMPORTANT: Only use this for protos whose textual representation is + * deterministic (that may not be the case for the map collection type). + */ +class TextProtoMatcher { + public: + explicit inline TextProtoMatcher(absl::string_view expected) + : expected_(expected) {} + + bool MatchAndExplain(const google::protobuf::MessageLite& p, + ::testing::MatchResultListener* listener) const { + return MatchAndExplain(google::protobuf::DownCastMessage(p), + listener); + } + + bool MatchAndExplain(const google::protobuf::MessageLite* p, + ::testing::MatchResultListener* listener) const { + return MatchAndExplain(google::protobuf::DownCastMessage(p), + listener); + } + + bool MatchAndExplain(const google::protobuf::Message& p, + ::testing::MatchResultListener* listener) const { + auto message = absl::WrapUnique(p.New()); + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); + return google::protobuf::util::MessageDifferencer::Equals( + *message, google::protobuf::DownCastMessage(p)); + } + + bool MatchAndExplain(const google::protobuf::Message* p, + ::testing::MatchResultListener* listener) const { + auto message = absl::WrapUnique(p->New()); + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); + return google::protobuf::util::MessageDifferencer::Equals( + *message, google::protobuf::DownCastMessage(*p)); + } + + inline void DescribeTo(::std::ostream* os) const { *os << expected_; } + inline void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_; + } + + private: + const std::string expected_; +}; + +/** + * Simple implementation of a proto matcher comparing string representations. + * + * IMPORTANT: Only use this for protos whose textual representation is + * deterministic (that may not be the case for the map collection type). + */ +class ProtoMatcher { + public: + explicit inline ProtoMatcher(const google::protobuf::Message& expected) + : expected_(expected.New()) { + expected_->CopyFrom(expected); + } + + bool MatchAndExplain(const google::protobuf::MessageLite& p, + ::testing::MatchResultListener* listener) const { + return MatchAndExplain(google::protobuf::DownCastMessage(p), + listener); + } + + bool MatchAndExplain(const google::protobuf::MessageLite* p, + ::testing::MatchResultListener* listener) const { + return MatchAndExplain(google::protobuf::DownCastMessage(p), + listener); + } + + bool MatchAndExplain(const google::protobuf::Message& p, + ::testing::MatchResultListener* /* listener */) const { + return google::protobuf::util::MessageDifferencer::Equals(*expected_, p); + } + + bool MatchAndExplain(const google::protobuf::Message* p, + ::testing::MatchResultListener* /* listener */) const { + return google::protobuf::util::MessageDifferencer::Equals(*expected_, *p); + } + + inline void DescribeTo(::std::ostream* os) const { + *os << expected_->DebugString(); + } + inline void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_->DebugString(); + } + + private: + std::shared_ptr expected_; +}; + +// Polymorphic matcher to compare any two protos. +inline ::testing::PolymorphicMatcher EqualsProto( + absl::string_view x) { + return ::testing::MakePolymorphicMatcher(TextProtoMatcher(x)); +} + +// Polymorphic matcher to compare any two protos. +inline ::testing::PolymorphicMatcher EqualsProto( + const google::protobuf::Message& x) { + return ::testing::MakePolymorphicMatcher(ProtoMatcher(x)); +} + +} // namespace cel::internal::test + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_MATCHERS_H_ diff --git a/internal/proto_time_encoding.cc b/internal/proto_time_encoding.cc new file mode 100644 index 000000000..194aab396 --- /dev/null +++ b/internal/proto_time_encoding.cc @@ -0,0 +1,103 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/proto_time_encoding.h" + +#include + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "internal/status_macros.h" +#include "internal/time.h" +#include "google/protobuf/util/time_util.h" + +namespace cel::internal { + +namespace { + +absl::Status Validate(absl::Time time) { + if (time < cel::internal::MinTimestamp()) { + return absl::InvalidArgumentError("time below min"); + } + + if (time > cel::internal::MaxTimestamp()) { + return absl::InvalidArgumentError("time above max"); + } + return absl::OkStatus(); +} + +absl::Status CelValidateDuration(absl::Duration duration) { + if (duration < cel::internal::MinDuration()) { + return absl::InvalidArgumentError("duration below min"); + } + + if (duration > cel::internal::MaxDuration()) { + return absl::InvalidArgumentError("duration above max"); + } + return absl::OkStatus(); +} + +} // namespace + +absl::Duration DecodeDuration(const google::protobuf::Duration& proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + +absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { + return absl::FromUnixSeconds(proto.seconds()) + + absl::Nanoseconds(proto.nanos()); +} + +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto) { + CEL_RETURN_IF_ERROR(CelValidateDuration(duration)); + // s and n may both be negative, per the Duration proto spec. + const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); + const int64_t n = + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); + proto->set_seconds(s); + proto->set_nanos(n); + return absl::OkStatus(); +} + +absl::StatusOr EncodeDurationToString(absl::Duration duration) { + google::protobuf::Duration d; + auto status = EncodeDuration(duration, &d); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(d); +} + +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { + CEL_RETURN_IF_ERROR(Validate(time)); + const int64_t s = absl::ToUnixSeconds(time); + proto->set_seconds(s); + proto->set_nanos((time - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); + return absl::OkStatus(); +} + +absl::StatusOr EncodeTimeToString(absl::Time time) { + google::protobuf::Timestamp t; + auto status = EncodeTime(time, &t); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(t); +} + +} // namespace cel::internal diff --git a/internal/proto_time_encoding.h b/internal/proto_time_encoding.h new file mode 100644 index 000000000..aa4128ee7 --- /dev/null +++ b/internal/proto_time_encoding.h @@ -0,0 +1,49 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Defines basic encode/decode operations for proto time and duration formats. +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/time/time.h" + +namespace cel::internal { + +/** Helper function to encode a duration in a google::protobuf::Duration. */ +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto); + +/** Helper function to encode an absl::Duration to a JSON-formatted string. */ +absl::StatusOr EncodeDurationToString(absl::Duration duration); + +/** Helper function to encode a time in a google::protobuf::Timestamp. */ +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto); + +/** Helper function to encode an absl::Time to a JSON-formatted string. */ +absl::StatusOr EncodeTimeToString(absl::Time time); + +/** Helper function to decode a duration from a google::protobuf::Duration. */ +absl::Duration DecodeDuration(const google::protobuf::Duration& proto); + +/** Helper function to decode a time from a google::protobuf::Timestamp. */ +absl::Time DecodeTime(const google::protobuf::Timestamp& proto); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ diff --git a/internal/proto_time_encoding_test.cc b/internal/proto_time_encoding_test.cc new file mode 100644 index 000000000..29b2d2af6 --- /dev/null +++ b/internal/proto_time_encoding_test.cc @@ -0,0 +1,74 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/proto_time_encoding.h" + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/time/time.h" +#include "internal/testing.h" +#include "testutil/util.h" + +namespace cel::internal { +namespace { + +using ::google::api::expr::testutil::EqualsProto; + +TEST(EncodeDuration, Basic) { + google::protobuf::Duration proto_duration; + ASSERT_OK( + EncodeDuration(absl::Seconds(2) + absl::Nanoseconds(3), &proto_duration)); + + EXPECT_THAT(proto_duration, EqualsProto("seconds: 2 nanos: 3")); +} + +TEST(EncodeDurationToString, Basic) { + ASSERT_OK_AND_ASSIGN( + std::string json, + EncodeDurationToString(absl::Seconds(5) + absl::Nanoseconds(20))); + EXPECT_EQ(json, "5.000000020s"); +} + +TEST(EncodeTime, Basic) { + google::protobuf::Timestamp proto_timestamp; + ASSERT_OK(EncodeTime(absl::FromUnixMillis(300000), &proto_timestamp)); + + EXPECT_THAT(proto_timestamp, EqualsProto("seconds: 300")); +} + +TEST(EncodeTimeToString, Basic) { + ASSERT_OK_AND_ASSIGN(std::string json, + EncodeTimeToString(absl::FromUnixMillis(80030))); + + EXPECT_EQ(json, "1970-01-01T00:01:20.030Z"); +} + +TEST(DecodeDuration, Basic) { + google::protobuf::Duration proto_duration; + proto_duration.set_seconds(450); + proto_duration.set_nanos(4); + + EXPECT_EQ(DecodeDuration(proto_duration), + absl::Seconds(450) + absl::Nanoseconds(4)); +} + +TEST(DecodeTime, Basic) { + google::protobuf::Timestamp proto_timestamp; + proto_timestamp.set_seconds(450); + + EXPECT_EQ(DecodeTime(proto_timestamp), absl::FromUnixSeconds(450)); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/proto_util.cc b/internal/proto_util.cc deleted file mode 100644 index b2dd0a22b..000000000 --- a/internal/proto_util.cc +++ /dev/null @@ -1,72 +0,0 @@ -#include "internal/proto_util.h" -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/rpc/status.pb.h" -#include "absl/strings/str_cat.h" -#include "common/macros.h" -#include "internal/status_util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -namespace { - -google::rpc::Status Validate(absl::Duration duration) { - if (duration < MakeGoogleApiDurationMin()) { - return InvalidArgumentError(absl::StrCat("duration below min")); - } - - if (duration > MakeGoogleApiDurationMax()) { - return InvalidArgumentError(absl::StrCat("duration above max")); - } - return OkStatus(); -} - -google::rpc::Status Validate(absl::Time time) { - if (time < MakeGoogleApiTimeMin()) { - return InvalidArgumentError(absl::StrCat("time below min")); - } - - if (time > MakeGoogleApiTimeMax()) { - return InvalidArgumentError(absl::StrCat("time above max")); - } - return OkStatus(); -} - -} // namespace - -absl::Duration DecodeDuration(const google::protobuf::Duration& proto) { - return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); -} - -absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { - return absl::FromUnixSeconds(proto.seconds()) + - absl::Nanoseconds(proto.nanos()); -} - -google::rpc::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto) { - RETURN_IF_STATUS_ERROR(Validate(duration)); - // s and n may both be negative, per the Duration proto spec. - const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); - const int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); - proto->set_seconds(s); - proto->set_nanos(n); - return OkStatus(); -} - -google::rpc::Status EncodeTime(absl::Time time, - google::protobuf::Timestamp* proto) { - RETURN_IF_STATUS_ERROR(Validate(time)); - const int64_t s = absl::ToUnixSeconds(time); - proto->set_seconds(s); - proto->set_nanos((time - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); - return OkStatus(); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/proto_util.h b/internal/proto_util.h index 12534ec50..5f28581d9 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -1,74 +1,82 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_UTIL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_UTIL_H_ -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/rpc/status.pb.h" +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "google/protobuf/util/message_differencer.h" -#include "absl/memory/memory.h" -#include "absl/time/time.h" namespace google { namespace api { namespace expr { namespace internal { -struct DefaultProtoEqual { - inline bool operator()(const google::protobuf::Message& lhs, - const google::protobuf::Message& rhs) const { - return google::protobuf::util::MessageDifferencer::Equals(lhs, rhs); - } -}; - -/** Helper function to encode a duration in a google::protobuf::Duration. */ -google::rpc::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto); - -/** Helper function to encode a time in a google::protobuf::Timestamp. */ -google::rpc::Status EncodeTime(absl::Time time, - google::protobuf::Timestamp* proto); - -/** Helper function to decode a duration from a google::protobuf::Duration. */ -absl::Duration DecodeDuration(const google::protobuf::Duration& proto); - -/** Helper function to decode a time from a google::protobuf::Timestamp. */ -absl::Time DecodeTime(const google::protobuf::Timestamp& proto); - -/** Returns the min absl::Duration that can be represented as -/ * google::protobuf::Duration. */ -inline absl::Duration MakeGoogleApiDurationMin() { - return absl::Seconds(-315576000000) + absl::Nanoseconds(-999999999); -} - -/** Returns the max absl::Duration that can be represented as -/ * google::protobuf::Duration. */ -inline absl::Duration MakeGoogleApiDurationMax() { - return absl::Seconds(315576000000) + absl::Nanoseconds(999999999); -} +template +absl::Status ValidateStandardMessageType( + const google::protobuf::DescriptorPool& descriptor_pool) { + if constexpr (std::is_base_of_v) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + const google::protobuf::Descriptor* descriptor_from_pool = + descriptor_pool.FindMessageTypeByName(descriptor->full_name()); + if (descriptor_from_pool == nullptr) { + return absl::NotFoundError( + absl::StrFormat("Descriptor '%s' not found in descriptor pool", + descriptor->full_name())); + } + if (descriptor_from_pool == descriptor) { + return absl::OkStatus(); + } + google::protobuf::DescriptorProto descriptor_proto; + google::protobuf::DescriptorProto descriptor_from_pool_proto; + descriptor->CopyTo(&descriptor_proto); + descriptor_from_pool->CopyTo(&descriptor_from_pool_proto); -/** Returns the min absl::Time that can be represented as -/ * google::protobuf::Timestamp. */ -inline absl::Time MakeGoogleApiTimeMin() { - return absl::UnixEpoch() + absl::Seconds(-62135596800); -} - -/** Returns the max absl::Time that can be represented as -/ * google::protobuf::Timestamp. */ -inline absl::Time MakeGoogleApiTimeMax() { - return absl::UnixEpoch() + absl::Seconds(253402300799) + - absl::Nanoseconds(999999999); -} - -inline std::unique_ptr Clone(const google::protobuf::Message& value) { - auto result = absl::WrapUnique(value.New()); - result->CopyFrom(value); - return result; -} - -inline std::unique_ptr Clone(google::protobuf::Message&& value) { - auto result = absl::WrapUnique(value.New()); - result->GetReflection()->Swap(&value, result.get()); - return result; + google::protobuf::util::MessageDifferencer descriptor_differencer; + std::string differences; + descriptor_differencer.ReportDifferencesToString(&differences); + // The json_name is a compiler detail and does not change the message + // content. It can differ, e.g., between C++ and Go compilers. Hence ignore. + const google::protobuf::FieldDescriptor* json_name_field_desc = + google::protobuf::FieldDescriptorProto::descriptor()->FindFieldByName( + "json_name"); + if (json_name_field_desc != nullptr) { + descriptor_differencer.IgnoreField(json_name_field_desc); + } + if (!descriptor_differencer.Compare(descriptor_proto, + descriptor_from_pool_proto)) { + return absl::FailedPreconditionError(absl::StrFormat( + "The descriptor for '%s' in the descriptor pool differs from the " + "compiled-in generated version as follows: %s", + descriptor->full_name(), differences)); + } + } else { + // Lite runtime. Just verify the message exists. + const auto& type_name = MessageType::default_instance().GetTypeName(); + const google::protobuf::Descriptor* descriptor_from_pool = + descriptor_pool.FindMessageTypeByName(type_name); + if (descriptor_from_pool == nullptr) { + return absl::NotFoundError(absl::StrFormat( + "Descriptor '%s' not found in descriptor pool", type_name)); + } + } + return absl::OkStatus(); } } // namespace internal diff --git a/internal/proto_util_test.cc b/internal/proto_util_test.cc new file mode 100644 index 000000000..179ad50bd --- /dev/null +++ b/internal/proto_util_test.cc @@ -0,0 +1,64 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/proto_util.h" + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +using google::api::expr::internal::ValidateStandardMessageType; +using google::api::expr::runtime::GetStandardMessageTypesFileDescriptorSet; + +using ::absl_testing::StatusIs; +using ::testing::HasSubstr; + +TEST(ProtoUtil, ValidateStandardMessageTypesRejectsIncompatible) { + google::protobuf::DescriptorPool descriptor_pool; + google::protobuf::FileDescriptorSet standard_fds = + GetStandardMessageTypesFileDescriptorSet(); + + const google::protobuf::Descriptor* descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.protobuf.Duration"); + ASSERT_NE(descriptor, nullptr); + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + // We emulate a modification by external code that replaced the nanos by a + // millis field. + google::protobuf::FieldDescriptorProto seconds_desc_proto; + google::protobuf::FieldDescriptorProto nanos_desc_proto; + descriptor->FindFieldByName("seconds")->CopyTo(&seconds_desc_proto); + descriptor->FindFieldByName("nanos")->CopyTo(&nanos_desc_proto); + nanos_desc_proto.set_name("millis"); + file_descriptor_proto.mutable_message_type(0)->clear_field(); + *file_descriptor_proto.mutable_message_type(0)->add_field() = + seconds_desc_proto; + *file_descriptor_proto.mutable_message_type(0)->add_field() = + nanos_desc_proto; + + descriptor_pool.BuildFile(file_descriptor_proto); + + EXPECT_THAT( + ValidateStandardMessageType(descriptor_pool), + StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("differs"))); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/protobuf_runtime_version.h b/internal/protobuf_runtime_version.h new file mode 100644 index 000000000..2873a409d --- /dev/null +++ b/internal/protobuf_runtime_version.h @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTOBUF_VERSION_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_PROTOBUF_VERSION_H_ + +#ifdef __has_include +#if __has_include("third_party/protobuf/runtime_version.h") +#include "google/protobuf/runtime_version.h" // IWYU pragma: keep +#endif +#endif + +#ifdef PROTOBUF_OSS_VERSION +#define CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(major, minor, patch) \ + ((major) * 1000000 + (minor) * 1000 + (patch) <= PROTOBUF_OSS_VERSION) +#else +// Older versions of protobuf did not have the macro. +#define CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(major, minor, patch) 0 +#endif + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTOBUF_VERSION_H_ diff --git a/internal/re2_options.h b/internal/re2_options.h new file mode 100644 index 000000000..25a30f6bd --- /dev/null +++ b/internal/re2_options.h @@ -0,0 +1,61 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "re2/re2.h" + +namespace cel::internal { + +inline RE2::Options MakeRE2Options() { + RE2::Options options; + options.set_log_errors(false); + return options; +} + +inline absl::Status CheckRE2(const RE2& re, int max_program_size) { + if (!re.ok()) { + switch (re.error_code()) { + case RE2::ErrorInternal: + return absl::InternalError( + absl::StrCat("internal RE2 error: ", re.error())); + case RE2::ErrorPatternTooLarge: + return absl::InvalidArgumentError( + absl::StrCat("regular expression too large: ", re.error())); + default: + return absl::InvalidArgumentError( + absl::StrCat("invalid regular expression: ", re.error())); + } + } + int program_size = re.ProgramSize(); + if (max_program_size > 0 && program_size > 0 && + program_size > max_program_size) { + return absl::InvalidArgumentError( + "regular expression exceeds max allowed size"); + } + int reverse_program_size = re.ReverseProgramSize(); + if (max_program_size > 0 && reverse_program_size > 0 && + reverse_program_size > max_program_size) { + return absl::InvalidArgumentError( + "regular expression exceeds max allowed size"); + } + return absl::OkStatus(); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ diff --git a/internal/ref_countable.cc b/internal/ref_countable.cc deleted file mode 100644 index d97bc2e3f..000000000 --- a/internal/ref_countable.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include "internal/ref_countable.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -RefCountable::~RefCountable() { assert(unowned()); } - -std::size_t RefCountable::owner_count() const { - return refcount_.load(std::memory_order_acquire); -} - -bool RefCountable::single_owner() const { - // If we are the sole owner, only we have a view of the refcount, so no memory - // barrier needed. - return refcount_.load(std::memory_order_relaxed) == 1; -} - -bool RefCountable::unowned() const { - // If refcounting is not being used, then this value must not have changed - // since construction. If refcounting is being used and this value is 0, then - // a call to unowned() is inherently racy all ready. So no memory barrier - // needed. - return refcount_.load(std::memory_order_relaxed) == 0; -} - -void RefCountable::Ref() const { - refcount_.fetch_add(1, std::memory_order_relaxed); -} - -bool RefCountable::Unref() const { - auto prev_value = refcount_.fetch_sub(1, std::memory_order_release); - assert(prev_value > 0); - if (prev_value == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - return true; - } - return false; -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/ref_countable.h b/internal/ref_countable.h deleted file mode 100644 index 514cbce88..000000000 --- a/internal/ref_countable.h +++ /dev/null @@ -1,262 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNTED_CEL_VALUE_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNTED_CEL_VALUE_H_ - -#include -#include -#include -#include - -#include "internal/holder.h" -#include "internal/specialize.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -/** - * A base class for optionally reference counted objects. - * - * Like std::shared_ptr, reference counting is optional and only active when - * used with a smart pointer (ReffedPtr). The main difference is that the - * counter is embedded within the object, so the size of the smart pointer - * is equal to that of a normal pointer. - * - * Unlike typical reference counting implementations, inheriting from this class - * does not restrict the function of the subclass. For example, subclasses can - * be constructed on the stack, passed by value, or used with other smart - * pointer classes (e.g. std::unique_ptr). Additionally, a virtual destructor - * is not required (although it is often useful). - */ -class RefCountable { - protected: - // Cannot be constructed or destructed directly. - constexpr RefCountable() : refcount_(0) {} - ~RefCountable(); - - // The refcount is not copied. - constexpr RefCountable(const RefCountable&) : refcount_(0) {} - RefCountable& operator=(const RefCountable&) { return *this; } - - std::size_t owner_count() const; - - /** Returns true if this object only has a single owner. */ - bool single_owner() const; - - /** Returns true if this object is not owned by a ReffedPtr. */ - bool unowned() const; - - private: - template - friend class ReffedPtr; - - mutable std::atomic refcount_; - - /** Increments the ref count. */ - void Ref() const; - - /** Decrements the ref count and returns true if the count becomes 0. */ - bool Unref() const; -}; - -/** - * A smart pointer to a reference countable object. - * - * @tparam T The type to point to. Should either be 'final' or have a virtual - * destructor. - */ -template -class ReffedPtr { - public: - template - static ReffedPtr Make(Args&&... args) { - return ReffedPtr(new T(std::forward(args)...)); - } - - constexpr ReffedPtr() {} - constexpr ReffedPtr(std::nullptr_t) {} - explicit ReffedPtr(T* ptr) : ptr_(ptr) { MaybeRef(); } - explicit ReffedPtr(std::unique_ptr ptr) : ptr_(ptr.release()) { - MaybeRef(); - } - - ReffedPtr(const ReffedPtr& other) : ptr_(other.ptr_) { MaybeRef(); } - template >> - ReffedPtr(const ReffedPtr& other) : ptr_(other.ptr_) { - MaybeRef(); - } - ReffedPtr(ReffedPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; } - template >> - ReffedPtr(ReffedPtr&& other) : ptr_(other.ptr_) { - other.ptr_ = nullptr; - } - - ReffedPtr& operator=(const ReffedPtr& other) { return EqImpl(other); } - template >> - ReffedPtr& operator=(const ReffedPtr& other) { - return EqImpl(other); - } - - ReffedPtr& operator=(ReffedPtr&& other) { return EqImpl(std::move(other)); } - template >> - ReffedPtr& operator=(ReffedPtr&& other) { - return EqImpl(std::move(other)); - } - - inline ReffedPtr& operator=(std::nullptr_t) { - reset(); - return *this; - } - constexpr inline bool operator==(std::nullptr_t) const { - return ptr_ == nullptr; - } - constexpr inline bool operator!=(std::nullptr_t) const { - return ptr_ != nullptr; - } - - void reset(); - - ~ReffedPtr() { reset(); } - - constexpr T* get() const { return ptr_; } - constexpr T& operator*() const { return *ptr_; } - constexpr T* operator->() const { return get(); } - - private: - template - friend class ReffedPtr; - - T* ptr_ = nullptr; - - void MaybeRef() { - if (ptr_ != nullptr) { - ptr_->Ref(); - } - } - - template - ReffedPtr& EqImpl(const ReffedPtr& other) { - reset(); - ptr_ = other.ptr_; - MaybeRef(); - return *this; - } - - template - ReffedPtr& EqImpl(ReffedPtr&& other) { - reset(); - ptr_ = other.ptr_; - other.ptr_ = nullptr; - return *this; - } -}; - -template -ReffedPtr MakeReffed(Args&&... args) { - return ReffedPtr::Make(std::forward(args)...); -} - -template -ReffedPtr MakeReffed(std::unique_ptr value) { - return ReffedPtr(std::move(value)); -} - -template -ReffedPtr WrapReffed(T* value) { - return ReffedPtr(value); -} - -/** - * A ReffedPtr based HolderPolicy. - */ -struct RefPtr : BaseHolderPolicy { - constexpr static const bool kOwnsValue = true; - - template - using ValueType = ReffedPtr; - - template - static T& get(ReffedPtr& value) { - return *value; - } - - template - static const T& get(const ReffedPtr& value) { - return *value; - } -}; - -template -using RefPtrHolder = Holder; - -/** - * A reference countable holder. - * - * @see Holder - */ -template -class RefCountableHolder : public RefCountable { - public: - template - explicit RefCountableHolder(Args&&... args) - : value_(std::forward(args)...) {} - - T& value() { return HolderPolicy::template get(value_); } - const T& value() const { return HolderPolicy::template get(value_); } - - private: - typename HolderPolicy::template ValueType value_; -}; - -/** A reference counting HolderPolicy */ -template -struct Ref : BaseHolderPolicy { - constexpr static const bool kOwnsValue = true; - - template - using ValueType = ReffedPtr>; - - // Forward to Make function. - template - static V Create(Args&&... args) { - return V::Make(std::forward(args)...); - } - - template - static T& get(ReffedPtr>& value) { - return value->value(); - } - - template - static const T& get( - const ReffedPtr>& value) { - return value->value(); - } -}; - -template -using RefCopyHolder = Holder>; - -// If the size of T is smaller than MAX, it is stored inline, otherwise -// it is stored in a heap allocated, reference counted holder. -template -using SizeLimitHolder = conditional_t, - Holder>>; - -template -void ReffedPtr::reset() { - if (ptr_ == nullptr) { - return; - } - if (ptr_->Unref()) { - delete ptr_; - } - ptr_ = nullptr; -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_REFERENCE_COUNTED_CEL_VALUE_H_ diff --git a/internal/ref_countable_test.cc b/internal/ref_countable_test.cc deleted file mode 100644 index 3a2c272b8..000000000 --- a/internal/ref_countable_test.cc +++ /dev/null @@ -1,345 +0,0 @@ -#include "internal/ref_countable.h" - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "testutil/util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { -namespace { - -class TestRefCounted : public RefCountable { - public: - explicit TestRefCounted(int id) : id_(id) {} - - int id() const { return id_; } - bool unowned() const { return RefCountable::unowned(); } - bool single_owner() const { return RefCountable::single_owner(); } - std::size_t owner_count() const { return RefCountable::owner_count(); } - - private: - int id_; - int data_[5]; -}; - -TEST(RefCountableTest, StackValue) { - TestRefCounted v1(1); - - EXPECT_TRUE(v1.unowned()); - EXPECT_FALSE(v1.single_owner()); - EXPECT_EQ(v1.id(), 1); -} - -TEST(RefCountableTest, UnqiuePtr) { - auto v1 = absl::make_unique(1); - - EXPECT_TRUE(v1->unowned()); - EXPECT_FALSE(v1->single_owner()); - EXPECT_EQ(v1->id(), 1); -} - -TEST(RefCountableTest, ReffedPtr) { - auto v1 = MakeReffed(1); - - EXPECT_FALSE(v1->unowned()); - EXPECT_TRUE(v1->single_owner()); - EXPECT_EQ(v1->id(), 1); -} - -TEST(RefCountableTest, ReffedPtr_Eq) { - auto v1 = MakeReffed(1); - auto v2 = v1; - - EXPECT_FALSE(v1->unowned()); - EXPECT_FALSE(v1->single_owner()); - EXPECT_EQ(2, v1->owner_count()); - EXPECT_EQ(v2->id(), 1); -} - -TEST(RefCountableTest, ReffedPtr_Move) { - auto v1 = MakeReffed(1); - auto v2 = std::move(v1); - - EXPECT_FALSE(v2->unowned()); - EXPECT_TRUE(v2->single_owner()); - EXPECT_EQ(1, v2->owner_count()); - EXPECT_EQ(v2->id(), 1); - - ReffedPtr v3(std::move(v2)); - EXPECT_FALSE(v3->unowned()); - EXPECT_TRUE(v3->single_owner()); - EXPECT_EQ(1, v3->owner_count()); - EXPECT_EQ(v3->id(), 1); -} - -TEST(RefCountableTest, ReffedPtr_CopyConstructed) { - auto v1 = MakeReffed(1); - ReffedPtr v2(v1); - - EXPECT_FALSE(v1->unowned()); - EXPECT_FALSE(v1->single_owner()); - EXPECT_EQ(2, v1->owner_count()); - EXPECT_EQ(v1->id(), 1); -} - -TEST(RefCountableTest, ReffedPtr_PtrConstructed) { - auto v1 = MakeReffed(1); - ReffedPtr v2(&*v1); - - EXPECT_FALSE(v1->unowned()); - EXPECT_FALSE(v1->single_owner()); - EXPECT_EQ(2, v1->owner_count()); - EXPECT_EQ(v1->id(), 1); -} - -TEST(RefCountableTest, ReffedPtr_Reset) { - auto v1 = MakeReffed(1); - - EXPECT_FALSE(v1->unowned()); - EXPECT_TRUE(v1->single_owner()); - EXPECT_EQ(v1->id(), 1); - - auto v2 = v1; - // Can be constructed directly from the raw pointer with out breaking ref - // counting. - ReffedPtr v3(&*v1); - - // They all point to the same object. - EXPECT_EQ(&*v1, &*v2); - EXPECT_EQ(&*v1, &*v3); - EXPECT_FALSE(v3->unowned()); - EXPECT_FALSE(v3->single_owner()); - EXPECT_EQ(v3->id(), 1); - EXPECT_EQ(3, v3->owner_count()); - - v3.reset(); - EXPECT_FALSE(v1->unowned()); - EXPECT_FALSE(v1->single_owner()); - EXPECT_EQ(2, v1->owner_count()); - - v2.reset(); - EXPECT_FALSE(v1->unowned()); - EXPECT_TRUE(v1->single_owner()); - EXPECT_EQ(1, v1->owner_count()); -} - -TEST(RefCountableTest, Copy) { - TestRefCounted v1(1); - TestRefCounted v2(2); - TestRefCounted v3(v1); - - EXPECT_TRUE(v1.unowned()); - EXPECT_FALSE(v1.single_owner()); - EXPECT_EQ(v1.id(), 1); - - EXPECT_TRUE(v3.unowned()); - EXPECT_FALSE(v3.single_owner()); - EXPECT_EQ(v3.id(), 1); - - EXPECT_TRUE(v2.unowned()); - EXPECT_FALSE(v2.single_owner()); - EXPECT_EQ(v2.id(), 2); - - v2 = v1; - EXPECT_TRUE(v2.unowned()); - EXPECT_FALSE(v2.single_owner()); - EXPECT_EQ(v2.id(), 1); -} - -TEST(RefCountableTest, CopyFromPtr) { - auto v1 = MakeReffed(1); - TestRefCounted v2(2); - - EXPECT_FALSE(v1->unowned()); - EXPECT_TRUE(v1->single_owner()); - EXPECT_EQ(v1->id(), 1); - EXPECT_TRUE(v2.unowned()); - EXPECT_FALSE(v2.single_owner()); - EXPECT_EQ(v2.id(), 2); - - v2 = *v1; - // Only the value is copied. - EXPECT_FALSE(v1->unowned()); - EXPECT_TRUE(v1->single_owner()); - EXPECT_EQ(v1->id(), 1); - EXPECT_TRUE(v2.unowned()); - EXPECT_FALSE(v2.single_owner()); - EXPECT_EQ(v2.id(), 1); -} - -TEST(ReffedPtrTest, ConstConversion) { - ReffedPtr const_ref = MakeReffed(1); - EXPECT_EQ(const_ref->id(), 1); -} - -TEST(Ref, NotRefCountable) { - using HolderType = Holder>; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Value can be mutated. - HolderType holder(1); - testutil::ExpectSameType(); - EXPECT_EQ(1, holder.value()); - holder.value() = 2; - EXPECT_EQ(2, holder.value()); - - // Value is shared. - HolderType holder2 = holder; - EXPECT_EQ(&holder.value(), &holder2.value()); - - // Const holder cannot be assigned or have its value changed. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(2); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value()); -} - -TEST(Ref, NotRefCountable_cost) { - // All modes work (unlike Copy). - using HolderType = Holder>; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Value cannot be changed. - HolderType holder(1); - testutil::ExpectSameType(); - EXPECT_EQ(1, holder.value()); - - // Value is shared. - HolderType holder2 = holder; - EXPECT_EQ(&holder.value(), &holder2.value()); - - // Const holder has the same properties. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(2); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value()); -} - -TEST(RefPtr, RefCountable) { - using HolderType = Holder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Value can be mutated. - HolderType holder(new TestRefCounted(1)); - testutil::ExpectSameType(); - EXPECT_EQ(1, holder.value().id()); - holder.value() = TestRefCounted(2); - EXPECT_EQ(2, holder.value().id()); - - // Value is shared. - HolderType holder2 = holder; - EXPECT_EQ(&holder.value(), &holder2.value()); - - // Const holder cannot be assigned or have its value changed. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(new TestRefCounted(2)); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value().id()); -} - -TEST(RefPtr, RefCountable_cost) { - // All modes work (unlike Copy). - using HolderType = Holder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - - // Value cannot be changed. - HolderType holder(new TestRefCounted(1)); - testutil::ExpectSameType(); - EXPECT_EQ(1, holder.value().id()); - - // Value is shared. - HolderType holder2 = holder; - EXPECT_EQ(&holder.value(), &holder2.value()); - - // Const holder has the same properties. - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_FALSE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_FALSE(std::is_move_assignable::value); - const HolderType const_holder(new TestRefCounted(2)); - testutil::ExpectSameType(); - EXPECT_EQ(2, const_holder.value().id()); -} - -TEST(SizeLimitHolder, Inline16) { - using HolderType = SizeLimitHolder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - EXPECT_LE(sizeof(int16_t), 8); - EXPECT_LE(sizeof(HolderType), 8); - EXPECT_EQ(sizeof(HolderType), sizeof(int16_t)); - HolderType h1(1); - HolderType h2 = h1; - - testutil::ExpectSameType(); - EXPECT_EQ(h1.value(), h2.value()); - EXPECT_NE(&h1.value(), &h2.value()); -} - -TEST(SizeLimitHolder, Inline64) { - using HolderType = SizeLimitHolder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - EXPECT_LE(sizeof(int64_t), 8); - EXPECT_LE(sizeof(HolderType), 8); - EXPECT_EQ(sizeof(HolderType), sizeof(int64_t)); - - HolderType h1(1); - HolderType h2 = h1; - - testutil::ExpectSameType(); - EXPECT_EQ(h1.value(), h2.value()); - EXPECT_NE(&h1.value(), &h2.value()); -} - -TEST(SizeLimitHolder, OverSized) { - using HolderType = SizeLimitHolder; - EXPECT_TRUE(std::is_copy_constructible::value); - EXPECT_TRUE(std::is_copy_assignable::value); - EXPECT_TRUE(std::is_move_constructible::value); - EXPECT_TRUE(std::is_move_assignable::value); - EXPECT_GT(sizeof(TestRefCounted), 8); - EXPECT_LE(sizeof(HolderType), 8); - - HolderType h1(TestRefCounted(1)); - HolderType h2 = h1; - - testutil::ExpectSameType(); - EXPECT_EQ(h1.value().id(), h2.value().id()); - EXPECT_EQ(&h1.value(), &h2.value()); -} - -} // namespace -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/runfiles.cc b/internal/runfiles.cc new file mode 100644 index 000000000..bffbfa9d1 --- /dev/null +++ b/internal/runfiles.cc @@ -0,0 +1,53 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/runfiles.h" + +#include +#include +#include + +#include "rules_cc/cc/runfiles/runfiles.h" +#include "absl/log/absl_check.h" + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +std::string ResolveRunfilesPath(absl::string_view path) { + using ::rules_cc::cc::runfiles::Runfiles; + static Runfiles* runfiles = []() { + std::string error; + auto runfiles = + Runfiles::CreateForTest(BAZEL_CURRENT_REPOSITORY, &error); + ABSL_QCHECK(runfiles != nullptr) + << absl::StrCat("failed to init runfiles", error); + return runfiles; + }(); + return runfiles->Rlocation(std::string(path)); +} + +absl::Status GetFileContents(absl::string_view path, std::string* out) { + std::ifstream file{std::string(path)}; + if (!file.is_open()) { + return absl::NotFoundError(absl::StrCat("Failed to open file: ", path)); + } + out->append((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + return absl::OkStatus(); +} + +} // namespace cel::internal diff --git a/internal/runfiles.h b/internal/runfiles.h new file mode 100644 index 000000000..11fdcf337 --- /dev/null +++ b/internal/runfiles.h @@ -0,0 +1,36 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Utilities for working with bazel runfiles. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +// Resolves a path relative to the runfiles directory. +// Intended for resolving test cases from cel-spec and cel-policy. +std::string ResolveRunfilesPath(absl::string_view path); + +// Read contents of a file at a resolved path to a string. +absl::Status GetFileContents(absl::string_view path, std::string* out); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ diff --git a/internal/specialize.h b/internal/specialize.h deleted file mode 100644 index 1837c4ee3..000000000 --- a/internal/specialize.h +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Classes and helpers to allow preferring a specialized function, even if it - * is tied for "most specialized" with a generic implementation. - * Two flavors are supported. The first works with function overloads: - * - * template void foo(T&& v, generic); - * template void foo(T&& v, specialize_ift>); - * template void foo(T&& v) { - * return foo(std::forward(v), specialize()); - * } - * - * In this case, if T is a pointer type, the second `foo` overload is chosen. - * Without general/specialize, a call to `foo` with a pointer type would be - * ambiguous. - * - * The second flavor works with class secializations: - * - * template - * struct Foo { - * ... default impl ... - * }; - * - * template - * struct Foo>> { - * ... specialized impl ... - * }; - * - * In this case, if T is a string type, the second `Foo` implemenation is - * chosen. Without specialize, each string type would have to be instantiated - * explicitly. - * - * Specialize helper functions come in three different flavors. A traling 't' - * indicates a type containing a value is expected. A trailing 'd' indicates - * that the arguments just need to be defined. - */ - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_SPECIALIZE_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_SPECIALIZE_H_ - -#include - -namespace google { -namespace api { -namespace expr { -namespace internal { - -struct general {}; -struct specialize : general {}; - -/** Returns an instance of T for use in decltype expressions. */ -template -T inst_of(); - -/** - * Resolves to std::true_type, when all args are valid, otherwise is - * not defined. - */ -template -struct is_defined : std::true_type {}; - -/** - * A type that is only defined if the template argument is `true`. - * - * Not that at least one dependent type argument is required when used to - * select between overloads with identical signatures. A compile time error - * with a message similar to "'enable_if' cannot be used to disable this - * declaration" is produced when this is not the case. Any local template - * argument can be added to `Args` to resolve this issue. - * - * @tparam B The value to be tested. - * @tparam T The type returned when B is true. - * @tparam Args Any additional types that must resolve or that are needed to - * distinguish between two overload. - */ -template -using specialize_if = - typename std::enable_if::value, T>::type; - -/** - * A type that is only defined if the template argument defines a constexpr - * 'value' that resolves to 'true'. - * - * For use with many std::* helper types. - */ -template -using specialize_ift = specialize_if; - -/** - * A type that is only defined if the template arguments are defined. - */ -template -using specialize_ifd = specialize_if; - -/** - * A type that is only defined if all `Arg` types can be resolved. - * - * Used to take advantage of SFINAE so that a specialization considered - * if all `Arg` types can be resolved. See `is_ptr` for an example. - */ -template -using specialize_for = specialize_ifd; - -/** - * A type that is only defined the return type of C is R. - */ -template -using specialize_if_returns = - specialize_ift::type>, T>; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_SPECIALIZE_H_ diff --git a/internal/status_builder.h b/internal/status_builder.h new file mode 100644 index 000000000..9caa6c462 --- /dev/null +++ b/internal/status_builder.h @@ -0,0 +1,114 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_BUILDER_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" + +namespace cel::internal { + +class StatusBuilder; + +template +inline constexpr bool StatusBuilderResultMatches = + std::is_same_v>, Expected>; + +template +using StatusBuilderPurePolicy = std::enable_if_t< + StatusBuilderResultMatches, + std::invoke_result_t>; + +template +using StatusBuilderSideEffect = + std::enable_if_t, + std::invoke_result_t>; + +template +using StatusBuilderConversion = std::enable_if_t< + !StatusBuilderResultMatches && + !StatusBuilderResultMatches, + std::invoke_result_t>; + +class StatusBuilder final { + public: + StatusBuilder() = default; + + explicit StatusBuilder(const absl::Status& status) : status_(status) {} + + StatusBuilder(const StatusBuilder&) = default; + + StatusBuilder(StatusBuilder&&) = default; + + ~StatusBuilder() = default; + + StatusBuilder& operator=(const StatusBuilder&) = default; + + StatusBuilder& operator=(StatusBuilder&&) = default; + + bool ok() const { return status_.ok(); } + + absl::StatusCode code() const { return status_.code(); } + + operator absl::Status() const& { return status_; } // NOLINT + + operator absl::Status() && { return std::move(status_); } // NOLINT + + template + auto With( + Adaptor&& adaptor) & -> StatusBuilderPurePolicy { + return std::forward(adaptor)(*this); + } + template + ABSL_MUST_USE_RESULT auto With( + Adaptor&& + adaptor) && -> StatusBuilderPurePolicy { + return std::forward(adaptor)(std::move(*this)); + } + + template + auto With( + Adaptor&& adaptor) & -> StatusBuilderSideEffect { + return std::forward(adaptor)(*this); + } + template + ABSL_MUST_USE_RESULT auto With( + Adaptor&& + adaptor) && -> StatusBuilderSideEffect { + return std::forward(adaptor)(std::move(*this)); + } + + template + auto With( + Adaptor&& adaptor) & -> StatusBuilderConversion { + return std::forward(adaptor)(*this); + } + template + ABSL_MUST_USE_RESULT auto With( + Adaptor&& + adaptor) && -> StatusBuilderConversion { + return std::forward(adaptor)(std::move(*this)); + } + + private: + absl::Status status_; +}; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_BUILDER_H_ diff --git a/internal/status_macros.h b/internal/status_macros.h new file mode 100644 index 000000000..a4b662df6 --- /dev/null +++ b/internal/status_macros.h @@ -0,0 +1,152 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_MACROS_H_ + +#include + +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "internal/status_builder.h" + +#define CEL_RETURN_IF_ERROR(expr) \ + CEL_INTERNAL_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (::cel::internal::StatusAdaptor cel_internal_status_macro = {(expr)}) { \ + } else /* NOLINT */ \ + return cel_internal_status_macro.Consume() + +// The GNU compiler historically emitted warnings for obscure usages of +// `if (foo) if (bar) {} else`. This suppresses that. + +// clang-format off +#define CEL_INTERNAL_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + switch (0) case 0: default: /* NOLINT */ +// clang-format on + +#define CEL_ASSIGN_OR_RETURN(...) \ + CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_( \ + (__VA_ARGS__, CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_3_, \ + CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_2_)) \ + (__VA_ARGS__) + +// The following are macro magic to select either the 2 arg variant or 3 arg +// variant of CEL_ASSIGN_OR_RETURN. + +#define CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) \ + NAME +#define CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_(args) \ + CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_HELPER_ args + +#define CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ + CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_( \ + CEL_INTERNAL_STATUS_MACROS_CONCAT(_status_or_value, __LINE__), lhs, \ + rexpr, \ + return absl::Status(std::move(CEL_INTERNAL_STATUS_MACROS_CONCAT( \ + _status_or_value, __LINE__)) \ + .status())) + +#define CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_3_(lhs, rexpr, \ + error_expression) \ + CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_( \ + CEL_INTERNAL_STATUS_MACROS_CONCAT(_status_or_value, __LINE__), lhs, \ + rexpr, \ + ::cel::internal::StatusBuilder _( \ + std::move( \ + CEL_INTERNAL_STATUS_MACROS_CONCAT(_status_or_value, __LINE__)) \ + .status()); \ + (void)_; /* error_expression is allowed to not use this variable */ \ + return (error_expression)) + +// Common implementation of CEL_ASSIGN_OR_RETURN. Both the 2 arg variant and 3 +// arg variant are implemented by this macro. + +#define CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ + error_expression) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + error_expression; \ + } \ + CEL_INTERNAL_STATUS_MACROS_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ + std::move(statusor).value() + +#define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER(...) \ + CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1)) + +// MSVC historically expands variadic macros incorrectly, so another level of +// indirection is required. +#define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_HELPER(args) \ + CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_I args +#define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) \ + is_empty + +#define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY(...) \ + CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_I(__VA_ARGS__) +#define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_I(...) \ + CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER(_, ##__VA_ARGS__) + +#define CEL_INTERNAL_STATUS_MACROS_IF_1(_Then, _Else) _Then +#define CEL_INTERNAL_STATUS_MACROS_IF_0(_Then, _Else) _Else +#define CEL_INTERNAL_STATUS_MACROS_IF(_Cond, _Then, _Else) \ + CEL_INTERNAL_STATUS_MACROS_CONCAT(CEL_INTERNAL_STATUS_MACROS_IF_, _Cond) \ + (_Then, _Else) + +#define CEL_INTERNAL_STATUS_MACROS_EAT(...) +#define CEL_INTERNAL_STATUS_MACROS_REM(...) __VA_ARGS__ +#define CEL_INTERNAL_STATUS_MACROS_EMPTY() + +// Expands to 1 if the input is surrounded by parenthesis, 0 otherwise. +#define CEL_INTERNAL_STATUS_MACROS_IS_PARENTHESIZED(...) \ + CEL_INTERNAL_STATUS_MACROS_IS_EMPTY( \ + CEL_INTERNAL_STATUS_MACROS_EAT __VA_ARGS__) + +// If the input is surrounded by parenthesis, remove them. Otherwise expand it +// unchanged. +#define CEL_INTERNAL_STATUS_MACROS_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ + CEL_INTERNAL_STATUS_MACROS_IF( \ + CEL_INTERNAL_STATUS_MACROS_IS_PARENTHESIZED(__VA_ARGS__), \ + CEL_INTERNAL_STATUS_MACROS_REM, CEL_INTERNAL_STATUS_MACROS_EMPTY()) \ + __VA_ARGS__ + +#define CEL_INTERNAL_STATUS_MACROS_CONCAT_HELPER(x, y) x##y +#define CEL_INTERNAL_STATUS_MACROS_CONCAT(x, y) \ + CEL_INTERNAL_STATUS_MACROS_CONCAT_HELPER(x, y) + +namespace cel::internal { + +class StatusAdaptor final { + public: + StatusAdaptor() = default; + + StatusAdaptor(const StatusAdaptor&) = delete; + + StatusAdaptor(StatusAdaptor&&) = delete; + + StatusAdaptor(const absl::Status& status) : builder_(status) {} // NOLINT + + StatusAdaptor& operator=(const StatusAdaptor&) = delete; + + StatusAdaptor& operator=(StatusAdaptor&&) = delete; + + StatusBuilder&& Consume() { return std::move(builder_); } + + explicit operator bool() const { return ABSL_PREDICT_TRUE(builder_.ok()); } + + private: + StatusBuilder builder_; +}; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_MACROS_H_ diff --git a/internal/status_util.cc b/internal/status_util.cc deleted file mode 100644 index 893156c4d..000000000 --- a/internal/status_util.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include "internal/status_util.h" - -#include "google/rpc/code.pb.h" -#include "absl/strings/str_cat.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -namespace { - -google::rpc::Status NewStatus(google::rpc::Code code, - absl::string_view message) { - google::rpc::Status error; - error.set_code(code); - error.set_message(std::string(message)); - return error; -} - -} // namespace - -google::rpc::Status InvalidArgumentError(absl::string_view message) { - return NewStatus(google::rpc::Code::INVALID_ARGUMENT, message); -} - -google::rpc::Status NotFoundError(absl::string_view message) { - return NewStatus(google::rpc::Code::NOT_FOUND, message); -} - -google::rpc::Status UnimplementedError(absl::string_view message) { - return NewStatus(google::rpc::Code::UNIMPLEMENTED, message); -} - -google::rpc::Status OutOfRangeError(absl::string_view message) { - return NewStatus(google::rpc::Code::OUT_OF_RANGE, message); -} - -google::rpc::Status CancelledError() { - return NewStatus(google::rpc::Code::CANCELLED, ""); -} - -google::rpc::Status InternalError(absl::string_view message) { - return NewStatus(google::rpc::Code::INTERNAL, message); -} - -google::rpc::Status OutOfRangeError(size_t index, size_t size) { - return OutOfRangeError(absl::StrCat(index, " exceeds size ", size)); -} - -google::rpc::Status NoSuchCall(absl::string_view call, - absl::string_view full_type_name) { - return NotFoundError(absl::StrCat(call, " not found in ", full_type_name)); -} -google::rpc::Status NoSuchMember(absl::string_view member, - absl::string_view full_type_name) { - return NotFoundError(absl::StrCat(member, " not found in ", full_type_name)); -} -google::rpc::Status UnknownType(absl::string_view full_type_name) { - return InvalidArgumentError(absl::StrCat("Unknown type: ", full_type_name)); -} -google::rpc::Status ParseError(absl::string_view full_type_name) { - return InvalidArgumentError(absl::StrCat("Could not parse ", full_type_name)); -} - -google::rpc::Status UnexpectedType(absl::string_view full_type_name, - absl::string_view context) { - return InvalidArgumentError( - absl::StrCat("Unexpected type, ", full_type_name, ", in ", context)); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/status_util.h b/internal/status_util.h deleted file mode 100644 index dfa299477..000000000 --- a/internal/status_util.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_UTIL_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_UTIL_H_ - -#include "google/rpc/code.pb.h" -#include "google/rpc/status.pb.h" -#include "absl/strings/string_view.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -// Helper functions to create common error values. -google::rpc::Status InvalidArgumentError(absl::string_view message); -google::rpc::Status NotFoundError(absl::string_view message); -google::rpc::Status UnimplementedError(absl::string_view message); -google::rpc::Status OutOfRangeError(absl::string_view message); -google::rpc::Status InternalError(absl::string_view message); - -google::rpc::Status CancelledError(); - -google::rpc::Status OutOfRangeError(size_t index, size_t size); -google::rpc::Status NoSuchCall(absl::string_view call, - absl::string_view full_type_name); -google::rpc::Status NoSuchMember(absl::string_view member, - absl::string_view full_type_name); -google::rpc::Status UnexpectedType(absl::string_view full_type_name, - absl::string_view context); - -inline google::rpc::Status NoSuchKey(absl::string_view key_as_string) { - return NoSuchMember(key_as_string, "map"); -} - -google::rpc::Status UnknownType(absl::string_view full_type_name); -google::rpc::Status ParseError(absl::string_view full_type_name); - -inline google::rpc::Status OkStatus() { return google::rpc::Status(); } -inline bool IsOk(const google::rpc::Status& status) { - return status.code() == google::rpc::Code::OK; -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_UTIL_H_ diff --git a/internal/string_pool.cc b/internal/string_pool.cc new file mode 100644 index 000000000..b38c45c7f --- /dev/null +++ b/internal/string_pool.cc @@ -0,0 +1,79 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/string_pool.h" + +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/arena.h" + +namespace cel::internal { + +absl::string_view StringPool::InternString(absl::string_view string) { + if (string.empty()) { + return ""; + } + return *strings_.lazy_emplace(string, [&](const auto& ctor) { + char* data = + reinterpret_cast(arena()->AllocateAligned(string.size())); + std::memcpy(data, string.data(), string.size()); + ctor(absl::string_view(data, string.size())); + }); +} + +absl::string_view StringPool::InternString(std::string&& string) { + if (string.empty()) { + return ""; + } + return *strings_.lazy_emplace(string, [&](const auto& ctor) { + if (string.size() <= sizeof(std::string)) { + char* data = + reinterpret_cast(arena()->AllocateAligned(string.size())); + std::memcpy(data, string.data(), string.size()); + ctor(absl::string_view(data, string.size())); + } else { + google::protobuf::Arena* arena = this->arena(); + ABSL_ASSUME(arena != nullptr); + ctor(absl::string_view( + *google::protobuf::Arena::Create(arena, std::move(string)))); + } + }); +} + +absl::string_view StringPool::InternString(const absl::Cord& string) { + if (string.empty()) { + return ""; + } + return *strings_.lazy_emplace(string, [&](const auto& ctor) { + char* data = + reinterpret_cast(arena()->AllocateAligned(string.size())); + absl::Cord::CharIterator string_begin = string.char_begin(); + const absl::Cord::CharIterator string_end = string.char_end(); + char* p = data; + while (string_begin != string_end) { + absl::string_view chunk = absl::Cord::ChunkRemaining(string_begin); + std::memcpy(p, chunk.data(), chunk.size()); + p += chunk.size(); + absl::Cord::Advance(&string_begin, chunk.size()); + } + ctor(absl::string_view(data, string.size())); + }); +} + +} // namespace cel::internal diff --git a/internal/string_pool.h b/internal/string_pool.h new file mode 100644 index 000000000..8028107ab --- /dev/null +++ b/internal/string_pool.h @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STRING_POOL_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_STRING_POOL_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/die_if_null.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/arena.h" + +namespace cel::internal { + +// `StringPool` efficiently performs string interning using `google::protobuf::Arena`. +// +// This class is thread compatible, but typically requires external +// synchronization or serial usage. +class StringPool final { + public: + explicit StringPool( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) + : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK + + google::protobuf::Arena* absl_nonnull arena() const { return arena_; } + + absl::string_view InternString(const char* absl_nullable string) { + return InternString(absl::NullSafeStringView(string)); + } + + absl::string_view InternString(absl::string_view string); + + absl::string_view InternString(std::string&& string); + + absl::string_view InternString(const absl::Cord& string); + + private: + google::protobuf::Arena* absl_nonnull const arena_; + absl::flat_hash_set strings_; +}; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_STRING_POOL_H_ diff --git a/internal/string_pool_test.cc b/internal/string_pool_test.cc new file mode 100644 index 000000000..8bc2765dc --- /dev/null +++ b/internal/string_pool_test.cc @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/string_pool.h" + +#include "absl/strings/string_view.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel::internal { +namespace { + +TEST(StringPool, EmptyString) { + google::protobuf::Arena arena; + StringPool string_pool(&arena); + absl::string_view interned_string = string_pool.InternString(""); + EXPECT_EQ(interned_string.data(), string_pool.InternString("").data()); +} + +TEST(StringPool, InternString) { + google::protobuf::Arena arena; + StringPool string_pool(&arena); + absl::string_view interned_string = string_pool.InternString("Hello, world!"); + EXPECT_EQ(interned_string.data(), + string_pool.InternString("Hello, world!").data()); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/strings.cc b/internal/strings.cc new file mode 100644 index 000000000..a272aaa46 --- /dev/null +++ b/internal/strings.cc @@ -0,0 +1,693 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/strings.h" + +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/cord.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/lexis.h" +#include "internal/unicode.h" +#include "internal/utf8.h" + +namespace cel::internal { + +namespace { + +constexpr char kHexTable[] = "0123456789abcdef"; + +constexpr int HexDigitToInt(char x) { + if (x > '9') { + x += 9; + } + return x & 0xf; +} + +constexpr bool IsOctalDigit(char x) { return x >= '0' && x <= '7'; } + +// Returns true when following conditions are met: +// - is a suffix of . +// - No other unescaped occurrence of inside (apart from +// being a suffix). +// Returns false otherwise. If is non-NULL, returns an error message in +// . If is non-NULL, returns the offset in that +// corresponds to the location of the error. +bool CheckForClosingString(absl::string_view source, + absl::string_view closing_str, std::string* error) { + if (closing_str.empty()) return true; + + const char* p = source.data(); + const char* end = p + source.size(); + + bool is_closed = false; + while (p + closing_str.length() <= end) { + if (*p != '\\') { + size_t cur_pos = p - source.data(); + bool is_closing = + absl::StartsWith(absl::ClippedSubstr(source, cur_pos), closing_str); + if (is_closing && p + closing_str.length() < end) { + if (error) { + *error = + absl::StrCat("String cannot contain unescaped ", closing_str); + } + return false; + } + is_closed = is_closing && (p + closing_str.length() == end); + } else { + p++; // Read past the escaped character. + } + p++; + } + + if (!is_closed) { + if (error) { + *error = absl::StrCat("String must end with ", closing_str); + } + return false; + } + + return true; +} + +// ---------------------------------------------------------------------- +// CUnescapeInternal() +// Unescapes C escape sequences and is the reverse of CEscape(). +// +// If 'source' is valid, stores the unescaped string and its size in +// 'dest' and 'dest_len' respectively, and returns true. Otherwise +// returns false and optionally stores the error description in +// 'error' and the error offset in 'error_offset'. If 'error' is +// nonempty on return, 'error_offset' is in range [0, str.size()]. +// Set 'error' and 'error_offset' to NULL to disable error reporting. +// +// 'dest' must point to a buffer that is at least as big as 'source'. The +// unescaped string cannot grow bigger than the source string since no +// unescaped sequence is longer than the corresponding escape sequence. +// 'source' and 'dest' must not be the same. +// +// If is non-empty, for to be valid: +// - It must end with . +// - Should not contain any other unescaped occurrence of . +// ---------------------------------------------------------------------- +bool UnescapeInternal(absl::string_view source, absl::string_view closing_str, + bool is_raw_literal, bool is_bytes_literal, + std::string* dest, std::string* error) { + if (!CheckForClosingString(source, closing_str, error)) { + return false; + } + + if (ABSL_PREDICT_FALSE(source.empty())) { + *dest = std::string(); + return true; + } + + // Strip off the closing_str from the end before unescaping. + source = source.substr(0, source.size() - closing_str.size()); + if (!is_bytes_literal) { + if (!Utf8IsValid(source)) { + if (error) { + *error = absl::StrCat("Structurally invalid UTF8 string: ", + EscapeBytes(source)); + } + return false; + } + } + + dest->reserve(source.size()); + + const char* p = source.data(); + const char* end = p + source.size(); + const char* last_byte = end - 1; + + while (p < end) { + if (*p != '\\') { + if (*p != '\r') { + dest->push_back(*p++); + } else { + // All types of newlines in different platforms i.e. '\r', '\n', '\r\n' + // are replaced with '\n'. + dest->push_back('\n'); + p++; + if (p < end && *p == '\n') { + p++; + } + } + } else { + if ((p + 1) > last_byte) { + if (error) { + *error = is_raw_literal + ? "Raw literals cannot end with odd number of \\" + : is_bytes_literal ? "Bytes literal cannot end with \\" + : "String literal cannot end with \\"; + } + return false; + } + if (is_raw_literal) { + // For raw literals, all escapes are valid and those characters ('\\' + // and the escaped character) come through literally in the string. + dest->push_back(*p++); + dest->push_back(*p++); + continue; + } + // Any error that occurs in the escape is accounted to the start of + // the escape. + p++; // Read past the escape character. + + switch (*p) { + case 'a': + dest->push_back('\a'); + break; + case 'b': + dest->push_back('\b'); + break; + case 'f': + dest->push_back('\f'); + break; + case 'n': + dest->push_back('\n'); + break; + case 'r': + dest->push_back('\r'); + break; + case 't': + dest->push_back('\t'); + break; + case 'v': + dest->push_back('\v'); + break; + case '\\': + dest->push_back('\\'); + break; + case '?': + dest->push_back('\?'); + break; // \? Who knew? + case '\'': + dest->push_back('\''); + break; + case '"': + dest->push_back('\"'); + break; + case '`': + dest->push_back('`'); + break; + case '0': + ABSL_FALLTHROUGH_INTENDED; + case '1': + ABSL_FALLTHROUGH_INTENDED; + case '2': + ABSL_FALLTHROUGH_INTENDED; + case '3': { + // Octal escape '\ddd': requires exactly 3 octal digits. Note that + // the highest valid escape sequence is '\377'. + // For string literals, octal and hex escape sequences are interpreted + // as unicode code points, and the related UTF8-encoded character is + // added to the destination. For bytes literals, octal and hex + // escape sequences are interpreted as a single byte value. + const char* octal_start = p; + if (p + 2 >= end) { + if (error) { + *error = + "Illegal escape sequence: Octal escape must be followed by 3 " + "octal digits but saw: \\" + + std::string(octal_start, end - p); + } + // Error offset was set to the start of the escape above the switch. + return false; + } + const char* octal_end = p + 2; + char32_t ch = 0; + for (; p <= octal_end; ++p) { + if (IsOctalDigit(*p)) { + ch = ch * 8 + *p - '0'; + } else { + if (error) { + *error = + "Illegal escape sequence: Octal escape must be followed by " + "3 octal digits but saw: \\" + + std::string(octal_start, 3); + } + // Error offset was set to the start of the escape above the + // switch. + return false; + } + } + p = octal_end; // p points at last digit. + if (is_bytes_literal) { + dest->push_back(static_cast(ch)); + } else { + Utf8Encode(*dest, ch); + } + break; + } + case 'x': + ABSL_FALLTHROUGH_INTENDED; + case 'X': { + // Hex escape '\xhh': requires exactly 2 hex digits. + // For string literals, octal and hex escape sequences are + // interpreted as unicode code points, and the related UTF8-encoded + // character is added to the destination. For bytes literals, octal + // and hex escape sequences are interpreted as a single byte value. + const char* hex_start = p; + if (p + 2 >= end) { + if (error) { + *error = + "Illegal escape sequence: Hex escape must be followed by 2 " + "hex digits but saw: \\" + + std::string(hex_start, end - p); + } + // Error offset was set to the start of the escape above the switch. + return false; + } + char32_t ch = 0; + const char* hex_end = p + 2; + for (++p; p <= hex_end; ++p) { + if (absl::ascii_isxdigit(*p)) { + ch = (ch << 4) + HexDigitToInt(*p); + } else { + if (error) { + *error = + "Illegal escape sequence: Hex escape must be followed by 2 " + "hex digits but saw: \\" + + std::string(hex_start, 3); + } + // Error offset was set to the start of the escape above the + // switch. + return false; + } + } + p = hex_end; // p points at last digit. + if (is_bytes_literal) { + dest->push_back(static_cast(ch)); + } else { + Utf8Encode(*dest, ch); + } + break; + } + case 'u': { + if (is_bytes_literal) { + if (error) { + *error = + std::string( + "Illegal escape sequence: Unicode escape sequence \\") + + *p + " cannot be used in bytes literals"; + } + // Error offset was set to the start of the escape above the switch. + return false; + } + // \uhhhh => Read 4 hex digits as a code point, + // then write it as UTF-8 bytes. + char32_t cp = 0; + const char* hex_start = p; + if (p + 4 >= end) { + if (error) { + *error = + "Illegal escape sequence: \\u must be followed by 4 hex " + "digits but saw: \\" + + std::string(hex_start, end - p); + } + // Error offset was set to the start of the escape above the switch. + return false; + } + for (int i = 0; i < 4; ++i) { + // Look one char ahead. + if (absl::ascii_isxdigit(p[1])) { + cp = (cp << 4) + HexDigitToInt(*++p); // Advance p. + } else { + if (error) { + *error = + "Illegal escape sequence: \\u must be followed by 4 " + "hex digits but saw: \\" + + std::string(hex_start, 5); + } + // Error offset was set to the start of the escape above the + // switch. + return false; + } + } + if (!UnicodeIsValid(cp)) { + if (error) { + *error = "Illegal escape sequence: Unicode value \\" + + std::string(hex_start, 5) + " is invalid"; + } + // Error offset was set to the start of the escape above the switch. + return false; + } + Utf8Encode(*dest, cp); + break; + } + case 'U': { + if (is_bytes_literal) { + if (error) { + *error = + std::string( + "Illegal escape sequence: Unicode escape sequence \\") + + *p + " cannot be used in bytes literals"; + } + return false; + } + // \Uhhhhhhhh => convert 8 hex digits to UTF-8. Note that the + // first two digits must be 00: The valid range is + // '\U00000000' to '\U0010FFFF' (excluding surrogates). + char32_t cp = 0; + const char* hex_start = p; + if (p + 8 >= end) { + if (error) { + *error = + "Illegal escape sequence: \\U must be followed by 8 hex " + "digits but saw: \\" + + std::string(hex_start, end - p); + } + // Error offset was set to the start of the escape above the switch. + return false; + } + for (int i = 0; i < 8; ++i) { + // Look one char ahead. + if (absl::ascii_isxdigit(p[1])) { + cp = (cp << 4) + HexDigitToInt(*++p); + if (cp > 0x10FFFF) { + if (error) { + *error = "Illegal escape sequence: Value of \\" + + std::string(hex_start, 9) + + " exceeds Unicode limit (0x0010FFFF)"; + } + // Error offset was set to the start of the escape above the + // switch. + return false; + } + } else { + if (error) { + *error = + "Illegal escape sequence: \\U must be followed by 8 " + "hex digits but saw: \\" + + std::string(hex_start, 9); + } + // Error offset was set to the start of the escape above the + // switch. + return false; + } + } + if (!UnicodeIsValid(cp)) { + if (error) { + *error = "Illegal escape sequence: Unicode value \\" + + std::string(hex_start, 9) + " is invalid"; + } + // Error offset was set to the start of the escape above the switch. + return false; + } + Utf8Encode(*dest, cp); + break; + } + case '\r': + ABSL_FALLTHROUGH_INTENDED; + case '\n': { + if (error) { + *error = "Illegal escaped newline"; + } + // Error offset was set to the start of the escape above the switch. + return false; + } + default: { + if (error) { + *error = std::string("Illegal escape sequence: \\") + *p; + } + // Error offset was set to the start of the escape above the switch. + return false; + } + } + p++; // read past letter we escaped + } + } + + dest->shrink_to_fit(); + + return true; +} + +std::string EscapeInternal(absl::string_view src, bool escape_all_bytes, + char escape_quote_char) { + std::string dest; + // Worst case size is every byte has to be hex escaped, so 4 char for every + // byte. + dest.reserve(src.size() * 4); + bool last_hex_escape = false; // true if last output char was \xNN. + const char* p = src.data(); + const char* end = p + src.size(); + for (; p < end; ++p) { + unsigned char c = static_cast(*p); + bool is_hex_escape = false; + switch (c) { + case '\n': + dest.append("\\n"); + break; + case '\r': + dest.append("\\r"); + break; + case '\t': + dest.append("\\t"); + break; + case '\\': + dest.append("\\\\"); + break; + case '\'': + ABSL_FALLTHROUGH_INTENDED; + case '\"': + ABSL_FALLTHROUGH_INTENDED; + case '`': + // Escape only quote chars that match escape_quote_char. + if (escape_quote_char == 0 || c == escape_quote_char) { + dest.push_back('\\'); + } + dest.push_back(c); + break; + default: + // Note that if we emit \xNN and the src character after that is a hex + // digit then that digit must be escaped too to prevent it being + // interpreted as part of the character code by C. + if ((!escape_all_bytes || c < 0x80) && + (!absl::ascii_isprint(c) || + (last_hex_escape && absl::ascii_isxdigit(c)))) { + dest.append("\\x"); + dest.push_back(kHexTable[c / 16]); + dest.push_back(kHexTable[c % 16]); + is_hex_escape = true; + } else { + dest.push_back(c); + break; + } + } + last_hex_escape = is_hex_escape; + } + dest.shrink_to_fit(); + return dest; +} + +bool MayBeTripleQuotedString(absl::string_view str) { + return (str.size() >= 6 && + ((absl::StartsWith(str, "\"\"\"") && absl::EndsWith(str, "\"\"\"")) || + (absl::StartsWith(str, "'''") && absl::EndsWith(str, "'''")))); +} + +bool MayBeStringLiteral(absl::string_view str) { + return (str.size() >= 2 && str[0] == str[str.size() - 1] && + (str[0] == '\'' || str[0] == '"')); +} + +bool MayBeBytesLiteral(absl::string_view str) { + return (str.size() >= 3 && absl::StartsWithIgnoreCase(str, "b") && + str[1] == str[str.size() - 1] && (str[1] == '\'' || str[1] == '"')); +} + +bool MayBeRawStringLiteral(absl::string_view str) { + return (str.size() >= 3 && absl::StartsWithIgnoreCase(str, "r") && + str[1] == str[str.size() - 1] && (str[1] == '\'' || str[1] == '"')); +} + +bool MayBeRawBytesLiteral(absl::string_view str) { + return (str.size() >= 4 && + (absl::StartsWithIgnoreCase(str, "rb") || + absl::StartsWithIgnoreCase(str, "br")) && + (str[2] == str[str.size() - 1]) && (str[2] == '\'' || str[2] == '"')); +} + +} // namespace + +absl::StatusOr UnescapeString(absl::string_view str) { + std::string out; + std::string error; + if (!UnescapeInternal(str, "", false, false, &out, &error)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid escaped string: ", error)); + } + return out; +} + +absl::StatusOr UnescapeBytes(absl::string_view str) { + std::string out; + std::string error; + if (!UnescapeInternal(str, "", false, true, &out, &error)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid escaped bytes: ", error)); + } + return out; +} + +std::string EscapeString(absl::string_view str) { + return EscapeInternal(str, true, '\0'); +} + +std::string EscapeBytes(absl::string_view str, bool escape_all_bytes, + char escape_quote_char) { + std::string escaped_bytes; + const char* p = str.data(); + const char* end = p + str.size(); + for (; p < end; ++p) { + unsigned char c = *p; + if (escape_all_bytes || !absl::ascii_isprint(c)) { + escaped_bytes += "\\x"; + escaped_bytes += absl::BytesToHexString(absl::string_view(p, 1)); + } else { + switch (c) { + // Note that we only handle printable escape characters here. All + // unprintable (\n, \r, \t, etc.) are hex escaped above. + case '\\': + escaped_bytes += "\\\\"; + break; + case '\'': + case '"': + case '`': + // Escape only quote chars that match escape_quote_char. + if (escape_quote_char == 0 || c == escape_quote_char) { + escaped_bytes += '\\'; + } + escaped_bytes += c; + break; + default: + escaped_bytes += c; + break; + } + } + } + return escaped_bytes; +} + +absl::StatusOr ParseStringLiteral(absl::string_view str) { + std::string out; + bool is_string_literal = MayBeStringLiteral(str); + bool is_raw_string_literal = MayBeRawStringLiteral(str); + if (!is_string_literal && !is_raw_string_literal) { + return absl::InvalidArgumentError("Invalid string literal"); + } + + absl::string_view copy_str = str; + if (is_raw_string_literal) { + // Strip off the prefix 'r' from the raw string content before parsing. + copy_str = absl::ClippedSubstr(copy_str, 1); + } + + bool is_triple_quoted = MayBeTripleQuotedString(copy_str); + // Starts after the opening quotes {""", '''} or {", '}. + int quotes_length = is_triple_quoted ? 3 : 1; + absl::string_view quotes = copy_str.substr(0, quotes_length); + copy_str = absl::ClippedSubstr(copy_str, quotes_length); + std::string error; + if (!UnescapeInternal(copy_str, quotes, is_raw_string_literal, false, &out, + &error)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid string literal: ", error)); + } + return out; +} + +absl::StatusOr ParseBytesLiteral(absl::string_view str) { + std::string out; + bool is_bytes_literal = MayBeBytesLiteral(str); + bool is_raw_bytes_literal = MayBeRawBytesLiteral(str); + if (!is_bytes_literal && !is_raw_bytes_literal) { + return absl::InvalidArgumentError("Invalid bytes literal"); + } + + absl::string_view copy_str = str; + if (is_raw_bytes_literal) { + // Strip off the prefix {"rb", "br"} from the raw bytes content before + copy_str = absl::ClippedSubstr(copy_str, 2); + } else { + // Strip off the prefix 'b' from the bytes content before parsing. + copy_str = absl::ClippedSubstr(copy_str, 1); + } + + bool is_triple_quoted = MayBeTripleQuotedString(copy_str); + // Starts after the opening quotes {""", '''} or {", '}. + int quotes_length = is_triple_quoted ? 3 : 1; + absl::string_view quotes = copy_str.substr(0, quotes_length); + // Includes the closing quotes. + copy_str = absl::ClippedSubstr(copy_str, quotes_length); + std::string error; + if (!UnescapeInternal(copy_str, quotes, is_raw_bytes_literal, true, &out, + &error)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid bytes literal: ", error)); + } + return out; +} + +std::string FormatStringLiteral(absl::string_view str) { + absl::string_view quote = + (str.find('"') != str.npos && str.find('\'') == str.npos) ? "'" : "\""; + return absl::StrCat(quote, EscapeInternal(str, true, quote[0]), quote); +} + +std::string FormatStringLiteral(const absl::Cord& str) { + if (auto flat = str.TryFlat(); flat) { + return FormatStringLiteral(*flat); + } + return FormatStringLiteral(static_cast(str)); +} + +std::string FormatSingleQuotedStringLiteral(absl::string_view str) { + return absl::StrCat("'", EscapeInternal(str, true, '\''), "'"); +} + +std::string FormatDoubleQuotedStringLiteral(absl::string_view str) { + return absl::StrCat("\"", EscapeInternal(str, true, '"'), "\""); +} + +std::string FormatBytesLiteral(absl::string_view str) { + absl::string_view quote = + (str.find('"') != str.npos && str.find('\'') == str.npos) ? "'" : "\""; + return absl::StrCat("b", quote, EscapeBytes(str, false, quote[0]), quote); +} + +std::string FormatSingleQuotedBytesLiteral(absl::string_view str) { + return absl::StrCat("b'", EscapeBytes(str, false, '\''), "'"); +} + +std::string FormatDoubleQuotedBytesLiteral(absl::string_view str) { + return absl::StrCat("b\"", EscapeBytes(str, false, '"'), "\""); +} + +absl::StatusOr ParseIdentifier(absl::string_view str) { + if (!LexisIsIdentifier(str)) { + return absl::InvalidArgumentError("Invalid identifier"); + } + return std::string(str); +} + +} // namespace cel::internal diff --git a/internal/strings.h b/internal/strings.h new file mode 100644 index 000000000..ae82a14fd --- /dev/null +++ b/internal/strings.h @@ -0,0 +1,90 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STRINGS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_STRINGS_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +// Expand escaped characters according to CEL escaping rules. +// This is for raw strings with no quoting. +absl::StatusOr UnescapeString(absl::string_view str); + +// Expand escaped characters according to CEL escaping rules. +// Rules for bytes values are slightly different than those for strings. This +// is for raw literals with no quoting. +absl::StatusOr UnescapeBytes(absl::string_view str); + +// Escape a string without quoting it. All quote characters are escaped. +std::string EscapeString(absl::string_view str); + +// Escape a bytes value without quoting it. Escaped bytes use hex escapes. +// If is true then all bytes are escaped. Otherwise only +// unprintable bytes and escape/quote characters are escaped. +// If is not 0, then quotes that do not match are not +// escaped. +std::string EscapeBytes(absl::string_view str, bool escape_all_bytes = false, + char escape_quote_char = '\0'); + +// Unquote and unescape a quoted CEL string literal (of the form '...', +// "...", r'...' or r"..."). +// If an error occurs and is not NULL, then it is populated with +// the relevant error message. If is not NULL, it is populated +// with the offset in at which the invalid input occurred. +absl::StatusOr ParseStringLiteral(absl::string_view str); + +// Unquote and unescape a CEL bytes literal (of the form b'...', +// b"...", rb'...', rb"...", br'...' or br"..."). +// If an error occurs and is not NULL, then it is populated with +// the relevant error message. If is not NULL, it is populated +// with the offset in at which the invalid input occurred. +absl::StatusOr ParseBytesLiteral(absl::string_view str); + +// Return a quoted and escaped CEL string literal for . +// May choose to quote with ' or " to produce nicer output. +std::string FormatStringLiteral(absl::string_view str); +std::string FormatStringLiteral(const absl::Cord& str); + +// Return a quoted and escaped CEL string literal for . +// Always uses single quotes. +std::string FormatSingleQuotedStringLiteral(absl::string_view str); + +// Return a quoted and escaped CEL string literal for . +// Always uses double quotes. +std::string FormatDoubleQuotedStringLiteral(absl::string_view str); + +// Return a quoted and escaped CEL bytes literal for . +// Prefixes with b and may choose to quote with ' or " to produce nicer output. +std::string FormatBytesLiteral(absl::string_view str); + +// Return a quoted and escaped CEL bytes literal for . +// Prefixes with b and always uses single quotes. +std::string FormatSingleQuotedBytesLiteral(absl::string_view str); + +// Return a quoted and escaped CEL bytes literal for . +// Prefixes with b and always uses double quotes. +std::string FormatDoubleQuotedBytesLiteral(absl::string_view str); + +// Parse a CEL identifier. +absl::StatusOr ParseIdentifier(absl::string_view str); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_STRINGS_H_ diff --git a/internal/strings_test.cc b/internal/strings_test.cc new file mode 100644 index 000000000..fcdb6d4ec --- /dev/null +++ b/internal/strings_test.cc @@ -0,0 +1,876 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/strings.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/cord.h" +#include "absl/strings/cord_test_helpers.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "internal/testing.h" +#include "internal/utf8.h" + +namespace cel::internal { +namespace { + +using ::absl_testing::StatusIs; + +constexpr char kUnicodeNotAllowedInBytes1[] = + "Unicode escape sequence \\u cannot be used in bytes literals"; +constexpr char kUnicodeNotAllowedInBytes2[] = + "Unicode escape sequence \\U cannot be used in bytes literals"; + +// takes a string literal of the form '...', r'...', "..." or r"...". +// is the expected parsed form of . +void TestQuotedString(const std::string& unquoted, const std::string& quoted) { + auto status_or_unquoted = ParseStringLiteral(quoted); + EXPECT_OK(status_or_unquoted) << unquoted; + EXPECT_EQ(unquoted, status_or_unquoted.value()) << quoted; +} + +void TestString(const std::string& unquoted) { + TestQuotedString(unquoted, FormatStringLiteral(unquoted)); + TestQuotedString(unquoted, FormatStringLiteral(absl::Cord(unquoted))); + if (unquoted.size() > 1) { + const size_t mid = unquoted.size() / 2; + TestQuotedString(unquoted, FormatStringLiteral(absl::MakeFragmentedCord( + {absl::string_view(unquoted).substr(0, mid), + absl::string_view(unquoted).substr(mid)}))); + } + TestQuotedString(unquoted, + absl::StrCat("'''", EscapeString(unquoted), "'''")); + TestQuotedString(unquoted, + absl::StrCat("\"\"\"", EscapeString(unquoted), "\"\"\"")); +} + +void TestRawString(const std::string& unquoted) { + const std::string quote = (!absl::StrContains(unquoted, "'")) ? "'" : "\""; + TestQuotedString(unquoted, absl::StrCat("r", quote, unquoted, quote)); + TestQuotedString(unquoted, absl::StrCat("r\"", unquoted, "\"")); + TestQuotedString(unquoted, absl::StrCat("r'''", unquoted, "'''")); + TestQuotedString(unquoted, absl::StrCat("r\"\"\"", unquoted, "\"\"\"")); +} + +// is the quoted version of and represents the original +// string mentioned in the test case. +// This method compares the unescaped against its round trip version +// i.e. after carrying out escaping followed by unescaping on it. +void TestBytesEscaping(const std::string& unquoted, const std::string& quoted) { + ASSERT_OK_AND_ASSIGN(auto unescaped, UnescapeBytes(unquoted)); + const std::string escaped = EscapeBytes(unescaped); + ASSERT_OK_AND_ASSIGN(auto unescaped2, UnescapeBytes(escaped)); + EXPECT_EQ(unescaped, unescaped2); + std::string escaped2 = EscapeBytes(unescaped, true); + ASSERT_OK_AND_ASSIGN(auto unescaped3, UnescapeBytes(escaped2)); + EXPECT_EQ(unescaped, unescaped3); +} + +// takes a byte literal of the form b'...', b'''...''' +void TestBytesLiteral(const std::string& quoted) { + // Parse the literal. + ASSERT_OK_AND_ASSIGN(auto unquoted, ParseBytesLiteral(quoted)); + + // Take the parsed literal and turn it back to a literal. + std::string requoted = FormatBytesLiteral(unquoted); + // Parse it again. + ASSERT_OK_AND_ASSIGN(auto unquoted2, ParseBytesLiteral(requoted)); + // Test the parsed literal forms for equality, not the unparsed forms. + // This is because the unparsed forms can have different representations for + // the same data, i.e., \000 and \x00. + EXPECT_EQ(unquoted, unquoted2) + << "unquoted : " << unquoted << "\nunquoted2: " << unquoted2; + + TestBytesEscaping(unquoted, quoted); +} + +// takes a raw byte literal of the form rb'...', br'...', rb'''...''' +// or br'''...'''. is the expected parsed form of . +void TestQuotedRawBytesLiteral(const std::string& unquoted, + const std::string& quoted) { + ASSERT_OK_AND_ASSIGN(auto actual_unquoted, ParseBytesLiteral(quoted)); + EXPECT_EQ(unquoted, actual_unquoted) << "quoted: " << quoted; +} + +// takes a string of not escaped unquoted bytes. +void TestUnescapedBytes(const std::string& unquoted) { + TestBytesLiteral(FormatBytesLiteral(unquoted)); +} + +void TestRawBytes(const std::string& unquoted) { + const std::string quote = (!absl::StrContains(unquoted, "'")) ? "'" : "\""; + TestQuotedRawBytesLiteral(unquoted, + absl::StrCat("rb", quote, unquoted, quote)); + TestQuotedRawBytesLiteral(unquoted, + absl::StrCat("br", quote, unquoted, quote)); + TestQuotedRawBytesLiteral(unquoted, absl::StrCat("rb'''", unquoted, "'''")); + TestQuotedRawBytesLiteral(unquoted, absl::StrCat("br'''", unquoted, "'''")); + TestQuotedRawBytesLiteral(unquoted, + absl::StrCat("rb\"\"\"", unquoted, "\"\"\"")); + TestQuotedRawBytesLiteral(unquoted, + absl::StrCat("br\"\"\"", unquoted, "\"\"\"")); +} + +void TestParseString(const std::string& orig) { + EXPECT_OK(ParseStringLiteral(orig)) << orig; +} + +void TestParseBytes(const std::string& orig) { + EXPECT_OK(ParseBytesLiteral(orig)) << orig; +} + +void TestStringEscaping(const std::string& orig) { + const std::string escaped = EscapeString(orig); + ASSERT_OK_AND_ASSIGN(auto unescaped, UnescapeString(escaped)); + EXPECT_EQ(orig, unescaped) << "escaped: " << escaped; +} + +void TestValue(const std::string& orig) { + TestStringEscaping(orig); + TestString(orig); +} + +// Test that is treated as invalid, with error offset +// and an error that contains substring +// . The last arguments are optional because most +// flat-out bad inputs are rejected without further information. +void TestInvalidString(const std::string& str, + const std::string& expected_error_substr = "") { + auto status_or_string = ParseStringLiteral(str); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains(status_or_string.status().message(), + expected_error_substr)); +} + +// Test that is treated as invalid, with error offset +// and an error that contains substring +// . The last arguments are optional because most +// flat-out bad inputs are rejected without further information. +void TestInvalidBytes(const std::string& str, + const std::string& expected_error_substr = "") { + auto status_or_string = ParseBytesLiteral(str); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains(status_or_string.status().message(), + expected_error_substr)); +} + +TEST(StringsTest, TestParsingOfAllEscapeCharacters) { + // All the valid escapes. + const std::set valid_escapes = {'a', 'b', 'f', 'n', 'r', 't', + 'v', '\\', '?', '"', '\'', '`', + 'u', 'U', 'x', 'X'}; + for (int escape_char_int = 0; escape_char_int < 255; ++escape_char_int) { + char escape_char = static_cast(escape_char_int); + absl::string_view escape_piece(&escape_char, 1); + if (valid_escapes.find(escape_char) != valid_escapes.end()) { + if (escape_char == '\'') { + TestParseString(absl::StrCat("\"a\\", escape_piece, "0010ffff\"")); + } + TestParseString(absl::StrCat("'a\\", escape_piece, "0010ffff'")); + TestParseString(absl::StrCat("'''a\\", escape_piece, "0010ffff'''")); + } else if (absl::ascii_isdigit(escape_char)) { + // Can also escape 0-3. + const std::string test_string = + absl::StrCat("'a\\", escape_piece, "00b'"); + const std::string test_triple_quoted_string = + absl::StrCat("'''a\\", escape_piece, "00b'''"); + if (escape_char <= '3') { + TestParseString(test_string); + TestParseString(test_triple_quoted_string); + } else { + TestInvalidString(test_string, "Illegal escape sequence: "); + TestInvalidString(test_triple_quoted_string, + "Illegal escape sequence: "); + } + } else { + if (Utf8IsValid(escape_piece)) { + const std::string expected_error = + ((escape_char == '\n' || escape_char == '\r') + ? "Illegal escaped newline" + : "Illegal escape sequence: "); + TestInvalidString(absl::StrCat("'a\\", escape_piece, "b'"), + expected_error); + TestInvalidString(absl::StrCat("'''a\\", escape_piece, "b'''"), + expected_error); + } else { + TestInvalidString(absl::StrCat("'a\\", escape_piece, "b'"), + "Structurally invalid UTF8" // + " string"); + TestInvalidString(absl::StrCat("'''a\\", escape_piece, "b'''"), + "Structurally invalid UTF8" // + " string"); + } + } + } +} + +TEST(StringsTest, TestParsingOfOctalEscapes) { + for (int idx = 0; idx < 256; ++idx) { + const char end_char = (idx % 8) + '0'; + const char mid_char = ((idx / 8) % 8) + '0'; + const char lead_char = (idx / 64) + '0'; + absl::string_view lead_piece(&lead_char, 1); + absl::string_view mid_piece(&mid_char, 1); + absl::string_view end_piece(&end_char, 1); + const std::string test_string = + absl::StrCat(lead_piece, mid_piece, end_piece); + TestParseString(absl::StrCat("'\\", test_string, "'")); + TestParseString(absl::StrCat("'''\\", test_string, "'''")); + TestParseBytes(absl::StrCat("b'\\", test_string, "'")); + } + TestInvalidString("'\\'", "String must end with '"); + TestInvalidString("'abc\\'", "String must end with '"); + TestInvalidString("'''\\'''", "String must end with '''"); + TestInvalidString("'''abc\\'''", "String must end with '''"); + TestInvalidString( + "'\\0'", "Octal escape must be followed by 3 octal digits but saw: \\0"); + TestInvalidString( + "'''abc\\0'''", + "Octal escape must be followed by 3 octal digits but saw: \\0"); + TestInvalidString( + "'\\00'", + "Octal escape must be followed by 3 octal digits but saw: \\00"); + TestInvalidString( + "'''ab\\00'''", + "Octal escape must be followed by 3 octal digits but saw: \\00"); + TestInvalidString( + "'a\\008'", + "Octal escape must be followed by 3 octal digits but saw: \\008"); + TestInvalidString( + "'''\\008'''", + "Octal escape must be followed by 3 octal digits but saw: \\008"); + TestInvalidString("'\\400'", "Illegal escape sequence: \\4"); + TestInvalidString("'''\\400'''", "Illegal escape sequence: \\4"); + TestInvalidString("'\\777'", "Illegal escape sequence: \\7"); + TestInvalidString("'''\\777'''", "Illegal escape sequence: \\7"); +} + +TEST(StringsTest, TestParsingOfHexEscapes) { + for (int idx = 0; idx < 256; ++idx) { + char lead_char = absl::StrFormat("%X", idx / 16)[0]; + char end_char = absl::StrFormat("%x", idx % 16)[0]; + absl::string_view lead_piece(&lead_char, 1); + absl::string_view end_piece(&end_char, 1); + TestParseString(absl::StrCat("'\\x", lead_piece, end_piece, "'")); + TestParseString(absl::StrCat("'''\\x", lead_piece, end_piece, "'''")); + TestParseString(absl::StrCat("'\\X", lead_piece, end_piece, "'")); + TestParseString(absl::StrCat("'''\\X", lead_piece, end_piece, "'''")); + TestParseBytes(absl::StrCat("b'\\X", lead_piece, end_piece, "'")); + } + TestInvalidString("'\\x'"); + TestInvalidString("'''\\x'''"); + TestInvalidString("'\\x0'"); + TestInvalidString("'''\\x0'''"); + TestInvalidString("'\\x0G'"); + TestInvalidString("'''\\x0G'''"); +} + +TEST(StringsTest, RoundTrip) { + // Empty string is valid as a string but not an identifier. + TestStringEscaping(""); + TestString(""); + + TestValue("abc"); + TestValue("abc123"); + TestValue("123abc"); + TestValue("_abc123"); + TestValue("_123"); + TestValue("abc def"); + TestValue("a`b"); + TestValue("a77b"); + TestValue("\"abc\""); + TestValue("'abc'"); + TestValue("`abc`"); + TestValue("aaa'bbb\"ccc`ddd"); + TestValue("\n"); + TestValue("\\"); + TestValue("\\n"); + TestValue("\x12"); + TestValue("a,g 8q483 *(YG(*$(&*98fg\\r\\n\\t\x12gb"); + + // Value with an embedded zero char. + std::string s = "abc"; + s[1] = 0; + TestValue(s); + + // Reserved SQL keyword, which must be quoted as an identifier. + TestValue("select"); + TestValue("SELECT"); + TestValue("SElecT"); + // Non-reserved SQL keyword, which shouldn't be quoted. + TestValue("options"); + + // Note that control characters and other odd byte values such as \0 are + // allowed in string literals as long as they are utf8 structurally valid. + TestValue("\x01\x31"); + TestValue("abc\xb\x42\141bc"); + TestValue("123\1\x31\x32\x33"); + TestValue("\\\"\xe8\xb0\xb7\xe6\xad\x8c\\\" is Google\\\'s Chinese name"); +} + +TEST(StringsTest, InvalidString) { + const std::string kInvalidStringLiteral = "Invalid string literal"; + + TestInvalidString("A", kInvalidStringLiteral); // No quote at all + TestInvalidString("'", kInvalidStringLiteral); // No closing quote + TestInvalidString("\"", kInvalidStringLiteral); // No closing quote + TestInvalidString("a'", kInvalidStringLiteral); // No opening quote + TestInvalidString("a\"", kInvalidStringLiteral); // No opening quote + TestInvalidString("'''", "String cannot contain unescaped '"); + TestInvalidString("\"\"\"", "String cannot contain unescaped \""); + TestInvalidString("''''", "String cannot contain unescaped '"); + TestInvalidString("\"\"\"\"", "String cannot contain unescaped \""); + TestInvalidString("'''''", "String cannot contain unescaped '"); + TestInvalidString("\"\"\"\"\"", "String cannot contain unescaped \""); + TestInvalidString("'''''''", "String cannot contain unescaped '''"); + TestInvalidString("\"\"\"\"\"\"\"", "String cannot contain unescaped \"\"\""); + TestInvalidString("'''''''''", "String cannot contain unescaped '''"); + TestInvalidString("\"\"\"\"\"\"\"\"\"", + "String cannot contain unescaped \"\"\""); + + TestInvalidString("abc"); + + TestInvalidString("'abc'def'", "String cannot contain unescaped '"); + TestInvalidString("'abc''def'", "String cannot contain unescaped '"); + TestInvalidString("\"abc\"\"def\"", "String cannot contain unescaped \""); + TestInvalidString("'''abc'''def'''", "String cannot contain unescaped '''"); + TestInvalidString("\"\"\"abc\"\"\"def\"\"\"", + "String cannot contain unescaped \"\"\""); + + TestInvalidString("'abc"); + TestInvalidString("\"abc"); + TestInvalidString("'''abc"); + TestInvalidString("\"\"\"abc"); + + TestInvalidString("abc'"); + TestInvalidString("abc\""); + TestInvalidString("abc'''"); + TestInvalidString("abc\"\"\""); + + TestInvalidString("\"abc'"); + TestInvalidString("'abc\""); + TestInvalidString("'''abc'", "String cannot contain unescaped '"); + TestInvalidString("'''abc\""); + + TestInvalidString("'''a'", "String cannot contain unescaped '"); + TestInvalidString("\"\"\"a\"", "String cannot contain unescaped \""); + TestInvalidString("'''a''", "String cannot contain unescaped '"); + TestInvalidString("\"\"\"a\"\"", "String cannot contain unescaped \""); + TestInvalidString("'''a''''", "String cannot contain unescaped '''"); + TestInvalidString("\"\"\"a\"\"\"\"", + "String cannot contain unescaped \"\"\""); + + TestInvalidString("'''abc\"\"\""); + TestInvalidString("\"\"\"abc'"); + TestInvalidString("\"\"\"abc\"", "String cannot contain unescaped \""); + TestInvalidString("\"\"\"abc'''"); + TestInvalidString("'''\\\''''''", "String cannot contain unescaped '''"); + TestInvalidString("\"\"\"\\\"\"\"\"\"\"", + "String cannot contain unescaped \"\"\""); + TestInvalidString("''''\\\'''''", "String cannot contain unescaped '''"); + TestInvalidString("\"\"\"\"\\\"\"\"\"\"", + "String cannot contain unescaped \"\"\""); + TestInvalidString("\"\"\"'a' \"b\"\"\"\"", + "String cannot contain unescaped \"\"\""); + + TestInvalidString("`abc`"); + + TestInvalidString("'abc\\'", "String must end with '"); + TestInvalidString("\"abc\\\"", "String must end with \""); + TestInvalidString("'''abc\\'''", "String must end with '''"); + TestInvalidString("\"\"\"abc\\\"\"\"", "String must end with \"\"\""); + + TestInvalidString("'\\U12345678'", + "Value of \\U12345678 exceeds Unicode limit (0x0010FFFF)"); + + // All trailing escapes. + TestInvalidString("'\\"); + TestInvalidString("\"\\"); + TestInvalidString("''''''\\"); + TestInvalidString("\"\"\"\"\"\"\\"); + TestInvalidString("''\\\\"); + TestInvalidString("\"\"\\\\"); + TestInvalidString("''''''\\\\"); + TestInvalidString("\"\"\"\"\"\"\\\\"); + + // String with an unescaped 0 byte. + std::string s = "abc"; + s[1] = 0; + TestInvalidString(s); + // Note: These are C-escapes to define the invalid strings. + TestInvalidString("'\xc1'", "Structurally invalid UTF8 string"); + TestInvalidString("'\xca'", "Structurally invalid UTF8 string"); + TestInvalidString("'\xcc'", "Structurally invalid UTF8 string"); + TestInvalidString("'\xFA'", "Structurally invalid UTF8 string"); + TestInvalidString("'\xc1\xca\x1b\x62\x19o\xcc\x04'", + "Structurally invalid UTF8 string"); + + TestInvalidString("'\xc2\xc0'", + "Structurally invalid UTF8 string"); // First byte ok utf8, + // invalid together. + TestValue("\xc2\xbf"); // Same first byte, good sequence. + + // These are all valid prefixes for utf8 characters, but the characters + // are not complete. + TestInvalidString( + "'\xc2'", + "Structurally invalid UTF8 string"); // Should be 2 byte utf8 character. + TestInvalidString( + "'\xc3'", + "Structurally invalid UTF8 string"); // Should be 2 byte utf8 character. + TestInvalidString( + "'\xe0'", + "Structurally invalid UTF8 string"); // Should be 3 byte utf8 character. + TestInvalidString( + "'\xe0\xac'", + "Structurally invalid UTF8 string"); // Should be 3 byte utf8 character. + TestInvalidString( + "'\xf0'", + "Structurally invalid UTF8 string"); // Should be 4 byte utf8 character. + TestInvalidString( + "'\xf0\x90'", + "Structurally invalid UTF8 string"); // Should be 4 byte utf8 character. + TestInvalidString( + "'\xf0\x90\x80'", + "Structurally invalid UTF8 string"); // Should be 4 byte utf8 character. +} + +TEST(BytesTest, RoundTrip) { + TestBytesLiteral("b\"\""); + TestBytesLiteral("b\"\"\"\"\"\""); + TestUnescapedBytes(""); + + TestBytesLiteral("b'\\000\\x00AAA\\xfF\\377'"); + TestBytesLiteral("b'''\\000\\x00AAA\\xfF\\377'''"); + TestBytesLiteral("b'\\a\\b\\f\\n\\r\\t\\v\\\\\\?\\\"\\'\\`\\x00\\Xff'"); + TestBytesLiteral("b'''\\a\\b\\f\\n\\r\\t\\v\\\\\\?\\\"\\'\\`\\x00\\Xff'''"); + + TestBytesLiteral("b'\\n\\012\\x0A'"); // Different newline representations. + TestBytesLiteral("b'''\\n\\012\\x0A'''"); + + // Note the C-escaping to define the bytes. These are invalid strings for + // various reasons, but are valid as bytes. + TestUnescapedBytes("\xc1"); + TestUnescapedBytes("\xca"); + TestUnescapedBytes("\xcc"); + TestUnescapedBytes("\xFA"); + TestUnescapedBytes("\xc1\xca\x1b\x62\x19o\xcc\x04"); +} + +TEST(BytesTest, ToBytesLiteralTests) { + // ToBytesLiteral will choose to quote with ' if it will avoid escaping. + // Non-printable bytes are escaped as hex. For printable bytes, only the + // escape character and quote character are escaped. + EXPECT_EQ("b\"\"", FormatBytesLiteral("")); + EXPECT_EQ("b\"abc\"", FormatBytesLiteral("abc")); + EXPECT_EQ("b\"abc'def\"", FormatBytesLiteral("abc'def")); + EXPECT_EQ("b'abc\"def'", FormatBytesLiteral("abc\"def")); + EXPECT_EQ("b\"abc`def\"", FormatBytesLiteral("abc`def")); + EXPECT_EQ("b\"abc'\\\"`def\"", FormatBytesLiteral("abc'\"`def")); + + // Override the quoting style to use single quotes. + EXPECT_EQ("b''", FormatSingleQuotedBytesLiteral("")); + EXPECT_EQ("b'abc'", FormatSingleQuotedBytesLiteral("abc")); + EXPECT_EQ("b'abc\\'def'", FormatSingleQuotedBytesLiteral("abc'def")); + EXPECT_EQ("b'abc\"def'", FormatSingleQuotedBytesLiteral("abc\"def")); + EXPECT_EQ("b'abc`def'", FormatSingleQuotedBytesLiteral("abc`def")); + EXPECT_EQ("b'abc\\'\"`def'", FormatSingleQuotedBytesLiteral("abc'\"`def")); + + // Override the quoting style to use double quotes. + EXPECT_EQ("b\"\"", FormatDoubleQuotedBytesLiteral("")); + EXPECT_EQ("b\"abc\"", FormatDoubleQuotedBytesLiteral("abc")); + EXPECT_EQ("b\"abc'def\"", FormatDoubleQuotedBytesLiteral("abc'def")); + EXPECT_EQ("b\"abc\\\"def\"", FormatDoubleQuotedBytesLiteral("abc\"def")); + EXPECT_EQ("b\"abc`def\"", FormatDoubleQuotedBytesLiteral("abc`def")); + EXPECT_EQ("b\"abc'\\\"`def\"", FormatDoubleQuotedBytesLiteral("abc'\"`def")); + + EXPECT_EQ("b\"\\x07-\\x08-\\x0c-\\x0a-\\x0d-\\x09-\\x0b-\\\\-?-\\\"-'-`\"", + FormatBytesLiteral("\a-\b-\f-\n-\r-\t-\v-\\-?-\"-'-`")); + + EXPECT_EQ("b\"\\x0a\"", FormatBytesLiteral("\n")); + + ASSERT_OK_AND_ASSIGN(auto unquoted, + ParseBytesLiteral("b'\\n\\012\\x0a\\x0A'")); + EXPECT_EQ("b\"\\x0a\\x0a\\x0a\\x0a\"", FormatBytesLiteral(unquoted)); +} + +TEST(ByesTest, InvalidBytes) { + TestInvalidBytes("A", "Invalid bytes literal"); // No quotes + TestInvalidBytes("b'A", "Invalid bytes literal"); // No ending quote + TestInvalidBytes("'A'", "Invalid bytes literal"); // No ending quote + TestInvalidBytes("'A'", "Invalid bytes literal"); // No 'b' prefix. + TestInvalidBytes("'''A'''"); + TestInvalidBytes("b'k\\u0030'", kUnicodeNotAllowedInBytes1); + TestInvalidBytes("b'''\\u0030'''", kUnicodeNotAllowedInBytes1); + TestInvalidBytes("b'\\U00000030'", kUnicodeNotAllowedInBytes2); + TestInvalidBytes("b'''qwerty\\U00000030'''", kUnicodeNotAllowedInBytes2); + EXPECT_FALSE(UnescapeBytes("abc\\u0030").ok()); + EXPECT_FALSE(UnescapeBytes("abc\\U00000030").ok()); + EXPECT_FALSE(UnescapeBytes("abc\\U00000030").ok()); +} + +TEST(RawStringsTest, ValidCases) { + TestRawString(""); + TestRawString("1"); + TestRawString("\\x53"); + TestRawString("\\x123"); + TestRawString("\\001"); + TestRawString("a\\44'A"); + TestRawString("a\\e"); + TestRawString("\\ea"); + TestRawString("\\U1234"); + TestRawString("\\u"); + TestRawString("\\xc2\\\\"); + TestRawString("f\\(abc',(.*),def\\?"); + TestRawString("a\\\"b"); +} + +TEST(RawStringsTest, InvalidRawStrings) { + TestInvalidString("r\"\\\"", "String must end with \""); + TestInvalidString("r\"\\\\\\\"", "String must end with \""); + TestInvalidString("r\""); + TestInvalidString("r"); + TestInvalidString("rb\"\""); + TestInvalidString("b\"\""); + TestInvalidString("r'''", "String cannot contain unescaped '"); +} + +TEST(RawBytesTest, ValidCases) { + TestRawBytes(""); + TestRawBytes("1"); + TestRawBytes("\\x53"); + TestRawBytes("\\x123"); + TestRawBytes("\\001"); + TestRawBytes("a\\44'A"); + TestRawBytes("a\\e"); + TestRawBytes("\\ea"); + TestRawBytes("\\U1234"); + TestRawBytes("\\u"); + TestRawBytes("\\xc2\\\\"); + TestRawBytes("f\\(abc',(.*),def\\?"); +} + +TEST(RawBytesTest, InvalidRawBytes) { + TestInvalidBytes("r''"); + TestInvalidBytes("r''''''"); + TestInvalidBytes("rrb''"); + TestInvalidBytes("brb''"); + TestInvalidBytes("rb'a\\e"); + TestInvalidBytes("rb\"\\\"", "String must end with \""); + TestInvalidBytes("br\"\\\\\\\"", "String must end with \""); + TestInvalidBytes("rb"); + TestInvalidBytes("br"); + TestInvalidBytes("rb\""); + TestInvalidBytes("rb\"\"\"", "String cannot contain unescaped \""); + TestInvalidBytes("rb\"xyz\"\"", "String cannot contain unescaped \""); +} + +TEST(StringsTest, QuotedForms) { + // EscapeString escapes all quote characters. + EXPECT_EQ("", EscapeString("")); + EXPECT_EQ("abc", EscapeString("abc")); + EXPECT_EQ("abc\\'def", EscapeString("abc'def")); + EXPECT_EQ("abc\\\"def", EscapeString("abc\"def")); + EXPECT_EQ("abc\\`def", EscapeString("abc`def")); + + // ToStringLiteral will choose to quote with ' if it will avoid escaping. + // Other quoted characters will not be escaped. + EXPECT_EQ("\"\"", FormatStringLiteral("")); + EXPECT_EQ("\"abc\"", FormatStringLiteral("abc")); + EXPECT_EQ("\"abc'def\"", FormatStringLiteral("abc'def")); + EXPECT_EQ("'abc\"def'", FormatStringLiteral("abc\"def")); + EXPECT_EQ("\"abc`def\"", FormatStringLiteral("abc`def")); + EXPECT_EQ("\"abc'\\\"`def\"", FormatStringLiteral("abc'\"`def")); + + // Override the quoting style to use single quotes. + EXPECT_EQ("''", FormatSingleQuotedStringLiteral("")); + EXPECT_EQ("'abc'", FormatSingleQuotedStringLiteral("abc")); + EXPECT_EQ("'abc\\'def'", FormatSingleQuotedStringLiteral("abc'def")); + EXPECT_EQ("'abc\"def'", FormatSingleQuotedStringLiteral("abc\"def")); + EXPECT_EQ("'abc`def'", FormatSingleQuotedStringLiteral("abc`def")); + EXPECT_EQ("'abc\\'\"`def'", FormatSingleQuotedStringLiteral("abc'\"`def")); + + // Override the quoting style to use double quotes. + EXPECT_EQ("\"\"", FormatDoubleQuotedStringLiteral("")); + EXPECT_EQ("\"abc\"", FormatDoubleQuotedStringLiteral("abc")); + EXPECT_EQ("\"abc'def\"", FormatDoubleQuotedStringLiteral("abc'def")); + EXPECT_EQ("\"abc\\\"def\"", FormatDoubleQuotedStringLiteral("abc\"def")); + EXPECT_EQ("\"abc`def\"", FormatDoubleQuotedStringLiteral("abc`def")); + EXPECT_EQ("\"abc'\\\"`def\"", FormatDoubleQuotedStringLiteral("abc'\"`def")); +} + +void ExpectParsedString(const std::string& expected, + const std::vector& quoted_strings) { + for (const std::string& quoted : quoted_strings) { + ASSERT_OK_AND_ASSIGN(auto parsed, ParseStringLiteral(quoted)); + EXPECT_EQ(expected, parsed); + } +} + +void ExpectParsedBytes(const std::string& expected, + const std::vector& quoted_strings) { + for (const std::string& quoted : quoted_strings) { + ASSERT_OK_AND_ASSIGN(auto parsed, ParseBytesLiteral(quoted)); + EXPECT_EQ(expected, parsed); + } +} + +TEST(StringsTest, Parse) { + ExpectParsedString("abc", + {"'abc'", "\"abc\"", "'''abc'''", "\"\"\"abc\"\"\""}); + ExpectParsedString( + "abc\ndef\x12ghi", + {"'abc\\ndef\\x12ghi'", "\"abc\\ndef\\x12ghi\"", + "'''abc\\ndef\\x12ghi'''", "\"\"\"abc\\ndef\\x12ghi\"\"\""}); + ExpectParsedString("\xF4\x8F\xBF\xBD", + {"'\\U0010FFFD'", "\"\\U0010FFFD\"", "'''\\U0010FFFD'''", + "\"\"\"\\U0010FFFD\"\"\""}); + + // Some more test cases for triple quoted content. + ExpectParsedString("", {"''''''", "\"\"\"\"\"\""}); + ExpectParsedString("'\"", {"''''\"'''"}); + ExpectParsedString("''''''", {"'''''\\'''\\''''"}); + ExpectParsedString("'", {"'''\\''''"}); + ExpectParsedString("''", {"'''\\'\\''''"}); + ExpectParsedString("'\"", {"''''\"'''"}); + ExpectParsedString("'a", {"''''a'''"}); + ExpectParsedString("\"a", {"\"\"\"\"a\"\"\""}); + ExpectParsedString("''a", {"'''''a'''"}); + ExpectParsedString("\"\"a", {"\"\"\"\"\"a\"\"\""}); +} + +TEST(StringsTest, TestNewlines) { + ExpectParsedString("a\nb", {"'''a\rb'''", "'''a\nb'''", "'''a\r\nb'''"}); + ExpectParsedString("a\n\nb", {"'''a\n\rb'''", "'''a\r\n\r\nb'''"}); + // Escaped newlines. + ExpectParsedString("a\nb", {"'''a\\nb'''"}); + ExpectParsedString("a\rb", {"'''a\\rb'''"}); + ExpectParsedString("a\r\nb", {"'''a\\r\\nb'''"}); +} + +TEST(RawStringsTest, CompareRawAndRegularStringParsing) { + ExpectParsedString("\\n", + {"r'\\n'", "r\"\\n\"", "r'''\\n'''", "r\"\"\"\\n\"\"\""}); + ExpectParsedString("\n", + {"'\\n'", "\"\\n\"", "'''\\n'''", "\"\"\"\\n\"\"\""}); + + ExpectParsedString("\\e", + {"r'\\e'", "r\"\\e\"", "r'''\\e'''", "r\"\"\"\\e\"\"\""}); + TestInvalidString("'\\e'", "Illegal escape sequence: \\e"); + TestInvalidString("\"\\e\"", "Illegal escape sequence: \\e"); + TestInvalidString("'''\\e'''", "Illegal escape sequence: \\e"); + TestInvalidString("\"\"\"\\e\"\"\"", "Illegal escape sequence: \\e"); + + ExpectParsedString( + "\\x0", {"r'\\x0'", "r\"\\x0\"", "r'''\\x0'''", "r\"\"\"\\x0\"\"\""}); + constexpr char kHexError[] = + "Hex escape must be followed by 2 hex digits but saw: \\x0"; + TestInvalidString("'\\x0'", kHexError); + TestInvalidString("\"\\x0\"", kHexError); + TestInvalidString("'''\\x0'''", kHexError); + TestInvalidString("\"\"\"\\x0\"\"\"", kHexError); + + ExpectParsedString("\\'", {"r'\\\''"}); + ExpectParsedString("'", {"'\\\''"}); + ExpectParsedString("\\\"", {"r\"\\\"\""}); + ExpectParsedString("\"", {"\"\\\"\""}); + ExpectParsedString("''\\'", {"r'''\'\'\\\''''"}); + ExpectParsedString("'''", {"'''\'\'\\\''''"}); + ExpectParsedString("\"\"\\\"", {"r\"\"\"\"\"\\\"\"\"\""}); + ExpectParsedString("\"\"\"", {"\"\"\"\"\"\\\"\"\"\""}); +} + +TEST(RawBytesTest, CompareRawAndRegularBytesParsing) { + ExpectParsedBytes("\\n", {"rb'\\n'", "br'\\n'", "rb\"\\n\"", "br\"\\n\""}); + ExpectParsedBytes("\n", {"b'\\n'", "b\"\\n\""}); + + ExpectParsedBytes("\\u0030", {"rb'\\u0030'", "br'\\u0030'", "rb\"\\u0030\"", + "br\"\\u0030\""}); + TestInvalidBytes("b'\\u0030'", kUnicodeNotAllowedInBytes1); + TestInvalidBytes("b\"\\u0030\"", kUnicodeNotAllowedInBytes1); + TestInvalidBytes("b\"abc\\u0030\"", kUnicodeNotAllowedInBytes1); + + ExpectParsedBytes("\\U00000030", {"rb'\\U00000030'", "br'\\U00000030'", + "rb\"\\U00000030\"", "br\"\\U00000030\""}); + TestInvalidBytes("b'\\U00000030'", kUnicodeNotAllowedInBytes2); + TestInvalidBytes("b\"\\U00000030\"", kUnicodeNotAllowedInBytes2); + TestInvalidBytes("b\"abc\\U00000030\"", kUnicodeNotAllowedInBytes2); + + ExpectParsedBytes("\\e", {"rb'\\e'", "br'\\e'", "rb\"\\e\"", "br\"\\e\""}); + TestInvalidBytes("b'\\e'", "Illegal escape sequence: \\e"); + TestInvalidBytes("b\"\\e\"", "Illegal escape sequence: \\e"); + TestInvalidBytes("b\"abcd\\e\"", "Illegal escape sequence: \\e"); + + ExpectParsedBytes("\\'", {"rb'\\\''", "br'\\\''"}); + ExpectParsedBytes("'", {"b'\\\''"}); + ExpectParsedBytes("\\\"", {"rb\"\\\"\"", "br\"\\\"\""}); + ExpectParsedBytes("\"", {"b\"\\\"\""}); + ExpectParsedBytes("''\\'", {"rb'''\'\'\\\''''", "br'''\'\'\\\''''"}); + ExpectParsedBytes("'''", {"b'''\'\'\\\''''"}); + ExpectParsedBytes("\"\"\\\"", + {"rb\"\"\"\"\"\\\"\"\"\"", "br\"\"\"\"\"\\\"\"\"\""}); + ExpectParsedBytes("\"\"\"", {"b\"\"\"\"\"\\\"\"\"\""}); +} + +struct epair { + std::string escaped; + std::string unescaped; +}; + +// Copied from strings/escaping_test.cc, CEscape::BasicEscaping. +TEST(StringsTest, UTF8Escape) { + epair utf8_hex_values[] = { + {"\x20\xe4\xbd\xa0\\t\xe5\xa5\xbd,\\r!\\n", + "\x20\xe4\xbd\xa0\t\xe5\xa5\xbd,\r!\n"}, + {"\xe8\xa9\xa6\xe9\xa8\x93\\\' means \\\"test\\\"", + "\xe8\xa9\xa6\xe9\xa8\x93\' means \"test\""}, + {"\\\\\xe6\x88\x91\\\\:\\\\\xe6\x9d\xa8\xe6\xac\xa2\\\\", + "\\\xe6\x88\x91\\:\\\xe6\x9d\xa8\xe6\xac\xa2\\"}, + {"\xed\x81\xac\xeb\xa1\xac\\x08\\t\\n\\x0b\\x0c\\r", + "\xed\x81\xac\xeb\xa1\xac\010\011\012\013\014\015"}}; + + for (int i = 0; i < ABSL_ARRAYSIZE(utf8_hex_values); ++i) { + std::string escaped = EscapeString(utf8_hex_values[i].unescaped); + EXPECT_EQ(escaped, utf8_hex_values[i].escaped); + } +} + +// Originally copied from strings/escaping_test.cc, Unescape::BasicFunction, +// but changes for '\\xABCD' which only parses 2 hex digits after the escape. +TEST(StringsTest, UTF8Unescape) { + epair tests[] = {{"\\u0030", "0"}, + {"\\u00A3", "\xC2\xA3"}, + {"\\u22FD", "\xE2\x8B\xBD"}, + {"\\ud7FF", "\xED\x9F\xBF"}, + {"\\u22FD", "\xE2\x8B\xBD"}, + {"\\U00010000", "\xF0\x90\x80\x80"}, + {"\\U0000E000", "\xEE\x80\x80"}, + {"\\U0001DFFF", "\xF0\x9D\xBF\xBF"}, + {"\\U0010FFFD", "\xF4\x8F\xBF\xBD"}, + {"\\xAbCD", + "\xc2\xab" + "CD"}, + {"\\253CD", + "\xc2\xab" + "CD"}, + {"\\x4141", "A41"}}; + for (int i = 0; i < ABSL_ARRAYSIZE(tests); ++i) { + const std::string& e = tests[i].escaped; + const std::string& u = tests[i].unescaped; + ASSERT_OK_AND_ASSIGN(auto out, UnescapeString(e)); + EXPECT_EQ(u, out) << "original escaped: '" << e << "'\nunescaped: '" << out + << "'\nexpected unescaped: '" << u << "'"; + } + std::string bad[] = {"\\u1", // too short + "\\U1", // too short + "\\Uffffff", "\\777"}; // exceeds 0xff + for (int i = 0; i < ABSL_ARRAYSIZE(bad); ++i) { + const std::string& e = bad[i]; + auto status_or_string = UnescapeString(e); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains(status_or_string.status().message(), + "Invalid escaped string")); + } +} + +TEST(StringsTest, TestUnescapeErrorMessages) { + std::string error_string; + std::string out; + + auto status_or_string = UnescapeString("\\2"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: Octal escape must be followed by 3 octal " + "digits but saw: \\2")); + + status_or_string = UnescapeString("\\22X0"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: Octal escape must be followed by 3 octal " + "digits but saw: \\22X")); + + status_or_string = UnescapeString("\\X0"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: Hex escape must be followed by 2 hex digits " + "but saw: \\X0")); + + status_or_string = UnescapeString("\\x0G0"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: Hex escape must be followed by 2 hex digits " + "but saw: \\x0G")); + + status_or_string = UnescapeString("\\u00"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: \\u must be followed by 4 hex digits but saw: " + "\\u00")); + + status_or_string = UnescapeString("\\ude8c"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: Unicode value \\ude8c is invalid")); + + status_or_string = UnescapeString("\\u000G0"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: \\u must be followed by 4 hex digits but saw: " + "\\u000G")); + + status_or_string = UnescapeString("\\U00"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: \\U must be followed by 8 hex digits but saw: " + "\\U00")); + + status_or_string = UnescapeString("\\U000000G00"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: \\U must be followed by 8 hex digits but saw: " + "\\U000000G0")); + + status_or_string = UnescapeString("\\U0000D83D"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: Unicode value \\U0000D83D is invalid")); + + status_or_string = UnescapeString("\\UFFFFFFFF0"); + EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_TRUE(absl::StrContains( + status_or_string.status().message(), + "Illegal escape sequence: Value of \\UFFFFFFFF exceeds Unicode limit " + "(0x0010FFFF)")); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/testing.cc b/internal/testing.cc new file mode 100644 index 000000000..84aa58cce --- /dev/null +++ b/internal/testing.cc @@ -0,0 +1,31 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/testing.h" + +#include "absl/strings/str_cat.h" // IWYU pragma: keep + +namespace cel::internal { + +void AddFatalFailure(const char* file, int line, absl::string_view expression, + const StatusBuilder& builder) { + GTEST_MESSAGE_AT_(file, line, + absl::StrCat(expression, " returned error: ", + absl::Status(builder).ToString( + absl::StatusToStringMode::kWithEverything)) + .c_str(), + ::testing::TestPartResult::kFatalFailure); +} + +} // namespace cel::internal diff --git a/internal/testing.h b/internal/testing.h new file mode 100644 index 000000000..e1b9f7498 --- /dev/null +++ b/internal/testing.h @@ -0,0 +1,44 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_H_ + +#include "gmock/gmock.h" // IWYU pragma: export +#include "gtest/gtest.h" // IWYU pragma: export +#include "absl/status/status_matchers.h" +#include "internal/status_macros.h" // IWYU pragma: keep + +#ifndef ASSERT_OK +#define ASSERT_OK(expr) ASSERT_THAT(expr, ::absl_testing::IsOk()) +#endif + +#ifndef EXPECT_OK +#define EXPECT_OK(expr) EXPECT_THAT(expr, ::absl_testing::IsOk()) +#endif + +#ifndef ASSERT_OK_AND_ASSIGN +#define ASSERT_OK_AND_ASSIGN(lhs, rhs) \ + CEL_ASSIGN_OR_RETURN( \ + lhs, rhs, ::cel::internal::AddFatalFailure(__FILE__, __LINE__, #rhs, _)) +#endif + +namespace cel::internal { + +void AddFatalFailure(const char* file, int line, absl::string_view expression, + const StatusBuilder& builder); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_H_ diff --git a/internal/testing_descriptor_pool.cc b/internal/testing_descriptor_pool.cc new file mode 100644 index 000000000..eaa89eb5e --- /dev/null +++ b/internal/testing_descriptor_pool.cc @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/testing_descriptor_pool.h" + +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "internal/noop_delete.h" +#include "google/protobuf/descriptor.h" + +namespace cel::internal { + +namespace { + +ABSL_CONST_INIT const uint8_t kTestingDescriptorSet[] = { +#include "internal/testing_descriptor_set_embed.inc" +}; + +} // namespace + +const google::protobuf::DescriptorPool* absl_nonnull GetTestingDescriptorPool() { + static const google::protobuf::DescriptorPool* absl_nonnull const pool = []() { + google::protobuf::FileDescriptorSet file_desc_set; + ABSL_CHECK(file_desc_set.ParseFromArray( // Crash OK + kTestingDescriptorSet, ABSL_ARRAYSIZE(kTestingDescriptorSet))); + auto* pool = new google::protobuf::DescriptorPool(); + for (const auto& file_desc : file_desc_set.file()) { + ABSL_CHECK(pool->BuildFile(file_desc) != nullptr); // Crash OK + } + return pool; + }(); + return pool; +} + +absl_nonnull std::shared_ptr +GetSharedTestingDescriptorPool() { + static const absl::NoDestructor< + absl_nonnull std::shared_ptr> + instance(GetTestingDescriptorPool(), + internal::NoopDeleteFor()); + return *instance; +} + +} // namespace cel::internal diff --git a/internal/testing_descriptor_pool.h b/internal/testing_descriptor_pool.h new file mode 100644 index 000000000..0f8f63fcc --- /dev/null +++ b/internal/testing_descriptor_pool.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_DESCRIPTOR_POOL_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_DESCRIPTOR_POOL_H_ + +#include + +#include "absl/base/nullability.h" +#include "google/protobuf/descriptor.h" + +namespace cel::internal { + +// GetTestingDescriptorPool returns a pointer to a `google::protobuf::DescriptorPool` +// which includes has the necessary descriptors required for the purposes of +// testing. The returning `google::protobuf::DescriptorPool` is valid for the lifetime of +// the process. +const google::protobuf::DescriptorPool* absl_nonnull GetTestingDescriptorPool(); +absl_nonnull std::shared_ptr +GetSharedTestingDescriptorPool(); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_DESCRIPTOR_POOL_H_ diff --git a/internal/testing_descriptor_pool_test.cc b/internal/testing_descriptor_pool_test.cc new file mode 100644 index 000000000..093ce8beb --- /dev/null +++ b/internal/testing_descriptor_pool_test.cc @@ -0,0 +1,175 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/testing_descriptor_pool.h" + +#include "internal/testing.h" +#include "google/protobuf/descriptor.h" + +namespace cel::internal { +namespace { + +using ::testing::NotNull; + +TEST(TestingDescriptorPool, NullValue) { + ASSERT_THAT(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.protobuf.NullValue"), + NotNull()); +} + +TEST(TestingDescriptorPool, BoolValue) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.BoolValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE); +} + +TEST(TestingDescriptorPool, Int32Value) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Int32Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE); +} + +TEST(TestingDescriptorPool, Int64Value) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Int64Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE); +} + +TEST(TestingDescriptorPool, UInt32Value) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.UInt32Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE); +} + +TEST(TestingDescriptorPool, UInt64Value) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.UInt64Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE); +} + +TEST(TestingDescriptorPool, FloatValue) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.FloatValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE); +} + +TEST(TestingDescriptorPool, DoubleValue) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.DoubleValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE); +} + +TEST(TestingDescriptorPool, BytesValue) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.BytesValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE); +} + +TEST(TestingDescriptorPool, StringValue) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.StringValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE); +} + +TEST(TestingDescriptorPool, Any) { + const auto* desc = + GetTestingDescriptorPool()->FindMessageTypeByName("google.protobuf.Any"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_ANY); +} + +TEST(TestingDescriptorPool, Duration) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Duration"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION); +} + +TEST(TestingDescriptorPool, Timestamp) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Timestamp"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP); +} + +TEST(TestingDescriptorPool, Value) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Value"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); +} + +TEST(TestingDescriptorPool, ListValue) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.ListValue"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); +} + +TEST(TestingDescriptorPool, Struct) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Struct"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); +} + +TEST(TestingDescriptorPool, FieldMask) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.FieldMask"); + ASSERT_THAT(desc, NotNull()); + EXPECT_EQ(desc->well_known_type(), + google::protobuf::Descriptor::WELLKNOWNTYPE_FIELDMASK); +} + +TEST(TestingDescriptorPool, Empty) { + const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( + "google.protobuf.Empty"); + ASSERT_THAT(desc, NotNull()); +} + +TEST(TestingDescriptorPool, TestAllTypesProto2) { + EXPECT_THAT(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto2.TestAllTypes"), + NotNull()); +} + +TEST(TestingDescriptorPool, TestAllTypesProto3) { + EXPECT_THAT(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"), + NotNull()); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/testing_message_factory.cc b/internal/testing_message_factory.cc new file mode 100644 index 000000000..958c60c3e --- /dev/null +++ b/internal/testing_message_factory.cc @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/testing_message_factory.h" + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace cel::internal { + +google::protobuf::MessageFactory* absl_nonnull GetTestingMessageFactory() { + static absl::NoDestructor factory( + GetTestingDescriptorPool()); + return &*factory; +} + +} // namespace cel::internal diff --git a/internal/testing_message_factory.h b/internal/testing_message_factory.h new file mode 100644 index 000000000..35406d0fc --- /dev/null +++ b/internal/testing_message_factory.h @@ -0,0 +1,31 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_MESSAGE_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_MESSAGE_FACTORY_H_ + +#include "absl/base/nullability.h" +#include "google/protobuf/message.h" + +namespace cel::internal { + +// GetTestingMessageFactory returns a pointer to a `google::protobuf::MessageFactory` +// which should be used with the descriptor pool returned by +// `GetTestingDescriptorPool`. The returning `google::protobuf::MessageFactory` is valid +// for the lifetime of the process. +google::protobuf::MessageFactory* absl_nonnull GetTestingMessageFactory(); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_MESSAGE_FACTORY_H_ diff --git a/internal/time.cc b/internal/time.cc new file mode 100644 index 000000000..45945613d --- /dev/null +++ b/internal/time.cc @@ -0,0 +1,198 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/time.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "internal/status_macros.h" +#include "google/protobuf/util/time_util.h" + +namespace cel::internal { + +namespace { + +std::string RawFormatTimestamp(absl::Time timestamp) { + return absl::FormatTime("%Y-%m-%d%ET%H:%M:%E*SZ", timestamp, + absl::UTCTimeZone()); +} + +} // namespace + +absl::Duration MaxDuration() { + // This currently supports a larger range then the current CEL spec. The + // intent is to widen the CEL spec to support the larger range and match + // google.protobuf.Duration from protocol buffer messages, which this + // implementation currently supports. + // TODO(google/cel-spec/issues/214): revisit + return absl::Seconds(google::protobuf::util::TimeUtil::kDurationMaxSeconds) + + absl::Nanoseconds(google::protobuf::util::TimeUtil::kDurationMaxNanoseconds); +} + +absl::Duration MinDuration() { + // This currently supports a larger range then the current CEL spec. The + // intent is to widen the CEL spec to support the larger range and match + // google.protobuf.Duration from protocol buffer messages, which this + // implementation currently supports. + // TODO(google/cel-spec/issues/214): revisit + return absl::Seconds(google::protobuf::util::TimeUtil::kDurationMinSeconds) + + absl::Nanoseconds(google::protobuf::util::TimeUtil::kDurationMinNanoseconds); +} + +absl::Time MaxTimestamp() { + return absl::UnixEpoch() + + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMaxSeconds) + + absl::Nanoseconds(google::protobuf::util::TimeUtil::kTimestampMaxNanoseconds); +} + +absl::Time MinTimestamp() { + return absl::UnixEpoch() + + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMinSeconds) + + absl::Nanoseconds(google::protobuf::util::TimeUtil::kTimestampMinNanoseconds); +} + +absl::Status ValidateDuration(absl::Duration duration) { + if (duration < MinDuration()) { + return absl::InvalidArgumentError( + absl::StrCat("Duration \"", absl::FormatDuration(duration), + "\" below minimum allowed duration \"", + absl::FormatDuration(MinDuration()), "\"")); + } + if (duration > MaxDuration()) { + return absl::InvalidArgumentError( + absl::StrCat("Duration \"", absl::FormatDuration(duration), + "\" above maximum allowed duration \"", + absl::FormatDuration(MaxDuration()), "\"")); + } + return absl::OkStatus(); +} + +absl::StatusOr ParseDuration(absl::string_view input) { + absl::Duration duration; + if (!absl::ParseDuration(input, &duration)) { + return absl::InvalidArgumentError("Failed to parse duration from string"); + } + return duration; +} + +absl::StatusOr FormatDuration(absl::Duration duration) { + CEL_RETURN_IF_ERROR(ValidateDuration(duration)); + return absl::FormatDuration(duration); +} + +std::string DebugStringDuration(absl::Duration duration) { + return absl::FormatDuration(duration); +} + +absl::Status ValidateTimestamp(absl::Time timestamp) { + if (timestamp < MinTimestamp()) { + return absl::InvalidArgumentError( + absl::StrCat("Timestamp \"", RawFormatTimestamp(timestamp), + "\" below minimum allowed timestamp \"", + RawFormatTimestamp(MinTimestamp()), "\"")); + } + if (timestamp > MaxTimestamp()) { + return absl::InvalidArgumentError( + absl::StrCat("Timestamp \"", RawFormatTimestamp(timestamp), + "\" above maximum allowed timestamp \"", + RawFormatTimestamp(MaxTimestamp()), "\"")); + } + return absl::OkStatus(); +} + +absl::StatusOr ParseTimestamp(absl::string_view input) { + absl::Time timestamp; + std::string err; + if (!absl::ParseTime(absl::RFC3339_full, input, absl::UTCTimeZone(), + ×tamp, &err)) { + return err.empty() ? absl::InvalidArgumentError( + "Failed to parse timestamp from string") + : absl::InvalidArgumentError(absl::StrCat( + "Failed to parse timestamp from string: ", err)); + } + CEL_RETURN_IF_ERROR(ValidateTimestamp(timestamp)); + return timestamp; +} + +absl::StatusOr FormatTimestamp(absl::Time timestamp) { + CEL_RETURN_IF_ERROR(ValidateTimestamp(timestamp)); + return RawFormatTimestamp(timestamp); +} + +std::string FormatNanos(int32_t nanos) { + constexpr int32_t kNanosPerMillisecond = 1000000; + constexpr int32_t kNanosPerMicrosecond = 1000; + + if (nanos % kNanosPerMillisecond == 0) { + return absl::StrFormat("%03d", nanos / kNanosPerMillisecond); + } else if (nanos % kNanosPerMicrosecond == 0) { + return absl::StrFormat("%06d", nanos / kNanosPerMicrosecond); + } + return absl::StrFormat("%09d", nanos); +} + +absl::StatusOr EncodeDurationToJson(absl::Duration duration) { + // Adapted from protobuf time_util. + CEL_RETURN_IF_ERROR(ValidateDuration(duration)); + std::string result; + int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); + int64_t nanos = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); + + if (seconds < 0 || nanos < 0) { + result = "-"; + seconds = -seconds; + nanos = -nanos; + } + + absl::StrAppend(&result, seconds); + if (nanos != 0) { + absl::StrAppend(&result, ".", FormatNanos(nanos)); + } + + absl::StrAppend(&result, "s"); + return result; +} + +absl::StatusOr EncodeTimestampToJson(absl::Time timestamp) { + // Adapted from protobuf time_util. + static constexpr absl::string_view kTimestampFormat = "%E4Y-%m-%dT%H:%M:%S"; + CEL_RETURN_IF_ERROR(ValidateTimestamp(timestamp)); + // Handle nanos and the seconds separately to match proto JSON format. + absl::Time unix_seconds = + absl::FromUnixSeconds(absl::ToUnixSeconds(timestamp)); + int64_t n = (timestamp - unix_seconds) / absl::Nanoseconds(1); + + std::string result = + absl::FormatTime(kTimestampFormat, unix_seconds, absl::UTCTimeZone()); + + if (n > 0) { + absl::StrAppend(&result, ".", FormatNanos(n)); + } + + absl::StrAppend(&result, "Z"); + return result; +} + +std::string DebugStringTimestamp(absl::Time timestamp) { + return RawFormatTimestamp(timestamp); +} + +} // namespace cel::internal diff --git a/internal/time.h b/internal/time.h new file mode 100644 index 000000000..402cb6c8b --- /dev/null +++ b/internal/time.h @@ -0,0 +1,65 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TIME_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_TIME_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" + +namespace cel::internal { + +absl::Duration MaxDuration(); + +absl::Duration MinDuration(); + +absl::Time MaxTimestamp(); + +absl::Time MinTimestamp(); + +absl::Status ValidateDuration(absl::Duration duration); + +absl::StatusOr ParseDuration(absl::string_view input); + +// Human-friendly format for duration provided to match DebugString. +// Checks that the duration is in the supported range for CEL values. +absl::StatusOr FormatDuration(absl::Duration duration); + +// Encodes duration as a string for JSON. +// This implementation is compatible with protobuf. +absl::StatusOr EncodeDurationToJson(absl::Duration duration); + +std::string DebugStringDuration(absl::Duration duration); + +absl::Status ValidateTimestamp(absl::Time timestamp); + +absl::StatusOr ParseTimestamp(absl::string_view input); + +// Human-friendly format for timestamp provided to match DebugString. +// Checks that the timestamp is in the supported range for CEL values. +absl::StatusOr FormatTimestamp(absl::Time timestamp); + +// Encodes timestamp as a string for JSON. +// This implementation is compatible with protobuf. +absl::StatusOr EncodeTimestampToJson(absl::Time timestamp); + +std::string DebugStringTimestamp(absl::Time timestamp); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_TIME_H_ diff --git a/internal/time_test.cc b/internal/time_test.cc new file mode 100644 index 000000000..94eb4bf32 --- /dev/null +++ b/internal/time_test.cc @@ -0,0 +1,188 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/time.h" + +#include + +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "internal/testing.h" +#include "google/protobuf/util/time_util.h" + +namespace cel::internal { +namespace { + +using ::absl_testing::StatusIs; + +TEST(MaxDuration, ProtoEquiv) { + EXPECT_EQ(MaxDuration(), + absl::Seconds(google::protobuf::util::TimeUtil::kDurationMaxSeconds) + + absl::Nanoseconds(999999999)); +} + +TEST(MinDuration, ProtoEquiv) { + EXPECT_EQ(MinDuration(), + absl::Seconds(google::protobuf::util::TimeUtil::kDurationMinSeconds) + + absl::Nanoseconds(-999999999)); +} + +TEST(MaxTimestamp, ProtoEquiv) { + EXPECT_EQ(MaxTimestamp(), + absl::UnixEpoch() + + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMaxSeconds) + + absl::Nanoseconds(999999999)); +} + +TEST(MinTimestamp, ProtoEquiv) { + EXPECT_EQ(MinTimestamp(), + absl::UnixEpoch() + + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMinSeconds)); +} + +TEST(ParseDuration, Conformance) { + absl::Duration parsed; + ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("1s")); + EXPECT_EQ(parsed, absl::Seconds(1)); + ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("0.010s")); + EXPECT_EQ(parsed, absl::Milliseconds(10)); + ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("0.000010s")); + EXPECT_EQ(parsed, absl::Microseconds(10)); + ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("0.000000010s")); + EXPECT_EQ(parsed, absl::Nanoseconds(10)); + + EXPECT_THAT(internal::ParseDuration("abc"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(internal::ParseDuration("1c"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FormatDuration, Conformance) { + std::string formatted; + ASSERT_OK_AND_ASSIGN(formatted, internal::FormatDuration(absl::Seconds(1))); + EXPECT_EQ(formatted, "1s"); + ASSERT_OK_AND_ASSIGN(formatted, + internal::FormatDuration(absl::Milliseconds(10))); + EXPECT_EQ(formatted, "10ms"); + ASSERT_OK_AND_ASSIGN(formatted, + internal::FormatDuration(absl::Microseconds(10))); + EXPECT_EQ(formatted, "10us"); + ASSERT_OK_AND_ASSIGN(formatted, + internal::FormatDuration(absl::Nanoseconds(10))); + EXPECT_EQ(formatted, "10ns"); + + EXPECT_THAT(internal::FormatDuration(absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(internal::FormatDuration(-absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ParseTimestamp, Conformance) { + absl::Time parsed; + ASSERT_OK_AND_ASSIGN(parsed, internal::ParseTimestamp("1-01-01T00:00:00Z")); + EXPECT_EQ(parsed, MinTimestamp()); + ASSERT_OK_AND_ASSIGN( + parsed, internal::ParseTimestamp("9999-12-31T23:59:59.999999999Z")); + EXPECT_EQ(parsed, MaxTimestamp()); + ASSERT_OK_AND_ASSIGN(parsed, + internal::ParseTimestamp("1970-01-01T00:00:00Z")); + EXPECT_EQ(parsed, absl::UnixEpoch()); + ASSERT_OK_AND_ASSIGN(parsed, + internal::ParseTimestamp("1970-01-01T00:00:00.010Z")); + EXPECT_EQ(parsed, absl::UnixEpoch() + absl::Milliseconds(10)); + ASSERT_OK_AND_ASSIGN(parsed, + internal::ParseTimestamp("1970-01-01T00:00:00.000010Z")); + EXPECT_EQ(parsed, absl::UnixEpoch() + absl::Microseconds(10)); + ASSERT_OK_AND_ASSIGN( + parsed, internal::ParseTimestamp("1970-01-01T00:00:00.000000010Z")); + EXPECT_EQ(parsed, absl::UnixEpoch() + absl::Nanoseconds(10)); + + EXPECT_THAT(internal::ParseTimestamp("abc"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(internal::ParseTimestamp("10000-01-01T00:00:00Z"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FormatTimestamp, Conformance) { + std::string formatted; + ASSERT_OK_AND_ASSIGN(formatted, internal::FormatTimestamp(MinTimestamp())); + EXPECT_EQ(formatted, "1-01-01T00:00:00Z"); + ASSERT_OK_AND_ASSIGN(formatted, internal::FormatTimestamp(MaxTimestamp())); + EXPECT_EQ(formatted, "9999-12-31T23:59:59.999999999Z"); + ASSERT_OK_AND_ASSIGN(formatted, internal::FormatTimestamp(absl::UnixEpoch())); + EXPECT_EQ(formatted, "1970-01-01T00:00:00Z"); + ASSERT_OK_AND_ASSIGN( + formatted, + internal::FormatTimestamp(absl::UnixEpoch() + absl::Milliseconds(10))); + EXPECT_EQ(formatted, "1970-01-01T00:00:00.01Z"); + ASSERT_OK_AND_ASSIGN( + formatted, + internal::FormatTimestamp(absl::UnixEpoch() + absl::Microseconds(10))); + EXPECT_EQ(formatted, "1970-01-01T00:00:00.00001Z"); + ASSERT_OK_AND_ASSIGN( + formatted, + internal::FormatTimestamp(absl::UnixEpoch() + absl::Nanoseconds(10))); + EXPECT_EQ(formatted, "1970-01-01T00:00:00.00000001Z"); + + EXPECT_THAT(internal::FormatTimestamp(absl::InfiniteFuture()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(internal::FormatTimestamp(absl::InfinitePast()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(EncodeDurationToJson, Conformance) { + std::string formatted; + ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Seconds(1))); + EXPECT_EQ(formatted, "1s"); + ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Milliseconds(10))); + EXPECT_EQ(formatted, "0.010s"); + ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Microseconds(10))); + EXPECT_EQ(formatted, "0.000010s"); + ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Nanoseconds(10))); + EXPECT_EQ(formatted, "0.000000010s"); + + EXPECT_THAT(EncodeDurationToJson(absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(EncodeDurationToJson(-absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(EncodeTimestampToJson, Conformance) { + std::string formatted; + ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(MinTimestamp())); + EXPECT_EQ(formatted, "0001-01-01T00:00:00Z"); + ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(MaxTimestamp())); + EXPECT_EQ(formatted, "9999-12-31T23:59:59.999999999Z"); + ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(absl::UnixEpoch())); + EXPECT_EQ(formatted, "1970-01-01T00:00:00Z"); + ASSERT_OK_AND_ASSIGN( + formatted, + EncodeTimestampToJson(absl::UnixEpoch() + absl::Milliseconds(10))); + EXPECT_EQ(formatted, "1970-01-01T00:00:00.010Z"); + ASSERT_OK_AND_ASSIGN( + formatted, + EncodeTimestampToJson(absl::UnixEpoch() + absl::Microseconds(10))); + EXPECT_EQ(formatted, "1970-01-01T00:00:00.000010Z"); + ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(absl::UnixEpoch() + + absl::Nanoseconds(10))); + EXPECT_EQ(formatted, "1970-01-01T00:00:00.000000010Z"); + + EXPECT_THAT(EncodeTimestampToJson(absl::InfiniteFuture()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(EncodeTimestampToJson(absl::InfinitePast()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/to_address.h b/internal/to_address.h new file mode 100644 index 000000000..36e7eeb60 --- /dev/null +++ b/internal/to_address.h @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TO_ADDRESS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_TO_ADDRESS_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/meta/type_traits.h" + +namespace cel::internal { + +// ----------------------------------------------------------------------------- +// Function Template: to_address() +// ----------------------------------------------------------------------------- +// +// Backport of std::to_address introduced in C++20. Enables obtaining the +// address of an object regardless of whether the pointer is raw or fancy. +#if defined(__cpp_lib_to_address) && __cpp_lib_to_address >= 201711L +using std::to_address; +#else +template +constexpr T* to_address(T* ptr) noexcept { + static_assert(!std::is_function::value, "T must not be a function"); + return ptr; +} + +template +struct PointerTraitsToAddress { + static constexpr auto Dispatch( + const T& p ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return internal::to_address(p.operator->()); + } +}; + +template +struct PointerTraitsToAddress< + T, std::void_t::to_address( + std::declval()))> > { + static constexpr auto Dispatch( + const T& p ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return std::pointer_traits::to_address(p); + } +}; + +template +constexpr auto to_address(const T& ptr ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return PointerTraitsToAddress::Dispatch(ptr); +} +#endif + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_TO_ADDRESS_H_ diff --git a/internal/to_address_test.cc b/internal/to_address_test.cc new file mode 100644 index 000000000..554cfd29d --- /dev/null +++ b/internal/to_address_test.cc @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/to_address.h" + +#include + +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(ToAddress, RawPointer) { + char c; + EXPECT_EQ(internal::to_address(&c), &c); +} + +struct ImplicitFancyPointer { + using element_type = char; + + char* operator->() const { return ptr; } + + char* ptr; +}; + +struct ExplicitFancyPointer { + char* ptr; +}; + +} // namespace +} // namespace cel + +namespace std { + +template <> +struct pointer_traits : pointer_traits { + static constexpr char* to_address( + const cel::ExplicitFancyPointer& efp) noexcept { + return efp.ptr; + } +}; + +} // namespace std + +namespace cel { +namespace { + +TEST(ToAddress, FancyPointerNoPointerTraits) { + char c; + ImplicitFancyPointer ip{&c}; + EXPECT_EQ(internal::to_address(ip), &c); +} + +TEST(ToAddress, FancyPointerWithPointerTraits) { + char c; + ExplicitFancyPointer ip{&c}; + EXPECT_EQ(internal::to_address(ip), &c); +} + +} // namespace +} // namespace cel diff --git a/internal/types.h b/internal/types.h deleted file mode 100644 index d2bd2a270..000000000 --- a/internal/types.h +++ /dev/null @@ -1,156 +0,0 @@ -// Helpers to work with sets of types. - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TYPE_UTIL_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_TYPE_UTIL_H_ - -#include -#include - -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" -#include "internal/port.h" -#include "internal/specialize.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -// Short names for AND, OR and NOT. -template -using and_t = conjunction; -template -using or_t = disjunction; -template -using not_t = negation; - -// Holder for a static list of types. -template -using types = std::tuple; - -// Helper that defines 'value' to be the number of types in T. -template -using types_size = std::tuple_size; - -template -using types_cat = decltype(std::tuple_cat(inst_of>()...)); - -// Helper that resolves to the Ith type in T. -template -using type_at = typename std::tuple_element::type; - -// Helper that defines 'value' to be true if there are no types in T. -template -using types_empty = std::integral_constant::value == 0>; - -// Helper that returns the index of the first occurrence of E in T. -template -struct type_index { - static constexpr std::size_t value = -1; -}; -template -struct type_index, I> { - static constexpr std::size_t value = I; -}; -template -struct type_index, I> - : type_index, I + 1> {}; - -// Helper that defines 'value' to be true if E occurs in T. -template -using type_in = std::integral_constant::value != - static_cast(-1)>; - -// A 'type map' lookup of `Key` in a map from Keys -> Values. -template -using get_type = type_at::value, Values>; - -// Helper definitions for packed types. -template -using args_size = types_size>; -template -using args_empty = types_empty>; -template -using arg_in = type_in>; - -template -using type_is = std::is_same; - -template -using type_not = negation>; - -/** - * Tests if a type is a raw or smart pointer. - * - * Any type that defines overloads for * and -> is considered a smart pointer. - */ -template -struct is_ptr : std::is_pointer> {}; -template -struct is_ptr::operator*), - decltype(&decay_t::operator->)>> - : std::true_type {}; - -/** - * Tests if a type is convertible to absl::string_view. - */ -template -using is_string = and_t, std::nullptr_t>, - std::is_convertible>; - -/** - * Tests if a type is a signed integer type. - */ -template -using is_int = conjunction>, - std::is_signed>>; - -/** - * Tests if a type is an unsigned integer type. - */ -template -using is_uint = - conjunction, bool>, std::is_integral>, - std::is_unsigned>>; - -/** - * Tests if a type is a floating point type. - */ -template -using is_float = std::is_floating_point>; - -template -using is_numeric = or_t, is_uint, is_float>; - -// Containers define a "value_type" and "iterator". -// Note: The full spec can be found at -// https://en.cppreference.com/w/cpp/named_req/Container -template -struct is_container : public std::false_type {}; -template -struct is_container::value_type, - typename remove_cvref_t::iterator>> - : public std::true_type {}; - -// Maps are containers that also define a "mapped_type". -template -struct is_map : public std::false_type {}; -template -struct is_map::mapped_type>> - : public is_container {}; - -// Lists are containers that are not maps. -template -using is_list = bool_constant::value && !is_map::value>; - -// Used to create a compiler error when a specialized function/class is -// instantiated with an unsupported type. -template -struct UnsupportedType; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif diff --git a/internal/types_test.cc b/internal/types_test.cc deleted file mode 100644 index 7a5fce0a8..000000000 --- a/internal/types_test.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "internal/types.h" - -#include "gtest/gtest.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -TEST(Types, Numeric) { - static_assert(!is_numeric::value, "bool is not numeric"); -} - -TEST(Types, String) { - static_assert(!is_string::value, "nullptr is not a string"); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/unicode.h b/internal/unicode.h new file mode 100644 index 000000000..5723258f7 --- /dev/null +++ b/internal/unicode.h @@ -0,0 +1,28 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_UNICODE_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_UNICODE_H_ + +namespace cel::internal { + +inline constexpr char32_t kUnicodeReplacementCharacter = 0xfffd; + +constexpr bool UnicodeIsValid(char32_t code_point) { + return code_point < 0xd800 || (code_point > 0xdfff && code_point <= 0x10ffff); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_UNICODE_H_ diff --git a/internal/utf8.cc b/internal/utf8.cc new file mode 100644 index 000000000..8cda91505 --- /dev/null +++ b/internal/utf8.cc @@ -0,0 +1,526 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/utf8.h" + +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "internal/unicode.h" + +// Implementation is based on +// https://go.googlesource.com/go/+/refs/heads/master/src/unicode/utf8/utf8.go +// but adapted for C++. + +namespace cel::internal { + +namespace { + +constexpr uint8_t kUtf8RuneSelf = 0x80; +constexpr size_t kUtf8Max = 4; + +constexpr uint8_t kLow = 0x80; +constexpr uint8_t kHigh = 0xbf; + +constexpr uint8_t kMaskX = 0x3f; +constexpr uint8_t kMask2 = 0x1f; +constexpr uint8_t kMask3 = 0xf; +constexpr uint8_t kMask4 = 0x7; + +constexpr uint8_t kTX = 0x80; +constexpr uint8_t kT2 = 0xc0; +constexpr uint8_t kT3 = 0xe0; +constexpr uint8_t kT4 = 0xf0; + +constexpr uint8_t kXX = 0xf1; +constexpr uint8_t kAS = 0xf0; +constexpr uint8_t kS1 = 0x02; +constexpr uint8_t kS2 = 0x13; +constexpr uint8_t kS3 = 0x03; +constexpr uint8_t kS4 = 0x23; +constexpr uint8_t kS5 = 0x34; +constexpr uint8_t kS6 = 0x04; +constexpr uint8_t kS7 = 0x44; + +// NOLINTBEGIN +// clang-format off +constexpr uint8_t kLeading[256] = { + // 1 2 3 4 5 6 7 8 9 A B C D E F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x00-0x0F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x10-0x1F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x20-0x2F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x30-0x3F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x40-0x4F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x50-0x5F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x60-0x6F + kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x70-0x7F + // 1 2 3 4 5 6 7 8 9 A B C D E F + kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0x80-0x8F + kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0x90-0x9F + kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0xA0-0xAF + kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0xB0-0xBF + kXX, kXX, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, // 0xC0-0xCF + kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, // 0xD0-0xDF + kS2, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS4, kS3, kS3, // 0xE0-0xEF + kS5, kS6, kS6, kS6, kS7, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0xF0-0xFF +}; +// clang-format on +// NOLINTEND + +constexpr std::pair kAccept[16] = { + {kLow, kHigh}, {0xa0, kHigh}, {kLow, 0x9f}, {0x90, kHigh}, + {kLow, 0x8f}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, + {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, + {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, +}; + +class StringReader final { + public: + constexpr explicit StringReader(absl::string_view input) : input_(input) {} + + size_t Remaining() const { return input_.size(); } + + bool HasRemaining() const { return !input_.empty(); } + + absl::string_view Peek(size_t n) { + ABSL_ASSERT(n <= Remaining()); + return input_.substr(0, n); + } + + char Read() { + ABSL_ASSERT(HasRemaining()); + char value = input_.front(); + input_.remove_prefix(1); + return value; + } + + void Advance(size_t n) { + ABSL_ASSERT(n <= Remaining()); + input_.remove_prefix(n); + } + + void Reset(absl::string_view input) { input_ = input; } + + private: + absl::string_view input_; +}; + +class CordReader final { + public: + explicit CordReader(const absl::Cord& input) + : input_(input), size_(input_.size()), buffer_(), index_(0) {} + + size_t Remaining() const { return size_; } + + bool HasRemaining() const { return size_ != 0; } + + absl::string_view Peek(size_t n) { + ABSL_ASSERT(n <= Remaining()); + if (n == 0) { + return absl::string_view(); + } + if (n <= buffer_.size() - index_) { + // Enough data remaining in temporary buffer. + return absl::string_view(buffer_.data() + index_, n); + } + // We do not have enough data. See if we can fit it without allocating by + // shifting data back to the beginning of the buffer. + if (buffer_.capacity() >= n) { + // It will fit in the current capacity, see if we need to shift the + // existing data to make it fit. + if (buffer_.capacity() - buffer_.size() < n && index_ != 0) { + // We need to shift. + buffer_.erase(buffer_.begin(), buffer_.begin() + index_); + index_ = 0; + } + } + // Ensure we never reserve less than kUtf8Max. + buffer_.reserve(std::max(buffer_.size() + n, kUtf8Max)); + size_t to_copy = n - (buffer_.size() - index_); + absl::CopyCordToString(input_.Subcord(0, to_copy), &buffer_); + input_.RemovePrefix(to_copy); + return absl::string_view(buffer_.data() + index_, n); + } + + char Read() { + char value = Peek(1).front(); + Advance(1); + return value; + } + + void Advance(size_t n) { + ABSL_ASSERT(n <= Remaining()); + if (n == 0) { + return; + } + if (index_ < buffer_.size()) { + size_t count = std::min(n, buffer_.size() - index_); + index_ += count; + n -= count; + size_ -= count; + if (index_ < buffer_.size()) { + return; + } + // Temporary buffer is empty, clear it. + buffer_.clear(); + index_ = 0; + } + input_.RemovePrefix(n); + size_ -= n; + } + + void Reset(const absl::Cord& input) { + input_ = input; + size_ = input_.size(); + buffer_.clear(); + index_ = 0; + } + + private: + absl::Cord input_; + size_t size_; + std::string buffer_; + size_t index_; +}; + +template +bool Utf8IsValidImpl(BufferedByteReader* reader) { + while (reader->HasRemaining()) { + const auto b = static_cast(reader->Read()); + if (b < kUtf8RuneSelf) { + continue; + } + const auto leading = kLeading[b]; + if (leading == kXX) { + return false; + } + const auto size = static_cast(leading & 7) - 1; + if (size > reader->Remaining()) { + return false; + } + const absl::string_view segment = reader->Peek(size); + const auto& accept = kAccept[leading >> 4]; + if (static_cast(segment[0]) < accept.first || + static_cast(segment[0]) > accept.second) { + return false; + } else if (size == 1) { + } else if (static_cast(segment[1]) < kLow || + static_cast(segment[1]) > kHigh) { + return false; + } else if (size == 2) { + } else if (static_cast(segment[2]) < kLow || + static_cast(segment[2]) > kHigh) { + return false; + } + reader->Advance(size); + } + return true; +} + +template +size_t Utf8CodePointCountImpl(BufferedByteReader* reader) { + size_t count = 0; + while (reader->HasRemaining()) { + count++; + const auto b = static_cast(reader->Read()); + if (b < kUtf8RuneSelf) { + continue; + } + const auto leading = kLeading[b]; + if (leading == kXX) { + continue; + } + auto size = static_cast(leading & 7) - 1; + if (size > reader->Remaining()) { + continue; + } + const absl::string_view segment = reader->Peek(size); + const auto& accept = kAccept[leading >> 4]; + if (static_cast(segment[0]) < accept.first || + static_cast(segment[0]) > accept.second) { + size = 0; + } else if (size == 1) { + } else if (static_cast(segment[1]) < kLow || + static_cast(segment[1]) > kHigh) { + size = 0; + } else if (size == 2) { + } else if (static_cast(segment[2]) < kLow || + static_cast(segment[2]) > kHigh) { + size = 0; + } + reader->Advance(size); + } + return count; +} + +template +std::pair Utf8ValidateImpl(BufferedByteReader* reader) { + size_t count = 0; + while (reader->HasRemaining()) { + const auto b = static_cast(reader->Read()); + if (b < kUtf8RuneSelf) { + count++; + continue; + } + const auto leading = kLeading[b]; + if (leading == kXX) { + return {count, false}; + } + const auto size = static_cast(leading & 7) - 1; + if (size > reader->Remaining()) { + return {count, false}; + } + const absl::string_view segment = reader->Peek(size); + const auto& accept = kAccept[leading >> 4]; + if (static_cast(segment[0]) < accept.first || + static_cast(segment[0]) > accept.second) { + return {count, false}; + } else if (size == 1) { + count++; + } else if (static_cast(segment[1]) < kLow || + static_cast(segment[1]) > kHigh) { + return {count, false}; + } else if (size == 2) { + count++; + } else if (static_cast(segment[2]) < kLow || + static_cast(segment[2]) > kHigh) { + return {count, false}; + } else { + count++; + } + reader->Advance(size); + } + return {count, true}; +} + +} // namespace + +bool Utf8IsValid(absl::string_view str) { + StringReader reader(str); + bool valid = Utf8IsValidImpl(&reader); + ABSL_ASSERT((reader.Reset(str), valid == Utf8ValidateImpl(&reader).second)); + return valid; +} + +bool Utf8IsValid(const absl::Cord& str) { + CordReader reader(str); + bool valid = Utf8IsValidImpl(&reader); + ABSL_ASSERT((reader.Reset(str), valid == Utf8ValidateImpl(&reader).second)); + return valid; +} + +size_t Utf8CodePointCount(absl::string_view str) { + StringReader reader(str); + return Utf8CodePointCountImpl(&reader); +} + +size_t Utf8CodePointCount(const absl::Cord& str) { + CordReader reader(str); + return Utf8CodePointCountImpl(&reader); +} + +std::pair Utf8Validate(absl::string_view str) { + StringReader reader(str); + auto result = Utf8ValidateImpl(&reader); + ABSL_ASSERT((reader.Reset(str), result.second == Utf8IsValidImpl(&reader))); + return result; +} + +std::pair Utf8Validate(const absl::Cord& str) { + CordReader reader(str); + auto result = Utf8ValidateImpl(&reader); + ABSL_ASSERT((reader.Reset(str), result.second == Utf8IsValidImpl(&reader))); + return result; +} + +namespace { + +size_t Utf8DecodeImpl(uint8_t b, uint8_t leading, size_t size, + absl::string_view str, + char32_t* absl_nullable code_point) { + const auto& accept = kAccept[leading >> 4]; + const auto b1 = static_cast(str.front()); + if (ABSL_PREDICT_FALSE(b1 < accept.first || b1 > accept.second)) { + if (code_point != nullptr) { + *code_point = kUnicodeReplacementCharacter; + } + return 1; + } + if (size <= 1) { + if (code_point != nullptr) { + *code_point = (static_cast(b & kMask2) << 6) | + static_cast(b1 & kMaskX); + } + return 2; + } + str.remove_prefix(1); + const auto b2 = static_cast(str.front()); + if (ABSL_PREDICT_FALSE(b2 < kLow || b2 > kHigh)) { + if (code_point != nullptr) { + *code_point = kUnicodeReplacementCharacter; + } + return 1; + } + if (size <= 2) { + if (code_point != nullptr) { + *code_point = (static_cast(b & kMask3) << 12) | + (static_cast(b1 & kMaskX) << 6) | + static_cast(b2 & kMaskX); + } + return 3; + } + str.remove_prefix(1); + const auto b3 = static_cast(str.front()); + if (ABSL_PREDICT_FALSE(b3 < kLow || b3 > kHigh)) { + if (code_point != nullptr) { + *code_point = kUnicodeReplacementCharacter; + } + return 1; + } + if (code_point != nullptr) { + *code_point = (static_cast(b & kMask4) << 18) | + (static_cast(b1 & kMaskX) << 12) | + (static_cast(b2 & kMaskX) << 6) | + static_cast(b3 & kMaskX); + } + return 4; +} + +} // namespace + +size_t Utf8Decode(absl::string_view str, char32_t* absl_nullable code_point) { + ABSL_DCHECK(!str.empty()); + const auto b = static_cast(str.front()); + if (b < kUtf8RuneSelf) { + if (code_point != nullptr) { + *code_point = static_cast(b); + } + return 1; + } + const auto leading = kLeading[b]; + if (ABSL_PREDICT_FALSE(leading == kXX)) { + if (code_point != nullptr) { + *code_point = kUnicodeReplacementCharacter; + } + return 1; + } + auto size = static_cast(leading & 7) - 1; + str.remove_prefix(1); + if (ABSL_PREDICT_FALSE(size > str.size())) { + if (code_point != nullptr) { + *code_point = kUnicodeReplacementCharacter; + } + return 1; + } + return Utf8DecodeImpl(b, leading, size, str, code_point); +} + +size_t Utf8Decode(const absl::Cord::CharIterator& it, + char32_t* absl_nullable code_point) { + absl::string_view str = absl::Cord::ChunkRemaining(it); + ABSL_DCHECK(!str.empty()); + const auto b = static_cast(str.front()); + if (b < kUtf8RuneSelf) { + if (code_point != nullptr) { + *code_point = static_cast(b); + } + return 1; + } + const auto leading = kLeading[b]; + if (ABSL_PREDICT_FALSE(leading == kXX)) { + if (code_point != nullptr) { + *code_point = kUnicodeReplacementCharacter; + } + return 1; + } + auto size = static_cast(leading & 7) - 1; + str.remove_prefix(1); + if (ABSL_PREDICT_TRUE(size <= str.size())) { + // Fast path. + return Utf8DecodeImpl(b, leading, size, str, code_point); + } + absl::Cord::CharIterator current = it; + absl::Cord::Advance(¤t, 1); + char buffer[3]; + size_t buffer_len = 0; + while (buffer_len < size) { + str = absl::Cord::ChunkRemaining(current); + if (ABSL_PREDICT_FALSE(str.empty())) { + if (code_point != nullptr) { + *code_point = kUnicodeReplacementCharacter; + } + return 1; + } + size_t to_copy = std::min(size_t{3} - buffer_len, str.size()); + std::memcpy(buffer + buffer_len, str.data(), to_copy); + buffer_len += to_copy; + absl::Cord::Advance(¤t, to_copy); + } + return Utf8DecodeImpl(b, leading, size, absl::string_view(buffer, buffer_len), + code_point); +} + +size_t Utf8Encode(char32_t code_point, std::string* absl_nonnull buffer) { + ABSL_DCHECK(buffer != nullptr); + + char storage[4]; + size_t storage_len = Utf8Encode(code_point, storage); + buffer->append(storage, storage_len); + return storage_len; +} + +size_t Utf8Encode(char32_t code_point, char* absl_nonnull buffer) { + ABSL_DCHECK(buffer != nullptr); + + if (ABSL_PREDICT_FALSE(!UnicodeIsValid(code_point))) { + code_point = kUnicodeReplacementCharacter; + } + size_t storage_len = 0; + if (code_point <= 0x7f) { + buffer[storage_len++] = static_cast(static_cast(code_point)); + } else if (code_point <= 0x7ff) { + buffer[storage_len++] = + static_cast(kT2 | static_cast(code_point >> 6)); + buffer[storage_len++] = + static_cast(kTX | (static_cast(code_point) & kMaskX)); + } else if (code_point <= 0xffff) { + buffer[storage_len++] = + static_cast(kT3 | static_cast(code_point >> 12)); + buffer[storage_len++] = static_cast( + kTX | (static_cast(code_point >> 6) & kMaskX)); + buffer[storage_len++] = + static_cast(kTX | (static_cast(code_point) & kMaskX)); + } else { + buffer[storage_len++] = + static_cast(kT4 | static_cast(code_point >> 18)); + buffer[storage_len++] = static_cast( + kTX | (static_cast(code_point >> 12) & kMaskX)); + buffer[storage_len++] = static_cast( + kTX | (static_cast(code_point >> 6) & kMaskX)); + buffer[storage_len++] = + static_cast(kTX | (static_cast(code_point) & kMaskX)); + } + return storage_len; +} + +} // namespace cel::internal diff --git a/internal/utf8.h b/internal/utf8.h new file mode 100644 index 000000000..f6b530636 --- /dev/null +++ b/internal/utf8.h @@ -0,0 +1,82 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_UTF8_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_UTF8_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +// Returns true if the given UTF-8 encoded string is not malformed, false +// otherwise. +bool Utf8IsValid(absl::string_view str); +bool Utf8IsValid(const absl::Cord& str); + +// Returns the number of Unicode code points in the UTF-8 encoded string. +// +// If there are any invalid bytes, they will each be counted as an invalid code +// point. +size_t Utf8CodePointCount(absl::string_view str); +size_t Utf8CodePointCount(const absl::Cord& str); + +// Validates the given UTF-8 encoded string. The first return value is the +// number of code points and its meaning depends on the second return value. If +// the second return value is true the entire string is not malformed and the +// first return value is the number of code points. If the second return value +// is false the string is malformed and the first return value is the number of +// code points up until the malformed sequence was encountered. +std::pair Utf8Validate(absl::string_view str); +std::pair Utf8Validate(const absl::Cord& str); + +// Decodes the next code point, returning the decoded code point and the number +// of code units (a.k.a. bytes) consumed. In the event that an invalid code unit +// sequence is returned the replacement character, U+FFFD, is returned with a +// code unit count of 1. As U+FFFD requires 3 code units when encoded, this can +// be used to differentiate valid input from malformed input. +size_t Utf8Decode(absl::string_view str, char32_t* absl_nullable code_point); +size_t Utf8Decode(const absl::Cord::CharIterator& it, + char32_t* absl_nullable code_point); +inline std::pair Utf8Decode(absl::string_view str) { + char32_t code_point; + size_t code_units = Utf8Decode(str, &code_point); + return std::pair{code_point, code_units}; +} +inline std::pair Utf8Decode( + const absl::Cord::CharIterator& it) { + char32_t code_point; + size_t code_units = Utf8Decode(it, &code_point); + return std::pair{code_point, code_units}; +} + +// Encodes the given code point and appends it to the buffer. If the code point +// is an unpaired surrogate or outside of the valid Unicode range it is replaced +// with the replacement character, U+FFFD. +size_t Utf8Encode(char32_t code_point, std::string* absl_nonnull buffer); +size_t Utf8Encode(char32_t code_point, char* absl_nonnull buffer); +ABSL_DEPRECATED("Use other overload") +inline size_t Utf8Encode(std::string& buffer, char32_t code_point) { + return Utf8Encode(code_point, &buffer); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_UTF8_H_ diff --git a/internal/utf8_test.cc b/internal/utf8_test.cc new file mode 100644 index 000000000..800102b12 --- /dev/null +++ b/internal/utf8_test.cc @@ -0,0 +1,421 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/utf8.h" + +#include +#include + +#include "absl/strings/cord.h" +#include "absl/strings/cord_test_helpers.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "internal/benchmark.h" +#include "internal/testing.h" + +// Tests is based on +// https://go.googlesource.com/go/+/refs/heads/master/src/unicode/utf8/utf8.go +// but adapted for C++. + +namespace cel::internal { +namespace { + +TEST(Utf8IsValid, String) { + EXPECT_TRUE(Utf8IsValid("")); + EXPECT_TRUE(Utf8IsValid("a")); + EXPECT_TRUE(Utf8IsValid("abc")); + EXPECT_TRUE(Utf8IsValid("\xd0\x96")); + EXPECT_TRUE(Utf8IsValid("\xd0\x96\xd0\x96")); + EXPECT_TRUE(Utf8IsValid( + "\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c")); + EXPECT_TRUE(Utf8IsValid("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")); + EXPECT_TRUE(Utf8IsValid("a\ufffdb")); + EXPECT_TRUE(Utf8IsValid("\xf4\x8f\xbf\xbf")); + + EXPECT_FALSE(Utf8IsValid("\x42\xfa")); + EXPECT_FALSE(Utf8IsValid("\x42\xfa\x43")); + EXPECT_FALSE(Utf8IsValid("\xf4\x90\x80\x80")); + EXPECT_FALSE(Utf8IsValid("\xf7\xbf\xbf\xbf")); + EXPECT_FALSE(Utf8IsValid("\xfb\xbf\xbf\xbf\xbf")); + EXPECT_FALSE(Utf8IsValid("\xc0\x80")); + EXPECT_FALSE(Utf8IsValid("\xed\xa0\x80")); + EXPECT_FALSE(Utf8IsValid("\xed\xbf\xbf")); +} + +TEST(Utf8IsValid, Cord) { + EXPECT_TRUE(Utf8IsValid(absl::Cord(""))); + EXPECT_TRUE(Utf8IsValid(absl::Cord("a"))); + EXPECT_TRUE(Utf8IsValid(absl::Cord("abc"))); + EXPECT_TRUE(Utf8IsValid(absl::Cord("\xd0\x96"))); + EXPECT_TRUE(Utf8IsValid(absl::Cord("\xd0\x96\xd0\x96"))); + EXPECT_TRUE(Utf8IsValid(absl::Cord( + "\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c"))); + EXPECT_TRUE(Utf8IsValid(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9"))); + EXPECT_TRUE(Utf8IsValid(absl::Cord("a\ufffdb"))); + EXPECT_TRUE(Utf8IsValid(absl::Cord("\xf4\x8f\xbf\xbf"))); + + EXPECT_FALSE(Utf8IsValid(absl::Cord("\x42\xfa"))); + EXPECT_FALSE(Utf8IsValid(absl::Cord("\x42\xfa\x43"))); + EXPECT_FALSE(Utf8IsValid(absl::Cord("\xf4\x90\x80\x80"))); + EXPECT_FALSE(Utf8IsValid(absl::Cord("\xf7\xbf\xbf\xbf"))); + EXPECT_FALSE(Utf8IsValid(absl::Cord("\xfb\xbf\xbf\xbf\xbf"))); + EXPECT_FALSE(Utf8IsValid(absl::Cord("\xc0\x80"))); + EXPECT_FALSE(Utf8IsValid(absl::Cord("\xed\xa0\x80"))); + EXPECT_FALSE(Utf8IsValid(absl::Cord("\xed\xbf\xbf"))); +} + +TEST(Utf8CodePointCount, String) { + EXPECT_EQ(Utf8CodePointCount("abcd"), 4); + EXPECT_EQ(Utf8CodePointCount("1,2,3,4"), 7); + EXPECT_EQ(Utf8CodePointCount("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9"), 3); + EXPECT_EQ(Utf8CodePointCount(absl::string_view("\xe2\x00", 2)), 2); + EXPECT_EQ(Utf8CodePointCount("\xe2\x80"), 2); + EXPECT_EQ(Utf8CodePointCount("a\xe2\x80"), 3); +} + +TEST(Utf8CodePointCount, Cord) { + EXPECT_EQ(Utf8CodePointCount(absl::Cord("abcd")), 4); + EXPECT_EQ(Utf8CodePointCount(absl::Cord("1,2,3,4")), 7); + EXPECT_EQ( + Utf8CodePointCount(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")), + 3); + EXPECT_EQ(Utf8CodePointCount(absl::Cord(absl::string_view("\xe2\x00", 2))), + 2); + EXPECT_EQ(Utf8CodePointCount(absl::Cord("\xe2\x80")), 2); + EXPECT_EQ(Utf8CodePointCount(absl::Cord("a\xe2\x80")), 3); +} + +TEST(Utf8Validate, String) { + EXPECT_TRUE(Utf8Validate("").second); + EXPECT_TRUE(Utf8Validate("a").second); + EXPECT_TRUE(Utf8Validate("abc").second); + EXPECT_TRUE(Utf8Validate("\xd0\x96").second); + EXPECT_TRUE(Utf8Validate("\xd0\x96\xd0\x96").second); + EXPECT_TRUE( + Utf8Validate( + "\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c") + .second); + EXPECT_TRUE(Utf8Validate("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9").second); + EXPECT_TRUE(Utf8Validate("a\ufffdb").second); + EXPECT_TRUE(Utf8Validate("\xf4\x8f\xbf\xbf").second); + + EXPECT_FALSE(Utf8Validate("\x42\xfa").second); + EXPECT_FALSE(Utf8Validate("\x42\xfa\x43").second); + EXPECT_FALSE(Utf8Validate("\xf4\x90\x80\x80").second); + EXPECT_FALSE(Utf8Validate("\xf7\xbf\xbf\xbf").second); + EXPECT_FALSE(Utf8Validate("\xfb\xbf\xbf\xbf\xbf").second); + EXPECT_FALSE(Utf8Validate("\xc0\x80").second); + EXPECT_FALSE(Utf8Validate("\xed\xa0\x80").second); + EXPECT_FALSE(Utf8Validate("\xed\xbf\xbf").second); + + EXPECT_EQ(Utf8Validate("abcd").first, 4); + EXPECT_EQ(Utf8Validate("1,2,3,4").first, 7); + EXPECT_EQ(Utf8Validate("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9").first, 3); + EXPECT_EQ(Utf8Validate(absl::string_view("\xe2\x00", 2)).first, 0); + EXPECT_EQ(Utf8Validate("\xe2\x80").first, 0); + EXPECT_EQ(Utf8Validate("a\xe2\x80").first, 1); +} + +TEST(Utf8Validate, Cord) { + EXPECT_TRUE(Utf8Validate(absl::Cord("")).second); + EXPECT_TRUE(Utf8Validate(absl::Cord("a")).second); + EXPECT_TRUE(Utf8Validate(absl::Cord("abc")).second); + EXPECT_TRUE(Utf8Validate(absl::Cord("\xd0\x96")).second); + EXPECT_TRUE(Utf8Validate(absl::Cord("\xd0\x96\xd0\x96")).second); + EXPECT_TRUE(Utf8Validate(absl::Cord("\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-" + "\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c")) + .second); + EXPECT_TRUE( + Utf8Validate(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")).second); + EXPECT_TRUE(Utf8Validate(absl::Cord("a\ufffdb")).second); + EXPECT_TRUE(Utf8Validate(absl::Cord("\xf4\x8f\xbf\xbf")).second); + + EXPECT_FALSE(Utf8Validate(absl::Cord("\x42\xfa")).second); + EXPECT_FALSE(Utf8Validate(absl::Cord("\x42\xfa\x43")).second); + EXPECT_FALSE(Utf8Validate(absl::Cord("\xf4\x90\x80\x80")).second); + EXPECT_FALSE(Utf8Validate(absl::Cord("\xf7\xbf\xbf\xbf")).second); + EXPECT_FALSE(Utf8Validate(absl::Cord("\xfb\xbf\xbf\xbf\xbf")).second); + EXPECT_FALSE(Utf8Validate(absl::Cord("\xc0\x80")).second); + EXPECT_FALSE(Utf8Validate(absl::Cord("\xed\xa0\x80")).second); + EXPECT_FALSE(Utf8Validate(absl::Cord("\xed\xbf\xbf")).second); + + EXPECT_EQ(Utf8Validate(absl::Cord("abcd")).first, 4); + EXPECT_EQ(Utf8Validate(absl::Cord("1,2,3,4")).first, 7); + EXPECT_EQ( + Utf8Validate(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")).first, + 3); + EXPECT_EQ(Utf8Validate(absl::Cord(absl::string_view("\xe2\x00", 2))).first, + 0); + EXPECT_EQ(Utf8Validate(absl::Cord("\xe2\x80")).first, 0); + EXPECT_EQ(Utf8Validate(absl::Cord("a\xe2\x80")).first, 1); +} + +struct Utf8EncodeTestCase final { + char32_t code_point; + absl::string_view code_units; +}; + +using Utf8EncodeTest = testing::TestWithParam; + +TEST_P(Utf8EncodeTest, Compliance) { + const Utf8EncodeTestCase& test_case = GetParam(); + std::string result; + EXPECT_EQ(Utf8Encode(result, test_case.code_point), + test_case.code_units.size()); + EXPECT_EQ(result, test_case.code_units); +} + +INSTANTIATE_TEST_SUITE_P(Utf8EncodeTest, Utf8EncodeTest, + testing::ValuesIn({ + {0x0000, absl::string_view("\x00", 1)}, + {0x0001, "\x01"}, + {0x007e, "\x7e"}, + {0x007f, "\x7f"}, + {0x0080, "\xc2\x80"}, + {0x0081, "\xc2\x81"}, + {0x00bf, "\xc2\xbf"}, + {0x00c0, "\xc3\x80"}, + {0x00c1, "\xc3\x81"}, + {0x00c8, "\xc3\x88"}, + {0x00d0, "\xc3\x90"}, + {0x00e0, "\xc3\xa0"}, + {0x00f0, "\xc3\xb0"}, + {0x00f8, "\xc3\xb8"}, + {0x00ff, "\xc3\xbf"}, + {0x0100, "\xc4\x80"}, + {0x07ff, "\xdf\xbf"}, + {0x0400, "\xd0\x80"}, + {0x0800, "\xe0\xa0\x80"}, + {0x0801, "\xe0\xa0\x81"}, + {0x1000, "\xe1\x80\x80"}, + {0xd000, "\xed\x80\x80"}, + {0xd7ff, "\xed\x9f\xbf"}, + {0xe000, "\xee\x80\x80"}, + {0xfffe, "\xef\xbf\xbe"}, + {0xffff, "\xef\xbf\xbf"}, + {0x10000, "\xf0\x90\x80\x80"}, + {0x10001, "\xf0\x90\x80\x81"}, + {0x40000, "\xf1\x80\x80\x80"}, + {0x10fffe, "\xf4\x8f\xbf\xbe"}, + {0x10ffff, "\xf4\x8f\xbf\xbf"}, + {0xFFFD, "\xef\xbf\xbd"}, + })); + +struct Utf8DecodeTestCase final { + char32_t code_point; + absl::string_view code_units; +}; + +using Utf8DecodeTest = testing::TestWithParam; + +TEST_P(Utf8DecodeTest, StringView) { + const Utf8DecodeTestCase& test_case = GetParam(); + auto [code_point, code_units] = Utf8Decode(test_case.code_units); + EXPECT_EQ(code_units, test_case.code_units.size()) + << absl::CHexEscape(test_case.code_units); + EXPECT_EQ(code_point, test_case.code_point) + << absl::CHexEscape(test_case.code_units); + EXPECT_EQ(Utf8Decode(test_case.code_units, nullptr), + test_case.code_units.size()); +} + +TEST_P(Utf8DecodeTest, Cord) { + const Utf8DecodeTestCase& test_case = GetParam(); + auto cord = absl::Cord(test_case.code_units); + auto it = cord.char_begin(); + auto [code_point, code_units] = Utf8Decode(it); + absl::Cord::Advance(&it, code_units); + EXPECT_EQ(it, cord.char_end()); + EXPECT_EQ(code_units, test_case.code_units.size()) + << absl::CHexEscape(test_case.code_units); + EXPECT_EQ(code_point, test_case.code_point) + << absl::CHexEscape(test_case.code_units); + it = cord.char_begin(); + EXPECT_EQ(Utf8Decode(it, nullptr), test_case.code_units.size()); +} + +std::vector FragmentString(absl::string_view text) { + std::vector fragments; + fragments.reserve(text.size()); + for (const auto& c : text) { + fragments.emplace_back().push_back(c); + } + return fragments; +} + +TEST_P(Utf8DecodeTest, CordFragmented) { + const Utf8DecodeTestCase& test_case = GetParam(); + auto cord = absl::MakeFragmentedCord(FragmentString(test_case.code_units)); + auto it = cord.char_begin(); + auto [code_point, code_units] = Utf8Decode(it); + absl::Cord::Advance(&it, code_units); + EXPECT_EQ(it, cord.char_end()); + EXPECT_EQ(code_units, test_case.code_units.size()) + << absl::CHexEscape(test_case.code_units); + EXPECT_EQ(code_point, test_case.code_point) + << absl::CHexEscape(test_case.code_units); +} + +INSTANTIATE_TEST_SUITE_P(Utf8DecodeTest, Utf8DecodeTest, + testing::ValuesIn({ + {0x0000, absl::string_view("\x00", 1)}, + {0x0001, "\x01"}, + {0x007e, "\x7e"}, + {0x007f, "\x7f"}, + {0x0080, "\xc2\x80"}, + {0x0081, "\xc2\x81"}, + {0x00bf, "\xc2\xbf"}, + {0x00c0, "\xc3\x80"}, + {0x00c1, "\xc3\x81"}, + {0x00c8, "\xc3\x88"}, + {0x00d0, "\xc3\x90"}, + {0x00e0, "\xc3\xa0"}, + {0x00f0, "\xc3\xb0"}, + {0x00f8, "\xc3\xb8"}, + {0x00ff, "\xc3\xbf"}, + {0x0100, "\xc4\x80"}, + {0x07ff, "\xdf\xbf"}, + {0x0400, "\xd0\x80"}, + {0x0800, "\xe0\xa0\x80"}, + {0x0801, "\xe0\xa0\x81"}, + {0x1000, "\xe1\x80\x80"}, + {0xd000, "\xed\x80\x80"}, + {0xd7ff, "\xed\x9f\xbf"}, + {0xe000, "\xee\x80\x80"}, + {0xfffe, "\xef\xbf\xbe"}, + {0xffff, "\xef\xbf\xbf"}, + {0x10000, "\xf0\x90\x80\x80"}, + {0x10001, "\xf0\x90\x80\x81"}, + {0x40000, "\xf1\x80\x80\x80"}, + {0x10fffe, "\xf4\x8f\xbf\xbe"}, + {0x10ffff, "\xf4\x8f\xbf\xbf"}, + {0xFFFD, "\xef\xbf\xbd"}, + })); + +void BM_Utf8CodePointCount_String_AsciiTen(benchmark::State& state) { + for (auto s : state) { + benchmark::DoNotOptimize(Utf8CodePointCount("0123456789")); + } +} + +BENCHMARK(BM_Utf8CodePointCount_String_AsciiTen); + +void BM_Utf8CodePointCount_Cord_AsciiTen(benchmark::State& state) { + absl::Cord value("0123456789"); + for (auto s : state) { + benchmark::DoNotOptimize(Utf8CodePointCount(value)); + } +} + +BENCHMARK(BM_Utf8CodePointCount_Cord_AsciiTen); + +void BM_Utf8CodePointCount_String_JapaneseTen(benchmark::State& state) { + for (auto s : state) { + benchmark::DoNotOptimize(Utf8CodePointCount( + "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" + "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5")); + } +} + +BENCHMARK(BM_Utf8CodePointCount_String_JapaneseTen); + +void BM_Utf8CodePointCount_Cord_JapaneseTen(benchmark::State& state) { + absl::Cord value( + "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" + "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5"); + for (auto s : state) { + benchmark::DoNotOptimize(Utf8CodePointCount(value)); + } +} + +BENCHMARK(BM_Utf8CodePointCount_Cord_JapaneseTen); + +void BM_Utf8IsValid_String_AsciiTen(benchmark::State& state) { + for (auto s : state) { + benchmark::DoNotOptimize(Utf8IsValid("0123456789")); + } +} + +BENCHMARK(BM_Utf8IsValid_String_AsciiTen); + +void BM_Utf8IsValid_Cord_AsciiTen(benchmark::State& state) { + absl::Cord value("0123456789"); + for (auto s : state) { + benchmark::DoNotOptimize(Utf8IsValid(value)); + } +} + +BENCHMARK(BM_Utf8IsValid_Cord_AsciiTen); + +void BM_Utf8IsValid_String_JapaneseTen(benchmark::State& state) { + for (auto s : state) { + benchmark::DoNotOptimize(Utf8IsValid( + "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" + "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5")); + } +} + +BENCHMARK(BM_Utf8IsValid_String_JapaneseTen); + +void BM_Utf8IsValid_Cord_JapaneseTen(benchmark::State& state) { + absl::Cord value( + "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" + "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5"); + for (auto s : state) { + benchmark::DoNotOptimize(Utf8IsValid(value)); + } +} + +BENCHMARK(BM_Utf8IsValid_Cord_JapaneseTen); + +void BM_Utf8Validate_String_AsciiTen(benchmark::State& state) { + for (auto s : state) { + benchmark::DoNotOptimize(Utf8Validate("0123456789")); + } +} + +BENCHMARK(BM_Utf8Validate_String_AsciiTen); + +void BM_Utf8Validate_Cord_AsciiTen(benchmark::State& state) { + absl::Cord value("0123456789"); + for (auto s : state) { + benchmark::DoNotOptimize(Utf8Validate(value)); + } +} + +BENCHMARK(BM_Utf8Validate_Cord_AsciiTen); + +void BM_Utf8Validate_String_JapaneseTen(benchmark::State& state) { + for (auto s : state) { + benchmark::DoNotOptimize(Utf8Validate( + "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" + "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5")); + } +} + +BENCHMARK(BM_Utf8Validate_String_JapaneseTen); + +void BM_Utf8Validate_Cord_JapaneseTen(benchmark::State& state) { + absl::Cord value( + "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" + "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5"); + for (auto s : state) { + benchmark::DoNotOptimize(Utf8Validate(value)); + } +} + +BENCHMARK(BM_Utf8Validate_Cord_JapaneseTen); + +} // namespace +} // namespace cel::internal diff --git a/internal/value_internal.h b/internal/value_internal.h deleted file mode 100644 index f9cee1334..000000000 --- a/internal/value_internal.h +++ /dev/null @@ -1,398 +0,0 @@ -// Internal declarations for value.h/.cc. - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_VALUE_INTERNAL_H_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_VALUE_INTERNAL_H_H_ - -#include -#include "absl/types/optional.h" -#include "absl/types/variant.h" -#include "common/enum.h" -#include "common/error.h" -#include "common/id.h" -#include "common/parent_ref.h" -#include "common/type.h" -#include "common/unknown.h" -#include "internal/adapter_util.h" -#include "internal/cast.h" -#include "internal/ref_countable.h" - -namespace google { -namespace api { -namespace expr { - -namespace common { - -class List; -class Map; -class Object; - -} // namespace common - -namespace internal { - -// The types stored directly in ValueData. -using InlineTypes = types; -// The types stored by reffed copy in ValueData. -using RefCopyTypes = types; - -template -using is_shared_value = std::is_convertible; - -// If a type can be returned by pointer or reference. -template -using is_direct_type = - or_t, type_in, is_shared_value>; - -// If a type can be stored in a 'Value'. -template -using is_value_type = or_t, is_numeric, is_string>; - -class BaseValue { - protected: - // The return type of `Value::get_if` call for T. - template - using GetIfType = - conditional_t::value, const T*, absl::optional>; - - // The return type of `Value::get` call for T. - template - using GetType = conditional_t::value, const T&, T>; - - // If the string is 8 bytes, it is assumed to be copy-on-write and is stored - // inline. Otherwise it is held in a ref-counted container. - using OwnedStr = SizeLimitHolder; - - // Holds the string_view and parent in a refcounted container. - using ParentOwnedStr = - Holder>>; - - // Holds the string_view in a ref counted container. - using UnownedStr = RefCopyHolder; - - // The variant type used by `expr::Value`. - using ValueData = absl::variant< - // Small types can be stored inline. - CopyHolder, // null_value - CopyHolder, // bool_value - CopyHolder, // int_value - CopyHolder, // uint_value - CopyHolder, // double_value - CopyHolder, // enum_value - CopyHolder, // type_value - CopyHolder, // type_value - CopyHolder, // type_value - CopyHolder, // unknown of a single id. - - // Other types are stored as unowned, owned pointers. - // An intrusive shared pointer is used to minimize the variant size. - OwnedStr, ParentOwnedStr, UnownedStr, // string_value - OwnedStr, ParentOwnedStr, UnownedStr, // bytes_value - RefPtrHolder, - UnownedPtrHolder, // map_value - RefPtrHolder, - UnownedPtrHolder, // list_value - RefPtrHolder, - UnownedPtrHolder, // object_value - RefCopyHolder, // duration - RefCopyHolder, // time - RefCopyHolder, // enum_value - RefCopyHolder, // type_value - RefCopyHolder, // error - RefCopyHolder>; // unknown - - // The public types associated with each Value::Kind entry. - using KindToType = types; // kUnknown - - enum Index { - kNull, // null_value - kBool, // bool_value - kInt, // int_value - kUInt, // uint_value - kDouble, // double_value - kNamedEnum, // enum_value - kBasicType, // type_value - kObjectType, // type_value - kEnumType, // type_value - kId, // unknown - kInlineEnd, - - kStr = kInlineEnd, // string_value - kStrView, // string_value - kStrPtr, // string_value - kBytes, // bytes_value - kBytesView, // bytes_value - kBytesPtr, // bytes_value - kMap, // map_value - kMapPtr, // map_value - kList, // list_value - kListPtr, // list_value - kObject, // object_value - kObjectPtr, // object_value - kOptionalOwnershipEnd, - - kDuration = kOptionalOwnershipEnd, // duration - kTime, // time - kObjectEnd, - kUnnamedEnum = kObjectEnd, // enum_value - kUnrecognizedType, // type_value - kValueEnd, - - kError = kValueEnd, // error - kUnknown, // unknown - DATA_SIZE, - }; - - /** - * A helper function to grab either the unowned or owned string value. - * - * Assumes the owned value is immediately after the unowned value. - * - * @tparam I the index of the unowned value. - */ - template - static absl::string_view GetStr(const ValueData& data); - - BaseValue() = default; - ~BaseValue() = default; - - private: - friend class ValueVisitTest; - friend class ValueAdapterTest; - - template - using NumericValueType = - conditional_t::value, int64_t, - conditional_t::value, uint64_t, double>>; - template - using CustomValueType = conditional_t< - std::is_convertible::value, common::Map, - conditional_t::value, - common::List, common::Object>>; - - template - using HolderType = conditional_t::value, - CopyHolder, RefCopyHolder>; - - // A ValueData visitor that, for any type in 'Alts', returns the associated - // value as T. - template - using GetVisitor = OrderedVisitor< - ConvertAdapter, T>, Fallback>; - - template - using GetPtrVisitor = - OrderedVisitor, Fallback>; - - // Base class for types that are stored as one of several other types. - template - struct BaseTypeHelper { - static absl::optional get_if(const ValueData* data); - static T get(const ValueData& data); - }; - - /** - * An adapter that converts private Value internal types into public types. - */ - struct ValueDataAdapter { - // Convert the single id case into Unknown. - common::Unknown operator()(const common::Id& value) { - return common::Unknown(value); - } - - // Normalize string values to string_view - template - specialize_ift, absl::string_view> operator()(T& value); - }; - - protected: - using ValueAdapter = CompositeAdapter; - - /** The adapted visitor type. */ - template - using AdaptedVisitor = VisitorAdapter; - - /** - * The result of applying an adapted visitor to Value. - * - * Values must contain only Value types. - */ - template - using VisitType = - decltype(absl::visit(inst_of&&>(), - inst_of()...)); - - // A collection of helper functions for the given type. - template - struct TypeHelper : BaseTypeHelper {}; -}; - -template -absl::string_view BaseValue::GetStr(const ValueData& data) { - switch (data.index()) { - case I: // Owned - return *absl::get(data); - case I + 1: // Parent owned - return *absl::get(data); - default: // Unowned - return *absl::get(data); - } -} - -// Base class for types that are stored directly. -template -struct BaseValue::BaseTypeHelper { - static const T* get_if(const ValueData* data); - static const T& get(const ValueData& data); -}; - -// Base class for types that are stored as a runtime compatible type. -template -struct BaseValue::BaseTypeHelper { - static absl::optional get_if(const ValueData* data); - static T get(const ValueData& data); -}; - -// Base class for SharedValue types. -template -struct BaseValue::BaseTypeHelper { - using C = CustomValueType; - static const T* get_if(const ValueData* data); - static const T& get(const ValueData& data); -}; - -// Specialization for CustomValues. -template -struct BaseValue::TypeHelper>> - : BaseTypeHelper {}; - -// Specialization for numeric types. -template -struct BaseValue::TypeHelper>> - : BaseTypeHelper> {}; - -template <> -struct BaseValue::TypeHelper - : BaseTypeHelper {}; - -template <> -struct BaseValue::TypeHelper - : BaseTypeHelper {}; - -template <> -struct BaseValue::TypeHelper - : BaseTypeHelper {}; -template -specialize_ift, absl::string_view> BaseValue::ValueDataAdapter:: -operator()(T& value) { - return value; -} - -template -absl::optional BaseValue::BaseTypeHelper::get_if( - const ValueData* data) { - using R = absl::optional; - return absl::visit(GetVisitor, R, F, Alts...>(), *data); -} - -template -T BaseValue::BaseTypeHelper::get(const ValueData& data) { - return absl::visit(GetVisitor, T, F, Alts...>(), data); -} - -template -const T* BaseValue::BaseTypeHelper::get_if(const ValueData* data) { - if (auto holder = absl::get_if>(data)) { - return &holder->value(); - } - return nullptr; -} - -template -const T& BaseValue::BaseTypeHelper::get(const ValueData& data) { - return *absl::get>(data); -} - -template -absl::optional BaseValue::BaseTypeHelper::get_if( - const ValueData* data) { - auto value = BaseTypeHelper::get_if(data); - if (!value || !representable_as(*value)) { - return absl::nullopt; - } - return absl::optional(absl::in_place, *value); -} - -template -T BaseValue::BaseTypeHelper::get(const ValueData& data) { - const N& value = BaseTypeHelper::get(data); - if (!representable_as(value)) { - // Throw bad_variant_access without using `throw` keyword. - return absl::get<0>(absl::variant(absl::in_place_index<1>, 1)); - } - return T(value); -} - -template -const T* BaseValue::BaseTypeHelper::get_if( - const ValueData* data) { - return C::template cast_if( - absl::visit(GetPtrVisitor, C>(), *data)); -} - -template -const T& BaseValue::BaseTypeHelper::get( - const ValueData& data) { - const T* value = get_if(&data); - if (value == nullptr) { - // Throw bad_variant_access without using `throw` keyword. - return *absl::get<0>(absl::variant(absl::in_place_index<1>, 1)); - } - return *value; -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -// Hash specialization for parented owned values. -namespace std { -template -struct hash> { - std::size_t operator()( - const std::pair& value) { - return google::api::expr::internal::Hash(*value.second); - } -}; - -template -struct hash> { - std::size_t operator()( - const std::pair& value) { - return google::api::expr::internal::Hash(value.second); - } -}; - -} // namespace std - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_VALUE_INTERNAL_H_H_ diff --git a/internal/value_internal_test.cc b/internal/value_internal_test.cc deleted file mode 100644 index 4cf8ccc62..000000000 --- a/internal/value_internal_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -#include "internal/value_internal.h" - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "absl/strings/escaping.h" -#include "absl/strings/str_cat.h" -#include "testutil/util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -using testutil::ExpectSameType; - -struct RefOnlyType : public RefCountable { - public: - RefOnlyType() = default; - RefOnlyType(const RefOnlyType&) = delete; - RefOnlyType(RefOnlyType&&) = delete; - RefOnlyType& operator=(const RefOnlyType&) = delete; - RefOnlyType& operator=(RefOnlyType&&) = delete; - - bool operator==(const RefOnlyType& rhs) const { return this == &rhs; } -}; - -class ValueAdapterTest : public ::testing::Test { - public: - using ValueData = BaseValue::ValueData; - using ValueAdapter = BaseValue::ValueAdapter; - template - using BaseTypeHelper = BaseValue::BaseTypeHelper; - template - using TypeHelper = BaseValue::TypeHelper; - using UnownedStr = BaseValue::UnownedStr; - using OwnedStr = BaseValue::OwnedStr; - - template - void TestAdapter(T&& value, E&& expected) { - // Verify that the adapter produces the expected type. - ExpectSameType(value)))>(); - - EXPECT_EQ(expected, MaybeAdapt(ValueAdapter(), std::forward(value))); - } - - template - void TestValueAdapter(T&& value, E&& expected) { - ExpectSameType(value)))>(); - EXPECT_EQ(expected, ValueAdapter()(value)); - TestAdapter(std::forward(value), - std::forward(expected)); - } -}; - -TEST_F(ValueAdapterTest, Bool) { - ExpectSameType, BaseTypeHelper>(); - ExpectSameType::get_if( - inst_of()))>(); - ExpectSameType::get_if( - inst_of()))>(); - static_assert(!is_numeric::value, "blah"); -} - -TEST_F(ValueAdapterTest, NullPtr) { - // nullptr passes through by value, unchanged - TestValueAdapter(nullptr, nullptr); -} - -TEST_F(ValueAdapterTest, RefPtr) { - // Smart pointer is dereferenced. - auto ptr = RefPtrHolder(new RefOnlyType()); - TestAdapter(ptr, *ptr); - auto cptr = RefPtrHolder(new RefOnlyType()); - TestAdapter(cptr, *cptr); -} - -TEST_F(ValueAdapterTest, RefCopy) { - auto i = RefCopyHolder(1); - TestAdapter(i, 1); - - const auto& const_ptr_int = i; - TestAdapter(const_ptr_int, 1); - - auto ptr_const_int = RefCopyHolder(2); - TestAdapter(ptr_const_int, 2); -} - -TEST_F(ValueAdapterTest, String) { - // Strings are normalized to string_view. - absl::string_view view("hi"); - TestAdapter(view, "hi"); - std::string value = "hi"; - TestAdapter(value, "hi"); - - const std::string& cvalue = value; - TestAdapter(cvalue, "hi"); - - UnownedStr unowned(cvalue); - TestAdapter(unowned, "hi"); - - using ParentOwnedStrPolicy = Ref, Copy>>; - using ParentOwnedStr = Holder; - auto parent = MakeReffed(); - ParentOwnedStr parent_owned(parent, cvalue); - TestValueAdapter(parent_owned, "hi"); - - OwnedStr owned("hi"); - TestAdapter(owned, "hi"); -} - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/visitor_util.h b/internal/visitor_util.h deleted file mode 100644 index 6a80b8208..000000000 --- a/internal/visitor_util.h +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Utilities for working with visitor patterns. - * - * Terminology: - * - A 'visitor' is a callable that can be passed arguments and returns - * a result. - * - A 'mixin visitor' is a visitor only handles a subset of cases. These can - * be merged using `OrderedVisitor`. - * - * The primary utilities provided in this library include: - * - VisitResultType: The return type of apply a visitor to a list of arguments. - * - MaybeVisit/MaybeVisitResult: Uses the given visitor if - * possible, otherwise uses a given fallback visitor. - * - OrderedVisitor: A visitor that calls the first provided visitor that - * supports the given arguments. Useful for mixin visitor or to resolve - * call ambiguities. - */ - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_VISITOR_UTIL_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_VISITOR_UTIL_H_ - -#include - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/variant.h" -#include "internal/specialize.h" -#include "internal/types.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { - -/** - * The result type of applying the visitor once. - * - * Undefined when a matching overload is not present or *ambiguous*. - */ -template -using VisitResultType = typename std::result_of::type; - -// Specialization for when Visitor can be applied to the arguments. -template -VisitResultType MaybeVisitImpl(Visitor&& vis, - Fallback&& fallback, - specialize, Args&&... args) { - return vis(std::forward(args)...); -} - -// Specialization for when a specialized Visitor can be applied to the -// arguments. -template -VisitResultType MaybeVisitImpl( - Visitor&& vis, Fallback&& fallback, specialize, Args&&... args) { - return vis(std::forward(args)..., specialize()); -} - -// Specialization for when Visitor cannot be applied to the arguments. -template -VisitResultType MaybeVisitImpl(Visitor&& vis, - Fallback&& fallback, general, - Args&&... args) { - return fallback(std::forward(args)...); -} - -/** The type returned by a call to MaybeVisit. */ -template -using MaybeVisitResultType = - decltype(MaybeVisitImpl(inst_of(), inst_of(), - specialize(), inst_of()...)); - -/** - * A helper function that tries to apply Visitor, and falls back on Fallback - * when it cannot. - */ -template -MaybeVisitResultType MaybeVisit(Visitor&& vis, - Fallback&& fallback, - Args&&... args) { - return MaybeVisitImpl(std::forward(vis), - std::forward(fallback), specialize(), - std::forward(args)...); -} - -/** A visitor that delegates to the first applicable visitor. */ -template -class OrderedVisitor; - -// Only a single visitor, so expose it direclty. -template -class OrderedVisitor : public Visitor { - public: - OrderedVisitor() = default; - OrderedVisitor(const OrderedVisitor&) = default; - OrderedVisitor(OrderedVisitor&&) = default; - explicit OrderedVisitor(Visitor&& vis) - : Visitor(std::forward(vis)) {} - - template - using ResultType = VisitResultType; -}; - -// Multiple visitors, so pull of the head, and construct a new visitor from -// the tail. Then use MaybeVisit to try and visit Head first. -template -class OrderedVisitor { - private: - using Visitor = Head; - using Fallback = OrderedVisitor; - - public: - template - using ResultType = MaybeVisitResultType; - - OrderedVisitor() = default; - OrderedVisitor(const OrderedVisitor&) = default; - OrderedVisitor(OrderedVisitor&&) = default; - OrderedVisitor(Visitor&& vis, Tail&&... fallback) - : vis_(std::forward(vis)), - fallback_(std::forward(fallback)...) {} - - template - ResultType operator()(Args&&... args) { - return MaybeVisit(vis_, fallback_, std::forward(args)...); - } - - private: - Visitor vis_; - Fallback fallback_; -}; - -/** Helper function to construct an OrderedVisitor. */ -template -OrderedVisitor MakeOrderedVisitor(Visitors&&... vis) { - return OrderedVisitor(std::forward(vis)...); -} - -/** A visitor that ignores all arguments and returns the given value. */ -template -struct DefaultVisitor { - template - T operator()(Args&&... args) { - return T(); - } -}; - -// A visitor that throws a absl::bad_variant_access exception. -template -struct BadAccessVisitor { - template - R operator()(Args&&...) { - // Throw bad_variant_access without using `throw` keyword. - return absl::get<0>(absl::variant(absl::in_place_index<1>, 1)); - } -}; - -/** - * A visitor that check the equality of two arguments if they are the same type. - * - * If the arguments are of different types, false is returned, otherwise the - * result of the == operator is returned. - */ -struct StrictEqVisitor { - // Specialization for double that treats NaN as equal. - bool operator()(double lhs, double rhs) { - if (lhs == rhs) { - return true; - } - return std::isnan(lhs) && std::isnan(rhs); - } - - template - bool operator()(const T& lhs, const T& rhs) { - return lhs == rhs; - } - - template - bool operator()(const T& lhs, const S& rhs) { - return false; - } -}; - -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_VISITOR_UTIL_H_ diff --git a/internal/visitor_util_test.cc b/internal/visitor_util_test.cc deleted file mode 100644 index 1827b4527..000000000 --- a/internal/visitor_util_test.cc +++ /dev/null @@ -1,150 +0,0 @@ -#include "internal/visitor_util.h" -#include "internal/adapter_util.h" - -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "testutil/util.h" - -namespace google { -namespace api { -namespace expr { -namespace internal { -namespace { - -using testutil::ExpectSameType; - -struct NoCopyType { - NoCopyType() = default; - NoCopyType(const NoCopyType&) = delete; - NoCopyType(NoCopyType&&) = default; -}; - -struct MyStruct {}; - -struct TestAdapter { - int operator()(int value) { return value + 1; } - - absl::string_view operator()(const MyStruct& value, general) { - return "general"; - } - absl::string_view operator()(const MyStruct& value, specialize) { - return "specialize"; - } - - template - MaybeAdaptResultType operator()(T* ptr) { - return MaybeAdapt(*this, *ptr); - } -}; - -TEST(MaybeVisitTest, MaybeAdaptResultType) { - // Finds the right function for all convertible types. - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - - // Even recursively. - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - ExpectSameType>(); - - // Reference is preserved. - ExpectSameType>(); - // Move is preserved. - ExpectSameType>(); - // Even recursively. - ExpectSameType>(); -} - -TEST(MaybeVisitTest, ExactMatch) { - // Visits when exactly matches. - EXPECT_EQ(MaybeAdapt(TestAdapter(), 1), 2); - // Event recursively. - int i = 9; - EXPECT_EQ(MaybeAdapt(TestAdapter(), &i), 10); -} - -TEST(MaybeVisitTest, Convertible) { - // Visits when convertible. - EXPECT_EQ(MaybeAdapt(TestAdapter(), 2.5), 3); - // Even recursively. - double d = 9.5; - EXPECT_EQ(MaybeAdapt(TestAdapter(), &d), 10); -} - -TEST(MaybeVisitTest, NoOverload) { - // Does not visit mismatched. - EXPECT_EQ(MaybeAdapt(TestAdapter(), std::string("hi")), "hi"); - - // Works with move only. - NoCopyType no_copy; - EXPECT_EQ(&MaybeAdapt(TestAdapter(), no_copy), &no_copy); - - // Even recursively. - EXPECT_EQ(&MaybeAdapt(TestAdapter(), &no_copy), &no_copy); -} - -TEST(MaybeVisitTest, Specialize) { - // Automatically specializes. - EXPECT_EQ(MaybeAdapt(TestAdapter(), MyStruct()), "specialize"); -} - -struct StringVisitor { - absl::string_view operator()(absl::string_view any_string) { - return "string"; - } -}; - -struct IntVisitor { - absl::string_view operator()(int any_int) { return "int"; } -}; - -struct DoubleVisitor { - absl::string_view operator()(double any_double) { return "double"; } -}; - -struct NoCopyVisitor { - absl::string_view operator()(const NoCopyType& any_int) { - return "NoCopyType"; - } -}; - -TEST(OrderedVisitorTest, OrderedVisitor) { - auto vis = MakeOrderedVisitor(StringVisitor(), IntVisitor(), DoubleVisitor(), - NoCopyVisitor(), DefaultVisitor()); - EXPECT_EQ("string", vis("hi")); - EXPECT_EQ("int", vis(1)); - // Double is converted to an int, as IntVisitor is higher priority. - EXPECT_EQ("int", vis(2.5)); - EXPECT_EQ("NoCopyType", vis(NoCopyType())); - EXPECT_EQ("", vis(MyStruct())); -} - -} // namespace -} // namespace internal -} // namespace expr -} // namespace api -} // namespace google diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc new file mode 100644 index 000000000..02e50c3e3 --- /dev/null +++ b/internal/well_known_types.cc @@ -0,0 +1,2181 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/well_known_types.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "common/json.h" +#include "common/memory.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/protobuf_runtime_version.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" +#include "google/protobuf/reflection.h" +#include "google/protobuf/util/time_util.h" + +namespace cel::well_known_types { + +namespace { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::EnumDescriptor; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::OneofDescriptor; +using ::google::protobuf::util::TimeUtil; + +using CppStringType = ::google::protobuf::FieldDescriptor::CppStringType; + +FieldDescriptor::Label GetFieldLabel( + const FieldDescriptor* absl_nonnull field) { + if (field->is_required()) { + return FieldDescriptor::LABEL_REQUIRED; + } else if (field->is_repeated()) { + return FieldDescriptor::LABEL_REPEATED; + } else { + return FieldDescriptor::LABEL_OPTIONAL; + } +} + +absl::string_view FlatStringValue( + const StringValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return absl::visit( + absl::Overload( + [](absl::string_view string) -> absl::string_view { return string; }, + [&](const absl::Cord& cord) -> absl::string_view { + if (auto flat = cord.TryFlat(); flat) { + return *flat; + } + scratch = static_cast(cord); + return scratch; + }), + AsVariant(value)); +} + +StringValue CopyStringValue(const StringValue& value, + std::string& scratch + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> StringValue { + if (string.data() != scratch.data()) { + scratch.assign(string.data(), string.size()); + return scratch; + } + return string; + }, + [](const absl::Cord& cord) -> StringValue { return cord; }), + AsVariant(value)); +} + +BytesValue CopyBytesValue(const BytesValue& value, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return absl::visit( + absl::Overload( + [&](absl::string_view string) -> BytesValue { + if (string.data() != scratch.data()) { + scratch.assign(string.data(), string.size()); + return scratch; + } + return string; + }, + [](const absl::Cord& cord) -> BytesValue { return cord; }), + AsVariant(value)); +} + +google::protobuf::Reflection::ScratchSpace& GetScratchSpace() { + static absl::NoDestructor scratch_space; + return *scratch_space; +} + +template +Variant GetStringField(const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const FieldDescriptor* absl_nonnull field, + CppStringType string_type, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(field->cpp_string_type() == string_type); + switch (string_type) { + case CppStringType::kCord: + return reflection->GetCord(message, field); + case CppStringType::kView: + ABSL_FALLTHROUGH_INTENDED; + case CppStringType::kString: + // Message is guaranteed to be storing as some sort of contiguous array of + // bytes, there is no need to copy. But unfortunately `GetStringView` + // forces taking scratch space. + return reflection->GetStringView(message, field, GetScratchSpace()); + default: + return absl::string_view( + reflection->GetStringReference(message, field, &scratch)); + } +} + +template +Variant GetStringField(const google::protobuf::Message& message, + const FieldDescriptor* absl_nonnull field, + CppStringType string_type, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return GetStringField(message.GetReflection(), message, field, + string_type, scratch); +} + +template +Variant GetRepeatedStringField( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, + CppStringType string_type, int index, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(field->cpp_string_type() == string_type); + switch (string_type) { + case CppStringType::kView: + ABSL_FALLTHROUGH_INTENDED; + case CppStringType::kString: + // Message is guaranteed to be storing as some sort of contiguous array of + // bytes, there is no need to copy. But unfortunately `GetStringView` + // forces taking scratch space. + return reflection->GetRepeatedStringView(message, field, index, + GetScratchSpace()); + default: + return absl::string_view(reflection->GetRepeatedStringReference( + message, field, index, &scratch)); + } +} + +template +Variant GetRepeatedStringField( + const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, + CppStringType string_type, int index, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return GetRepeatedStringField(message.GetReflection(), message, + field, string_type, index, scratch); +} + +absl::StatusOr GetMessageTypeByName( + const DescriptorPool* absl_nonnull pool, absl::string_view name) { + const auto* descriptor = pool->FindMessageTypeByName(name); + if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { + return absl::InvalidArgumentError(absl::StrCat( + "descriptor missing for protocol buffer message well known type: ", + name)); + } + return descriptor; +} + +absl::StatusOr GetEnumTypeByName( + const DescriptorPool* absl_nonnull pool, absl::string_view name) { + const auto* descriptor = pool->FindEnumTypeByName(name); + if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { + return absl::InvalidArgumentError(absl::StrCat( + "descriptor missing for protocol buffer enum well known type: ", name)); + } + return descriptor; +} + +absl::StatusOr GetOneofByName( + const Descriptor* absl_nonnull descriptor, absl::string_view name) { + const auto* oneof = descriptor->FindOneofByName(name); + if (ABSL_PREDICT_FALSE(oneof == nullptr)) { + return absl::InvalidArgumentError(absl::StrCat( + "oneof missing for protocol buffer message well known type: ", + descriptor->full_name(), ".", name)); + } + return oneof; +} + +absl::StatusOr GetFieldByNumber( + const Descriptor* absl_nonnull descriptor, int32_t number) { + const auto* field = descriptor->FindFieldByNumber(number); + if (ABSL_PREDICT_FALSE(field == nullptr)) { + return absl::InvalidArgumentError(absl::StrCat( + "field missing for protocol buffer message well known type: ", + descriptor->full_name(), ".", number)); + } + return field; +} + +absl::Status CheckFieldType(const FieldDescriptor* absl_nonnull field, + FieldDescriptor::Type type) { + if (ABSL_PREDICT_FALSE(field->type() != type)) { + return absl::InvalidArgumentError(absl::StrCat( + "unexpected field type for protocol buffer message well known type: ", + field->full_name(), " ", field->type_name())); + } + return absl::OkStatus(); +} + +absl::Status CheckFieldCppType(const FieldDescriptor* absl_nonnull field, + FieldDescriptor::CppType cpp_type) { + if (ABSL_PREDICT_FALSE(field->cpp_type() != cpp_type)) { + return absl::InvalidArgumentError(absl::StrCat( + "unexpected field type for protocol buffer message well known type: ", + field->full_name(), " ", field->cpp_type_name())); + } + return absl::OkStatus(); +} + +absl::string_view LabelToString(FieldDescriptor::Label label) { + switch (label) { + case FieldDescriptor::LABEL_REPEATED: + return "REPEATED"; + case FieldDescriptor::LABEL_REQUIRED: + return "REQUIRED"; + case FieldDescriptor::LABEL_OPTIONAL: + return "OPTIONAL"; + default: + return "ERROR"; + } +} + +absl::Status CheckFieldCardinality(const FieldDescriptor* absl_nonnull field, + FieldDescriptor::Label label) { + if (ABSL_PREDICT_FALSE(GetFieldLabel(field) != label)) { + return absl::InvalidArgumentError(absl::StrCat( + "unexpected field cardinality for protocol buffer message " + "well known type: ", + field->full_name(), " ", LabelToString(GetFieldLabel(field)))); + } + return absl::OkStatus(); +} + +absl::string_view WellKnownTypeToString( + Descriptor::WellKnownType well_known_type) { + switch (well_known_type) { + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return "BOOLVALUE"; + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + return "INT32VALUE"; + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + return "INT64VALUE"; + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + return "UINT32VALUE"; + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return "UINT64VALUE"; + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + return "FLOATVALUE"; + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + return "DOUBLEVALUE"; + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return "BYTESVALUE"; + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return "STRINGVALUE"; + case Descriptor::WELLKNOWNTYPE_ANY: + return "ANY"; + case Descriptor::WELLKNOWNTYPE_DURATION: + return "DURATION"; + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return "TIMESTAMP"; + case Descriptor::WELLKNOWNTYPE_VALUE: + return "VALUE"; + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + return "LISTVALUE"; + case Descriptor::WELLKNOWNTYPE_STRUCT: + return "STRUCT"; + case Descriptor::WELLKNOWNTYPE_FIELDMASK: + return "FIELDMASK"; + default: + return "ERROR"; + } +} + +absl::Status CheckWellKnownType(const Descriptor* absl_nonnull descriptor, + Descriptor::WellKnownType well_known_type) { + if (ABSL_PREDICT_FALSE(descriptor->well_known_type() != well_known_type)) { + return absl::InvalidArgumentError(absl::StrCat( + "expected message to be well known type: ", descriptor->full_name(), + " ", WellKnownTypeToString(descriptor->well_known_type()))); + } + return absl::OkStatus(); +} + +absl::Status CheckFieldWellKnownType( + const FieldDescriptor* absl_nonnull field, + Descriptor::WellKnownType well_known_type) { + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_MESSAGE); + if (ABSL_PREDICT_FALSE(field->message_type()->well_known_type() != + well_known_type)) { + return absl::InvalidArgumentError(absl::StrCat( + "expected message field to be well known type for protocol buffer " + "message well known type: ", + field->full_name(), " ", + WellKnownTypeToString(field->message_type()->well_known_type()))); + } + return absl::OkStatus(); +} + +absl::Status CheckFieldOneof(const FieldDescriptor* absl_nonnull field, + const OneofDescriptor* absl_nonnull oneof, + int index) { + if (ABSL_PREDICT_FALSE(field->containing_oneof() != oneof)) { + return absl::InvalidArgumentError( + absl::StrCat("expected field to be member of oneof for protocol buffer " + "message well known type: ", + field->full_name())); + } + if (ABSL_PREDICT_FALSE(field->index_in_oneof() != index)) { + return absl::InvalidArgumentError(absl::StrCat( + "expected field to have index in oneof of ", index, + " for protocol buffer " + "message well known type: ", + field->full_name(), " oneof_index=", field->index_in_oneof())); + } + return absl::OkStatus(); +} + +absl::Status CheckMapField(const FieldDescriptor* absl_nonnull field) { + if (ABSL_PREDICT_FALSE(!field->is_map())) { + return absl::InvalidArgumentError( + absl::StrCat("expected field to be map for protocol buffer " + "message well known type: ", + field->full_name())); + } + return absl::OkStatus(); +} + +} // namespace + +bool StringValue::ConsumePrefix(absl::string_view prefix) { + return absl::visit(absl::Overload( + [&](absl::string_view& value) { + return absl::ConsumePrefix(&value, prefix); + }, + [&](absl::Cord& cord) { + if (cord.StartsWith(prefix)) { + cord.RemovePrefix(prefix.size()); + return true; + } + return false; + }), + AsVariant(*this)); +} + +StringValue GetStringField(const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const FieldDescriptor* absl_nonnull field, + std::string& scratch) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && !field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + return GetStringField(reflection, message, field, + field->cpp_string_type(), scratch); +} + +BytesValue GetBytesField(const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, + const FieldDescriptor* absl_nonnull field, + std::string& scratch) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && !field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + return GetStringField(reflection, message, field, + field->cpp_string_type(), scratch); +} + +StringValue GetRepeatedStringField( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, + int index, std::string& scratch) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + return GetRepeatedStringField( + reflection, message, field, field->cpp_string_type(), index, scratch); +} + +BytesValue GetRepeatedBytesField( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, + int index, std::string& scratch) { + ABSL_DCHECK_EQ(reflection, message.GetReflection()); + ABSL_DCHECK(!field->is_map() && field->is_repeated()); + ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); + ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); + return GetRepeatedStringField( + reflection, message, field, field->cpp_string_type(), index, scratch); +} + +absl::Status NullValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetEnumTypeByName(pool, "google.protobuf.NullValue")); + return Initialize(descriptor); +} + +absl::Status NullValueReflection::Initialize( + const EnumDescriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + if (ABSL_PREDICT_FALSE(descriptor->full_name() != + "google.protobuf.NullValue")) { + return absl::InvalidArgumentError(absl::StrCat( + "expected enum to be well known type: ", descriptor->full_name(), + " google.protobuf.NullValue")); + } + descriptor_ = nullptr; + value_ = descriptor->FindValueByNumber(0); + if (ABSL_PREDICT_FALSE(value_ == nullptr)) { + return absl::InvalidArgumentError( + "well known protocol buffer enum missing value: " + "google.protobuf.NullValue.NULL_VALUE"); + } + if (ABSL_PREDICT_FALSE(descriptor->value_count() != 1)) { + std::vector values; + values.reserve(static_cast(descriptor->value_count())); + for (int i = 0; i < descriptor->value_count(); ++i) { + values.push_back(descriptor->value(i)->name()); + } + return absl::InvalidArgumentError( + absl::StrCat("well known protocol buffer enum has multiple values: [", + absl::StrJoin(values, ", "), "]")); + } + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +absl::Status BoolValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.BoolValue")); + return Initialize(descriptor); +} + +absl::Status BoolValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_BOOL)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +bool BoolValueReflection::GetValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetBool(message, value_field_); +} + +void BoolValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + bool value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetBool(message, value_field_, value); +} + +absl::StatusOr GetBoolValueReflection( + const Descriptor* absl_nonnull descriptor) { + BoolValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status Int32ValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.Int32Value")); + return Initialize(descriptor); +} + +absl::Status Int32ValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_INT32)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +int32_t Int32ValueReflection::GetValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetInt32(message, value_field_); +} + +void Int32ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + int32_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetInt32(message, value_field_, value); +} + +absl::StatusOr GetInt32ValueReflection( + const Descriptor* absl_nonnull descriptor) { + Int32ValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status Int64ValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.Int64Value")); + return Initialize(descriptor); +} + +absl::Status Int64ValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_INT64)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +int64_t Int64ValueReflection::GetValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetInt64(message, value_field_); +} + +void Int64ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + int64_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetInt64(message, value_field_, value); +} + +absl::StatusOr GetInt64ValueReflection( + const Descriptor* absl_nonnull descriptor) { + Int64ValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status UInt32ValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.UInt32Value")); + return Initialize(descriptor); +} + +absl::Status UInt32ValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_UINT32)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +uint32_t UInt32ValueReflection::GetValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetUInt32(message, value_field_); +} + +void UInt32ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + uint32_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetUInt32(message, value_field_, value); +} + +absl::StatusOr GetUInt32ValueReflection( + const Descriptor* absl_nonnull descriptor) { + UInt32ValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status UInt64ValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.UInt64Value")); + return Initialize(descriptor); +} + +absl::Status UInt64ValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_UINT64)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +uint64_t UInt64ValueReflection::GetValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetUInt64(message, value_field_); +} + +void UInt64ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + uint64_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetUInt64(message, value_field_, value); +} + +absl::StatusOr GetUInt64ValueReflection( + const Descriptor* absl_nonnull descriptor) { + UInt64ValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status FloatValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.FloatValue")); + return Initialize(descriptor); +} + +absl::Status FloatValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_FLOAT)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +float FloatValueReflection::GetValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetFloat(message, value_field_); +} + +void FloatValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + float value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetFloat(message, value_field_, value); +} + +absl::StatusOr GetFloatValueReflection( + const Descriptor* absl_nonnull descriptor) { + FloatValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status DoubleValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.DoubleValue")); + return Initialize(descriptor); +} + +absl::Status DoubleValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_DOUBLE)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +double DoubleValueReflection::GetValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetDouble(message, value_field_); +} + +void DoubleValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + double value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetDouble(message, value_field_, value); +} + +absl::StatusOr GetDoubleValueReflection( + const Descriptor* absl_nonnull descriptor) { + DoubleValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status BytesValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.BytesValue")); + return Initialize(descriptor); +} + +absl::Status BytesValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldType(value_field_, FieldDescriptor::TYPE_BYTES)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + value_field_string_type_ = value_field_->cpp_string_type(); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +BytesValue BytesValueReflection::GetValue(const google::protobuf::Message& message, + std::string& scratch) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return GetStringField(message, value_field_, + value_field_string_type_, scratch); +} + +void BytesValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + absl::string_view value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, value_field_, + std::string(value)); +} + +void BytesValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, value_field_, value); +} + +absl::StatusOr GetBytesValueReflection( + const Descriptor* absl_nonnull descriptor) { + BytesValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status StringValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN( + const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.StringValue")); + return Initialize(descriptor); +} + +absl::Status StringValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldType(value_field_, FieldDescriptor::TYPE_STRING)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + value_field_string_type_ = value_field_->cpp_string_type(); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +StringValue StringValueReflection::GetValue(const google::protobuf::Message& message, + std::string& scratch) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return GetStringField(message, value_field_, + value_field_string_type_, scratch); +} + +void StringValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + absl::string_view value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, value_field_, + std::string(value)); +} + +void StringValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, value_field_, value); +} + +absl::StatusOr GetStringValueReflection( + const Descriptor* absl_nonnull descriptor) { + StringValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status AnyReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.Any")); + return Initialize(descriptor); +} + +absl::Status AnyReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(type_url_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldType(type_url_field_, FieldDescriptor::TYPE_STRING)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(type_url_field_, + FieldDescriptor::LABEL_OPTIONAL)); + type_url_field_string_type_ = type_url_field_->cpp_string_type(); + CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 2)); + CEL_RETURN_IF_ERROR( + CheckFieldType(value_field_, FieldDescriptor::TYPE_BYTES)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); + value_field_string_type_ = value_field_->cpp_string_type(); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +void AnyReflection::SetTypeUrl(google::protobuf::Message* absl_nonnull message, + absl::string_view type_url) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, type_url_field_, + std::string(type_url)); +} + +void AnyReflection::SetValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, value_field_, value); +} + +StringValue AnyReflection::GetTypeUrl(const google::protobuf::Message& message, + std::string& scratch) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return GetStringField(message, type_url_field_, + type_url_field_string_type_, scratch); +} + +BytesValue AnyReflection::GetValue(const google::protobuf::Message& message, + std::string& scratch) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return GetStringField(message, value_field_, + value_field_string_type_, scratch); +} + +absl::StatusOr GetAnyReflection( + const Descriptor* absl_nonnull descriptor) { + AnyReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +AnyReflection GetAnyReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + AnyReflection reflection; + ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK + return reflection; +} + +absl::Status DurationReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.Duration")); + return Initialize(descriptor); +} + +absl::Status DurationReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(seconds_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(seconds_field_, FieldDescriptor::CPPTYPE_INT64)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(seconds_field_, FieldDescriptor::LABEL_OPTIONAL)); + CEL_ASSIGN_OR_RETURN(nanos_field_, GetFieldByNumber(descriptor, 2)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(nanos_field_, FieldDescriptor::CPPTYPE_INT32)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(nanos_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +int64_t DurationReflection::GetSeconds(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetInt64(message, seconds_field_); +} + +int32_t DurationReflection::GetNanos(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetInt32(message, nanos_field_); +} + +void DurationReflection::SetSeconds(google::protobuf::Message* absl_nonnull message, + int64_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetInt64(message, seconds_field_, value); +} + +void DurationReflection::SetNanos(google::protobuf::Message* absl_nonnull message, + int32_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetInt32(message, nanos_field_, value); +} + +absl::Status DurationReflection::SetFromAbslDuration( + google::protobuf::Message* absl_nonnull message, absl::Duration duration) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || + seconds > TimeUtil::kDurationMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration seconds: ", seconds)); + } + int32_t nanos = static_cast( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || + nanos > TimeUtil::kDurationMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration nanoseconds: ", nanos)); + } + if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { + return absl::InvalidArgumentError(absl::StrCat( + "duration sign mismatch: seconds=", seconds, ", nanoseconds=", nanos)); + } + SetSeconds(message, seconds); + SetNanos(message, nanos); + return absl::OkStatus(); +} + +absl::Status DurationReflection::SetFromAbslDuration( + GeneratedMessageType* absl_nonnull message, absl::Duration duration) { + int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || + seconds > TimeUtil::kDurationMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration seconds: ", seconds)); + } + int32_t nanos = static_cast( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || + nanos > TimeUtil::kDurationMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration nanoseconds: ", nanos)); + } + if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { + return absl::InvalidArgumentError(absl::StrCat( + "duration sign mismatch: seconds=", seconds, ", nanoseconds=", nanos)); + } + SetSeconds(message, seconds); + SetNanos(message, nanos); + return absl::OkStatus(); +} + +void DurationReflection::UnsafeSetFromAbslDuration( + google::protobuf::Message* absl_nonnull message, absl::Duration duration) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); + int32_t nanos = static_cast( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + SetSeconds(message, seconds); + SetNanos(message, nanos); +} + +absl::StatusOr DurationReflection::ToAbslDuration( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + int64_t seconds = GetSeconds(message); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || + seconds > TimeUtil::kDurationMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration seconds: ", seconds)); + } + int32_t nanos = GetNanos(message); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || + nanos > TimeUtil::kDurationMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration nanoseconds: ", nanos)); + } + if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { + return absl::InvalidArgumentError(absl::StrCat( + "duration sign mismatch: seconds=", seconds, ", nanoseconds=", nanos)); + } + return absl::Seconds(seconds) + absl::Nanoseconds(nanos); +} + +absl::Duration DurationReflection::UnsafeToAbslDuration( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + int64_t seconds = GetSeconds(message); + int32_t nanos = GetNanos(message); + return absl::Seconds(seconds) + absl::Nanoseconds(nanos); +} + +absl::StatusOr GetDurationReflection( + const Descriptor* absl_nonnull descriptor) { + DurationReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status TimestampReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.Timestamp")); + return Initialize(descriptor); +} + +absl::Status TimestampReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(seconds_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(seconds_field_, FieldDescriptor::CPPTYPE_INT64)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(seconds_field_, FieldDescriptor::LABEL_OPTIONAL)); + CEL_ASSIGN_OR_RETURN(nanos_field_, GetFieldByNumber(descriptor, 2)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(nanos_field_, FieldDescriptor::CPPTYPE_INT32)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(nanos_field_, FieldDescriptor::LABEL_OPTIONAL)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +int64_t TimestampReflection::GetSeconds(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetInt64(message, seconds_field_); +} + +int32_t TimestampReflection::GetNanos(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetInt32(message, nanos_field_); +} + +void TimestampReflection::SetSeconds(google::protobuf::Message* absl_nonnull message, + int64_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetInt64(message, seconds_field_, value); +} + +void TimestampReflection::SetNanos(google::protobuf::Message* absl_nonnull message, + int32_t value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetInt32(message, nanos_field_, value); +} + +absl::Status TimestampReflection::SetFromAbslTime( + google::protobuf::Message* absl_nonnull message, absl::Time time) const { + int64_t seconds = absl::ToUnixSeconds(time); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || + seconds > TimeUtil::kTimestampMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp seconds: ", seconds)); + } + int64_t nanos = static_cast((time - absl::FromUnixSeconds(seconds)) / + absl::Nanoseconds(1)); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || + nanos > TimeUtil::kTimestampMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp nanoseconds: ", nanos)); + } + SetSeconds(message, seconds); + SetNanos(message, static_cast(nanos)); + return absl::OkStatus(); +} + +absl::Status TimestampReflection::SetFromAbslTime( + GeneratedMessageType* absl_nonnull message, absl::Time time) { + int64_t seconds = absl::ToUnixSeconds(time); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || + seconds > TimeUtil::kTimestampMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp seconds: ", seconds)); + } + int64_t nanos = static_cast((time - absl::FromUnixSeconds(seconds)) / + absl::Nanoseconds(1)); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || + nanos > TimeUtil::kTimestampMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp nanoseconds: ", nanos)); + } + SetSeconds(message, seconds); + SetNanos(message, static_cast(nanos)); + return absl::OkStatus(); +} + +void TimestampReflection::UnsafeSetFromAbslTime( + google::protobuf::Message* absl_nonnull message, absl::Time time) const { + int64_t seconds = absl::ToUnixSeconds(time); + int32_t nanos = static_cast((time - absl::FromUnixSeconds(seconds)) / + absl::Nanoseconds(1)); + SetSeconds(message, seconds); + SetNanos(message, nanos); +} + +absl::StatusOr TimestampReflection::ToAbslTime( + const google::protobuf::Message& message) const { + int64_t seconds = GetSeconds(message); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || + seconds > TimeUtil::kTimestampMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp seconds: ", seconds)); + } + int32_t nanos = GetNanos(message); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || + nanos > TimeUtil::kTimestampMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp nanoseconds: ", nanos)); + } + return absl::UnixEpoch() + absl::Seconds(seconds) + absl::Nanoseconds(nanos); +} + +absl::Time TimestampReflection::UnsafeToAbslTime( + const google::protobuf::Message& message) const { + int64_t seconds = GetSeconds(message); + int32_t nanos = GetNanos(message); + return absl::UnixEpoch() + absl::Seconds(seconds) + absl::Nanoseconds(nanos); +} + +absl::StatusOr GetTimestampReflection( + const Descriptor* absl_nonnull descriptor) { + TimestampReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +void ValueReflection::SetNumberValue( + google::protobuf::Value* absl_nonnull message, int64_t value) { + if (value < kJsonMinInt || value > kJsonMaxInt) { + SetStringValue(message, absl::StrCat(value)); + return; + } + SetNumberValue(message, static_cast(value)); +} + +void ValueReflection::SetNumberValue( + google::protobuf::Value* absl_nonnull message, uint64_t value) { + if (value > kJsonMaxUint) { + SetStringValue(message, absl::StrCat(value)); + return; + } + SetNumberValue(message, static_cast(value)); +} + +absl::Status ValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.Value")); + return Initialize(descriptor); +} + +absl::Status ValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(kind_field_, GetOneofByName(descriptor, "kind")); + CEL_ASSIGN_OR_RETURN(null_value_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(null_value_field_, FieldDescriptor::CPPTYPE_ENUM)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(null_value_field_, + FieldDescriptor::LABEL_OPTIONAL)); + CEL_RETURN_IF_ERROR(CheckFieldOneof(null_value_field_, kind_field_, 0)); + CEL_ASSIGN_OR_RETURN(bool_value_field_, GetFieldByNumber(descriptor, 4)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(bool_value_field_, FieldDescriptor::CPPTYPE_BOOL)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(bool_value_field_, + FieldDescriptor::LABEL_OPTIONAL)); + CEL_RETURN_IF_ERROR(CheckFieldOneof(bool_value_field_, kind_field_, 3)); + CEL_ASSIGN_OR_RETURN(number_value_field_, GetFieldByNumber(descriptor, 2)); + CEL_RETURN_IF_ERROR(CheckFieldCppType(number_value_field_, + FieldDescriptor::CPPTYPE_DOUBLE)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(number_value_field_, + FieldDescriptor::LABEL_OPTIONAL)); + CEL_RETURN_IF_ERROR(CheckFieldOneof(number_value_field_, kind_field_, 1)); + CEL_ASSIGN_OR_RETURN(string_value_field_, GetFieldByNumber(descriptor, 3)); + CEL_RETURN_IF_ERROR(CheckFieldCppType(string_value_field_, + FieldDescriptor::CPPTYPE_STRING)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(string_value_field_, + FieldDescriptor::LABEL_OPTIONAL)); + CEL_RETURN_IF_ERROR(CheckFieldOneof(string_value_field_, kind_field_, 2)); + string_value_field_string_type_ = string_value_field_->cpp_string_type(); + CEL_ASSIGN_OR_RETURN(list_value_field_, GetFieldByNumber(descriptor, 6)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(list_value_field_, FieldDescriptor::CPPTYPE_MESSAGE)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(list_value_field_, + FieldDescriptor::LABEL_OPTIONAL)); + CEL_RETURN_IF_ERROR(CheckFieldOneof(list_value_field_, kind_field_, 5)); + CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( + list_value_field_, Descriptor::WELLKNOWNTYPE_LISTVALUE)); + CEL_ASSIGN_OR_RETURN(struct_value_field_, GetFieldByNumber(descriptor, 5)); + CEL_RETURN_IF_ERROR(CheckFieldCppType(struct_value_field_, + FieldDescriptor::CPPTYPE_MESSAGE)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(struct_value_field_, + FieldDescriptor::LABEL_OPTIONAL)); + CEL_RETURN_IF_ERROR(CheckFieldOneof(struct_value_field_, kind_field_, 4)); + CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( + struct_value_field_, Descriptor::WELLKNOWNTYPE_STRUCT)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +google::protobuf::Value::KindCase ValueReflection::GetKindCase( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + const auto* field = + message.GetReflection()->GetOneofFieldDescriptor(message, kind_field_); + return field != nullptr ? static_cast( + field->index_in_oneof() + 1) + : google::protobuf::Value::KIND_NOT_SET; +} + +bool ValueReflection::GetBoolValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetBool(message, bool_value_field_); +} + +double ValueReflection::GetNumberValue(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetDouble(message, number_value_field_); +} + +StringValue ValueReflection::GetStringValue(const google::protobuf::Message& message, + std::string& scratch) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return GetStringField(message, string_value_field_, + string_value_field_string_type_, scratch); +} + +const google::protobuf::Message& ValueReflection::GetListValue( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); +#undef GetMessage + return message.GetReflection()->GetMessage(message, list_value_field_); +} + +const google::protobuf::Message& ValueReflection::GetStructValue( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); +#undef GetMessage + return message.GetReflection()->GetMessage(message, struct_value_field_); +} + +void ValueReflection::SetNullValue( + google::protobuf::Message* absl_nonnull message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetEnumValue(message, null_value_field_, 0); +} + +void ValueReflection::SetBoolValue(google::protobuf::Message* absl_nonnull message, + bool value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetBool(message, bool_value_field_, value); +} + +void ValueReflection::SetNumberValue(google::protobuf::Message* absl_nonnull message, + int64_t value) const { + if (value < kJsonMinInt || value > kJsonMaxInt) { + SetStringValue(message, absl::StrCat(value)); + return; + } + SetNumberValue(message, static_cast(value)); +} + +void ValueReflection::SetNumberValue(google::protobuf::Message* absl_nonnull message, + uint64_t value) const { + if (value > kJsonMaxUint) { + SetStringValue(message, absl::StrCat(value)); + return; + } + SetNumberValue(message, static_cast(value)); +} + +void ValueReflection::SetNumberValue(google::protobuf::Message* absl_nonnull message, + double value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetDouble(message, number_value_field_, value); +} + +void ValueReflection::SetStringValue(google::protobuf::Message* absl_nonnull message, + absl::string_view value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, string_value_field_, + std::string(value)); +} + +void ValueReflection::SetStringValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + message->GetReflection()->SetString(message, string_value_field_, value); +} + +void ValueReflection::SetStringValueFromBytes( + google::protobuf::Message* absl_nonnull message, absl::string_view value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + if (value.empty()) { + SetStringValue(message, value); + return; + } + SetStringValue(message, absl::Base64Escape(value)); +} + +void ValueReflection::SetStringValueFromBytes( + google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + if (value.empty()) { + SetStringValue(message, value); + return; + } + if (auto flat = value.TryFlat(); flat) { + SetStringValue(message, absl::Base64Escape(*flat)); + return; + } + std::string flat; + absl::CopyCordToString(value, &flat); + SetStringValue(message, absl::Base64Escape(flat)); +} + +void ValueReflection::SetStringValueFromDuration( + google::protobuf::Message* absl_nonnull message, absl::Duration duration) const { + google::protobuf::Duration proto; + proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); + proto.set_nanos(static_cast( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration))); + ABSL_DCHECK(TimeUtil::IsDurationValid(proto)); + SetStringValue(message, TimeUtil::ToString(proto)); +} + +void ValueReflection::SetStringValueFromTimestamp( + google::protobuf::Message* absl_nonnull message, absl::Time time) const { + google::protobuf::Timestamp proto; + proto.set_seconds(absl::ToUnixSeconds(time)); + proto.set_nanos((time - absl::FromUnixSeconds(proto.seconds())) / + absl::Nanoseconds(1)); + ABSL_DCHECK(TimeUtil::IsTimestampValid(proto)); + SetStringValue(message, TimeUtil::ToString(proto)); +} + +google::protobuf::Message* absl_nonnull ValueReflection::MutableListValue( + google::protobuf::Message* absl_nonnull message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + return message->GetReflection()->MutableMessage(message, list_value_field_); +} + +google::protobuf::Message* absl_nonnull ValueReflection::MutableStructValue( + google::protobuf::Message* absl_nonnull message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + return message->GetReflection()->MutableMessage(message, struct_value_field_); +} + +Unique ValueReflection::ReleaseListValue( + google::protobuf::Message* absl_nonnull message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + const auto* reflection = message->GetReflection(); + if (!reflection->HasField(*message, list_value_field_)) { + reflection->MutableMessage(message, list_value_field_); + } + return WrapUnique( + reflection->UnsafeArenaReleaseMessage(message, list_value_field_), + message->GetArena()); +} + +Unique ValueReflection::ReleaseStructValue( + google::protobuf::Message* absl_nonnull message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + const auto* reflection = message->GetReflection(); + if (!reflection->HasField(*message, struct_value_field_)) { + reflection->MutableMessage(message, struct_value_field_); + } + return WrapUnique( + reflection->UnsafeArenaReleaseMessage(message, struct_value_field_), + message->GetArena()); +} + +absl::StatusOr GetValueReflection( + const Descriptor* absl_nonnull descriptor) { + ValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} +ValueReflection GetValueReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + ValueReflection reflection; + ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK; + return reflection; +} + +absl::Status ListValueReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.ListValue")); + return Initialize(descriptor); +} + +absl::Status ListValueReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(values_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(values_field_, FieldDescriptor::CPPTYPE_MESSAGE)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(values_field_, FieldDescriptor::LABEL_REPEATED)); + CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( + values_field_, Descriptor::WELLKNOWNTYPE_VALUE)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +int ListValueReflection::ValuesSize(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->FieldSize(message, values_field_); +} + +google::protobuf::RepeatedFieldRef ListValueReflection::Values( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetRepeatedFieldRef( + message, values_field_); +} + +const google::protobuf::Message& ListValueReflection::Values( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + int index) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->GetRepeatedMessage(message, values_field_, + index); +} + +google::protobuf::MutableRepeatedFieldRef +ListValueReflection::MutableValues( + google::protobuf::Message* absl_nonnull message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + return message->GetReflection()->GetMutableRepeatedFieldRef( + message, values_field_); +} + +google::protobuf::Message* absl_nonnull ListValueReflection::AddValues( + google::protobuf::Message* absl_nonnull message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + return message->GetReflection()->AddMessage(message, values_field_); +} + +absl::StatusOr GetListValueReflection( + const Descriptor* absl_nonnull descriptor) { + ListValueReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +ListValueReflection GetListValueReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + ListValueReflection reflection; + ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK + return reflection; +} + +absl::Status StructReflection::Initialize( + const DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.Struct")); + return Initialize(descriptor); +} + +absl::Status StructReflection::Initialize( + const Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(fields_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR(CheckMapField(fields_field_)); + fields_key_field_ = fields_field_->message_type()->map_key(); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(fields_key_field_, FieldDescriptor::CPPTYPE_STRING)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(fields_key_field_, + FieldDescriptor::LABEL_OPTIONAL)); + fields_value_field_ = fields_field_->message_type()->map_value(); + CEL_RETURN_IF_ERROR(CheckFieldCppType(fields_value_field_, + FieldDescriptor::CPPTYPE_MESSAGE)); + CEL_RETURN_IF_ERROR(CheckFieldCardinality(fields_value_field_, + FieldDescriptor::LABEL_OPTIONAL)); + CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( + fields_value_field_, Descriptor::WELLKNOWNTYPE_VALUE)); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +int StructReflection::FieldsSize(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return cel::extensions::protobuf_internal::MapSize(*message.GetReflection(), + message, *fields_field_); +} + +google::protobuf::ConstMapIterator StructReflection::BeginFields( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return cel::extensions::protobuf_internal::ConstMapBegin( + *message.GetReflection(), message, *fields_field_); +} + +google::protobuf::ConstMapIterator StructReflection::EndFields( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return cel::extensions::protobuf_internal::ConstMapEnd( + *message.GetReflection(), message, *fields_field_); +} + +bool StructReflection::ContainsField(const google::protobuf::Message& message, + absl::string_view name) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); +#if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) + google::protobuf::MapKey key; + key.SetStringValue(name); +#else + std::string key_scratch(name); + google::protobuf::MapKey key; + key.SetStringValue(key_scratch); +#endif + return cel::extensions::protobuf_internal::ContainsMapKey( + *message.GetReflection(), message, *fields_field_, key); +} + +const google::protobuf::Message* absl_nullable StructReflection::FindField( + const google::protobuf::Message& message, absl::string_view name) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); +#if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) + google::protobuf::MapKey key; + key.SetStringValue(name); +#else + std::string key_scratch(name); + google::protobuf::MapKey key; + key.SetStringValue(key_scratch); +#endif + google::protobuf::MapValueConstRef value; + if (cel::extensions::protobuf_internal::LookupMapValue( + *message.GetReflection(), message, *fields_field_, key, &value)) { + return &value.GetMessageValue(); + } + return nullptr; +} + +google::protobuf::Message* absl_nonnull StructReflection::InsertField( + google::protobuf::Message* absl_nonnull message, absl::string_view name) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); +#if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) + google::protobuf::MapKey key; + key.SetStringValue(name); +#else + std::string key_scratch(name); + google::protobuf::MapKey key; + key.SetStringValue(key_scratch); +#endif + google::protobuf::MapValueRef value; + cel::extensions::protobuf_internal::InsertOrLookupMapValue( + *message->GetReflection(), message, *fields_field_, key, &value); + return value.MutableMessageValue(); +} + +bool StructReflection::DeleteField(google::protobuf::Message* absl_nonnull message, + absl::string_view name) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); +#if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) + google::protobuf::MapKey key; + key.SetStringValue(name); +#else + std::string key_scratch(name); + google::protobuf::MapKey key; + key.SetStringValue(key_scratch); +#endif + return cel::extensions::protobuf_internal::DeleteMapValue( + message->GetReflection(), message, fields_field_, key); +} + +absl::StatusOr GetStructReflection( + const Descriptor* absl_nonnull descriptor) { + StructReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +StructReflection GetStructReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + StructReflection reflection; + ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK + return reflection; +} + +absl::Status FieldMaskReflection::Initialize( + const google::protobuf::DescriptorPool* absl_nonnull pool) { + CEL_ASSIGN_OR_RETURN(const auto* descriptor, + GetMessageTypeByName(pool, "google.protobuf.FieldMask")); + return Initialize(descriptor); +} + +absl::Status FieldMaskReflection::Initialize( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + if (descriptor_ != descriptor) { + CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); + descriptor_ = nullptr; + CEL_ASSIGN_OR_RETURN(paths_field_, GetFieldByNumber(descriptor, 1)); + CEL_RETURN_IF_ERROR( + CheckFieldCppType(paths_field_, FieldDescriptor::CPPTYPE_STRING)); + CEL_RETURN_IF_ERROR( + CheckFieldCardinality(paths_field_, FieldDescriptor::LABEL_REPEATED)); + paths_field_string_type_ = paths_field_->cpp_string_type(); + descriptor_ = descriptor; + } + return absl::OkStatus(); +} + +int FieldMaskReflection::PathsSize(const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + return message.GetReflection()->FieldSize(message, paths_field_); +} + +StringValue FieldMaskReflection::Paths(const google::protobuf::Message& message, + int index, std::string& scratch) const { + return GetRepeatedStringField( + message, paths_field_, paths_field_string_type_, index, scratch); +} + +absl::StatusOr GetFieldMaskReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + FieldMaskReflection reflection; + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + return reflection; +} + +absl::Status JsonReflection::Initialize( + const google::protobuf::DescriptorPool* absl_nonnull pool) { + CEL_RETURN_IF_ERROR(Value().Initialize(pool)); + CEL_RETURN_IF_ERROR(ListValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(Struct().Initialize(pool)); + return absl::OkStatus(); +} + +absl::Status JsonReflection::Initialize( + const google::protobuf::Descriptor* absl_nonnull descriptor) { + switch (descriptor->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + CEL_RETURN_IF_ERROR(Value().Initialize(descriptor)); + CEL_RETURN_IF_ERROR( + ListValue().Initialize(Value().GetListValueDescriptor())); + CEL_RETURN_IF_ERROR(Struct().Initialize(Value().GetStructDescriptor())); + return absl::OkStatus(); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + CEL_RETURN_IF_ERROR(ListValue().Initialize(descriptor)); + CEL_RETURN_IF_ERROR(Value().Initialize(ListValue().GetValueDescriptor())); + CEL_RETURN_IF_ERROR(Struct().Initialize(Value().GetStructDescriptor())); + return absl::OkStatus(); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + CEL_RETURN_IF_ERROR(Struct().Initialize(descriptor)); + CEL_RETURN_IF_ERROR(Value().Initialize(Struct().GetValueDescriptor())); + CEL_RETURN_IF_ERROR( + ListValue().Initialize(Value().GetListValueDescriptor())); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError( + absl::StrCat("expected message to be JSON-like well known type: ", + descriptor->full_name(), " ", + WellKnownTypeToString(descriptor->well_known_type()))); + } +} + +bool JsonReflection::IsInitialized() const { + return Value().IsInitialized() && ListValue().IsInitialized() && + Struct().IsInitialized(); +} + +namespace { + +[[maybe_unused]] ABSL_CONST_INIT absl::once_flag + link_well_known_message_reflection; + +void LinkWellKnownMessageReflection() { + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); +} + +} // namespace + +absl::Status Reflection::Initialize(const DescriptorPool* absl_nonnull pool) { + if (pool == DescriptorPool::generated_pool()) { + absl::call_once(link_well_known_message_reflection, + &LinkWellKnownMessageReflection); + } + CEL_RETURN_IF_ERROR(NullValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(BoolValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(Int32Value().Initialize(pool)); + CEL_RETURN_IF_ERROR(Int64Value().Initialize(pool)); + CEL_RETURN_IF_ERROR(UInt32Value().Initialize(pool)); + CEL_RETURN_IF_ERROR(UInt64Value().Initialize(pool)); + CEL_RETURN_IF_ERROR(FloatValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(DoubleValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(BytesValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(StringValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(Any().Initialize(pool)); + CEL_RETURN_IF_ERROR(Duration().Initialize(pool)); + CEL_RETURN_IF_ERROR(Timestamp().Initialize(pool)); + CEL_RETURN_IF_ERROR(Json().Initialize(pool)); + // google.protobuf.FieldMask is not strictly mandatory, but we do have to + // treat it specifically for JSON. So use it if we have it. + if (const auto* descriptor = + pool->FindMessageTypeByName("google.protobuf.FieldMask"); + descriptor != nullptr) { + CEL_RETURN_IF_ERROR(FieldMask().Initialize(descriptor)); + } + return absl::OkStatus(); +} + +bool Reflection::IsInitialized() const { + // Check that everything is initialized except field mask, which is optional. + return NullValue().IsInitialized() && BoolValue().IsInitialized() && + Int32Value().IsInitialized() && Int64Value().IsInitialized() && + UInt32Value().IsInitialized() && UInt64Value().IsInitialized() && + FloatValue().IsInitialized() && DoubleValue().IsInitialized() && + BytesValue().IsInitialized() && StringValue().IsInitialized() && + Any().IsInitialized() && Duration().IsInitialized() && + Timestamp().IsInitialized() && Json().IsInitialized(); +} + +namespace { + +// AdaptListValue verifies the message is the well known type +// `google.protobuf.ListValue` and performs the complicated logic of reimaging +// it as `ListValue`. If adapted is empty, we return as a reference. If adapted +// is present, message must be a reference to the value held in adapted and it +// will be returned by value. +absl::StatusOr AdaptListValue(google::protobuf::Arena* absl_nullable arena, + const google::protobuf::Message& message, + Unique adapted) { + ABSL_DCHECK(!adapted || &message == cel::to_address(adapted)); + const auto* descriptor = message.GetDescriptor(); + if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { + return absl::InvalidArgumentError( + absl::StrCat("missing descriptor for protocol buffer message: ", + message.GetTypeName())); + } + // Not much to do. Just verify the well known type is well-formed. + CEL_RETURN_IF_ERROR(GetListValueReflection(descriptor).status()); + if (adapted) { + return ListValue(std::move(adapted)); + } + return ListValue(std::cref(message)); +} + +// AdaptStruct verifies the message is the well known type +// `google.protobuf.Struct` and performs the complicated logic of reimaging it +// as `Struct`. If adapted is empty, we return as a reference. If adapted is +// present, message must be a reference to the value held in adapted and it will +// be returned by value. +absl::StatusOr AdaptStruct(google::protobuf::Arena* absl_nullable arena, + const google::protobuf::Message& message, + Unique adapted) { + ABSL_DCHECK(!adapted || &message == cel::to_address(adapted)); + const auto* descriptor = message.GetDescriptor(); + if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { + return absl::InvalidArgumentError( + absl::StrCat("missing descriptor for protocol buffer message: ", + message.GetTypeName())); + } + // Not much to do. Just verify the well known type is well-formed. + CEL_RETURN_IF_ERROR(GetStructReflection(descriptor).status()); + if (adapted) { + return Struct(std::move(adapted)); + } + return Struct(std::cref(message)); +} + +// AdaptAny recursively unpacks a protocol buffer message which is an instance +// of `google.protobuf.Any`. +absl::StatusOr> AdaptAny( + google::protobuf::Arena* absl_nullable arena, AnyReflection& reflection, + const google::protobuf::Message& message, const Descriptor* absl_nonnull descriptor, + const DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory, bool error_if_unresolveable) { + ABSL_DCHECK_EQ(descriptor->well_known_type(), Descriptor::WELLKNOWNTYPE_ANY); + const google::protobuf::Message* absl_nonnull to_unwrap = &message; + Unique unwrapped; + std::string type_url_scratch; + std::string value_scratch; + do { + CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); + StringValue type_url = reflection.GetTypeUrl(*to_unwrap, type_url_scratch); + absl::string_view type_url_view = + FlatStringValue(type_url, type_url_scratch); + if (!absl::ConsumePrefix(&type_url_view, "type.googleapis.com/") && + !absl::ConsumePrefix(&type_url_view, "type.googleprod.com/")) { + if (!error_if_unresolveable) { + break; + } + return absl::InvalidArgumentError(absl::StrCat( + "unable to find descriptor for type URL: ", type_url_view)); + } + const auto* packed_descriptor = pool->FindMessageTypeByName(type_url_view); + if (packed_descriptor == nullptr) { + if (!error_if_unresolveable) { + break; + } + return absl::InvalidArgumentError(absl::StrCat( + "unable to find descriptor for type name: ", type_url_view)); + } + const auto* prototype = factory->GetPrototype(packed_descriptor); + if (prototype == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "unable to build prototype for type name: ", type_url_view)); + } + BytesValue value = reflection.GetValue(*to_unwrap, value_scratch); + Unique unpacked = WrapUnique(prototype->New(arena), arena); + const bool ok = absl::visit(absl::Overload( + [&](absl::string_view string) -> bool { + return unpacked->ParseFromString(string); + }, + [&](const absl::Cord& cord) -> bool { + return unpacked->ParseFromString(cord); + }), + AsVariant(value)); + if (!ok) { + return absl::InvalidArgumentError(absl::StrCat( + "failed to unpack protocol buffer message: ", type_url_view)); + } + // We can only update unwrapped at this point, not before. This is because + // we could have been unpacking from unwrapped itself. + unwrapped = std::move(unpacked); + to_unwrap = cel::to_address(unwrapped); + descriptor = to_unwrap->GetDescriptor(); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("missing descriptor for protocol buffer message: ", + to_unwrap->GetTypeName())); + } + } while (descriptor->well_known_type() == Descriptor::WELLKNOWNTYPE_ANY); + return unwrapped; +} + +} // namespace + +absl::StatusOr> UnpackAnyFrom( + google::protobuf::Arena* absl_nullable arena, AnyReflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory) { + ABSL_DCHECK_EQ(message.GetDescriptor()->well_known_type(), + Descriptor::WELLKNOWNTYPE_ANY); + return AdaptAny(arena, reflection, message, message.GetDescriptor(), pool, + factory, /*error_if_unresolveable=*/true); +} + +absl::StatusOr> UnpackAnyIfResolveable( + google::protobuf::Arena* absl_nullable arena, AnyReflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory) { + ABSL_DCHECK_EQ(message.GetDescriptor()->well_known_type(), + Descriptor::WELLKNOWNTYPE_ANY); + return AdaptAny(arena, reflection, message, message.GetDescriptor(), pool, + factory, /*error_if_unresolveable=*/false); +} + +absl::StatusOr AdaptFromMessage( + google::protobuf::Arena* absl_nullable arena, const google::protobuf::Message& message, + const DescriptorPool* absl_nonnull pool, + google::protobuf::MessageFactory* absl_nonnull factory, std::string& scratch) { + const auto* descriptor = message.GetDescriptor(); + if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { + return absl::InvalidArgumentError( + absl::StrCat("missing descriptor for protocol buffer message: ", + message.GetTypeName())); + } + const google::protobuf::Message* absl_nonnull to_adapt; + Unique adapted; + Descriptor::WellKnownType well_known_type = descriptor->well_known_type(); + if (well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { + AnyReflection reflection; + CEL_ASSIGN_OR_RETURN( + adapted, UnpackAnyFrom(arena, reflection, message, pool, factory)); + to_adapt = cel::to_address(adapted); + // GetDescriptor() is guaranteed to be nonnull by AdaptAny(). + descriptor = to_adapt->GetDescriptor(); + well_known_type = descriptor->well_known_type(); + } else { + to_adapt = &message; + } + switch (descriptor->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetDoubleValueReflection(descriptor)); + return reflection.GetValue(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetFloatValueReflection(descriptor)); + return reflection.GetValue(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_INT64VALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetInt64ValueReflection(descriptor)); + return reflection.GetValue(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetUInt64ValueReflection(descriptor)); + return reflection.GetValue(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_INT32VALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetInt32ValueReflection(descriptor)); + return reflection.GetValue(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetUInt32ValueReflection(descriptor)); + return reflection.GetValue(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetStringValueReflection(descriptor)); + auto value = reflection.GetValue(*to_adapt, scratch); + if (adapted) { + // value might actually be a view of data owned by adapted, force a copy + // to scratch if that is the case. + value = CopyStringValue(value, scratch); + } + return value; + } + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, + GetBytesValueReflection(descriptor)); + auto value = reflection.GetValue(*to_adapt, scratch); + if (adapted) { + // value might actually be a view of data owned by adapted, force a copy + // to scratch if that is the case. + value = CopyBytesValue(value, scratch); + } + return value; + } + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, GetBoolValueReflection(descriptor)); + return reflection.GetValue(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_ANY: + // This is unreachable, as AdaptAny() above recursively unpacks. + ABSL_UNREACHABLE(); + case Descriptor::WELLKNOWNTYPE_DURATION: { + CEL_ASSIGN_OR_RETURN(auto reflection, GetDurationReflection(descriptor)); + return reflection.ToAbslDuration(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + CEL_ASSIGN_OR_RETURN(auto reflection, GetTimestampReflection(descriptor)); + return reflection.ToAbslTime(*to_adapt); + } + case Descriptor::WELLKNOWNTYPE_VALUE: { + CEL_ASSIGN_OR_RETURN(auto reflection, GetValueReflection(descriptor)); + const auto kind_case = reflection.GetKindCase(*to_adapt); + switch (kind_case) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + return nullptr; + case google::protobuf::Value::kNumberValue: + return reflection.GetNumberValue(*to_adapt); + case google::protobuf::Value::kStringValue: { + auto value = reflection.GetStringValue(*to_adapt, scratch); + if (adapted) { + value = CopyStringValue(value, scratch); + } + return value; + } + case google::protobuf::Value::kBoolValue: + return reflection.GetBoolValue(*to_adapt); + case google::protobuf::Value::kStructValue: { + if (adapted) { + // We can release. + adapted = reflection.ReleaseStructValue(cel::to_address(adapted)); + to_adapt = cel::to_address(adapted); + } else { + to_adapt = &reflection.GetStructValue(*to_adapt); + } + return AdaptStruct(arena, *to_adapt, std::move(adapted)); + } + case google::protobuf::Value::kListValue: { + if (adapted) { + // We can release. + adapted = reflection.ReleaseListValue(cel::to_address(adapted)); + to_adapt = cel::to_address(adapted); + } else { + to_adapt = &reflection.GetListValue(*to_adapt); + } + return AdaptListValue(arena, *to_adapt, std::move(adapted)); + } + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected value kind case: ", kind_case)); + } + } + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + return AdaptListValue(arena, *to_adapt, std::move(adapted)); + case Descriptor::WELLKNOWNTYPE_STRUCT: + return AdaptStruct(arena, *to_adapt, std::move(adapted)); + default: + if (adapted) { + return adapted; + } + return std::monostate{}; + } +} + +} // namespace cel::well_known_types diff --git a/internal/well_known_types.h b/internal/well_known_types.h new file mode 100644 index 000000000..f63e5e76b --- /dev/null +++ b/internal/well_known_types.h @@ -0,0 +1,1593 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file provides handling for well known protocol buffer types, which is +// agnostic to whether the types are dynamic or generated. It also performs +// exhaustive verification of the structure of the well known message types, +// ensuring they will work as intended throughout the rest of our codebase. +// +// For each well know type, there is a class `XReflection` where `X` is the +// unqualified well know type name. Each class can be initialized from a +// descriptor pool or a descriptor. Once initialized, they can be used with +// messages which use that exact descriptor. Using them with a different version +// of the descriptor from a separate descriptor pool results in undefined +// behavior. If unsure, you can initialize multiple times. If initializing with +// the same descriptor, it is a noop. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_WELL_KNOWN_TYPES_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_WELL_KNOWN_TYPES_H_ + +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "common/any.h" +#include "common/memory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::well_known_types { + +// Strongly typed variant capable of holding the value representation of any +// protocol buffer message string field. We do this instead of type aliasing to +// avoid collisions in other variants such as `well_known_types::Value`. +class StringValue final : public absl::variant { + public: + using absl::variant::variant; + + bool ConsumePrefix(absl::string_view prefix); +}; + +// Older versions of GCC do not deal with inheriting from variant correctly when +// using `visit`, so we cheat by upcasting. +inline const absl::variant& AsVariant( + const StringValue& value) { + return static_cast&>( + value); +} +inline absl::variant& AsVariant( + StringValue& value) { + return static_cast&>(value); +} +inline const absl::variant&& AsVariant( + const StringValue&& value) { + return static_cast&&>( + value); +} +inline absl::variant&& AsVariant( + StringValue&& value) { + return static_cast&&>(value); +} + +inline bool operator==(const StringValue& lhs, const StringValue& rhs) { + return absl::visit( + [](const auto& lhs, const auto& rhs) { return lhs == rhs; }, + AsVariant(lhs), AsVariant(rhs)); +} + +inline bool operator!=(const StringValue& lhs, const StringValue& rhs) { + return !operator==(lhs, rhs); +} + +template +void AbslStringify(S& sink, const StringValue& value) { + sink.Append(absl::visit( + [&](const auto& value) -> std::string { return absl::StrCat(value); }, + AsVariant(value))); +} + +StringValue GetStringField(const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); +inline StringValue GetStringField( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return GetStringField(message.GetReflection(), message, field, scratch); +} + +StringValue GetRepeatedStringField( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); +inline StringValue GetRepeatedStringField( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return GetRepeatedStringField(message.GetReflection(), message, field, index, + scratch); +} + +// Strongly typed variant capable of holding the value representation of any +// protocol buffer message bytes field. We do this instead of type aliasing to +// avoid collisions in other variants such as `well_known_types::Value`. +class BytesValue final : public absl::variant { + public: + using absl::variant::variant; +}; + +// Older versions of GCC do not deal with inheriting from variant correctly when +// using `visit`, so we cheat by upcasting. +inline const absl::variant& AsVariant( + const BytesValue& value) { + return static_cast&>( + value); +} +inline absl::variant& AsVariant( + BytesValue& value) { + return static_cast&>(value); +} +inline const absl::variant&& AsVariant( + const BytesValue&& value) { + return static_cast&&>( + value); +} +inline absl::variant&& AsVariant( + BytesValue&& value) { + return static_cast&&>(value); +} + +inline bool operator==(const BytesValue& lhs, const BytesValue& rhs) { + return absl::visit( + [](const auto& lhs, const auto& rhs) { return lhs == rhs; }, + AsVariant(lhs), AsVariant(rhs)); +} + +inline bool operator!=(const BytesValue& lhs, const BytesValue& rhs) { + return !operator==(lhs, rhs); +} + +template +void AbslStringify(S& sink, const BytesValue& value) { + sink.Append(absl::visit( + [&](const auto& value) -> std::string { return absl::StrCat(value); }, + AsVariant(value))); +} + +BytesValue GetBytesField(const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); +inline BytesValue GetBytesField( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return GetBytesField(message.GetReflection(), message, field, scratch); +} + +BytesValue GetRepeatedBytesField( + const google::protobuf::Reflection* absl_nonnull reflection, + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); +inline BytesValue GetRepeatedBytesField( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::FieldDescriptor* absl_nonnull field, int index, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return GetRepeatedBytesField(message.GetReflection(), message, field, index, + scratch); +} + +class NullValueReflection final { + public: + NullValueReflection() = default; + NullValueReflection(const NullValueReflection&) = default; + NullValueReflection& operator=(const NullValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize( + const google::protobuf::EnumDescriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + private: + const google::protobuf::EnumDescriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::EnumValueDescriptor* absl_nullable value_ = nullptr; +}; + +class BoolValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE; + + using GeneratedMessageType = google::protobuf::BoolValue; + + static bool GetValue(const GeneratedMessageType& message) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, bool value) { + message->set_value(value); + } + + BoolValueReflection() = default; + BoolValueReflection(const BoolValueReflection&) = default; + BoolValueReflection& operator=(const BoolValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + bool GetValue(const google::protobuf::Message& message) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, bool value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; +}; + +absl::StatusOr GetBoolValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class Int32ValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE; + + using GeneratedMessageType = google::protobuf::Int32Value; + + static int32_t GetValue(const GeneratedMessageType& message) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + int32_t value) { + message->set_value(value); + } + + Int32ValueReflection() = default; + Int32ValueReflection(const Int32ValueReflection&) = default; + Int32ValueReflection& operator=(const Int32ValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + int32_t GetValue(const google::protobuf::Message& message) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, int32_t value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; +}; + +absl::StatusOr GetInt32ValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class Int64ValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE; + + using GeneratedMessageType = google::protobuf::Int64Value; + + static int64_t GetValue(const GeneratedMessageType& message) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + int64_t value) { + message->set_value(value); + } + + Int64ValueReflection() = default; + Int64ValueReflection(const Int64ValueReflection&) = default; + Int64ValueReflection& operator=(const Int64ValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + int64_t GetValue(const google::protobuf::Message& message) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, int64_t value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; +}; + +absl::StatusOr GetInt64ValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class UInt32ValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE; + + using GeneratedMessageType = google::protobuf::UInt32Value; + + static uint32_t GetValue(const GeneratedMessageType& message) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + uint32_t value) { + message->set_value(value); + } + + UInt32ValueReflection() = default; + UInt32ValueReflection(const UInt32ValueReflection&) = default; + UInt32ValueReflection& operator=(const UInt32ValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + uint32_t GetValue(const google::protobuf::Message& message) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, uint32_t value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; +}; + +absl::StatusOr GetUInt32ValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class UInt64ValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE; + + using GeneratedMessageType = google::protobuf::UInt64Value; + + static uint64_t GetValue(const GeneratedMessageType& message) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + uint64_t value) { + message->set_value(value); + } + + UInt64ValueReflection() = default; + UInt64ValueReflection(const UInt64ValueReflection&) = default; + UInt64ValueReflection& operator=(const UInt64ValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + uint64_t GetValue(const google::protobuf::Message& message) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, uint64_t value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; +}; + +absl::StatusOr GetUInt64ValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class FloatValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE; + + using GeneratedMessageType = google::protobuf::FloatValue; + + static float GetValue(const GeneratedMessageType& message) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + float value) { + message->set_value(value); + } + + FloatValueReflection() = default; + FloatValueReflection(const FloatValueReflection&) = default; + FloatValueReflection& operator=(const FloatValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + float GetValue(const google::protobuf::Message& message) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, float value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; +}; + +absl::StatusOr GetFloatValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class DoubleValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE; + + using GeneratedMessageType = google::protobuf::DoubleValue; + + static double GetValue(const GeneratedMessageType& message) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + double value) { + message->set_value(value); + } + + DoubleValueReflection() = default; + DoubleValueReflection(const DoubleValueReflection&) = default; + DoubleValueReflection& operator=(const DoubleValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + double GetValue(const google::protobuf::Message& message) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, double value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; +}; + +absl::StatusOr GetDoubleValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class BytesValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE; + + using GeneratedMessageType = google::protobuf::BytesValue; + + static absl::Cord GetValue(const GeneratedMessageType& message) { + return absl::Cord(message.value()); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + const absl::Cord& value) { + message->set_value(static_cast(value)); + } + + BytesValueReflection() = default; + BytesValueReflection(const BytesValueReflection&) = default; + BytesValueReflection& operator=(const BytesValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + BytesValue GetValue(const google::protobuf::Message& message + ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, + absl::string_view value) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; + google::protobuf::FieldDescriptor::CppStringType value_field_string_type_; +}; + +absl::StatusOr GetBytesValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class StringValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE; + + using GeneratedMessageType = google::protobuf::StringValue; + + static absl::string_view GetValue( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.value(); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + absl::string_view value) { + message->set_value(value); + } + + StringValueReflection() = default; + StringValueReflection(const StringValueReflection&) = default; + StringValueReflection& operator=(const StringValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + StringValue GetValue( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, + absl::string_view value) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; + google::protobuf::FieldDescriptor::CppStringType value_field_string_type_; +}; + +absl::StatusOr GetStringValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class AnyReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_ANY; + + using GeneratedMessageType = google::protobuf::Any; + + static absl::string_view GetTypeUrl( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.type_url(); + } + + static absl::Cord GetValue(const GeneratedMessageType& message) { + return GetAnyValueAsCord(message); + } + + static void SetTypeUrl(GeneratedMessageType* absl_nonnull message, + absl::string_view type_url) { + message->set_type_url(type_url); + } + + static void SetValue(GeneratedMessageType* absl_nonnull message, + const absl::Cord& value) { + SetAnyValueFromCord(message, value); + } + + AnyReflection() = default; + AnyReflection(const AnyReflection&) = default; + AnyReflection& operator=(const AnyReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + void SetTypeUrl(google::protobuf::Message* absl_nonnull message, + absl::string_view type_url) const; + + void SetValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const; + + StringValue GetTypeUrl( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + BytesValue GetValue(const google::protobuf::Message& message + ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable type_url_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; + google::protobuf::FieldDescriptor::CppStringType type_url_field_string_type_; + google::protobuf::FieldDescriptor::CppStringType value_field_string_type_; +}; + +absl::StatusOr GetAnyReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +AnyReflection GetAnyReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class DurationReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION; + + using GeneratedMessageType = google::protobuf::Duration; + + static int64_t GetSeconds(const GeneratedMessageType& message) { + return message.seconds(); + } + + static int64_t GetNanos(const GeneratedMessageType& message) { + return message.nanos(); + } + + static void SetSeconds(GeneratedMessageType* absl_nonnull message, + int64_t value) { + message->set_seconds(value); + } + + static void SetNanos(GeneratedMessageType* absl_nonnull message, + int32_t value) { + message->set_nanos(value); + } + + static absl::Status SetFromAbslDuration( + GeneratedMessageType* absl_nonnull message, absl::Duration duration); + + DurationReflection() = default; + DurationReflection(const DurationReflection&) = default; + DurationReflection& operator=(const DurationReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + int64_t GetSeconds(const google::protobuf::Message& message) const; + + int32_t GetNanos(const google::protobuf::Message& message) const; + + void SetSeconds(google::protobuf::Message* absl_nonnull message, int64_t value) const; + + void SetNanos(google::protobuf::Message* absl_nonnull message, int32_t value) const; + + absl::Status SetFromAbslDuration(google::protobuf::Message* absl_nonnull message, + absl::Duration duration) const; + + // Converts `absl::Duration` to `google.protobuf.Duration` without performing + // validity checks. Avoid use. + void UnsafeSetFromAbslDuration(google::protobuf::Message* absl_nonnull message, + absl::Duration duration) const; + + absl::StatusOr ToAbslDuration( + const google::protobuf::Message& message) const; + + // Converts `google.protobuf.Duration` to `absl::Duration` without performing + // validity checks. Avoid use. + absl::Duration UnsafeToAbslDuration(const google::protobuf::Message& message) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable seconds_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable nanos_field_ = nullptr; +}; + +absl::StatusOr GetDurationReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class TimestampReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP; + + using GeneratedMessageType = google::protobuf::Timestamp; + + static int64_t GetSeconds(const GeneratedMessageType& message) { + return message.seconds(); + } + + static int64_t GetNanos(const GeneratedMessageType& message) { + return message.nanos(); + } + + static void SetSeconds(GeneratedMessageType* absl_nonnull message, + int64_t value) { + message->set_seconds(value); + } + + static void SetNanos(GeneratedMessageType* absl_nonnull message, + int32_t value) { + message->set_nanos(value); + } + + static absl::Status SetFromAbslTime( + GeneratedMessageType* absl_nonnull message, absl::Time time); + + TimestampReflection() = default; + TimestampReflection(const TimestampReflection&) = default; + TimestampReflection& operator=(const TimestampReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + int64_t GetSeconds(const google::protobuf::Message& message) const; + + int32_t GetNanos(const google::protobuf::Message& message) const; + + void SetSeconds(google::protobuf::Message* absl_nonnull message, int64_t value) const; + + void SetNanos(google::protobuf::Message* absl_nonnull message, int32_t value) const; + + absl::StatusOr ToAbslTime(const google::protobuf::Message& message) const; + + // Converts `absl::Time` to `google.protobuf.Timestamp` without performing + // validity checks. Avoid use. + absl::Time UnsafeToAbslTime(const google::protobuf::Message& message) const; + + absl::Status SetFromAbslTime(google::protobuf::Message* absl_nonnull message, + absl::Time time) const; + + // Converts `google.protobuf.Timestamp` to `absl::Time` without performing + // validity checks. Avoid use. + void UnsafeSetFromAbslTime(google::protobuf::Message* absl_nonnull message, + absl::Time time) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable seconds_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable nanos_field_ = nullptr; +}; + +absl::StatusOr GetTimestampReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class ValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE; + + using GeneratedMessageType = google::protobuf::Value; + + static google::protobuf::Value::KindCase GetKindCase( + const google::protobuf::Value& message) { + return message.kind_case(); + } + + static bool GetBoolValue(const GeneratedMessageType& message) { + return message.bool_value(); + } + + static double GetNumberValue(const GeneratedMessageType& message) { + return message.number_value(); + } + + static absl::string_view GetStringValue( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.string_value(); + } + + static const google::protobuf::ListValue& GetListValue( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.list_value(); + } + + static const google::protobuf::Struct& GetStructValue( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.struct_value(); + } + + static void SetNullValue(GeneratedMessageType* absl_nonnull message) { + message->set_null_value(google::protobuf::NULL_VALUE); + } + + static void SetBoolValue(GeneratedMessageType* absl_nonnull message, + bool value) { + message->set_bool_value(value); + } + + static void SetNumberValue(GeneratedMessageType* absl_nonnull message, + int64_t value); + + static void SetNumberValue(GeneratedMessageType* absl_nonnull message, + uint64_t value); + + static void SetNumberValue(GeneratedMessageType* absl_nonnull message, + double value) { + message->set_number_value(value); + } + + static void SetStringValue(GeneratedMessageType* absl_nonnull message, + absl::string_view value) { + message->set_string_value(value); + } + + static void SetStringValue(GeneratedMessageType* absl_nonnull message, + const absl::Cord& value) { + message->set_string_value(static_cast(value)); + } + + static google::protobuf::ListValue* absl_nonnull MutableListValue( + GeneratedMessageType* absl_nonnull message) { + return message->mutable_list_value(); + } + + static google::protobuf::Struct* absl_nonnull MutableStructValue( + GeneratedMessageType* absl_nonnull message) { + return message->mutable_struct_value(); + } + + ValueReflection() = default; + ValueReflection(const ValueReflection&) = default; + ValueReflection& operator=(const ValueReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + const google::protobuf::Descriptor* absl_nonnull GetStructDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return struct_value_field_->message_type(); + } + + const google::protobuf::Descriptor* absl_nonnull GetListValueDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return list_value_field_->message_type(); + } + + google::protobuf::Value::KindCase GetKindCase( + const google::protobuf::Message& message) const; + + bool GetBoolValue(const google::protobuf::Message& message) const; + + double GetNumberValue(const google::protobuf::Message& message) const; + + StringValue GetStringValue( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + const google::protobuf::Message& GetListValue( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + const google::protobuf::Message& GetStructValue( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + void SetNullValue(google::protobuf::Message* absl_nonnull message) const; + + void SetBoolValue(google::protobuf::Message* absl_nonnull message, bool value) const; + + void SetNumberValue(google::protobuf::Message* absl_nonnull message, + int64_t value) const; + + void SetNumberValue(google::protobuf::Message* absl_nonnull message, + uint64_t value) const; + + void SetNumberValue(google::protobuf::Message* absl_nonnull message, + double value) const; + + void SetStringValue(google::protobuf::Message* absl_nonnull message, + absl::string_view value) const; + + void SetStringValue(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const; + + void SetStringValueFromBytes(google::protobuf::Message* absl_nonnull message, + absl::string_view value) const; + + void SetStringValueFromBytes(google::protobuf::Message* absl_nonnull message, + const absl::Cord& value) const; + + void SetStringValueFromDuration(google::protobuf::Message* absl_nonnull message, + absl::Duration duration) const; + + void SetStringValueFromTimestamp(google::protobuf::Message* absl_nonnull message, + absl::Time time) const; + + google::protobuf::Message* absl_nonnull MutableListValue( + google::protobuf::Message* absl_nonnull message) const; + + google::protobuf::Message* absl_nonnull MutableStructValue( + google::protobuf::Message* absl_nonnull message) const; + + Unique ReleaseListValue( + google::protobuf::Message* absl_nonnull message) const; + + Unique ReleaseStructValue( + google::protobuf::Message* absl_nonnull message) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::OneofDescriptor* absl_nullable kind_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable null_value_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable bool_value_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable number_value_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable string_value_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable list_value_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable struct_value_field_ = nullptr; + google::protobuf::FieldDescriptor::CppStringType string_value_field_string_type_; +}; + +absl::StatusOr GetValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// `GetValueReflectionOrDie()` is the same as `GetValueReflection` +// except that it aborts if `descriptor` is not a well formed descriptor of +// `google.protobuf.Value`. This should only be used in places where it is +// guaranteed that the aforementioned prerequisites are met. +ValueReflection GetValueReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class ListValueReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE; + + using GeneratedMessageType = google::protobuf::ListValue; + + static int ValuesSize(const GeneratedMessageType& message) { + return message.values_size(); + } + + static const google::protobuf::RepeatedPtrField& Values( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.values(); + } + + static const google::protobuf::Value& Values( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + int index) { + return message.values(index); + } + + static google::protobuf::RepeatedPtrField& MutableValues( + GeneratedMessageType* absl_nonnull message + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return *message->mutable_values(); + } + + static google::protobuf::Value* absl_nonnull AddValues( + GeneratedMessageType* absl_nonnull message + ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message->add_values(); + } + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + const google::protobuf::Descriptor* absl_nonnull GetValueDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return values_field_->message_type(); + } + + const google::protobuf::FieldDescriptor* absl_nonnull GetValuesDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return values_field_; + } + + int ValuesSize(const google::protobuf::Message& message) const; + + google::protobuf::RepeatedFieldRef Values( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + const google::protobuf::Message& Values(const google::protobuf::Message& message + ABSL_ATTRIBUTE_LIFETIME_BOUND, + int index) const; + + google::protobuf::MutableRepeatedFieldRef MutableValues( + google::protobuf::Message* absl_nonnull message + ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + google::protobuf::Message* absl_nonnull AddValues( + google::protobuf::Message* absl_nonnull message) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable values_field_ = nullptr; +}; + +absl::StatusOr GetListValueReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// `GetListValueReflectionOrDie()` is the same as `GetListValueReflection` +// except that it aborts if `descriptor` is not a well formed descriptor of +// `google.protobuf.ListValue`. This should only be used in places where it is +// guaranteed that the aforementioned prerequisites are met. +ListValueReflection GetListValueReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class StructReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT; + + using GeneratedMessageType = google::protobuf::Struct; + + static int FieldsSize(const GeneratedMessageType& message) { + return message.fields_size(); + } + + static auto BeginFields( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.fields().begin(); + } + + static auto EndFields( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return message.fields().end(); + } + + static bool ContainsField(const GeneratedMessageType& message, + absl::string_view name) { + return message.fields().contains(name); + } + + static const google::protobuf::Value* absl_nullable FindField( + const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view name) { + if (auto it = message.fields().find(name); it != message.fields().end()) { + return &it->second; + } + return nullptr; + } + + static google::protobuf::Value* absl_nonnull InsertField( + GeneratedMessageType* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view name) { + return &(*message->mutable_fields())[name]; + } + + static bool DeleteField(GeneratedMessageType* absl_nonnull message, + absl::string_view name) { + return message->mutable_fields()->erase(name) > 0; + } + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + const google::protobuf::Descriptor* absl_nonnull GetValueDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return fields_value_field_->message_type(); + } + + const google::protobuf::FieldDescriptor* absl_nonnull GetFieldsDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return fields_field_; + } + + int FieldsSize(const google::protobuf::Message& message) const; + + google::protobuf::ConstMapIterator BeginFields( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + google::protobuf::ConstMapIterator EndFields( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + bool ContainsField(const google::protobuf::Message& message, + absl::string_view name) const; + + const google::protobuf::Message* absl_nullable FindField( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view name) const; + + google::protobuf::Message* absl_nonnull InsertField( + google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view name) const; + + bool DeleteField(google::protobuf::Message* absl_nonnull message + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view name) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable fields_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable fields_key_field_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable fields_value_field_ = nullptr; +}; + +absl::StatusOr GetStructReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// `GetStructReflectionOrDie()` is the same as `GetStructReflection` +// except that it aborts if `descriptor` is not a well formed descriptor of +// `google.protobuf.Struct`. This should only be used in places where it is +// guaranteed that the aforementioned prerequisites are met. +StructReflection GetStructReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class FieldMaskReflection final { + public: + static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = + google::protobuf::Descriptor::WELLKNOWNTYPE_FIELDMASK; + + using GeneratedMessageType = google::protobuf::FieldMask; + + static int PathsSize(const GeneratedMessageType& message) { + return message.paths_size(); + } + + static absl::string_view Paths(const GeneratedMessageType& message + ABSL_ATTRIBUTE_LIFETIME_BOUND, + int index) { + return message.paths(index); + } + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const { return descriptor_ != nullptr; } + + const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return descriptor_; + } + + int PathsSize(const google::protobuf::Message& message) const; + + StringValue Paths( + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, int index, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + + private: + const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; + const google::protobuf::FieldDescriptor* absl_nullable paths_field_ = nullptr; + google::protobuf::FieldDescriptor::CppStringType paths_field_string_type_; +}; + +absl::StatusOr GetFieldMaskReflection( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + +using ListValuePtr = Unique; + +using ListValueConstRef = std::reference_wrapper; + +using StructPtr = Unique; + +using StructConstRef = std::reference_wrapper; + +// Variant holding `std::reference_wrapper` or `Unique`, either of which is an +// instance of `google.protobuf.ListValue` which is either a generated message +// or dynamic message. +class ListValue final : public absl::variant { + using absl::variant::variant; +}; + +// Older versions of GCC do not deal with inheriting from variant correctly when +// using `visit`, so we cheat by upcasting. +inline const absl::variant& AsVariant( + const ListValue& value) { + return static_cast&>( + value); +} +inline absl::variant& AsVariant( + ListValue& value) { + return static_cast&>(value); +} +inline const absl::variant&& AsVariant( + const ListValue&& value) { + return static_cast&&>( + value); +} +inline absl::variant&& AsVariant( + ListValue&& value) { + return static_cast&&>(value); +} + +// Variant holding `std::reference_wrapper` or `Unique`, either of which is an +// instance of `google.protobuf.Struct` which is either a generated message or +// dynamic message. +class Struct final : public absl::variant { + public: + using absl::variant::variant; +}; + +// Older versions of GCC do not deal with inheriting from variant correctly when +// using `visit`, so we cheat by upcasting. +inline const absl::variant& AsVariant( + const Struct& value) { + return static_cast&>(value); +} +inline absl::variant& AsVariant(Struct& value) { + return static_cast&>(value); +} +inline const absl::variant&& AsVariant( + const Struct&& value) { + return static_cast&&>(value); +} +inline absl::variant&& AsVariant(Struct&& value) { + return static_cast&&>(value); +} + +// Variant capable of representing any unwrapped well known type or message. +using Value = absl::variant>; + +// Unpacks the given instance of `google.protobuf.Any`. +absl::StatusOr> UnpackAnyFrom( + google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + AnyReflection& reflection, const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull factory ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Unpacks the given instance of `google.protobuf.Any` if it is resolvable. +absl::StatusOr> UnpackAnyIfResolveable( + google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + AnyReflection& reflection, const google::protobuf::Message& message, + const google::protobuf::DescriptorPool* absl_nonnull pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull factory ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Performs any necessary unwrapping of a well known message type. If no +// unwrapping is necessary, the resulting `Value` holds the alternative +// `absl::monostate`. +absl::StatusOr AdaptFromMessage( + google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + const google::protobuf::DescriptorPool* absl_nonnull pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nonnull factory ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); + +class JsonReflection final { + public: + JsonReflection() = default; + JsonReflection(const JsonReflection&) = default; + JsonReflection& operator=(const JsonReflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); + + bool IsInitialized() const; + + ValueReflection& Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } + + ListValueReflection& ListValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return list_value_; + } + + StructReflection& Struct() ABSL_ATTRIBUTE_LIFETIME_BOUND { return struct_; } + + const ValueReflection& Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_; + } + + const ListValueReflection& ListValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return list_value_; + } + + const StructReflection& Struct() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return struct_; + } + + private: + ValueReflection value_; + ListValueReflection list_value_; + StructReflection struct_; +}; + +class Reflection final { + public: + Reflection() = default; + Reflection(const Reflection&) = default; + Reflection& operator=(const Reflection&) = default; + + absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); + + bool IsInitialized() const; + + // At the moment we only use this class for verifying well known types in + // descriptor pools. We could eagerly initialize it and cache it somewhere to + // make things faster. + + BoolValueReflection& BoolValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return bool_value_; + } + + Int32ValueReflection& Int32Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return int32_value_; + } + + Int64ValueReflection& Int64Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return int64_value_; + } + + UInt32ValueReflection& UInt32Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return uint32_value_; + } + + UInt64ValueReflection& UInt64Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return uint64_value_; + } + + FloatValueReflection& FloatValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return float_value_; + } + + DoubleValueReflection& DoubleValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return double_value_; + } + + BytesValueReflection& BytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return bytes_value_; + } + + StringValueReflection& StringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return string_value_; + } + + AnyReflection& Any() ABSL_ATTRIBUTE_LIFETIME_BOUND { return any_; } + + DurationReflection& Duration() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return duration_; + } + + TimestampReflection& Timestamp() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return timestamp_; + } + + JsonReflection& Json() ABSL_ATTRIBUTE_LIFETIME_BOUND { return json_; } + + ValueReflection& Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().Value(); + } + + ListValueReflection& ListValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().ListValue(); + } + + StructReflection& Struct() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().Struct(); + } + + FieldMaskReflection& FieldMask() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return field_mask_; + } + + const BoolValueReflection& BoolValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return bool_value_; + } + + const Int32ValueReflection& Int32Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return int32_value_; + } + + const Int64ValueReflection& Int64Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return int64_value_; + } + + const UInt32ValueReflection& UInt32Value() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return uint32_value_; + } + + const UInt64ValueReflection& UInt64Value() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return uint64_value_; + } + + const FloatValueReflection& FloatValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return float_value_; + } + + const DoubleValueReflection& DoubleValue() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return double_value_; + } + + const BytesValueReflection& BytesValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return bytes_value_; + } + + const StringValueReflection& StringValue() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return string_value_; + } + + const AnyReflection& Any() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return any_; + } + + const DurationReflection& Duration() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return duration_; + } + + const TimestampReflection& Timestamp() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return timestamp_; + } + + const JsonReflection& Json() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return json_; + } + + const ValueReflection& Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().Value(); + } + + const ListValueReflection& ListValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().ListValue(); + } + + const StructReflection& Struct() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().Struct(); + } + + const FieldMaskReflection& FieldMask() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return field_mask_; + } + + private: + NullValueReflection& NullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return null_value_; + } + + const NullValueReflection& NullValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return null_value_; + } + + NullValueReflection null_value_; + BoolValueReflection bool_value_; + Int32ValueReflection int32_value_; + Int64ValueReflection int64_value_; + UInt32ValueReflection uint32_value_; + UInt64ValueReflection uint64_value_; + FloatValueReflection float_value_; + DoubleValueReflection double_value_; + BytesValueReflection bytes_value_; + StringValueReflection string_value_; + AnyReflection any_; + DurationReflection duration_; + TimestampReflection timestamp_; + JsonReflection json_; + FieldMaskReflection field_mask_; +}; + +} // namespace cel::well_known_types + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_WELL_KNOWN_TYPES_H_ diff --git a/internal/well_known_types_test.cc b/internal/well_known_types_test.cc new file mode 100644 index 000000000..afc8ce396 --- /dev/null +++ b/internal/well_known_types_test.cc @@ -0,0 +1,978 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/well_known_types.h" + +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "common/memory.h" +#include "internal/message_type_name.h" +#include "internal/minimal_descriptor_pool.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::well_known_types { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::GetMinimalDescriptorPool; +using ::cel::internal::GetTestingDescriptorPool; +using ::cel::internal::GetTestingMessageFactory; +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::IsNull; +using ::testing::NotNull; +using ::testing::Test; +using ::testing::VariantWith; + +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; + +class ReflectionTest : public Test { + public: + google::protobuf::Arena* absl_nonnull arena() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return &arena_; + } + + std::string& scratch_space() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return scratch_space_; + } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return GetTestingMessageFactory(); + } + + template + T* absl_nonnull MakeGenerated() { + return google::protobuf::Arena::Create(arena()); + } + + template + google::protobuf::Message* absl_nonnull MakeDynamic() { + const auto* descriptor = + ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( + internal::MessageTypeNameFor())); + const auto* prototype = + ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); + return prototype->New(arena()); + } + + private: + google::protobuf::Arena arena_; + std::string scratch_space_; +}; + +TEST_F(ReflectionTest, MinimalDescriptorPool) { + EXPECT_THAT(Reflection().Initialize(GetMinimalDescriptorPool()), IsOk()); +} + +TEST_F(ReflectionTest, TestingDescriptorPool) { + EXPECT_THAT(Reflection().Initialize(GetTestingDescriptorPool()), IsOk()); +} + +TEST_F(ReflectionTest, BoolValue_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(BoolValueReflection::GetValue(*value), false); + BoolValueReflection::SetValue(value, true); + EXPECT_EQ(BoolValueReflection::GetValue(*value), true); +} + +TEST_F(ReflectionTest, BoolValue_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetBoolValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value), false); + reflection.SetValue(value, true); + EXPECT_EQ(reflection.GetValue(*value), true); +} + +TEST_F(ReflectionTest, Int32Value_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(Int32ValueReflection::GetValue(*value), 0); + Int32ValueReflection::SetValue(value, 1); + EXPECT_EQ(Int32ValueReflection::GetValue(*value), 1); +} + +TEST_F(ReflectionTest, Int32Value_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetInt32ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value), 0); + reflection.SetValue(value, 1); + EXPECT_EQ(reflection.GetValue(*value), 1); +} + +TEST_F(ReflectionTest, Int64Value_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(Int64ValueReflection::GetValue(*value), 0); + Int64ValueReflection::SetValue(value, 1); + EXPECT_EQ(Int64ValueReflection::GetValue(*value), 1); +} + +TEST_F(ReflectionTest, Int64Value_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetInt64ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value), 0); + reflection.SetValue(value, 1); + EXPECT_EQ(reflection.GetValue(*value), 1); +} + +TEST_F(ReflectionTest, UInt32Value_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(UInt32ValueReflection::GetValue(*value), 0); + UInt32ValueReflection::SetValue(value, 1); + EXPECT_EQ(UInt32ValueReflection::GetValue(*value), 1); +} + +TEST_F(ReflectionTest, UInt32Value_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetUInt32ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value), 0); + reflection.SetValue(value, 1); + EXPECT_EQ(reflection.GetValue(*value), 1); +} + +TEST_F(ReflectionTest, UInt64Value_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(UInt64ValueReflection::GetValue(*value), 0); + UInt64ValueReflection::SetValue(value, 1); + EXPECT_EQ(UInt64ValueReflection::GetValue(*value), 1); +} + +TEST_F(ReflectionTest, UInt64Value_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetUInt64ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value), 0); + reflection.SetValue(value, 1); + EXPECT_EQ(reflection.GetValue(*value), 1); +} + +TEST_F(ReflectionTest, FloatValue_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(FloatValueReflection::GetValue(*value), 0); + FloatValueReflection::SetValue(value, 1); + EXPECT_EQ(FloatValueReflection::GetValue(*value), 1); +} + +TEST_F(ReflectionTest, FloatValue_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetFloatValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value), 0); + reflection.SetValue(value, 1); + EXPECT_EQ(reflection.GetValue(*value), 1); +} + +TEST_F(ReflectionTest, DoubleValue_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(DoubleValueReflection::GetValue(*value), 0); + DoubleValueReflection::SetValue(value, 1); + EXPECT_EQ(DoubleValueReflection::GetValue(*value), 1); +} + +TEST_F(ReflectionTest, DoubleValue_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetDoubleValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value), 0); + reflection.SetValue(value, 1); + EXPECT_EQ(reflection.GetValue(*value), 1); +} + +TEST_F(ReflectionTest, BytesValue_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(BytesValueReflection::GetValue(*value), ""); + BytesValueReflection::SetValue(value, absl::Cord("Hello World!")); + EXPECT_EQ(BytesValueReflection::GetValue(*value), "Hello World!"); +} + +TEST_F(ReflectionTest, BytesValue_Dynamic) { + auto* value = MakeDynamic(); + std::string scratch; + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetBytesValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value, scratch), ""); + reflection.SetValue(value, "Hello World!"); + EXPECT_EQ(reflection.GetValue(*value, scratch), "Hello World!"); + reflection.SetValue(value, absl::Cord()); + EXPECT_EQ(reflection.GetValue(*value, scratch), ""); +} + +TEST_F(ReflectionTest, StringValue_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(StringValueReflection::GetValue(*value), ""); + StringValueReflection::SetValue(value, "Hello World!"); + EXPECT_EQ(StringValueReflection::GetValue(*value), "Hello World!"); +} + +TEST_F(ReflectionTest, StringValue_Dynamic) { + auto* value = MakeDynamic(); + std::string scratch; + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetStringValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetValue(*value, scratch), ""); + reflection.SetValue(value, "Hello World!"); + EXPECT_EQ(reflection.GetValue(*value, scratch), "Hello World!"); + reflection.SetValue(value, absl::Cord()); + EXPECT_EQ(reflection.GetValue(*value, scratch), ""); +} + +TEST_F(ReflectionTest, Any_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(AnyReflection::GetTypeUrl(*value), ""); + AnyReflection::SetTypeUrl(value, "Hello World!"); + EXPECT_EQ(AnyReflection::GetTypeUrl(*value), "Hello World!"); + EXPECT_EQ(AnyReflection::GetValue(*value), ""); + AnyReflection::SetValue(value, absl::Cord("Hello World!")); + EXPECT_EQ(AnyReflection::GetValue(*value), "Hello World!"); +} + +TEST_F(ReflectionTest, Any_Dynamic) { + auto* value = MakeDynamic(); + std::string scratch; + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetAnyReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetTypeUrl(*value, scratch), ""); + reflection.SetTypeUrl(value, "Hello World!"); + EXPECT_EQ(reflection.GetTypeUrl(*value, scratch), "Hello World!"); + EXPECT_EQ(reflection.GetValue(*value, scratch), ""); + reflection.SetValue(value, absl::Cord("Hello World!")); + EXPECT_EQ(reflection.GetValue(*value, scratch), "Hello World!"); +} + +TEST_F(ReflectionTest, Duration_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(DurationReflection::GetSeconds(*value), 0); + DurationReflection::SetSeconds(value, 1); + EXPECT_EQ(DurationReflection::GetSeconds(*value), 1); + EXPECT_EQ(DurationReflection::GetNanos(*value), 0); + DurationReflection::SetNanos(value, 1); + EXPECT_EQ(DurationReflection::GetNanos(*value), 1); + + EXPECT_THAT(DurationReflection::SetFromAbslDuration( + value, absl::Seconds(1) + absl::Nanoseconds(1)), + IsOk()); + EXPECT_EQ(value->seconds(), 1); + EXPECT_EQ(value->nanos(), 1); + + EXPECT_THAT( + DurationReflection::SetFromAbslDuration(value, absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT( + DurationReflection::SetFromAbslDuration(value, -absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(ReflectionTest, Duration_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetDurationReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetSeconds(*value), 0); + reflection.SetSeconds(value, 1); + EXPECT_EQ(reflection.GetSeconds(*value), 1); + EXPECT_EQ(reflection.GetNanos(*value), 0); + reflection.SetNanos(value, 1); + EXPECT_EQ(reflection.GetNanos(*value), 1); + + EXPECT_THAT(reflection.SetFromAbslDuration( + value, absl::Seconds(1) + absl::Nanoseconds(1)), + IsOk()); + EXPECT_EQ(reflection.GetSeconds(*value), 1); + EXPECT_EQ(reflection.GetNanos(*value), 1); + + EXPECT_THAT(reflection.SetFromAbslDuration(value, absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(reflection.SetFromAbslDuration(value, -absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(ReflectionTest, Timestamp_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(TimestampReflection::GetSeconds(*value), 0); + TimestampReflection::SetSeconds(value, 1); + EXPECT_EQ(TimestampReflection::GetSeconds(*value), 1); + EXPECT_EQ(TimestampReflection::GetNanos(*value), 0); + TimestampReflection::SetNanos(value, 1); + EXPECT_EQ(TimestampReflection::GetNanos(*value), 1); + + EXPECT_THAT( + TimestampReflection::SetFromAbslTime( + value, absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)), + IsOk()); + EXPECT_EQ(value->seconds(), 1); + EXPECT_EQ(value->nanos(), 1); + + EXPECT_THAT( + TimestampReflection::SetFromAbslTime(value, absl::InfiniteFuture()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(TimestampReflection::SetFromAbslTime(value, absl::InfinitePast()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(ReflectionTest, Timestamp_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetTimestampReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetSeconds(*value), 0); + reflection.SetSeconds(value, 1); + EXPECT_EQ(reflection.GetSeconds(*value), 1); + EXPECT_EQ(reflection.GetNanos(*value), 0); + reflection.SetNanos(value, 1); + EXPECT_EQ(reflection.GetNanos(*value), 1); + + EXPECT_THAT( + reflection.SetFromAbslTime( + value, absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)), + IsOk()); + EXPECT_EQ(reflection.GetSeconds(*value), 1); + EXPECT_EQ(reflection.GetNanos(*value), 1); + + EXPECT_THAT(reflection.SetFromAbslTime(value, absl::InfiniteFuture()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(reflection.SetFromAbslTime(value, absl::InfinitePast()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(ReflectionTest, Value_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(ValueReflection::GetKindCase(*value), + google::protobuf::Value::KIND_NOT_SET); + ValueReflection::SetNullValue(value); + EXPECT_EQ(ValueReflection::GetKindCase(*value), + google::protobuf::Value::kNullValue); + ValueReflection::SetBoolValue(value, true); + EXPECT_EQ(ValueReflection::GetKindCase(*value), + google::protobuf::Value::kBoolValue); + EXPECT_EQ(ValueReflection::GetBoolValue(*value), true); + ValueReflection::SetNumberValue(value, 1.0); + EXPECT_EQ(ValueReflection::GetKindCase(*value), + google::protobuf::Value::kNumberValue); + EXPECT_EQ(ValueReflection::GetNumberValue(*value), 1.0); + ValueReflection::SetStringValue(value, "Hello World!"); + EXPECT_EQ(ValueReflection::GetKindCase(*value), + google::protobuf::Value::kStringValue); + EXPECT_EQ(ValueReflection::GetStringValue(*value), "Hello World!"); + ValueReflection::MutableListValue(value); + EXPECT_EQ(ValueReflection::GetKindCase(*value), + google::protobuf::Value::kListValue); + EXPECT_EQ(ValueReflection::GetListValue(*value).ByteSizeLong(), 0); + ValueReflection::MutableStructValue(value); + EXPECT_EQ(ValueReflection::GetKindCase(*value), + google::protobuf::Value::kStructValue); + EXPECT_EQ(ValueReflection::GetStructValue(*value).ByteSizeLong(), 0); +} + +TEST_F(ReflectionTest, Value_Dynamic) { + auto* value = MakeDynamic(); + std::string scratch; + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.GetKindCase(*value), + google::protobuf::Value::KIND_NOT_SET); + reflection.SetNullValue(value); + EXPECT_EQ(reflection.GetKindCase(*value), + google::protobuf::Value::kNullValue); + reflection.SetBoolValue(value, true); + EXPECT_EQ(reflection.GetKindCase(*value), + google::protobuf::Value::kBoolValue); + EXPECT_EQ(reflection.GetBoolValue(*value), true); + reflection.SetNumberValue(value, 1.0); + EXPECT_EQ(reflection.GetKindCase(*value), + google::protobuf::Value::kNumberValue); + EXPECT_EQ(reflection.GetNumberValue(*value), 1.0); + reflection.SetStringValue(value, "Hello World!"); + EXPECT_EQ(reflection.GetKindCase(*value), + google::protobuf::Value::kStringValue); + EXPECT_EQ(reflection.GetStringValue(*value, scratch), "Hello World!"); + reflection.MutableListValue(value); + EXPECT_EQ(reflection.GetKindCase(*value), + google::protobuf::Value::kListValue); + EXPECT_EQ(reflection.GetListValue(*value).ByteSizeLong(), 0); + EXPECT_THAT(reflection.ReleaseListValue(value), NotNull()); + reflection.MutableStructValue(value); + EXPECT_EQ(reflection.GetKindCase(*value), + google::protobuf::Value::kStructValue); + EXPECT_EQ(reflection.GetStructValue(*value).ByteSizeLong(), 0); + EXPECT_THAT(reflection.ReleaseStructValue(value), NotNull()); +} + +TEST_F(ReflectionTest, ListValue_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(ListValueReflection::ValuesSize(*value), 0); + EXPECT_EQ(ListValueReflection::Values(*value).size(), 0); + EXPECT_EQ(ListValueReflection::MutableValues(value).size(), 0); +} + +TEST_F(ReflectionTest, ListValue_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetListValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.ValuesSize(*value), 0); + EXPECT_EQ(reflection.Values(*value).size(), 0); + EXPECT_EQ(reflection.MutableValues(value).size(), 0); +} + +TEST_F(ReflectionTest, StructValue_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(StructReflection::FieldsSize(*value), 0); + EXPECT_EQ(StructReflection::BeginFields(*value), + StructReflection::EndFields(*value)); + EXPECT_FALSE(StructReflection::ContainsField(*value, "foo")); + EXPECT_THAT(StructReflection::FindField(*value, "foo"), IsNull()); + EXPECT_THAT(StructReflection::InsertField(value, "foo"), NotNull()); + EXPECT_TRUE(StructReflection::DeleteField(value, "foo")); +} + +TEST_F(ReflectionTest, StructValue_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetStructReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.FieldsSize(*value), 0); + EXPECT_EQ(reflection.BeginFields(*value), reflection.EndFields(*value)); + EXPECT_FALSE(reflection.ContainsField(*value, "foo")); + EXPECT_THAT(reflection.FindField(*value, "foo"), IsNull()); + EXPECT_THAT(reflection.InsertField(value, "foo"), NotNull()); + EXPECT_TRUE(reflection.DeleteField(value, "foo")); +} + +TEST_F(ReflectionTest, FieldMask_Generated) { + auto* value = MakeGenerated(); + EXPECT_EQ(FieldMaskReflection::PathsSize(*value), 0); + value->add_paths("foo"); + EXPECT_EQ(FieldMaskReflection::PathsSize(*value), 1); + EXPECT_EQ(FieldMaskReflection::Paths(*value, 0), "foo"); +} + +TEST_F(ReflectionTest, FieldMask_Dynamic) { + auto* value = MakeDynamic(); + ASSERT_OK_AND_ASSIGN( + auto reflection, + GetFieldMaskReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); + EXPECT_EQ(reflection.PathsSize(*value), 0); + value->GetReflection()->AddString( + &*value, + ABSL_DIE_IF_NULL(value->GetDescriptor()->FindFieldByName("paths")), + "foo"); + EXPECT_EQ(reflection.PathsSize(*value), 1); + EXPECT_EQ(reflection.Paths(*value, 0, scratch_space()), "foo"); +} + +TEST_F(ReflectionTest, NullValue_MissingValue) { + google::protobuf::DescriptorPool descriptor_pool; + { + google::protobuf::FileDescriptorProto file_proto; + file_proto.set_name("google/protobuf/struct.proto"); + file_proto.set_syntax("editions"); + file_proto.set_edition(google::protobuf::EDITION_2023); + file_proto.set_package("google.protobuf"); + auto* enum_proto = file_proto.add_enum_type(); + enum_proto->set_name("NullValue"); + auto* value_proto = enum_proto->add_value(); + value_proto->set_number(1); + value_proto->set_name("NULL_VALUE"); + enum_proto->mutable_options()->mutable_features()->set_enum_type( + google::protobuf::FeatureSet::CLOSED); + ASSERT_THAT(descriptor_pool.BuildFile(file_proto), NotNull()); + } + EXPECT_THAT( + NullValueReflection().Initialize(&descriptor_pool), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("well known protocol buffer enum missing value: "))); +} + +TEST_F(ReflectionTest, NullValue_MultipleValues) { + google::protobuf::DescriptorPool descriptor_pool; + { + google::protobuf::FileDescriptorProto file_proto; + file_proto.set_name("google/protobuf/struct.proto"); + file_proto.set_syntax("proto3"); + file_proto.set_package("google.protobuf"); + auto* enum_proto = file_proto.add_enum_type(); + enum_proto->set_name("NullValue"); + auto* value_proto = enum_proto->add_value(); + value_proto->set_number(0); + value_proto->set_name("NULL_VALUE"); + value_proto = enum_proto->add_value(); + value_proto->set_number(1); + value_proto->set_name("NULL_VALUE2"); + ASSERT_THAT(descriptor_pool.BuildFile(file_proto), NotNull()); + } + EXPECT_THAT( + NullValueReflection().Initialize(&descriptor_pool), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("well known protocol buffer enum has multiple values: "))); +} + +TEST_F(ReflectionTest, EnumDescriptorMissing) { + google::protobuf::DescriptorPool descriptor_pool; + EXPECT_THAT(NullValueReflection().Initialize(&descriptor_pool), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("descriptor missing for protocol buffer enum " + "well known type: "))); +} + +TEST_F(ReflectionTest, MessageDescriptorMissing) { + google::protobuf::DescriptorPool descriptor_pool; + EXPECT_THAT(BoolValueReflection().Initialize(&descriptor_pool), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("descriptor missing for protocol buffer " + "message well known type: "))); +} + +class AdaptFromMessageTest : public Test { + public: + google::protobuf::Arena* absl_nonnull arena() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return &arena_; + } + + std::string& scratch_space() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return scratch_space_; + } + + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { + return GetTestingDescriptorPool(); + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() { + return GetTestingMessageFactory(); + } + + template + google::protobuf::Message* absl_nonnull MakeDynamic() { + const auto* descriptor_pool = GetTestingDescriptorPool(); + const auto* descriptor = + ABSL_DIE_IF_NULL(descriptor_pool->FindMessageTypeByName( + internal::MessageTypeNameFor())); + const auto* prototype = + ABSL_DIE_IF_NULL(GetTestingMessageFactory()->GetPrototype(descriptor)); + return prototype->New(arena()); + } + + template + google::protobuf::Message* DynamicParseTextProto(absl::string_view text) { + return ::cel::internal::DynamicParseTextProto( + arena(), text, descriptor_pool(), message_factory()); + } + + absl::StatusOr AdaptFromMessage(const google::protobuf::Message& message) { + return well_known_types::AdaptFromMessage( + arena(), message, descriptor_pool(), message_factory(), + scratch_space()); + } + + private: + google::protobuf::Arena arena_; + std::string scratch_space_; +}; + +TEST_F(AdaptFromMessageTest, BoolValue) { + auto message = + DynamicParseTextProto(R"pb(value: true)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(true))); +} + +TEST_F(AdaptFromMessageTest, Int32Value) { + auto message = + DynamicParseTextProto(R"pb(value: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(1))); +} + +TEST_F(AdaptFromMessageTest, Int64Value) { + auto message = + DynamicParseTextProto(R"pb(value: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(1))); +} + +TEST_F(AdaptFromMessageTest, UInt32Value) { + auto message = + DynamicParseTextProto(R"pb(value: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(1))); +} + +TEST_F(AdaptFromMessageTest, UInt64Value) { + auto message = + DynamicParseTextProto(R"pb(value: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(1))); +} + +TEST_F(AdaptFromMessageTest, FloatValue) { + auto message = + DynamicParseTextProto(R"pb(value: 1.0)pb"); + EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); +} + +TEST_F(AdaptFromMessageTest, DoubleValue) { + auto message = + DynamicParseTextProto(R"pb(value: 1.0)pb"); + EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); +} + +TEST_F(AdaptFromMessageTest, BytesValue) { + auto message = DynamicParseTextProto( + R"pb(value: "foo")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(BytesValue("foo")))); +} + +TEST_F(AdaptFromMessageTest, StringValue) { + auto message = DynamicParseTextProto( + R"pb(value: "foo")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(StringValue("foo")))); +} + +TEST_F(AdaptFromMessageTest, Duration) { + auto message = DynamicParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(absl::Seconds(1) + + absl::Nanoseconds(1)))); +} + +TEST_F(AdaptFromMessageTest, Duration_SecondsOutOfRange) { + auto message = DynamicParseTextProto( + R"pb(seconds: 0x7fffffffffffffff nanos: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid duration seconds: "))); +} + +TEST_F(AdaptFromMessageTest, Duration_NanosOutOfRange) { + auto message = DynamicParseTextProto( + R"pb(seconds: 1 nanos: 0x7fffffff)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid duration nanoseconds: "))); +} + +TEST_F(AdaptFromMessageTest, Duration_SignMismatch) { + auto message = + DynamicParseTextProto(R"pb(seconds: -1 + nanos: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("duration sign mismatch: "))); +} + +TEST_F(AdaptFromMessageTest, Timestamp) { + auto message = + DynamicParseTextProto(R"pb(seconds: 1 + nanos: 1)pb"); + EXPECT_THAT( + AdaptFromMessage(*message), + IsOkAndHolds(VariantWith( + absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)))); +} + +TEST_F(AdaptFromMessageTest, Timestamp_SecondsOutOfRange) { + auto message = DynamicParseTextProto( + R"pb(seconds: 0x7fffffffffffffff nanos: 1)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid timestamp seconds: "))); +} + +TEST_F(AdaptFromMessageTest, Timestamp_NanosOutOfRange) { + auto message = DynamicParseTextProto( + R"pb(seconds: 1 nanos: 0x7fffffff)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid timestamp nanoseconds: "))); +} + +TEST_F(AdaptFromMessageTest, Value_NullValue) { + auto message = DynamicParseTextProto( + R"pb(null_value: NULL_VALUE)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(nullptr))); +} + +TEST_F(AdaptFromMessageTest, Value_BoolValue) { + auto message = + DynamicParseTextProto(R"pb(bool_value: true)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(true))); +} + +TEST_F(AdaptFromMessageTest, Value_NumberValue) { + auto message = DynamicParseTextProto( + R"pb(number_value: 1.0)pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(1.0))); +} + +TEST_F(AdaptFromMessageTest, Value_StringValue) { + auto message = DynamicParseTextProto( + R"pb(string_value: "foo")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(StringValue("foo")))); +} + +TEST_F(AdaptFromMessageTest, Value_ListValue) { + auto message = + DynamicParseTextProto(R"pb(list_value: {})pb"); + EXPECT_THAT( + AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(VariantWith(_)))); +} + +TEST_F(AdaptFromMessageTest, Value_StructValue) { + auto message = + DynamicParseTextProto(R"pb(struct_value: {})pb"); + EXPECT_THAT( + AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(VariantWith(_)))); +} + +TEST_F(AdaptFromMessageTest, ListValue) { + auto message = DynamicParseTextProto(R"pb()pb"); + EXPECT_THAT( + AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(VariantWith(_)))); +} + +TEST_F(AdaptFromMessageTest, Struct) { + auto message = DynamicParseTextProto(R"pb()pb"); + EXPECT_THAT( + AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(VariantWith(_)))); +} + +TEST_F(AdaptFromMessageTest, TestAllTypesProto3) { + auto message = DynamicParseTextProto(R"pb()pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(std::monostate()))); +} + +TEST_F(AdaptFromMessageTest, Any_BoolValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.BoolValue")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(false))); +} + +TEST_F(AdaptFromMessageTest, Any_Int32Value) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Int32Value")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(0))); +} + +TEST_F(AdaptFromMessageTest, Any_Int64Value) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Int64Value")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(0))); +} + +TEST_F(AdaptFromMessageTest, Any_UInt32Value) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.UInt32Value")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(0))); +} + +TEST_F(AdaptFromMessageTest, Any_UInt64Value) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.UInt64Value")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(0))); +} + +TEST_F(AdaptFromMessageTest, Any_FloatValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.FloatValue")pb"); + EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); +} + +TEST_F(AdaptFromMessageTest, Any_DoubleValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.DoubleValue")pb"); + EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); +} + +TEST_F(AdaptFromMessageTest, Any_BytesValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.BytesValue")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(BytesValue()))); +} + +TEST_F(AdaptFromMessageTest, Any_StringValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.StringValue")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(StringValue()))); +} + +TEST_F(AdaptFromMessageTest, Any_Duration) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Duration")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(absl::ZeroDuration()))); +} + +TEST_F(AdaptFromMessageTest, Any_Timestamp) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Timestamp")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(absl::UnixEpoch()))); +} + +TEST_F(AdaptFromMessageTest, Any_Value_NullValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Value")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(nullptr))); +} + +TEST_F(AdaptFromMessageTest, Any_Value_BoolValue) { + auto message = DynamicParseTextProto( + + R"pb(type_url: "type.googleapis.com/google.protobuf.Value" + value: "\x20\x01")pb"); // bool_value: true + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(true))); +} + +TEST_F(AdaptFromMessageTest, Any_Value_NumberValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Value" + value: "\x11\x00\x00\x00\x00\x00\x00\x00\x00")pb"); // number_value: + // 1.0 + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(0.0))); +} + +TEST_F(AdaptFromMessageTest, Any_Value_StringValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Value" + value: "\x1a\x03\x66\x6f\x6f")pb"); // string_value: "foo" + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(StringValue("foo")))); +} + +TEST_F(AdaptFromMessageTest, Any_Value_ListValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Value" + value: "\x32\x00")pb"); // list_value: {} + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith( + VariantWith(NotNull())))); +} + +TEST_F(AdaptFromMessageTest, Any_Value_StructValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Value" + value: "\x2a\x00")pb"); // struct_value: {} + EXPECT_THAT( + AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(VariantWith(NotNull())))); +} + +TEST_F(AdaptFromMessageTest, Any_ListValue) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.ListValue")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith( + VariantWith(NotNull())))); +} + +TEST_F(AdaptFromMessageTest, Any_Struct) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/google.protobuf.Struct")pb"); + EXPECT_THAT( + AdaptFromMessage(*message), + IsOkAndHolds(VariantWith(VariantWith(NotNull())))); +} + +TEST_F(AdaptFromMessageTest, Any_TestAllTypesProto3) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + IsOkAndHolds(VariantWith>(NotNull()))); +} + +TEST_F(AdaptFromMessageTest, Any_BadTypeUrlDomain) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.example.com/google.protobuf.BoolValue")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unable to find descriptor for type URL: "))); +} + +TEST_F(AdaptFromMessageTest, Any_UnknownMessage) { + auto message = DynamicParseTextProto( + R"pb(type_url: "type.googleapis.com/message.that.does.not.Exist")pb"); + EXPECT_THAT(AdaptFromMessage(*message), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unable to find descriptor for type name: "))); +} + +} // namespace +} // namespace cel::well_known_types diff --git a/parser/BUILD b/parser/BUILD index f6378ae37..6650d9fe9 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -1,15 +1,23 @@ -load("//bazel:antlr.bzl", "antlr_cc_library") +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -package( - default_visibility = ["//visibility:public"], -) +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") -licenses(["notice"]) # Apache 2.0 +package(default_visibility = ["//visibility:public"]) -antlr_cc_library( - name = "cel", - src = "Cel.g4", -) +licenses(["notice"]) cc_library( name = "parser", @@ -22,16 +30,45 @@ cc_library( copts = [ "-fexceptions", ], + defines = [ + "ANTLR4CPP_STATIC", + ], deps = [ - ":cel_cc_parser", ":macro", + ":macro_expr_factory", + ":macro_registry", + ":options", + ":parser_interface", ":source_factory", - ":visitor", - "@antlr4_runtimes//:cpp", + "//common:ast", + "//common:constant", + "//common:expr_factory", + "//common:operators", + "//common:source", + "//common/ast:expr_proto", + "//common/ast:source_info_proto", + "//internal:lexis", + "//internal:status_macros", + "//internal:strings", + "//internal:utf8", + "//parser/internal:cel_cc_parser", + "@antlr4-cpp-runtime", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -43,79 +80,202 @@ cc_library( hdrs = [ "macro.h", ], - copts = [ - "-fexceptions", - ], deps = [ - ":source_factory", + ":macro_expr_factory", + "//common:expr", "//common:operators", - "@com_google_absl//absl/strings:str_format", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "//internal:lexis", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) cc_library( - name = "visitor", + name = "macro_registry", srcs = [ - "balancer.cc", - "visitor.cc", + "macro_registry.cc", ], hdrs = [ - "balancer.h", - "visitor.h", - ], - copts = [ - "-fexceptions", + "macro_registry.h", ], deps = [ - ":cel_cc_parser", ":macro", - ":source_factory", - "//common:escaping", - "//common:operators", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "macro_registry_test", + srcs = ["macro_registry_test.cc"], + deps = [ + ":macro", + ":macro_registry", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", ], ) cc_library( - name = "source_factory", - srcs = [ - "source_factory.cc", + name = "macro_expr_factory", + srcs = ["macro_expr_factory.cc"], + hdrs = ["macro_expr_factory.h"], + deps = [ + "//common:constant", + "//common:expr", + "//common:expr_factory", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", ], +) + +cc_test( + name = "macro_expr_factory_test", + srcs = ["macro_expr_factory_test.cc"], + deps = [ + ":macro_expr_factory", + "//common:expr", + "//common:expr_factory", + "//internal:testing", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "source_factory", hdrs = [ "source_factory.h", ], - copts = [ - "-fexceptions", +) + +cc_library( + name = "options", + hdrs = ["options.h"], + deps = [ + "//parser/internal:options", + "@com_google_absl//absl/base:core_headers", ], +) + +cc_test( + name = "parser_test", + srcs = ["parser_test.cc"], deps = [ - ":cel_cc_parser", - "//common:operators", - "@antlr4_runtimes//:cpp", - "@com_google_absl//absl/memory", + ":macro", + ":options", + ":parser", + ":parser_interface", + ":source_factory", + "//common:constant", + "//common:expr", + "//common:source", + "//internal:testing", + "//testutil:expr_printer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( - name = "parser_test", - srcs = ["parser_test.cc"], + name = "parser_benchmarks", + srcs = ["parser_benchmarks.cc"], + tags = ["benchmark"], deps = [ + ":macro", + ":options", ":parser", ":source_factory", + "//common:constant", + "//common:expr", + "//common:source", + "//internal:benchmark", + "//internal:testing", "//testutil:expr_printer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", - "@com_google_googletest//:gtest_main", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_library( + name = "standard_macros", + srcs = ["standard_macros.cc"], + hdrs = ["standard_macros.h"], + deps = [ + ":macro", + ":macro_registry", + ":options", + "//internal:status_macros", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "parser_interface", + hdrs = ["parser_interface.h"], + deps = [ + ":macro", + ":options", + "//common:ast", + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "parser_subset_factory", + srcs = ["parser_subset_factory.cc"], + hdrs = ["parser_subset_factory.h"], + deps = [ + ":macro", + ":parser_interface", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "standard_macros_test", + srcs = ["standard_macros_test.cc"], + deps = [ + ":macro_registry", + ":options", + ":parser", + ":standard_macros", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", ], ) diff --git a/parser/balancer.cc b/parser/balancer.cc deleted file mode 100644 index 88cfbd6d9..000000000 --- a/parser/balancer.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "parser/balancer.h" - -#include "parser/source_factory.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { - -ExpressionBalancer::ExpressionBalancer(std::shared_ptr sf, - std::string function, Expr expr) - : sf_(sf), function_(function), terms_{expr}, ops_{} {} - -void ExpressionBalancer::addTerm(int64_t op, Expr term) { - terms_.push_back(term); - ops_.push_back(op); -} - -Expr ExpressionBalancer::balance() { - if (terms_.size() == 1) { - return terms_[0]; - } - return balancedTree(0, ops_.size() - 1); -} - -Expr ExpressionBalancer::balancedTree(int lo, int hi) { - int mid = (lo + hi + 1) / 2; - - Expr left; - if (mid == lo) { - left = terms_[mid]; - } else { - left = balancedTree(lo, mid - 1); - } - - Expr right; - if (mid == hi) { - right = terms_[mid + 1]; - } else { - right = balancedTree(mid + 1, hi); - } - return sf_->newGlobalCall(ops_[mid], function_, {left, right}); -} - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google diff --git a/parser/balancer.h b/parser/balancer.h deleted file mode 100644 index 623eb9323..000000000 --- a/parser/balancer.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_PARSER_BALANCER_H_ -#define THIRD_PARTY_CEL_CPP_PARSER_BALANCER_H_ - -#include -#include -#include - -#include "google/api/expr/v1alpha1/syntax.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { - -class SourceFactory; - -using google::api::expr::v1alpha1::Expr; - -// balancer performs tree balancing on operators whose arguments are of equal -// precedence. -// -// The purpose of the balancer is to ensure a compact serialization format for -// the logical &&, || operators which have a tendency to create long DAGs which -// are skewed in one direction. Since the operators are commutative re-ordering -// the terms *must not* affect the evaluation result. -// -// Based on code from //third_party/cel/go/parser/helper.go -class ExpressionBalancer { - public: - ExpressionBalancer(std::shared_ptr sf, std::string function, - Expr expr); - - // addTerm adds an operation identifier and term to the set of terms to be - // balanced. - void addTerm(int64_t op, Expr term); - - // balance creates a balanced tree from the sub-terms and returns the final - // Expr value. - Expr balance(); - - private: - // balancedTree recursively balances the terms provided to a commutative - // operator. - Expr balancedTree(int lo, int hi); - - private: - std::shared_ptr sf_; - std::string function_; - std::vector terms_; - std::vector ops_; -}; - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_PARSER_BALANCER_H_ diff --git a/parser/internal/BUILD b/parser/internal/BUILD new file mode 100644 index 000000000..af815588e --- /dev/null +++ b/parser/internal/BUILD @@ -0,0 +1,31 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("//bazel:antlr.bzl", "antlr_cc_library") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "options", + hdrs = ["options.h"], +) + +antlr_cc_library( + name = "cel", + src = "Cel.g4", + package = "cel_parser_internal", +) diff --git a/parser/Cel.g4 b/parser/internal/Cel.g4 similarity index 57% rename from parser/Cel.g4 rename to parser/internal/Cel.g4 index 62d8cf8d0..9b2c73954 100644 --- a/parser/Cel.g4 +++ b/parser/internal/Cel.g4 @@ -1,6 +1,16 @@ -// Common Expression Language grammar for C++ -// Based on Java grammar with the following changes: -// - rename grammar from CEL to Cel to generate C++ style compatible names. +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. grammar Cel; @@ -35,47 +45,67 @@ calc ; unary - : member # MemberExpr - | (ops+='!')+ member # LogicalNot - | (ops+='-')+ member # Negate + : member # MemberExpr + | (ops+='!')+ member # LogicalNot + | (ops+='-')+ member # Negate ; member - : primary # PrimaryExpr - | member op='.' id=IDENTIFIER (open='(' args=exprList? ')')? # SelectOrCall - | member op='[' index=expr ']' # Index - | member op='{' entries=fieldInitializerList? '}' # CreateMessage + : primary # PrimaryExpr + | member op='.' (opt='?')? id=escapeIdent # Select + | member op='.' id=IDENTIFIER open='(' args=exprList? ')' # MemberCall + | member op='[' (opt='?')? index=expr ']' # Index ; primary - : leadingDot='.'? id=IDENTIFIER (op='(' args=exprList? ')')? # IdentOrGlobalCall - | '(' e=expr ')' # Nested - | op='[' elems=exprList? ','? ']' # CreateList - | op='{' entries=mapInitializerList? '}' # CreateStruct - | literal # ConstantLiteral + : leadingDot='.'? id=IDENTIFIER # Ident + | leadingDot='.'? id=IDENTIFIER (op='(' args=exprList? ')') # GlobalCall + | '(' e=expr ')' # Nested + | op='[' elems=listInit? ','? ']' # CreateList + | op='{' entries=mapInitializerList? ','? '}' # CreateMap + | leadingDot='.'? ids+=IDENTIFIER (ops+='.' ids+=IDENTIFIER)* + op='{' entries=fieldInitializerList? ','? '}' # CreateMessage + | literal # ConstantLiteral ; exprList : e+=expr (',' e+=expr)* ; +listInit + : elems+=optExpr (',' elems+=optExpr)* + ; + fieldInitializerList - : fields+=IDENTIFIER cols+=':' values+=expr (',' fields+=IDENTIFIER cols+=':' values+=expr)* + : fields+=optField cols+=':' values+=expr (',' fields+=optField cols+=':' values+=expr)* + ; + +optField + : (opt='?')? escapeIdent ; mapInitializerList - : keys+=expr cols+=':' values+=expr (',' keys+=expr cols+=':' values+=expr)* + : keys+=optExpr cols+=':' values+=expr (',' keys+=optExpr cols+=':' values+=expr)* + ; + +escapeIdent + : id=IDENTIFIER # SimpleIdentifier + | id=ESC_IDENTIFIER # EscapedIdentifier + ; + +optExpr + : (opt='?')? e=expr ; literal : sign=MINUS? tok=NUM_INT # Int - | tok=NUM_UINT # Uint + | tok=NUM_UINT # Uint | sign=MINUS? tok=NUM_FLOAT # Double - | tok=STRING # String - | tok=BYTES # Bytes - | tok=TRUE # BoolTrue - | tok=FALSE # BoolFalse - | tok=NUL # Null + | tok=STRING # String + | tok=BYTES # Bytes + | tok=CEL_TRUE # BoolTrue + | tok=CEL_FALSE # BoolFalse + | tok=NUL # Null ; // Lexer Rules @@ -83,6 +113,7 @@ literal EQUALS : '=='; NOT_EQUALS : '!='; +IN: 'in'; LESS : '<'; LESS_EQUALS : '<='; GREATER_EQUALS : '>='; @@ -106,8 +137,8 @@ PLUS : '+'; STAR : '*'; SLASH : '/'; PERCENT : '%'; -TRUE : 'true'; -FALSE : 'false'; +CEL_TRUE : 'true'; +CEL_FALSE : 'false'; NUL : 'null'; fragment BACKSLASH : '\\'; @@ -173,3 +204,4 @@ STRING BYTES : ('b' | 'B') STRING; IDENTIFIER : (LETTER | '_') ( LETTER | DIGIT | '_')*; +ESC_IDENTIFIER : '`' (LETTER | DIGIT | '_' | '.' | '-' | '/' | ' ')+ '`'; diff --git a/parser/internal/options.h b/parser/internal/options.h new file mode 100644 index 000000000..ec2552204 --- /dev/null +++ b/parser/internal/options.h @@ -0,0 +1,28 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_PARSER_INTERNAL_OPTIONS_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_INTERNAL_OPTIONS_H_ + +namespace cel_parser_internal { + +inline constexpr int kDefaultErrorRecoveryLimit = 12; +inline constexpr int kDefaultMaxRecursionDepth = 32; +inline constexpr int kExpressionSizeCodepointLimit = 100'000; +inline constexpr int kDefaultErrorRecoveryTokenLookaheadLimit = 512; +inline constexpr bool kDefaultAddMacroCalls = false; + +} // namespace cel_parser_internal + +#endif // THIRD_PARTY_CEL_CPP_PARSER_INTERNAL_OPTIONS_H_ diff --git a/parser/macro.cc b/parser/macro.cc index 7eb8559c1..815b07401 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -1,114 +1,501 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "parser/macro.h" -#include "absl/strings/str_format.h" +#include +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" #include "common/operators.h" -#include "parser/source_factory.h" +#include "internal/lexis.h" +#include "parser/macro_expr_factory.h" + +namespace cel { + +namespace { -namespace google { -namespace api { -namespace expr { -namespace parser { +using google::api::expr::common::CelOperator; -using common::CelOperator; +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + +inline MacroExpander ToMacroExpander(GlobalMacroExpander expander) { + ABSL_DCHECK(expander); + return [expander = std::move(expander)]( + MacroExprFactory& factory, + absl::optional> target, + absl::Span arguments) -> absl::optional { + ABSL_DCHECK(!target.has_value()); + return (expander)(factory, arguments); + }; +} -std::string Macro::macroKey() const { - if (var_arg_style_) { - return absl::StrFormat("%s:*:%s", function_, - receiver_style_ ? "true" : "false"); - } else { - return absl::StrFormat("%s:%d:%s", function_, arg_count_, - receiver_style_ ? "true" : "false"); +inline MacroExpander ToMacroExpander(ReceiverMacroExpander expander) { + ABSL_DCHECK(expander); + return [expander = std::move(expander)]( + MacroExprFactory& factory, + absl::optional> target, + absl::Span arguments) -> absl::optional { + ABSL_DCHECK(target.has_value()); + return (expander)(factory, *target, arguments); + }; +} + +absl::optional ExpandHasMacro(MacroExprFactory& factory, + absl::Span args) { + if (args.size() != 1) { + return factory.ReportError("has() requires 1 arguments"); + } + if (!args[0].has_select_expr() || args[0].select_expr().test_only()) { + return factory.ReportErrorAt(args[0], + "has() argument must be a field selection"); } + return factory.NewPresenceTest( + args[0].mutable_select_expr().release_operand(), + args[0].mutable_select_expr().release_field()); +} + +Macro MakeHasMacro() { + auto macro_or_status = Macro::Global(CelOperator::HAS, 1, ExpandHasMacro); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 2) { + return factory.ReportError("all() requires 2 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "all() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(true); + auto condition = + factory.NewCall(CelOperator::NOT_STRICTLY_FALSE, factory.NewAccuIdent()); + auto step = factory.NewCall(CelOperator::LOGICAL_AND, factory.NewAccuIdent(), + std::move(args[1])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension(args[0].ident_expr().name(), + std::move(target), factory.AccuVarName(), + std::move(init), std::move(condition), + std::move(step), std::move(result)); +} + +Macro MakeAllMacro() { + auto status_or_macro = Macro::Receiver(CelOperator::ALL, 2, ExpandAllMacro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 2) { + return factory.ReportError("exists() requires 2 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "exists() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("exists() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(false); + auto condition = factory.NewCall( + CelOperator::NOT_STRICTLY_FALSE, + factory.NewCall(CelOperator::LOGICAL_NOT, factory.NewAccuIdent())); + auto step = factory.NewCall(CelOperator::LOGICAL_OR, factory.NewAccuIdent(), + std::move(args[1])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension(args[0].ident_expr().name(), + std::move(target), factory.AccuVarName(), + std::move(init), std::move(condition), + std::move(step), std::move(result)); +} + +Macro MakeExistsMacro() { + auto status_or_macro = + Macro::Receiver(CelOperator::EXISTS, 2, ExpandExistsMacro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, + Expr& target, absl::Span args) { + if (args.size() != 2) { + return factory.ReportError("exists_one() requires 2 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "exists_one() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("exists_one() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewIntConst(0); + auto condition = factory.NewBoolConst(true); + auto accu_ident = factory.NewAccuIdent(); + auto const_1 = factory.NewIntConst(1); + auto inc_step = factory.NewCall(CelOperator::ADD, std::move(accu_ident), + std::move(const_1)); + + auto step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), + std::move(inc_step), factory.NewAccuIdent()); + accu_ident = factory.NewAccuIdent(); + auto result = factory.NewCall(CelOperator::EQUALS, std::move(accu_ident), + factory.NewIntConst(1)); + return factory.NewComprehension(args[0].ident_expr().name(), + std::move(target), factory.AccuVarName(), + std::move(init), std::move(condition), + std::move(step), std::move(result)); +} + +Macro MakeExistsOneMacro() { + auto status_or_macro = + Macro::Receiver(CelOperator::EXISTS_ONE, 2, ExpandExistsOneMacro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 2) { + return factory.ReportError("map() requires 2 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "map() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewList(); + auto condition = factory.NewBoolConst(true); + auto accu_ref = factory.NewAccuIdent(); + auto accu_update = + factory.NewList(factory.NewListElement(std::move(args[1]))); + auto step = factory.NewCall(CelOperator::ADD, std::move(accu_ref), + std::move(accu_update)); + return factory.NewComprehension(args[0].ident_expr().name(), + std::move(target), factory.AccuVarName(), + std::move(init), std::move(condition), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeMap2Macro() { + auto status_or_macro = Macro::Receiver(CelOperator::MAP, 2, ExpandMap2Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("map() requires 3 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "map() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto init = factory.NewList(); + auto condition = factory.NewBoolConst(true); + auto accu_ref = factory.NewAccuIdent(); + auto accu_update = + factory.NewList(factory.NewListElement(std::move(args[2]))); + auto step = factory.NewCall(CelOperator::ADD, std::move(accu_ref), + std::move(accu_update)); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(args[0].ident_expr().name(), + std::move(target), factory.AccuVarName(), + std::move(init), std::move(condition), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeMap3Macro() { + auto status_or_macro = Macro::Receiver(CelOperator::MAP, 3, ExpandMap3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 2) { + return factory.ReportError("filter() requires 2 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "filter() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("filter() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto name = args[0].ident_expr().name(); + + auto init = factory.NewList(); + auto condition = factory.NewBoolConst(true); + auto accu_ref = factory.NewAccuIdent(); + auto accu_update = + factory.NewList(factory.NewListElement(std::move(args[0]))); + auto step = factory.NewCall(CelOperator::ADD, std::move(accu_ref), + std::move(accu_update)); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(name), std::move(target), + factory.AccuVarName(), std::move(init), + std::move(condition), std::move(step), + factory.NewAccuIdent()); +} + +Macro MakeFilterMacro() { + auto status_or_macro = + Macro::Receiver(CelOperator::FILTER, 2, ExpandFilterMacro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 2) { + return factory.ReportError("optMap() requires 2 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "optMap() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("optMap() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto var_name = args[0].ident_expr().name(); + + auto target_copy = factory.Copy(target); + std::vector call_args; + call_args.reserve(3); + call_args.push_back(factory.NewMemberCall("hasValue", std::move(target))); + auto iter_range = factory.NewList(); + auto accu_init = factory.NewMemberCall("value", std::move(target_copy)); + auto condition = factory.NewBoolConst(false); + auto fold = factory.NewComprehension( + "#unused", std::move(iter_range), std::move(var_name), + std::move(accu_init), std::move(condition), std::move(args[0]), + std::move(args[1])); + call_args.push_back(factory.NewCall("optional.of", std::move(fold))); + call_args.push_back(factory.NewCall("optional.none")); + return factory.NewCall(CelOperator::CONDITIONAL, std::move(call_args)); +} + +Macro MakeOptMapMacro() { + auto status_or_macro = Macro::Receiver("optMap", 2, ExpandOptMapMacro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 2) { + return factory.ReportError("optFlatMap() requires 2 arguments"); + } + if (!IsSimpleIdentifier(args[0])) { + return factory.ReportErrorAt( + args[0], "optFlatMap() variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("optFlatMap() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); + } + auto var_name = args[0].ident_expr().name(); + + auto target_copy = factory.Copy(target); + std::vector call_args; + call_args.reserve(3); + call_args.push_back(factory.NewMemberCall("hasValue", std::move(target))); + auto iter_range = factory.NewList(); + auto accu_init = factory.NewMemberCall("value", std::move(target_copy)); + auto condition = factory.NewBoolConst(false); + call_args.push_back(factory.NewComprehension( + "#unused", std::move(iter_range), std::move(var_name), + std::move(accu_init), std::move(condition), std::move(args[0]), + std::move(args[1]))); + call_args.push_back(factory.NewCall("optional.none")); + return factory.NewCall(CelOperator::CONDITIONAL, std::move(call_args)); +} + +Macro MakeOptFlatMapMacro() { + auto status_or_macro = + Macro::Receiver("optFlatMap", 2, ExpandOptFlatMapMacro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +} // namespace + +absl::StatusOr Macro::Global(absl::string_view name, + size_t argument_count, + GlobalMacroExpander expander) { + if (!expander) { + return absl::InvalidArgumentError( + absl::StrCat("macro expander for `", name, "` cannot be empty")); + } + return Make(name, argument_count, ToMacroExpander(std::move(expander)), + /*receiver_style=*/false, /*var_arg_style=*/false); +} + +absl::StatusOr Macro::GlobalVarArg(absl::string_view name, + GlobalMacroExpander expander) { + if (!expander) { + return absl::InvalidArgumentError( + absl::StrCat("macro expander for `", name, "` cannot be empty")); + } + return Make(name, 0, ToMacroExpander(std::move(expander)), + /*receiver_style=*/false, + /*var_arg_style=*/true); +} + +absl::StatusOr Macro::Receiver(absl::string_view name, + size_t argument_count, + ReceiverMacroExpander expander) { + if (!expander) { + return absl::InvalidArgumentError( + absl::StrCat("macro expander for `", name, "` cannot be empty")); + } + return Make(name, argument_count, ToMacroExpander(std::move(expander)), + /*receiver_style=*/true, /*var_arg_style=*/false); +} + +absl::StatusOr Macro::ReceiverVarArg(absl::string_view name, + ReceiverMacroExpander expander) { + if (!expander) { + return absl::InvalidArgumentError( + absl::StrCat("macro expander for `", name, "` cannot be empty")); + } + return Make(name, 0, ToMacroExpander(std::move(expander)), + /*receiver_style=*/true, + /*var_arg_style=*/true); } std::vector Macro::AllMacros() { - return { - // The macro "has(m.f)" which tests the presence of a field, avoiding the - // need to specify the field as a string. - Macro(CelOperator::HAS, 1, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - if (!args.empty() && args[0].has_select_expr()) { - const auto& sel_expr = args[0].select_expr(); - return sf->newPresenceTestForMacro(macro_id, sel_expr.operand(), - sel_expr.field()); - } else { - // error - return Expr(); - } - }), - - // The macro "range.all(var, predicate)", which is true if for all - // elements - // in range the predicate holds. - Macro( - CelOperator::ALL, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return sf->newQuantifierExprForMacro(SourceFactory::QUANTIFIER_ALL, - macro_id, target, args); - }, - /* receiver style*/ true), - - // The macro "range.exists(var, predicate)", which is true if for at least - // one element in range the predicate holds. - Macro( - CelOperator::EXISTS, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return sf->newQuantifierExprForMacro( - SourceFactory::QUANTIFIER_EXISTS, macro_id, target, args); - }, - /* receiver style*/ true), - - // The macro "range.exists_one(var, predicate)", which is true if for - // exactly one element in range the predicate holds. - Macro( - CelOperator::EXISTS_ONE, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return sf->newQuantifierExprForMacro( - SourceFactory::QUANTIFIER_EXISTS_ONE, macro_id, target, args); - }, - /* receiver style*/ true), - - // The macro "range.map(var, function)", applies the function to the vars - // in - // the range. - Macro( - CelOperator::MAP, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return sf->newMapForMacro(macro_id, target, args); - }, - /* receiver style*/ true), - - // The macro "range.map(var, predicate, function)", applies the function - // to - // the vars in the range for which the predicate holds true. The other - // variables are filtered out. - Macro( - CelOperator::MAP, 3, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return sf->newMapForMacro(macro_id, target, args); - }, - /* receiver style*/ true), - - // The macro "range.filter(var, predicate)", filters out the variables for - // which the - // predicate is false. - Macro( - CelOperator::FILTER, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return sf->newFilterExprForMacro(macro_id, target, args); - }, - /* receiver style*/ true), - }; + return {HasMacro(), AllMacro(), ExistsMacro(), ExistsOneMacro(), + Map2Macro(), Map3Macro(), FilterMacro()}; +} + +std::string Macro::Key(absl::string_view name, size_t argument_count, + bool receiver_style, bool var_arg_style) { + if (var_arg_style) { + return absl::StrCat(name, ":*:", receiver_style ? "true" : "false"); + } + return absl::StrCat(name, ":", argument_count, ":", + receiver_style ? "true" : "false"); +} + +absl::StatusOr Macro::Make(absl::string_view name, size_t argument_count, + MacroExpander expander, bool receiver_style, + bool var_arg_style) { + if (!internal::LexisIsIdentifier(name)) { + return absl::InvalidArgumentError(absl::StrCat( + "macro function name `", name, "` is not a valid identifier")); + } + if (!expander) { + return absl::InvalidArgumentError( + absl::StrCat("macro expander for `", name, "` cannot be empty")); + } + return Macro(std::make_shared( + std::string(name), + Key(name, argument_count, receiver_style, var_arg_style), argument_count, + std::move(expander), receiver_style, var_arg_style)); +} + +const Macro& HasMacro() { + static const absl::NoDestructor macro(MakeHasMacro()); + return *macro; +} + +const Macro& AllMacro() { + static const absl::NoDestructor macro(MakeAllMacro()); + return *macro; +} + +const Macro& ExistsMacro() { + static const absl::NoDestructor macro(MakeExistsMacro()); + return *macro; +} + +const Macro& ExistsOneMacro() { + static const absl::NoDestructor macro(MakeExistsOneMacro()); + return *macro; +} + +const Macro& Map2Macro() { + static const absl::NoDestructor macro(MakeMap2Macro()); + return *macro; +} + +const Macro& Map3Macro() { + static const absl::NoDestructor macro(MakeMap3Macro()); + return *macro; +} + +const Macro& FilterMacro() { + static const absl::NoDestructor macro(MakeFilterMacro()); + return *macro; +} + +const Macro& OptMapMacro() { + static const absl::NoDestructor macro(MakeOptMapMacro()); + return *macro; +} + +const Macro& OptFlatMapMacro() { + static const absl::NoDestructor macro(MakeOptFlatMapMacro()); + return *macro; } -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel diff --git a/parser/macro.h b/parser/macro.h index 7a30e725c..e39990fbe 100644 --- a/parser/macro.h +++ b/parser/macro.h @@ -1,96 +1,229 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_PARSER_MACRO_H_ #define THIRD_PARTY_CEL_CPP_PARSER_MACRO_H_ +#include #include #include #include +#include +#include -#include "google/api/expr/v1alpha1/syntax.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { +#include "absl/base/attributes.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "parser/macro_expr_factory.h" -using google::api::expr::v1alpha1::Expr; +namespace cel { -class SourceFactory; - -// MacroExpander converts the target and args of a function call that matches a +// MacroExpander converts the arguments of a function call that matches a // Macro. // -// Note: when the Macros.IsReceiverStyle() is true, the target argument will be -// empty. -using MacroExpander = - std::function sf, int64_t macro_id, Expr*, - const std::vector&)>; +// If this is a receiver-style macro, the second argument (optional expr) will +// be engaged. In the case of a global call, it will be `absl::nullopt`. +// +// Should return the replacement subexpression if replacement should occur, +// otherwise absl::nullopt. If `absl::nullopt` is returned, none of the +// arguments including the target must have been modified. Doing so is undefined +// behavior. Otherwise the expander is free to mutate the arguments and either +// include or exclude them from the result. +// +// We use `std::reference_wrapper` to be consistent with the fact that we +// do not use raw pointers elsewhere with `Expr` and friends. Ideally we would +// just use `absl::optional`, but that is not currently allowed and our +// `optional_ref` is internal. +using MacroExpander = absl::AnyInvocable( + MacroExprFactory&, absl::optional>, + absl::Span) const>; + +// `GlobalMacroExpander` is a `MacroExpander` for global macros. +using GlobalMacroExpander = absl::AnyInvocable( + MacroExprFactory&, absl::Span) const>; + +// `ReceiverMacroExpander` is a `MacroExpander` for receiver-style macros. +using ReceiverMacroExpander = absl::AnyInvocable( + MacroExprFactory&, Expr&, absl::Span) const>; // Macro interface for describing the function signature to match and the // MacroExpander to apply. // // Note: when a Macro should apply to multiple overloads (based on arg count) of // a given function, a Macro should be created per arg-count. -class Macro { +class Macro final { public: - // Create a Macro for a global function with the specified number of arguments - Macro(const std::string& function, int arg_count, MacroExpander expander, - bool receiver_style = false) - : function_(function), - receiver_style_(receiver_style), - var_arg_style_(false), - arg_count_(arg_count), - expander_(expander) {} - - Macro(const std::string& function, MacroExpander expander, - bool receiver_style = false) - : function_(function), - receiver_style_(receiver_style), - var_arg_style_(true), - arg_count_(0), - expander_(expander) {} + static absl::StatusOr Global(absl::string_view name, + size_t argument_count, + GlobalMacroExpander expander); + + static absl::StatusOr GlobalVarArg(absl::string_view name, + GlobalMacroExpander expander); + + static absl::StatusOr Receiver(absl::string_view name, + size_t argument_count, + ReceiverMacroExpander expander); + + static absl::StatusOr ReceiverVarArg(absl::string_view name, + ReceiverMacroExpander expander); + + Macro(const Macro&) = default; + Macro(Macro&&) = default; + Macro& operator=(const Macro&) = default; + Macro& operator=(Macro&&) = default; // Function name to match. - std::string function() const { return function_; } + absl::string_view function() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return rep_->function; + } - // ArgCount for the function call. + // argument_count() for the function call. // // When the macro is a var-arg style macro, the return value will be zero, but // the MacroKey will contain a `*` where the arg count would have been. - int argCount() const { return arg_count_; } + size_t argument_count() const { return rep_->arg_count; } + + // is_receiver_style returns true if the macro matches a receiver style call. + bool is_receiver_style() const { return rep_->receiver_style; } - // IsReceiverStyle returns true if the macro matches a receiver style call. - bool isReceiverStyle() const { return receiver_style_; } + bool is_variadic() const { return rep_->var_arg_style; } - // MacroKey returns the macro signatures accepted by this macro. + // key() returns the macro signatures accepted by this macro. // // Format: `::`. // // When the macros is a var-arg style macro, the `arg-count` value is // represented as a `*`. - std::string macroKey() const; + absl::string_view key() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return rep_->key; + } // Expander returns the MacroExpander to apply when the macro key matches the // parsed call signature. - const MacroExpander& expander() const { return expander_; } + const MacroExpander& expander() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return rep_->expander; + } + + ABSL_MUST_USE_RESULT absl::optional Expand( + MacroExprFactory& factory, + absl::optional> target, + absl::Span arguments) const { + return (expander())(factory, target, arguments); + } - Expr expand(std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return expander_(sf, macro_id, target, args); + friend void swap(Macro& lhs, Macro& rhs) noexcept { + using std::swap; + swap(lhs.rep_, rhs.rep_); } + ABSL_DEPRECATED("use MacroRegistry and RegisterStandardMacros") static std::vector AllMacros(); private: - std::string function_; - bool receiver_style_; - bool var_arg_style_; - int arg_count_; - MacroExpander expander_; + struct Rep final { + Rep(std::string function, std::string key, size_t arg_count, + MacroExpander expander, bool receiver_style, bool var_arg_style) + : function(std::move(function)), + key(std::move(key)), + arg_count(arg_count), + expander(std::move(expander)), + receiver_style(receiver_style), + var_arg_style(var_arg_style) {} + + std::string function; + std::string key; + size_t arg_count; + MacroExpander expander; + bool receiver_style; + bool var_arg_style; + }; + + static std::string Key(absl::string_view name, size_t argument_count, + bool receiver_style, bool var_arg_style); + + static absl::StatusOr Make(absl::string_view name, + size_t argument_count, + MacroExpander expander, bool receiver_style, + bool var_arg_style); + + explicit Macro(std::shared_ptr rep) : rep_(std::move(rep)) {} + + std::shared_ptr rep_; }; -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google +// The macro "has(m.f)" which tests the presence of a field, avoiding the +// need to specify the field as a string. +const Macro& HasMacro(); + +// The macro "range.all(var, predicate)", which is true if for all +// elements in range the predicate holds. +const Macro& AllMacro(); + +// The macro "range.exists(var, predicate)", which is true if for at least +// one element in range the predicate holds. +const Macro& ExistsMacro(); + +// The macro "range.exists_one(var, predicate)", which is true if for +// exactly one element in range the predicate holds. +const Macro& ExistsOneMacro(); + +// The macro "range.map(var, function)", applies the function to the vars +// in the range. +const Macro& Map2Macro(); + +// The macro "range.map(var, predicate, function)", applies the function +// to the vars in the range for which the predicate holds true. The other +// variables are filtered out. +const Macro& Map3Macro(); + +// The macro "range.filter(var, predicate)", filters out the variables for +// which the predicate is false. +const Macro& FilterMacro(); + +// `OptMapMacro` +// +// Apply a transformation to the optional's underlying value if it is not empty +// and return an optional typed result based on the transformation. The +// transformation expression type must return a type T which is wrapped into +// an optional. +// +// msg.?elements.optMap(e, e.size()).orValue(0) +const Macro& OptMapMacro(); + +// `OptFlatMapMacro` +// +// Apply a transformation to the optional's underlying value if it is not empty +// and return the result. The transform expression must return an optional(T) +// rather than type T. This can be useful when dealing with zero values and +// conditionally generating an empty or non-empty result in ways which cannot +// be expressed with `optMap`. +// +// msg.?elements.optFlatMap(e, e[?0]) // return the first element if present. +const Macro& OptFlatMapMacro(); + +} // namespace cel + +namespace google::api::expr::parser { + +using MacroExpander = cel::MacroExpander; + +using Macro = cel::Macro; + +} // namespace google::api::expr::parser #endif // THIRD_PARTY_CEL_CPP_PARSER_MACRO_H_ diff --git a/parser/macro_expr_factory.cc b/parser/macro_expr_factory.cc new file mode 100644 index 000000000..7e654126b --- /dev/null +++ b/parser/macro_expr_factory.cc @@ -0,0 +1,128 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/macro_expr_factory.h" + +#include +#include + +#include "absl/functional/overload.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "common/constant.h" +#include "common/expr.h" + +namespace cel { + +Expr MacroExprFactory::Copy(const Expr& expr) { + // Copying logic is recursive at the moment, we alter it to be iterative in + // the future. + return absl::visit( + absl::Overload( + [this, &expr](const UnspecifiedExpr&) -> Expr { + return NewUnspecified(CopyId(expr)); + }, + [this, &expr](const Constant& const_expr) -> Expr { + return NewConst(CopyId(expr), const_expr); + }, + [this, &expr](const IdentExpr& ident_expr) -> Expr { + return NewIdent(CopyId(expr), ident_expr.name()); + }, + [this, &expr](const SelectExpr& select_expr) -> Expr { + const auto id = CopyId(expr); + return select_expr.test_only() + ? NewPresenceTest(id, Copy(select_expr.operand()), + select_expr.field()) + : NewSelect(id, Copy(select_expr.operand()), + select_expr.field()); + }, + [this, &expr](const CallExpr& call_expr) -> Expr { + const auto id = CopyId(expr); + absl::optional target; + if (call_expr.has_target()) { + target = Copy(call_expr.target()); + } + std::vector args; + args.reserve(call_expr.args().size()); + for (const auto& arg : call_expr.args()) { + args.push_back(Copy(arg)); + } + return target.has_value() + ? NewMemberCall(id, call_expr.function(), + std::move(*target), std::move(args)) + : NewCall(id, call_expr.function(), std::move(args)); + }, + [this, &expr](const ListExpr& list_expr) -> Expr { + const auto id = CopyId(expr); + std::vector elements; + elements.reserve(list_expr.elements().size()); + for (const auto& element : list_expr.elements()) { + elements.push_back(Copy(element)); + } + return NewList(id, std::move(elements)); + }, + [this, &expr](const StructExpr& struct_expr) -> Expr { + const auto id = CopyId(expr); + std::vector fields; + fields.reserve(struct_expr.fields().size()); + for (const auto& field : struct_expr.fields()) { + fields.push_back(Copy(field)); + } + return NewStruct(id, struct_expr.name(), std::move(fields)); + }, + [this, &expr](const MapExpr& map_expr) -> Expr { + const auto id = CopyId(expr); + std::vector entries; + entries.reserve(map_expr.entries().size()); + for (const auto& entry : map_expr.entries()) { + entries.push_back(Copy(entry)); + } + return NewMap(id, std::move(entries)); + }, + [this, &expr](const ComprehensionExpr& comprehension_expr) -> Expr { + const auto id = CopyId(expr); + auto iter_range = Copy(comprehension_expr.iter_range()); + auto accu_init = Copy(comprehension_expr.accu_init()); + auto loop_condition = Copy(comprehension_expr.loop_condition()); + auto loop_step = Copy(comprehension_expr.loop_step()); + auto result = Copy(comprehension_expr.result()); + return NewComprehension( + id, comprehension_expr.iter_var(), std::move(iter_range), + comprehension_expr.accu_var(), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); + }), + expr.kind()); +} + +ListExprElement MacroExprFactory::Copy(const ListExprElement& element) { + return NewListElement(Copy(element.expr()), element.optional()); +} + +StructExprField MacroExprFactory::Copy(const StructExprField& field) { + auto field_id = CopyId(field.id()); + auto field_value = Copy(field.value()); + return NewStructField(field_id, field.name(), std::move(field_value), + field.optional()); +} + +MapExprEntry MacroExprFactory::Copy(const MapExprEntry& entry) { + auto entry_id = CopyId(entry.id()); + auto entry_key = Copy(entry.key()); + auto entry_value = Copy(entry.value()); + return NewMapEntry(entry_id, std::move(entry_key), std::move(entry_value), + entry.optional()); +} + +} // namespace cel diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h new file mode 100644 index 000000000..c66aa4fe0 --- /dev/null +++ b/parser/macro_expr_factory.h @@ -0,0 +1,327 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_PARSER_MACRO_EXPR_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_MACRO_EXPR_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "common/expr.h" +#include "common/expr_factory.h" + +namespace cel { + +class ParserMacroExprFactory; +class TestMacroExprFactory; + +// `MacroExprFactory` is a specialization of `ExprFactory` for `MacroExpander` +// which disallows explicitly specifying IDs. +class MacroExprFactory : protected ExprFactory { + protected: + using ExprFactory::IsArrayLike; + using ExprFactory::IsExprLike; + using ExprFactory::IsStringLike; + + template + struct IsRValue + : std::bool_constant< + std::disjunction_v, std::is_same>> {}; + + public: + ABSL_MUST_USE_RESULT Expr Copy(const Expr& expr); + + ABSL_MUST_USE_RESULT ListExprElement Copy(const ListExprElement& element); + + ABSL_MUST_USE_RESULT StructExprField Copy(const StructExprField& field); + + ABSL_MUST_USE_RESULT MapExprEntry Copy(const MapExprEntry& entry); + + ABSL_MUST_USE_RESULT Expr NewUnspecified() { + return NewUnspecified(NextId()); + } + + ABSL_MUST_USE_RESULT Expr NewNullConst() { return NewNullConst(NextId()); } + + ABSL_MUST_USE_RESULT Expr NewBoolConst(bool value) { + return NewBoolConst(NextId(), value); + } + + ABSL_MUST_USE_RESULT Expr NewIntConst(int64_t value) { + return NewIntConst(NextId(), value); + } + + ABSL_MUST_USE_RESULT Expr NewUintConst(uint64_t value) { + return NewUintConst(NextId(), value); + } + + ABSL_MUST_USE_RESULT Expr NewDoubleConst(double value) { + return NewDoubleConst(NextId(), value); + } + + ABSL_MUST_USE_RESULT Expr NewBytesConst(std::string value) { + return NewBytesConst(NextId(), std::move(value)); + } + + ABSL_MUST_USE_RESULT Expr NewBytesConst(absl::string_view value) { + return NewBytesConst(NextId(), value); + } + + ABSL_MUST_USE_RESULT Expr NewBytesConst(const char* absl_nullable value) { + return NewBytesConst(NextId(), value); + } + + ABSL_MUST_USE_RESULT Expr NewStringConst(std::string value) { + return NewStringConst(NextId(), std::move(value)); + } + + ABSL_MUST_USE_RESULT Expr NewStringConst(absl::string_view value) { + return NewStringConst(NextId(), value); + } + + ABSL_MUST_USE_RESULT Expr NewStringConst(const char* absl_nullable value) { + return NewStringConst(NextId(), value); + } + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewIdent(Name name) { + return NewIdent(NextId(), std::move(name)); + } + + absl::string_view AccuVarName() { return ExprFactory::AccuVarName(); } + + ABSL_MUST_USE_RESULT Expr NewAccuIdent() { return NewAccuIdent(NextId()); } + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewSelect(Operand operand, Field field) { + return NewSelect(NextId(), std::move(operand), std::move(field)); + } + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewPresenceTest(Operand operand, Field field) { + return NewPresenceTest(NextId(), std::move(operand), std::move(field)); + } + + template < + typename Function, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewCall(NextId(), std::move(function), std::move(array)); + } + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args args) { + return NewCall(NextId(), std::move(function), std::move(args)); + } + + template < + typename Function, typename Target, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(array)); + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args args) { + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(args)); + } + + using ExprFactory::NewListElement; + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewList(Elements&&... elements) { + std::vector array; + array.reserve(sizeof...(Elements)); + (array.push_back(std::forward(elements)), ...); + return NewList(NextId(), std::move(array)); + } + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewList(Elements elements) { + return NewList(NextId(), std::move(elements)); + } + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT StructExprField NewStructField(Name name, Value value, + bool optional = false) { + return NewStructField(NextId(), std::move(name), std::move(value), + optional); + } + + template ::value>, + typename = std::enable_if_t< + std::conjunction_v...>>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields&&... fields) { + std::vector array; + array.reserve(sizeof...(Fields)); + (array.push_back(std::forward(fields)), ...); + return NewStruct(NextId(), std::move(name), std::move(array)); + } + + template < + typename Name, typename Fields, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields fields) { + return NewStruct(NextId(), std::move(name), std::move(fields)); + } + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT MapExprEntry NewMapEntry(Key key, Value value, + bool optional = false) { + return NewMapEntry(NextId(), std::move(key), std::move(value), optional); + } + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries&&... entries) { + std::vector array; + array.reserve(sizeof...(Entries)); + (array.push_back(std::forward(entries)), ...); + return NewMap(NextId(), std::move(array)); + } + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries entries) { + return NewMap(NextId(), std::move(entries)); + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr + NewComprehension(IterVar iter_var, IterRange iter_range, AccuVar accu_var, + AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); + } + + ABSL_MUST_USE_RESULT virtual Expr ReportError(absl::string_view message) = 0; + + ABSL_MUST_USE_RESULT virtual Expr ReportErrorAt( + const Expr& expr, absl::string_view message) = 0; + + protected: + using ExprFactory::AccuVarName; + using ExprFactory::NewAccuIdent; + using ExprFactory::NewBoolConst; + using ExprFactory::NewBytesConst; + using ExprFactory::NewCall; + using ExprFactory::NewComprehension; + using ExprFactory::NewConst; + using ExprFactory::NewDoubleConst; + using ExprFactory::NewIdent; + using ExprFactory::NewIntConst; + using ExprFactory::NewList; + using ExprFactory::NewMap; + using ExprFactory::NewMapEntry; + using ExprFactory::NewMemberCall; + using ExprFactory::NewNullConst; + using ExprFactory::NewPresenceTest; + using ExprFactory::NewSelect; + using ExprFactory::NewStringConst; + using ExprFactory::NewStruct; + using ExprFactory::NewStructField; + using ExprFactory::NewUintConst; + using ExprFactory::NewUnspecified; + + ABSL_MUST_USE_RESULT virtual ExprId NextId() = 0; + + ABSL_MUST_USE_RESULT virtual ExprId CopyId(ExprId id) = 0; + + ABSL_MUST_USE_RESULT ExprId CopyId(const Expr& expr) { + return CopyId(expr.id()); + } + + private: + friend class ParserMacroExprFactory; + friend class TestMacroExprFactory; + + explicit MacroExprFactory() = default; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_PARSER_MACRO_EXPR_FACTORY_H_ diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc new file mode 100644 index 000000000..b95cbe16f --- /dev/null +++ b/parser/macro_expr_factory_test.cc @@ -0,0 +1,202 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/macro_expr_factory.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "internal/testing.h" + +namespace cel { + +class TestMacroExprFactory final : public MacroExprFactory { + public: + TestMacroExprFactory() = default; + + ExprId id() const { return id_; } + + Expr ReportError(absl::string_view) override { + return NewUnspecified(NextId()); + } + + Expr ReportErrorAt(const Expr&, absl::string_view) override { + return NewUnspecified(NextId()); + } + + using MacroExprFactory::NewBind; + using MacroExprFactory::NewBoolConst; + using MacroExprFactory::NewCall; + using MacroExprFactory::NewComprehension; + using MacroExprFactory::NewIdent; + using MacroExprFactory::NewList; + using MacroExprFactory::NewListElement; + using MacroExprFactory::NewMap; + using MacroExprFactory::NewMapEntry; + using MacroExprFactory::NewMemberCall; + using MacroExprFactory::NewSelect; + using MacroExprFactory::NewStruct; + using MacroExprFactory::NewStructField; + using MacroExprFactory::NewUnspecified; + + protected: + ExprId NextId() override { return id_++; } + + ExprId CopyId(ExprId id) override { + if (id == 0) { + return 0; + } + return NextId(); + } + + private: + int64_t id_ = 1; +}; + +namespace { + +using ::testing::IsEmpty; + +TEST(MacroExprFactory, CopyUnspecified) { + TestMacroExprFactory factory; + EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); +} + +TEST(MacroExprFactory, CopyIdent) { + TestMacroExprFactory factory; + EXPECT_EQ(factory.Copy(factory.NewIdent("foo")), factory.NewIdent(2, "foo")); +} + +TEST(MacroExprFactory, CopyConst) { + TestMacroExprFactory factory; + EXPECT_EQ(factory.Copy(factory.NewBoolConst(true)), + factory.NewBoolConst(2, true)); +} + +TEST(MacroExprFactory, CopySelect) { + TestMacroExprFactory factory; + EXPECT_EQ(factory.Copy(factory.NewSelect(factory.NewIdent("foo"), "bar")), + factory.NewSelect(3, factory.NewIdent(4, "foo"), "bar")); +} + +TEST(MacroExprFactory, CopyCall) { + TestMacroExprFactory factory; + std::vector copied_args; + copied_args.reserve(1); + copied_args.push_back(factory.NewIdent(6, "baz")); + EXPECT_EQ(factory.Copy(factory.NewMemberCall("bar", factory.NewIdent("foo"), + factory.NewIdent("baz"))), + factory.NewMemberCall(4, "bar", factory.NewIdent(5, "foo"), + absl::MakeSpan(copied_args))); +} + +TEST(MacroExprFactory, CopyList) { + TestMacroExprFactory factory; + std::vector copied_elements; + copied_elements.reserve(1); + copied_elements.push_back(factory.NewListElement(factory.NewIdent(4, "foo"))); + EXPECT_EQ(factory.Copy(factory.NewList( + factory.NewListElement(factory.NewIdent("foo")))), + factory.NewList(3, absl::MakeSpan(copied_elements))); +} + +TEST(MacroExprFactory, CopyStruct) { + TestMacroExprFactory factory; + std::vector copied_fields; + copied_fields.reserve(1); + copied_fields.push_back( + factory.NewStructField(5, "bar", factory.NewIdent(6, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewStruct( + "foo", factory.NewStructField("bar", factory.NewIdent("baz")))), + factory.NewStruct(4, "foo", absl::MakeSpan(copied_fields))); +} + +TEST(MacroExprFactory, CopyMap) { + TestMacroExprFactory factory; + std::vector copied_entries; + copied_entries.reserve(1); + copied_entries.push_back(factory.NewMapEntry(6, factory.NewIdent(7, "bar"), + factory.NewIdent(8, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewMap(factory.NewMapEntry( + factory.NewIdent("bar"), factory.NewIdent("baz")))), + factory.NewMap(5, absl::MakeSpan(copied_entries))); +} + +TEST(MacroExprFactory, CopyComprehension) { + TestMacroExprFactory factory; + EXPECT_EQ( + factory.Copy(factory.NewComprehension( + "foo", factory.NewList(), "bar", factory.NewBoolConst(true), + factory.NewIdent("baz"), factory.NewIdent("foo"), + factory.NewIdent("bar"))), + factory.NewComprehension( + 7, "foo", factory.NewList(8, std::vector()), "bar", + factory.NewBoolConst(9, true), factory.NewIdent(10, "baz"), + factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); +} + +TEST(MacroExprFactory, NewBind) { + TestMacroExprFactory factory; + Expr bind_expr = factory.NewIdent(10, "x"); + Expr rest_expr = factory.NewIdent(20, "y"); + + auto next_id = [id = 100]() mutable { return id++; }; + + Expr expr = + factory.NewBind(next_id, "a", std::move(bind_expr), std::move(rest_expr)); + + EXPECT_EQ(expr.id(), 100); + ASSERT_TRUE(expr.has_comprehension_expr()); + + const auto& comp = expr.comprehension_expr(); + EXPECT_EQ(comp.iter_var(), "#unused"); + + ASSERT_TRUE(comp.has_iter_range()); + EXPECT_EQ(comp.iter_range().id(), 101); + EXPECT_EQ(comp.iter_range().kind_case(), ExprKindCase::kListExpr); + EXPECT_THAT(comp.iter_range().list_expr().elements(), IsEmpty()); + + EXPECT_EQ(comp.accu_var(), "a"); + + ASSERT_TRUE(comp.has_accu_init()); + Expr expected_bind_expr; + expected_bind_expr.set_id(10); + expected_bind_expr.mutable_ident_expr().set_name("x"); + EXPECT_EQ(comp.accu_init(), expected_bind_expr); + + ASSERT_TRUE(comp.has_loop_condition()); + EXPECT_EQ(comp.loop_condition().id(), 102); + EXPECT_EQ(comp.loop_condition().kind_case(), ExprKindCase::kConstant); + EXPECT_TRUE(comp.loop_condition().const_expr().has_bool_value()); + EXPECT_FALSE(comp.loop_condition().const_expr().bool_value()); + + ASSERT_TRUE(comp.has_loop_step()); + EXPECT_EQ(comp.loop_step().id(), 103); + EXPECT_EQ(comp.loop_step().kind_case(), ExprKindCase::kIdentExpr); + EXPECT_EQ(comp.loop_step().ident_expr().name(), "a"); + + ASSERT_TRUE(comp.has_result()); + Expr expected_rest_expr; + expected_rest_expr.set_id(20); + expected_rest_expr.mutable_ident_expr().set_name("y"); + EXPECT_EQ(comp.result(), expected_rest_expr); +} + +} // namespace +} // namespace cel diff --git a/parser/macro_registry.cc b/parser/macro_registry.cc new file mode 100644 index 000000000..d36761e87 --- /dev/null +++ b/parser/macro_registry.cc @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/macro_registry.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "parser/macro.h" + +namespace cel { + +absl::Status MacroRegistry::RegisterMacro(const Macro& macro) { + if (!RegisterMacroImpl(macro)) { + return absl::AlreadyExistsError( + absl::StrCat("macro already exists: ", macro.key())); + } + return absl::OkStatus(); +} + +absl::Status MacroRegistry::RegisterMacros(absl::Span macros) { + for (size_t i = 0; i < macros.size(); ++i) { + const auto& macro = macros[i]; + if (!RegisterMacroImpl(macro)) { + for (size_t j = 0; j < i; ++j) { + macros_.erase(macros[j].key()); + } + return absl::AlreadyExistsError( + absl::StrCat("macro already exists: ", macro.key())); + } + } + return absl::OkStatus(); +} + +absl::optional MacroRegistry::FindMacro(absl::string_view name, + size_t arg_count, + bool receiver_style) const { + // :: + if (name.empty() || absl::StrContains(name, ':')) { + return std::nullopt; + } + // Try argument count specific key first. + auto key = absl::StrCat(name, ":", arg_count, ":", + receiver_style ? "true" : "false"); + if (auto it = macros_.find(key); it != macros_.end()) { + return it->second; + } + // Next try variadic. + key = absl::StrCat(name, ":*:", receiver_style ? "true" : "false"); + if (auto it = macros_.find(key); it != macros_.end()) { + return it->second; + } + return std::nullopt; +} + +std::vector MacroRegistry::ListMacros() const { + std::vector macros; + macros.reserve(macros_.size()); + for (auto it = macros_.begin(); it != macros_.end(); ++it) { + macros.push_back(it->second); + } + return macros; +} + +bool MacroRegistry::RegisterMacroImpl(const Macro& macro) { + return macros_.insert(std::pair{macro.key(), macro}).second; +} + +} // namespace cel diff --git a/parser/macro_registry.h b/parser/macro_registry.h new file mode 100644 index 000000000..01a0634ef --- /dev/null +++ b/parser/macro_registry.h @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "parser/macro.h" + +namespace cel { + +class MacroRegistry final { + public: + MacroRegistry() = default; + + // Move-only. + MacroRegistry(MacroRegistry&&) = default; + MacroRegistry& operator=(MacroRegistry&&) = default; + + // Registers `macro`. + absl::Status RegisterMacro(const Macro& macro); + + // Registers all `macros`. If an error is encountered registering one, the + // rest are not registered and the error is returned. + absl::Status RegisterMacros(absl::Span macros); + + absl::optional FindMacro(absl::string_view name, size_t arg_count, + bool receiver_style) const; + + // Returns a copy of all registered macros. + std::vector ListMacros() const; + + private: + bool RegisterMacroImpl(const Macro& macro); + + absl::flat_hash_map macros_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ diff --git a/parser/macro_registry_test.cc b/parser/macro_registry_test.cc new file mode 100644 index 000000000..db8a99ab2 --- /dev/null +++ b/parser/macro_registry_test.cc @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/macro_registry.h" + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "internal/testing.h" +#include "parser/macro.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::Ne; + +TEST(MacroRegistry, RegisterAndFind) { + MacroRegistry macros; + EXPECT_THAT(macros.RegisterMacro(HasMacro()), IsOk()); + EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(std::nullopt)); +} + +TEST(MacroRegistry, RegisterRollsback) { + MacroRegistry macros; + EXPECT_THAT(macros.RegisterMacros({HasMacro(), AllMacro(), AllMacro()}), + StatusIs(absl::StatusCode::kAlreadyExists)); + EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(std::nullopt)); +} + +} // namespace +} // namespace cel diff --git a/parser/options.h b/parser/options.h new file mode 100644 index 000000000..719bed454 --- /dev/null +++ b/parser/options.h @@ -0,0 +1,96 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_PARSER_OPTIONS_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_OPTIONS_H_ + +#include "absl/base/attributes.h" +#include "parser/internal/options.h" + +namespace cel { + +// Options for configuring the limits and features of the parser. +struct ParserOptions final { + // Limit of the number of error recovery attempts made by the ANTLR parser + // when processing an input. This limit, when reached, will halt further + // parsing of the expression. + int error_recovery_limit = ::cel_parser_internal::kDefaultErrorRecoveryLimit; + + // Limit on the amount of recursive parse instructions permitted when building + // the abstract syntax tree for the expression. This prevents pathological + // inputs from causing stack overflows. + int max_recursion_depth = ::cel_parser_internal::kDefaultMaxRecursionDepth; + + // Limit on the number of codepoints in the input string which the parser will + // attempt to parse. + int expression_size_codepoint_limit = + ::cel_parser_internal::kExpressionSizeCodepointLimit; + + // Limit on the number of lookahead tokens to consume when attempting to + // recover from an error. + int error_recovery_token_lookahead_limit = + ::cel_parser_internal::kDefaultErrorRecoveryTokenLookaheadLimit; + + // Add macro calls to macro_calls list in source_info. + bool add_macro_calls = ::cel_parser_internal::kDefaultAddMacroCalls; + + // Enable support for optional syntax. + bool enable_optional_syntax = false; + + // Disable standard macros (has, all, exists, exists_one, filter, map). + bool disable_standard_macros = false; + + // Deprecated: The builtin and extension macros now always use the new + // accumulator variable name. + // This option has no effect. + bool enable_hidden_accumulator_var = true; + + // Enables support for identifier quoting syntax: + // "message.`skewer-case-field`" + // + // Limited to field specifiers in select and message creation, + // enabled by default + bool enable_quoted_identifiers = true; + + // Enables parsing logical AND & OR operators as a single flat variadic call + // instead of a balanced/nested binary AST structure. + bool enable_variadic_logical_operators = false; +}; + +} // namespace cel + +namespace google::api::expr::parser { + +using ParserOptions = ::cel::ParserOptions; + +ABSL_DEPRECATED("Use ParserOptions().error_recovery_limit instead.") +inline constexpr int kDefaultErrorRecoveryLimit = + ::cel_parser_internal::kDefaultErrorRecoveryLimit; +ABSL_DEPRECATED("Use ParserOptions().max_recursion_depth instead.") +inline constexpr int kDefaultMaxRecursionDepth = + ::cel_parser_internal::kDefaultMaxRecursionDepth; +ABSL_DEPRECATED("Use ParserOptions().expression_size_codepoint_limit instead.") +inline constexpr int kExpressionSizeCodepointLimit = + ::cel_parser_internal::kExpressionSizeCodepointLimit; +ABSL_DEPRECATED( + "Use ParserOptions().error_recovery_token_lookahead_limit instead.") +inline constexpr int kDefaultErrorRecoveryTokenLookaheadLimit = + ::cel_parser_internal::kDefaultErrorRecoveryTokenLookaheadLimit; +ABSL_DEPRECATED("Use ParserOptions().add_macro_calls instead.") +inline constexpr bool kDefaultAddMacroCalls = + ::cel_parser_internal::kDefaultAddMacroCalls; + +} // namespace google::api::expr::parser + +#endif // THIRD_PARTY_CEL_CPP_PARSER_OPTIONS_H_ diff --git a/parser/parser.cc b/parser/parser.cc index 35b804f42..24b4ca079 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -1,43 +1,1565 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "parser/parser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "parser/cel_grammar.inc/cel_grammar/CelLexer.h" -#include "parser/cel_grammar.inc/cel_grammar/CelParser.h" -#include "parser/source_factory.h" -#include "parser/visitor.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" #include "antlr4-runtime.h" +#include "common/ast.h" +#include "common/ast/expr_proto.h" +#include "common/ast/source_info_proto.h" +#include "common/constant.h" +#include "common/expr_factory.h" +#include "common/operators.h" +#include "common/source.h" +#include "internal/lexis.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "internal/utf8.h" +#pragma push_macro("IN") +#undef IN +#include "parser/internal/CelBaseVisitor.h" +#include "parser/internal/CelLexer.h" +#include "parser/internal/CelParser.h" +#pragma pop_macro("IN") +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser_interface.h" +#include "parser/source_factory.h" + +namespace google::api::expr::parser { +namespace { +class ParserVisitor; +} +} // namespace google::api::expr::parser + +namespace cel { + +namespace { + +constexpr const char kHiddenAccumulatorVariableName[] = "@result"; + +std::any ExprPtrToAny(std::unique_ptr&& expr) { + return std::make_any(expr.release()); +} + +std::any ExprToAny(Expr&& expr) { + return ExprPtrToAny(std::make_unique(std::move(expr))); +} + +std::unique_ptr ExprPtrFromAny(std::any&& any) { + return absl::WrapUnique(std::any_cast(std::move(any))); +} + +Expr ExprFromAny(std::any&& any) { + auto expr = ExprPtrFromAny(std::move(any)); + return std::move(*expr); +} + +struct ParserError { + std::string message; + SourceRange range; +}; + +std::string DisplayParserError(const cel::Source& source, + SourceLocation location, + absl::string_view message) { + return absl::StrCat(absl::StrFormat("ERROR: %s:%zu:%zu: %s", + source.description(), location.line, + // add one to the 0-based column + location.column + 1, message), + source.DisplayErrorLocation(location)); +} + +int32_t PositiveOrMax(int32_t value) { + return value >= 0 ? value : std::numeric_limits::max(); +} + +SourceRange SourceRangeFromToken(const antlr4::Token* token) { + SourceRange range; + if (token != nullptr) { + if (auto start = token->getStartIndex(); start != INVALID_INDEX) { + range.begin = static_cast(start); + } + if (auto end = token->getStopIndex(); end != INVALID_INDEX) { + range.end = static_cast(end + 1); + } + } + return range; +} + +SourceRange SourceRangeFromParserRuleContext( + const antlr4::ParserRuleContext* context) { + SourceRange range; + if (context != nullptr) { + if (auto start = context->getStart() != nullptr + ? context->getStart()->getStartIndex() + : INVALID_INDEX; + start != INVALID_INDEX) { + range.begin = static_cast(start); + } + if (auto end = context->getStop() != nullptr + ? context->getStop()->getStopIndex() + : INVALID_INDEX; + end != INVALID_INDEX) { + range.end = static_cast(end + 1); + } + } + return range; +} + +} // namespace + +class ParserMacroExprFactory final : public MacroExprFactory { + public: + explicit ParserMacroExprFactory(const cel::Source& source) + : source_(source) {} + + void BeginMacro(SourceRange macro_position) { + macro_position_ = macro_position; + } + + void EndMacro() { macro_position_ = SourceRange{}; } + + Expr ReportError(absl::string_view message) override { + return ReportError(macro_position_, message); + } + + Expr ReportError(int64_t expr_id, absl::string_view message) { + return ReportError(GetSourceRange(expr_id), message); + } + + Expr ReportError(SourceRange range, absl::string_view message) { + ++error_count_; + if (errors_.size() <= 100) { + errors_.push_back(ParserError{std::string(message), range}); + } + return NewUnspecified(NextId(range)); + } + + Expr ReportErrorAt(const Expr& expr, absl::string_view message) override { + return ReportError(GetSourceRange(expr.id()), message); + } + + SourceRange GetSourceRange(int64_t id) const { + if (auto it = positions_.find(id); it != positions_.end()) { + return it->second; + } + return SourceRange{}; + } + + int64_t NextId(const SourceRange& range) { + auto id = expr_id_++; + if (range.begin != -1 || range.end != -1) { + positions_.insert(std::pair{id, range}); + } + return id; + } + + bool HasErrors() const { return error_count_ != 0; } + + std::vector CollectIssues() { + // Errors are collected as they are encountered, not by their location + // within the source. To have a more stable error message as implementation + // details change, we sort the collected errors by their source location + // first. + std::stable_sort( + errors_.begin(), errors_.end(), + [](const ParserError& lhs, const ParserError& rhs) -> bool { + auto lhs_begin = PositiveOrMax(lhs.range.begin); + auto lhs_end = PositiveOrMax(lhs.range.end); + auto rhs_begin = PositiveOrMax(rhs.range.begin); + auto rhs_end = PositiveOrMax(rhs.range.end); + return lhs_begin < rhs_begin || + (lhs_begin == rhs_begin && lhs_end < rhs_end); + }); + // Build the summary error message using the sorted errors. + bool errors_truncated = error_count_ > 100; + std::vector issues; + issues.reserve( + errors_.size() + + errors_truncated); // Reserve space for the transform and an + // additional element when truncation occurs. + std::transform( + errors_.begin(), errors_.end(), std::back_inserter(issues), + [this](const ParserError& error) { + auto location = + source_.GetLocation(error.range.begin).value_or(SourceLocation{}); + return cel::ParseIssue(location, error.message); + }); + if (errors_truncated) { + issues.push_back(cel::ParseIssue( + absl::StrCat(error_count_ - 100, " more errors were truncated."))); + } + return issues; + } + + void AddMacroCall(int64_t macro_id, absl::string_view function, + absl::optional target, std::vector arguments) { + macro_calls_.insert( + {macro_id, target.has_value() + ? NewMemberCall(0, function, std::move(*target), + std::move(arguments)) + : NewCall(0, function, std::move(arguments))}); + } + + Expr BuildMacroCallArg(const Expr& expr) { + if (auto it = macro_calls_.find(expr.id()); it != macro_calls_.end()) { + return NewUnspecified(expr.id()); + } + return absl::visit( + absl::Overload( + [this, &expr](const UnspecifiedExpr&) -> Expr { + return NewUnspecified(expr.id()); + }, + [this, &expr](const Constant& const_expr) -> Expr { + return NewConst(expr.id(), const_expr); + }, + [this, &expr](const IdentExpr& ident_expr) -> Expr { + return NewIdent(expr.id(), ident_expr.name()); + }, + [this, &expr](const SelectExpr& select_expr) -> Expr { + return select_expr.test_only() + ? NewPresenceTest( + expr.id(), + BuildMacroCallArg(select_expr.operand()), + select_expr.field()) + : NewSelect(expr.id(), + BuildMacroCallArg(select_expr.operand()), + select_expr.field()); + }, + [this, &expr](const CallExpr& call_expr) -> Expr { + std::vector macro_arguments; + macro_arguments.reserve(call_expr.args().size()); + for (const auto& argument : call_expr.args()) { + macro_arguments.push_back(BuildMacroCallArg(argument)); + } + absl::optional macro_target; + if (call_expr.has_target()) { + macro_target = BuildMacroCallArg(call_expr.target()); + } + return macro_target.has_value() + ? NewMemberCall(expr.id(), call_expr.function(), + std::move(*macro_target), + std::move(macro_arguments)) + : NewCall(expr.id(), call_expr.function(), + std::move(macro_arguments)); + }, + [this, &expr](const ListExpr& list_expr) -> Expr { + std::vector macro_elements; + macro_elements.reserve(list_expr.elements().size()); + for (const auto& element : list_expr.elements()) { + auto& cloned_element = macro_elements.emplace_back(); + if (element.has_expr()) { + cloned_element.set_expr(BuildMacroCallArg(element.expr())); + } + cloned_element.set_optional(element.optional()); + } + return NewList(expr.id(), std::move(macro_elements)); + }, + [this, &expr](const StructExpr& struct_expr) -> Expr { + std::vector macro_fields; + macro_fields.reserve(struct_expr.fields().size()); + for (const auto& field : struct_expr.fields()) { + auto& macro_field = macro_fields.emplace_back(); + macro_field.set_id(field.id()); + macro_field.set_name(field.name()); + macro_field.set_value(BuildMacroCallArg(field.value())); + macro_field.set_optional(field.optional()); + } + return NewStruct(expr.id(), struct_expr.name(), + std::move(macro_fields)); + }, + [this, &expr](const MapExpr& map_expr) -> Expr { + std::vector macro_entries; + macro_entries.reserve(map_expr.entries().size()); + for (const auto& entry : map_expr.entries()) { + auto& macro_entry = macro_entries.emplace_back(); + macro_entry.set_id(entry.id()); + macro_entry.set_key(BuildMacroCallArg(entry.key())); + macro_entry.set_value(BuildMacroCallArg(entry.value())); + macro_entry.set_optional(entry.optional()); + } + return NewMap(expr.id(), std::move(macro_entries)); + }, + [this, &expr](const ComprehensionExpr& comprehension_expr) -> Expr { + return NewComprehension( + expr.id(), comprehension_expr.iter_var(), + BuildMacroCallArg(comprehension_expr.iter_range()), + comprehension_expr.accu_var(), + BuildMacroCallArg(comprehension_expr.accu_init()), + BuildMacroCallArg(comprehension_expr.loop_condition()), + BuildMacroCallArg(comprehension_expr.loop_step()), + BuildMacroCallArg(comprehension_expr.result())); + }), + expr.kind()); + } + + using ExprFactory::NewBoolConst; + using ExprFactory::NewBytesConst; + using ExprFactory::NewCall; + using ExprFactory::NewComprehension; + using ExprFactory::NewConst; + using ExprFactory::NewDoubleConst; + using ExprFactory::NewIdent; + using ExprFactory::NewIntConst; + using ExprFactory::NewList; + using ExprFactory::NewListElement; + using ExprFactory::NewMap; + using ExprFactory::NewMapEntry; + using ExprFactory::NewMemberCall; + using ExprFactory::NewNullConst; + using ExprFactory::NewPresenceTest; + using ExprFactory::NewSelect; + using ExprFactory::NewStringConst; + using ExprFactory::NewStruct; + using ExprFactory::NewStructField; + using ExprFactory::NewUintConst; + using ExprFactory::NewUnspecified; + + const absl::btree_map& positions() const { + return positions_; + } + + const absl::flat_hash_map& macro_calls() const { + return macro_calls_; + } + + absl::flat_hash_map release_macro_calls() { + using std::swap; + absl::flat_hash_map result; + swap(result, macro_calls_); + return result; + } + + void EraseId(ExprId id) { + positions_.erase(id); + if (expr_id_ == id + 1) { + --expr_id_; + } + } + + protected: + int64_t NextId() override { return NextId(macro_position_); } -namespace google { -namespace api { -namespace expr { -namespace parser { + int64_t CopyId(int64_t id) override { + if (id == 0) { + return 0; + } + return NextId(GetSourceRange(id)); + } -using antlr4::ANTLRInputStream; -using antlr4::CommonTokenStream; -using antlr4::ParseCancellationException; -using antlr4::ParserRuleContext; + private: + int64_t expr_id_ = 1; + absl::btree_map positions_; + absl::flat_hash_map macro_calls_; + std::vector errors_; + size_t error_count_ = 0; + const Source& source_; + SourceRange macro_position_; +}; -using antlr4::tree::ErrorNode; -using antlr4::tree::TerminalNode; +} // namespace cel -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::ParsedExpr; +namespace google::api::expr::parser { namespace { +using ::antlr4::CharStream; +using ::antlr4::CommonTokenStream; +using ::antlr4::DefaultErrorStrategy; +using ::antlr4::ParseCancellationException; +using ::antlr4::Parser; +using ::antlr4::ParserRuleContext; +using ::antlr4::Token; +using ::antlr4::misc::IntervalSet; +using ::antlr4::tree::ErrorNode; +using ::antlr4::tree::ParseTreeListener; +using ::antlr4::tree::TerminalNode; +using ::cel::Expr; +using ::cel::ExprFromAny; +using ::cel::ExprKind; +using ::cel::ExprToAny; +using ::cel::IdentExpr; +using ::cel::ListExprElement; +using ::cel::MapExprEntry; +using ::cel::SelectExpr; +using ::cel::SourceRangeFromParserRuleContext; +using ::cel::SourceRangeFromToken; +using ::cel::StructExprField; +using ::cel_parser_internal::CelBaseVisitor; +using ::cel_parser_internal::CelLexer; +using ::cel_parser_internal::CelParser; +using common::CelOperator; +using common::ReverseLookupOperator; +using ::cel::expr::ParsedExpr; + +class CodePointStream final : public CharStream { + public: + CodePointStream(cel::SourceContentView buffer, absl::string_view source_name) + : buffer_(buffer), + source_name_(source_name), + size_(buffer_.size()), + index_(0) {} + + void consume() override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + ABSL_ASSERT(LA(1) == IntStream::EOF); + throw antlr4::IllegalStateException("cannot consume EOF"); + } + index_++; + } + + size_t LA(ptrdiff_t i) override { + if (ABSL_PREDICT_FALSE(i == 0)) { + return 0; + } + auto p = static_cast(index_); + if (i < 0) { + i++; + if (p + i - 1 < 0) { + return IntStream::EOF; + } + } + if (p + i - 1 >= static_cast(size_)) { + return IntStream::EOF; + } + return buffer_.at(static_cast(p + i - 1)); + } + + ptrdiff_t mark() override { return -1; } + + void release(ptrdiff_t marker) override {} + + size_t index() override { return index_; } + + void seek(size_t index) override { index_ = std::min(index, size_); } + + size_t size() override { return size_; } + + std::string getSourceName() const override { + return source_name_.empty() ? IntStream::UNKNOWN_SOURCE_NAME + : std::string(source_name_); + } + + std::string getText(const antlr4::misc::Interval& interval) override { + if (ABSL_PREDICT_FALSE(interval.a < 0 || interval.b < 0)) { + return std::string(); + } + size_t start = static_cast(interval.a); + if (ABSL_PREDICT_FALSE(start >= size_)) { + return std::string(); + } + size_t stop = static_cast(interval.b); + if (ABSL_PREDICT_FALSE(stop >= size_)) { + stop = size_ - 1; + } + return buffer_.ToString(static_cast(start), + static_cast(stop) + 1); + } + + std::string toString() const override { return buffer_.ToString(); } + + private: + cel::SourceContentView const buffer_; + const absl::string_view source_name_; + const size_t size_; + size_t index_; +}; + +// Scoped helper for incrementing the parse recursion count. +// Increments on creation, decrements on destruction (stack unwind). +class ScopedIncrement final { + public: + explicit ScopedIncrement(int& recursion_depth) + : recursion_depth_(recursion_depth) { + ++recursion_depth_; + } + + ~ScopedIncrement() { --recursion_depth_; } + + private: + int& recursion_depth_; +}; + +// balancer performs tree balancing on operators whose arguments are of equal +// precedence. +// +// The purpose of the balancer is to ensure a compact serialization format for +// the logical &&, || operators which have a tendency to create long DAGs which +// are skewed in one direction. Since the operators are commutative re-ordering +// the terms *must not* affect the evaluation result. +// +// Based on code from //third_party/cel/go/parser/helper.go +class ExpressionBalancer final { + public: + ExpressionBalancer(cel::ParserMacroExprFactory& factory, std::string function, + Expr expr); + + // addTerm adds an operation identifier and term to the set of terms to be + // balanced. + void AddTerm(int64_t op, Expr term); + + // balance creates a balanced tree from the sub-terms and returns the final + // Expr value. + Expr Balance(bool enable_variadic = false); + + private: + // balancedTree recursively balances the terms provided to a commutative + // operator. + Expr BalancedTree(int lo, int hi); + + private: + cel::ParserMacroExprFactory& factory_; + std::string function_; + std::vector terms_; + std::vector ops_; +}; + +ExpressionBalancer::ExpressionBalancer(cel::ParserMacroExprFactory& factory, + std::string function, Expr expr) + : factory_(factory), function_(std::move(function)) { + terms_.push_back(std::move(expr)); +} + +void ExpressionBalancer::AddTerm(int64_t op, Expr term) { + terms_.push_back(std::move(term)); + ops_.push_back(op); +} + +Expr ExpressionBalancer::Balance(bool enable_variadic) { + if (terms_.size() == 1) { + return std::move(terms_[0]); + } + if (enable_variadic) { + return factory_.NewCall(ops_[0], function_, std::move(terms_)); + } + return BalancedTree(0, ops_.size() - 1); +} + +Expr ExpressionBalancer::BalancedTree(int lo, int hi) { + int mid = (lo + hi + 1) / 2; + + std::vector arguments; + arguments.reserve(2); + + if (mid == lo) { + arguments.push_back(std::move(terms_[mid])); + } else { + arguments.push_back(BalancedTree(lo, mid - 1)); + } + + if (mid == hi) { + arguments.push_back(std::move(terms_[mid + 1])); + } else { + arguments.push_back(BalancedTree(mid + 1, hi)); + } + return factory_.NewCall(ops_[mid], function_, std::move(arguments)); +} + +std::string FormatIssues(const cel::Source& source, + absl::Span issues) { + return absl::StrJoin( + issues, "\n", [&source](std::string* out, const cel::ParseIssue& issue) { + absl::StrAppend(out, cel::DisplayParserError(source, issue.location(), + issue.message())); + }); +} + +class ParserVisitor final : public CelBaseVisitor, + public antlr4::BaseErrorListener { + public: + ParserVisitor(const cel::Source& source, int max_recursion_depth, + const cel::MacroRegistry& macro_registry, + bool add_macro_calls = false, + bool enable_optional_syntax = false, + bool enable_quoted_identifiers = false, + bool enable_variadic_logical_operators = false) + : source_(source), + factory_(source_), + macro_registry_(macro_registry), + recursion_depth_(0), + max_recursion_depth_(max_recursion_depth), + add_macro_calls_(add_macro_calls), + enable_optional_syntax_(enable_optional_syntax), + enable_quoted_identifiers_(enable_quoted_identifiers), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} + + ~ParserVisitor() override = default; + + std::any visit(antlr4::tree::ParseTree* tree) override; + + std::any visitStart(CelParser::StartContext* ctx) override; + std::any visitExpr(CelParser::ExprContext* ctx) override; + std::any visitConditionalOr(CelParser::ConditionalOrContext* ctx) override; + std::any visitConditionalAnd(CelParser::ConditionalAndContext* ctx) override; + std::any visitRelation(CelParser::RelationContext* ctx) override; + std::any visitCalc(CelParser::CalcContext* ctx) override; + std::any visitUnary(CelParser::UnaryContext* ctx); + std::any visitLogicalNot(CelParser::LogicalNotContext* ctx) override; + std::any visitNegate(CelParser::NegateContext* ctx) override; + std::any visitSelect(CelParser::SelectContext* ctx) override; + std::any visitMemberCall(CelParser::MemberCallContext* ctx) override; + std::any visitIndex(CelParser::IndexContext* ctx) override; + std::any visitCreateMessage(CelParser::CreateMessageContext* ctx) override; + std::any visitFieldInitializerList( + CelParser::FieldInitializerListContext* ctx) override; + std::vector visitFields( + CelParser::FieldInitializerListContext* ctx); + std::any visitGlobalCall(CelParser::GlobalCallContext* ctx) override; + std::any visitIdent(CelParser::IdentContext* ctx) override; + std::any visitNested(CelParser::NestedContext* ctx) override; + std::any visitCreateList(CelParser::CreateListContext* ctx) override; + std::vector visitList(CelParser::ListInitContext* ctx); + std::vector visitList(CelParser::ExprListContext* ctx); + std::any visitCreateMap(CelParser::CreateMapContext* ctx) override; + std::any visitConstantLiteral( + CelParser::ConstantLiteralContext* ctx) override; + std::any visitPrimaryExpr(CelParser::PrimaryExprContext* ctx) override; + std::any visitMemberExpr(CelParser::MemberExprContext* ctx) override; + + std::any visitMapInitializerList( + CelParser::MapInitializerListContext* ctx) override; + std::vector visitEntries( + CelParser::MapInitializerListContext* ctx); + std::any visitInt(CelParser::IntContext* ctx) override; + std::any visitUint(CelParser::UintContext* ctx) override; + std::any visitDouble(CelParser::DoubleContext* ctx) override; + std::any visitString(CelParser::StringContext* ctx) override; + std::any visitBytes(CelParser::BytesContext* ctx) override; + std::any visitBoolTrue(CelParser::BoolTrueContext* ctx) override; + std::any visitBoolFalse(CelParser::BoolFalseContext* ctx) override; + std::any visitNull(CelParser::NullContext* ctx) override; + // Note: this is destructive and intended to be called after the parse is + // finished. + cel::SourceInfo GetSourceInfo(); + EnrichedSourceInfo enriched_source_info() const; + void syntaxError(antlr4::Recognizer* recognizer, + antlr4::Token* offending_symbol, size_t line, size_t col, + const std::string& msg, std::exception_ptr e) override; + bool HasErrored() const; + + std::vector CollectIssues(); + + private: + template + Expr GlobalCallOrMacro(int64_t expr_id, absl::string_view function, + Args&&... args) { + std::vector arguments; + arguments.reserve(sizeof...(Args)); + (arguments.push_back(std::forward(args)), ...); + return GlobalCallOrMacroImpl(expr_id, function, std::move(arguments)); + } + + Expr GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, + std::vector args); + Expr ReceiverCallOrMacroImpl(int64_t expr_id, absl::string_view function, + Expr target, std::vector args); + std::string ExtractQualifiedName(antlr4::ParserRuleContext* ctx, + const Expr& e); + + std::string NormalizeIdentifier(CelParser::EscapeIdentContext* ctx); + // Attempt to unnest parse context. + // + // Walk the parse tree to the first complex term to reduce recursive depth in + // the visit* calls. + antlr4::tree::ParseTree* UnnestContext(antlr4::tree::ParseTree* tree); + + private: + const cel::Source& source_; + cel::ParserMacroExprFactory factory_; + const cel::MacroRegistry& macro_registry_; + int recursion_depth_; + const int max_recursion_depth_; + const bool add_macro_calls_; + const bool enable_optional_syntax_; + const bool enable_quoted_identifiers_; + const bool enable_variadic_logical_operators_; +}; + +template ::value>> +T* tree_as(antlr4::tree::ParseTree* tree) { + return dynamic_cast(tree); +} + +std::any ParserVisitor::visit(antlr4::tree::ParseTree* tree) { + ScopedIncrement inc(recursion_depth_); + if (recursion_depth_ > max_recursion_depth_) { + return ExprToAny(factory_.ReportError( + absl::StrFormat("Exceeded max recursion depth of %d when parsing.", + max_recursion_depth_))); + } + tree = UnnestContext(tree); + if (auto* ctx = tree_as(tree)) { + return visitStart(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitExpr(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitConditionalAnd(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitConditionalOr(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitRelation(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitCalc(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitLogicalNot(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitPrimaryExpr(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitMemberExpr(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitSelect(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitMemberCall(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitMapInitializerList(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitNegate(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitIndex(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitUnary(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitCreateList(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitCreateMessage(ctx); + } else if (auto* ctx = tree_as(tree)) { + return visitCreateMap(ctx); + } + + if (tree) { + return ExprToAny( + factory_.ReportError(SourceRangeFromParserRuleContext( + tree_as(tree)), + "unknown parsetree type")); + } + return ExprToAny(factory_.ReportError("<> parsetree")); +} + +std::any ParserVisitor::visitPrimaryExpr(CelParser::PrimaryExprContext* pctx) { + CelParser::PrimaryContext* primary = pctx->primary(); + if (auto* ctx = tree_as(primary)) { + return visitNested(ctx); + } else if (auto* ctx = tree_as(primary)) { + return visitIdent(ctx); + } else if (auto* ctx = tree_as(primary)) { + return visitGlobalCall(ctx); + } else if (auto* ctx = tree_as(primary)) { + return visitCreateList(ctx); + } else if (auto* ctx = tree_as(primary)) { + return visitCreateMap(ctx); + } else if (auto* ctx = tree_as(primary)) { + return visitCreateMessage(ctx); + } else if (auto* ctx = tree_as(primary)) { + return visitConstantLiteral(ctx); + } + if (factory_.HasErrors()) { + // ANTLR creates PrimaryContext rather than a derived class during certain + // error conditions. This is odd, but we ignore it as we already have errors + // that occurred. + return ExprToAny(factory_.NewUnspecified(factory_.NextId({}))); + } + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(pctx), + "invalid primary expression")); +} + +std::any ParserVisitor::visitMemberExpr(CelParser::MemberExprContext* mctx) { + CelParser::MemberContext* member = mctx->member(); + if (auto* ctx = tree_as(member)) { + return visitPrimaryExpr(ctx); + } else if (auto* ctx = tree_as(member)) { + return visitSelect(ctx); + } else if (auto* ctx = tree_as(member)) { + return visitMemberCall(ctx); + } else if (auto* ctx = tree_as(member)) { + return visitIndex(ctx); + } + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(mctx), + "unsupported simple expression")); +} + +std::any ParserVisitor::visitStart(CelParser::StartContext* ctx) { + return visit(ctx->expr()); +} + +antlr4::tree::ParseTree* ParserVisitor::UnnestContext( + antlr4::tree::ParseTree* tree) { + antlr4::tree::ParseTree* last = nullptr; + while (tree != last) { + last = tree; + + if (auto* ctx = tree_as(tree)) { + tree = ctx->expr(); + } + + if (auto* ctx = tree_as(tree)) { + if (ctx->op != nullptr) { + return ctx; + } + tree = ctx->e; + } + + if (auto* ctx = tree_as(tree)) { + if (!ctx->ops.empty()) { + return ctx; + } + tree = ctx->e; + } + + if (auto* ctx = tree_as(tree)) { + if (!ctx->ops.empty()) { + return ctx; + } + tree = ctx->e; + } + + if (auto* ctx = tree_as(tree)) { + if (ctx->calc() == nullptr) { + return ctx; + } + tree = ctx->calc(); + } + + if (auto* ctx = tree_as(tree)) { + if (ctx->unary() == nullptr) { + return ctx; + } + tree = ctx->unary(); + } + + if (auto* ctx = tree_as(tree)) { + tree = ctx->member(); + } + + if (auto* ctx = tree_as(tree)) { + if (auto* nested = tree_as(ctx->primary())) { + tree = nested->e; + } else { + return ctx; + } + } + } + + return tree; +} + +std::any ParserVisitor::visitExpr(CelParser::ExprContext* ctx) { + auto result = ExprFromAny(visit(ctx->e)); + if (!ctx->op) { + return ExprToAny(std::move(result)); + } + std::vector arguments; + arguments.reserve(3); + arguments.push_back(std::move(result)); + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + arguments.push_back(ExprFromAny(visit(ctx->e1))); + arguments.push_back(ExprFromAny(visit(ctx->e2))); + + return ExprToAny( + factory_.NewCall(op_id, CelOperator::CONDITIONAL, std::move(arguments))); +} + +std::any ParserVisitor::visitConditionalOr( + CelParser::ConditionalOrContext* ctx) { + auto result = ExprFromAny(visit(ctx->e)); + if (ctx->ops.empty()) { + return ExprToAny(std::move(result)); + } + ExpressionBalancer b(factory_, CelOperator::LOGICAL_OR, std::move(result)); + for (size_t i = 0; i < ctx->ops.size(); ++i) { + auto op = ctx->ops[i]; + if (i >= ctx->e1.size()) { + return ExprToAny( + factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "unexpected character, wanted '||'")); + } + auto next = ExprFromAny(visit(ctx->e1[i])); + int64_t op_id = factory_.NextId(SourceRangeFromToken(op)); + b.AddTerm(op_id, std::move(next)); + } + return ExprToAny(b.Balance(enable_variadic_logical_operators_)); +} + +std::any ParserVisitor::visitConditionalAnd( + CelParser::ConditionalAndContext* ctx) { + auto result = ExprFromAny(visit(ctx->e)); + if (ctx->ops.empty()) { + return ExprToAny(std::move(result)); + } + ExpressionBalancer b(factory_, CelOperator::LOGICAL_AND, std::move(result)); + for (size_t i = 0; i < ctx->ops.size(); ++i) { + auto op = ctx->ops[i]; + if (i >= ctx->e1.size()) { + return ExprToAny( + factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "unexpected character, wanted '&&'")); + } + auto next = ExprFromAny(visit(ctx->e1[i])); + int64_t op_id = factory_.NextId(SourceRangeFromToken(op)); + b.AddTerm(op_id, std::move(next)); + } + return ExprToAny(b.Balance(enable_variadic_logical_operators_)); +} + +std::any ParserVisitor::visitRelation(CelParser::RelationContext* ctx) { + if (ctx->calc()) { + return visit(ctx->calc()); + } + std::string op_text; + if (ctx->op) { + op_text = ctx->op->getText(); + } + auto op = ReverseLookupOperator(op_text); + if (op) { + auto lhs = ExprFromAny(visit(ctx->relation(0))); + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + auto rhs = ExprFromAny(visit(ctx->relation(1))); + return ExprToAny( + GlobalCallOrMacro(op_id, *op, std::move(lhs), std::move(rhs))); + } + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "operator not found")); +} + +std::any ParserVisitor::visitCalc(CelParser::CalcContext* ctx) { + if (ctx->unary()) { + return visit(ctx->unary()); + } + std::string op_text; + if (ctx->op) { + op_text = ctx->op->getText(); + } + auto op = ReverseLookupOperator(op_text); + if (op) { + auto lhs = ExprFromAny(visit(ctx->calc(0))); + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + auto rhs = ExprFromAny(visit(ctx->calc(1))); + return ExprToAny( + GlobalCallOrMacro(op_id, *op, std::move(lhs), std::move(rhs))); + } + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "operator not found")); +} + +std::any ParserVisitor::visitUnary(CelParser::UnaryContext* ctx) { + return ExprToAny(factory_.NewStringConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), "<>")); +} + +std::any ParserVisitor::visitLogicalNot(CelParser::LogicalNotContext* ctx) { + if (ctx->ops.size() % 2 == 0) { + return visit(ctx->member()); + } + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->ops[0])); + auto target = ExprFromAny(visit(ctx->member())); + return ExprToAny( + GlobalCallOrMacro(op_id, CelOperator::LOGICAL_NOT, std::move(target))); +} + +std::any ParserVisitor::visitNegate(CelParser::NegateContext* ctx) { + if (ctx->ops.size() % 2 == 0) { + return visit(ctx->member()); + } + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->ops[0])); + auto target = ExprFromAny(visit(ctx->member())); + return ExprToAny( + GlobalCallOrMacro(op_id, CelOperator::NEGATE, std::move(target))); +} + +std::string ParserVisitor::NormalizeIdentifier( + CelParser::EscapeIdentContext* ctx) { + if (auto* raw_id = tree_as(ctx); raw_id) { + return raw_id->id->getText(); + } + if (auto* escaped_id = tree_as(ctx); + escaped_id) { + if (!enable_quoted_identifiers_) { + factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "unsupported syntax '`'"); + } + auto escaped_id_text = escaped_id->id->getText(); + return escaped_id_text.substr(1, escaped_id_text.size() - 2); + } + + // Fallthrough might occur if the parser is in an error state. + return ""; +} + +std::any ParserVisitor::visitSelect(CelParser::SelectContext* ctx) { + auto operand = ExprFromAny(visit(ctx->member())); + // Handle the error case where no valid identifier is specified. + if (!ctx->id || !ctx->op) { + return ExprToAny(factory_.NewUnspecified( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); + } + auto id = NormalizeIdentifier(ctx->id); + if (ctx->opt != nullptr) { + if (!enable_optional_syntax_) { + return ExprToAny(factory_.ReportError( + SourceRangeFromParserRuleContext(ctx), "unsupported syntax '.?'")); + } + auto op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + std::vector arguments; + arguments.reserve(2); + arguments.push_back(std::move(operand)); + arguments.push_back(factory_.NewStringConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), std::move(id))); + return ExprToAny(factory_.NewCall(op_id, "_?._", std::move(arguments))); + } + return ExprToAny( + factory_.NewSelect(factory_.NextId(SourceRangeFromToken(ctx->op)), + std::move(operand), std::move(id))); +} + +std::any ParserVisitor::visitMemberCall(CelParser::MemberCallContext* ctx) { + auto operand = ExprFromAny(visit(ctx->member())); + // Handle the error case where no valid identifier is specified. + if (!ctx->id) { + return ExprToAny(factory_.NewUnspecified( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); + } + auto id = ctx->id->getText(); + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->open)); + auto args = visitList(ctx->args); + return ExprToAny( + ReceiverCallOrMacroImpl(op_id, id, std::move(operand), std::move(args))); +} + +std::any ParserVisitor::visitIndex(CelParser::IndexContext* ctx) { + auto target = ExprFromAny(visit(ctx->member())); + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + auto index = ExprFromAny(visit(ctx->index)); + if (!enable_optional_syntax_ && ctx->opt != nullptr) { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "unsupported syntax '.?'")); + } + return ExprToAny(GlobalCallOrMacro( + op_id, ctx->opt != nullptr ? "_[?_]" : CelOperator::INDEX, + std::move(target), std::move(index))); +} + +std::any ParserVisitor::visitCreateMessage( + CelParser::CreateMessageContext* ctx) { + std::vector parts; + parts.reserve(ctx->ids.size()); + for (const auto* id : ctx->ids) { + parts.push_back(id->getText()); + } + std::string name; + if (ctx->leadingDot) { + name.push_back('.'); + name.append(absl::StrJoin(parts, ".")); + } else { + name = absl::StrJoin(parts, "."); + } + int64_t obj_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); + std::vector fields; + if (ctx->entries) { + fields = visitFields(ctx->entries); + } + return ExprToAny( + factory_.NewStruct(obj_id, std::move(name), std::move(fields))); +} + +std::any ParserVisitor::visitFieldInitializerList( + CelParser::FieldInitializerListContext* ctx) { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "<>")); +} + +std::vector ParserVisitor::visitFields( + CelParser::FieldInitializerListContext* ctx) { + std::vector res; + if (!ctx || ctx->fields.empty()) { + return res; + } + + res.reserve(ctx->fields.size()); + for (size_t i = 0; i < ctx->fields.size(); ++i) { + if (i >= ctx->cols.size() || i >= ctx->values.size()) { + // This is the result of a syntax error detected elsewhere. + return res; + } + auto* f = ctx->fields[i]; + if (!f->escapeIdent()) { + ABSL_DCHECK(HasErrored()); + // This is the result of a syntax error detected elsewhere. + return res; + } + + std::string id = NormalizeIdentifier(f->escapeIdent()); + + int64_t init_id = factory_.NextId(SourceRangeFromToken(ctx->cols[i])); + if (!enable_optional_syntax_ && f->opt) { + factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "unsupported syntax '?'"); + continue; + } + auto value = ExprFromAny(visit(ctx->values[i])); + res.push_back(factory_.NewStructField(init_id, std::move(id), + std::move(value), f->opt != nullptr)); + } + + return res; +} + +std::any ParserVisitor::visitIdent(CelParser::IdentContext* ctx) { + std::string ident_name; + if (ctx->leadingDot) { + ident_name = "."; + } + if (!ctx->id) { + return ExprToAny(factory_.NewUnspecified( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); + } + // check if ID is in reserved identifiers + if (cel::internal::LexisIsReserved(ctx->id->getText())) { + return ExprToAny(factory_.ReportError( + SourceRangeFromParserRuleContext(ctx), + absl::StrFormat("reserved identifier: %s", ctx->id->getText()))); + } + + ident_name += ctx->id->getText(); + + return ExprToAny(factory_.NewIdent( + factory_.NextId(SourceRangeFromToken(ctx->id)), std::move(ident_name))); +} + +std::any ParserVisitor::visitGlobalCall(CelParser::GlobalCallContext* ctx) { + std::string ident_name; + if (ctx->leadingDot) { + ident_name = "."; + } + if (!ctx->id || !ctx->op) { + return ExprToAny(factory_.NewUnspecified( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); + } + // check if ID is in reserved identifiers + if (cel::internal::LexisIsReserved(ctx->id->getText())) { + return ExprToAny(factory_.ReportError( + SourceRangeFromParserRuleContext(ctx), + absl::StrFormat("reserved identifier: %s", ctx->id->getText()))); + } + + ident_name += ctx->id->getText(); + + int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + auto args = visitList(ctx->args); + return ExprToAny( + GlobalCallOrMacroImpl(op_id, std::move(ident_name), std::move(args))); +} + +std::any ParserVisitor::visitNested(CelParser::NestedContext* ctx) { + return visit(ctx->e); +} + +std::any ParserVisitor::visitCreateList(CelParser::CreateListContext* ctx) { + int64_t list_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); + auto elems = visitList(ctx->elems); + return ExprToAny(factory_.NewList(list_id, std::move(elems))); +} + +std::vector ParserVisitor::visitList( + CelParser::ListInitContext* ctx) { + std::vector rv; + if (!ctx) return rv; + rv.reserve(ctx->elems.size()); + for (size_t i = 0; i < ctx->elems.size(); ++i) { + auto* expr_ctx = ctx->elems[i]; + if (expr_ctx == nullptr) { + return rv; + } + if (!enable_optional_syntax_ && expr_ctx->opt != nullptr) { + factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "unsupported syntax '?'"); + rv.push_back(factory_.NewListElement(factory_.NewUnspecified(0), false)); + continue; + } + rv.push_back(factory_.NewListElement(ExprFromAny(visitExpr(expr_ctx->e)), + expr_ctx->opt != nullptr)); + } + return rv; +} + +std::vector ParserVisitor::visitList(CelParser::ExprListContext* ctx) { + std::vector rv; + if (!ctx) return rv; + std::transform(ctx->e.begin(), ctx->e.end(), std::back_inserter(rv), + [this](CelParser::ExprContext* expr_ctx) { + return ExprFromAny(visitExpr(expr_ctx)); + }); + return rv; +} + +std::any ParserVisitor::visitCreateMap(CelParser::CreateMapContext* ctx) { + int64_t struct_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); + std::vector entries; + if (ctx->entries) { + entries = visitEntries(ctx->entries); + } + return ExprToAny(factory_.NewMap(struct_id, std::move(entries))); +} + +std::any ParserVisitor::visitConstantLiteral( + CelParser::ConstantLiteralContext* clctx) { + CelParser::LiteralContext* literal = clctx->literal(); + if (auto* ctx = tree_as(literal)) { + return visitInt(ctx); + } else if (auto* ctx = tree_as(literal)) { + return visitUint(ctx); + } else if (auto* ctx = tree_as(literal)) { + return visitDouble(ctx); + } else if (auto* ctx = tree_as(literal)) { + return visitString(ctx); + } else if (auto* ctx = tree_as(literal)) { + return visitBytes(ctx); + } else if (auto* ctx = tree_as(literal)) { + return visitBoolFalse(ctx); + } else if (auto* ctx = tree_as(literal)) { + return visitBoolTrue(ctx); + } else if (auto* ctx = tree_as(literal)) { + return visitNull(ctx); + } + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(clctx), + "invalid constant literal expression")); +} + +std::any ParserVisitor::visitMapInitializerList( + CelParser::MapInitializerListContext* ctx) { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "<>")); +} + +std::vector ParserVisitor::visitEntries( + CelParser::MapInitializerListContext* ctx) { + std::vector res; + if (!ctx || ctx->keys.empty()) { + return res; + } + + res.reserve(ctx->cols.size()); + for (size_t i = 0; i < ctx->cols.size(); ++i) { + auto id = factory_.NextId(SourceRangeFromToken(ctx->cols[i])); + if (!enable_optional_syntax_ && ctx->keys[i]->opt) { + factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "unsupported syntax '?'"); + res.push_back(factory_.NewMapEntry(0, factory_.NewUnspecified(0), + factory_.NewUnspecified(0), false)); + continue; + } + auto key = ExprFromAny(visit(ctx->keys[i]->e)); + auto value = ExprFromAny(visit(ctx->values[i])); + res.push_back(factory_.NewMapEntry(id, std::move(key), std::move(value), + ctx->keys[i]->opt != nullptr)); + } + return res; +} + +std::any ParserVisitor::visitInt(CelParser::IntContext* ctx) { + std::string value; + if (ctx->sign) { + value = ctx->sign->getText(); + } + value += ctx->tok->getText(); + int64_t int_value; + if (absl::StartsWith(ctx->tok->getText(), "0x")) { + if (absl::SimpleHexAtoi(value, &int_value)) { + return ExprToAny(factory_.NewIntConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), int_value)); + } else { + return ExprToAny(factory_.ReportError( + SourceRangeFromParserRuleContext(ctx), "invalid hex int literal")); + } + } + if (absl::SimpleAtoi(value, &int_value)) { + return ExprToAny(factory_.NewIntConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), int_value)); + } else { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "invalid int literal")); + } +} + +std::any ParserVisitor::visitUint(CelParser::UintContext* ctx) { + std::string value = ctx->tok->getText(); + // trim the 'u' designator included in the uint literal. + if (!value.empty()) { + value.resize(value.size() - 1); + } + uint64_t uint_value; + if (absl::StartsWith(ctx->tok->getText(), "0x")) { + if (absl::SimpleHexAtoi(value, &uint_value)) { + return ExprToAny(factory_.NewUintConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), uint_value)); + } else { + return ExprToAny(factory_.ReportError( + SourceRangeFromParserRuleContext(ctx), "invalid hex uint literal")); + } + } + if (absl::SimpleAtoi(value, &uint_value)) { + return ExprToAny(factory_.NewUintConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), uint_value)); + } else { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "invalid uint literal")); + } +} + +std::any ParserVisitor::visitDouble(CelParser::DoubleContext* ctx) { + std::string value; + if (ctx->sign) { + value = ctx->sign->getText(); + } + value += ctx->tok->getText(); + double double_value; + if (absl::SimpleAtod(value, &double_value)) { + return ExprToAny(factory_.NewDoubleConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), double_value)); + } else { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + "invalid double literal")); + } +} + +std::any ParserVisitor::visitString(CelParser::StringContext* ctx) { + auto status_or_value = cel::internal::ParseStringLiteral(ctx->tok->getText()); + if (!status_or_value.ok()) { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + status_or_value.status().message())); + } + return ExprToAny(factory_.NewStringConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), + std::move(status_or_value).value())); +} + +std::any ParserVisitor::visitBytes(CelParser::BytesContext* ctx) { + auto status_or_value = cel::internal::ParseBytesLiteral(ctx->tok->getText()); + if (!status_or_value.ok()) { + return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), + status_or_value.status().message())); + } + return ExprToAny(factory_.NewBytesConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), + std::move(status_or_value).value())); +} + +std::any ParserVisitor::visitBoolTrue(CelParser::BoolTrueContext* ctx) { + return ExprToAny(factory_.NewBoolConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), true)); +} + +std::any ParserVisitor::visitBoolFalse(CelParser::BoolFalseContext* ctx) { + return ExprToAny(factory_.NewBoolConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)), false)); +} + +std::any ParserVisitor::visitNull(CelParser::NullContext* ctx) { + return ExprToAny(factory_.NewNullConst( + factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); +} + +cel::SourceInfo ParserVisitor::GetSourceInfo() { + cel::SourceInfo source_info; + source_info.set_location(std::string(source_.description())); + for (const auto& positions : factory_.positions()) { + source_info.mutable_positions().insert( + std::pair{positions.first, positions.second.begin}); + } + source_info.mutable_line_offsets().reserve(source_.line_offsets().size()); + for (const auto& line_offset : source_.line_offsets()) { + source_info.mutable_line_offsets().push_back(line_offset); + } + + source_info.mutable_macro_calls() = factory_.release_macro_calls(); + return source_info; +} + +EnrichedSourceInfo ParserVisitor::enriched_source_info() const { + std::map> offsets; + for (const auto& positions : factory_.positions()) { + offsets.insert( + std::pair{positions.first, + std::pair{positions.second.begin, positions.second.end - 1}}); + } + return EnrichedSourceInfo(std::move(offsets)); +} + +void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, + antlr4::Token* offending_symbol, size_t line, + size_t col, const std::string& msg, + std::exception_ptr e) { + cel::SourceRange range; + if (auto position = source_.GetPosition(cel::SourceLocation{ + static_cast(line), static_cast(col)}); + position) { + range.begin = *position; + } + factory_.ReportError(range, absl::StrCat("Syntax error: ", msg)); +} + +bool ParserVisitor::HasErrored() const { return factory_.HasErrors(); } + +std::vector ParserVisitor::CollectIssues() { + return factory_.CollectIssues(); +} + +Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, + absl::string_view function, + std::vector args) { + if (auto macro = macro_registry_.FindMacro(function, args.size(), false); + macro) { + std::vector macro_args; + if (add_macro_calls_) { + macro_args.reserve(args.size()); + for (const auto& arg : args) { + macro_args.push_back(factory_.BuildMacroCallArg(arg)); + } + } + factory_.BeginMacro(factory_.GetSourceRange(expr_id)); + auto expr = macro->Expand(factory_, std::nullopt, absl::MakeSpan(args)); + factory_.EndMacro(); + if (expr) { + if (add_macro_calls_) { + factory_.AddMacroCall(expr->id(), function, std::nullopt, + std::move(macro_args)); + } + // We did not end up using `expr_id`. Delete metadata. + factory_.EraseId(expr_id); + return std::move(*expr); + } + } + + return factory_.NewCall(expr_id, function, std::move(args)); +} + +Expr ParserVisitor::ReceiverCallOrMacroImpl(int64_t expr_id, + absl::string_view function, + Expr target, + std::vector args) { + if (auto macro = macro_registry_.FindMacro(function, args.size(), true); + macro) { + Expr macro_target; + std::vector macro_args; + if (add_macro_calls_) { + macro_args.reserve(args.size()); + macro_target = factory_.BuildMacroCallArg(target); + for (const auto& arg : args) { + macro_args.push_back(factory_.BuildMacroCallArg(arg)); + } + } + factory_.BeginMacro(factory_.GetSourceRange(expr_id)); + auto expr = macro->Expand(factory_, std::ref(target), absl::MakeSpan(args)); + factory_.EndMacro(); + if (expr) { + if (add_macro_calls_) { + factory_.AddMacroCall(expr->id(), function, std::move(macro_target), + std::move(macro_args)); + } + // We did not end up using `expr_id`. Delete metadata. + factory_.EraseId(expr_id); + return std::move(*expr); + } + } + return factory_.NewMemberCall(expr_id, function, std::move(target), + std::move(args)); +} + +std::string ParserVisitor::ExtractQualifiedName(antlr4::ParserRuleContext* ctx, + const Expr& e) { + if (e == Expr{}) { + return ""; + } + + if (const auto* ident_expr = absl::get_if(&e.kind()); ident_expr) { + return ident_expr->name(); + } + if (const auto* select_expr = absl::get_if(&e.kind()); + select_expr) { + std::string prefix = ExtractQualifiedName(ctx, select_expr->operand()); + if (!prefix.empty()) { + return absl::StrCat(prefix, ".", select_expr->field()); + } + } + factory_.ReportError(factory_.GetSourceRange(e.id()), + "expected a qualified name"); + return ""; +} + +// Replacements for absl::StrReplaceAll for escaping standard whitespace +// characters. +static constexpr auto kStandardReplacements = + std::array, 3>{ + std::make_pair("\n", "\\n"), + std::make_pair("\r", "\\r"), + std::make_pair("\t", "\\t"), + }; + +static constexpr absl::string_view kSingleQuote = "'"; + // ExprRecursionListener extends the standard ANTLR CelParser to ensure that // recursive entries into the 'expr' rule are limited to a configurable depth so // as to prevent stack overflows. -class ExprRecursionListener : public ::antlr4::tree::ParseTreeListener { +class ExprRecursionListener final : public ParseTreeListener { public: - ExprRecursionListener( + explicit ExprRecursionListener( const int max_recursion_depth = kDefaultMaxRecursionDepth) : max_recursion_depth_(max_recursion_depth), recursion_depth_(0) {} + ~ExprRecursionListener() override {} - void visitTerminal(TerminalNode* node) override{}; - void visitErrorNode(ErrorNode* error) override{}; + void visitTerminal(TerminalNode* node) override {}; + void visitErrorNode(ErrorNode* error) override {}; void enterEveryRule(ParserRuleContext* ctx) override; void exitEveryRule(ParserRuleContext* ctx) override; @@ -50,8 +1572,8 @@ void ExprRecursionListener::enterEveryRule(ParserRuleContext* ctx) { // Throw a ParseCancellationException since the parsing would otherwise // continue if this were treated as a syntax error and the problem would // continue to manifest. - if (ctx->getRuleIndex() == ::cel_grammar::CelParser::RuleExpr) { - if (recursion_depth_ >= max_recursion_depth_) { + if (ctx->getRuleIndex() == CelParser::RuleExpr) { + if (recursion_depth_ > max_recursion_depth_) { throw ParseCancellationException( absl::StrFormat("Expression recursion limit exceeded. limit: %d", max_recursion_depth_)); @@ -61,76 +1583,357 @@ void ExprRecursionListener::enterEveryRule(ParserRuleContext* ctx) { } void ExprRecursionListener::exitEveryRule(ParserRuleContext* ctx) { - if (ctx->getRuleIndex() == ::cel_grammar::CelParser::RuleExpr) { + if (ctx->getRuleIndex() == CelParser::RuleExpr) { recursion_depth_--; } } -} // namespace +class RecoveryLimitErrorStrategy final : public DefaultErrorStrategy { + public: + explicit RecoveryLimitErrorStrategy( + int recovery_limit = kDefaultErrorRecoveryLimit, + int recovery_token_lookahead_limit = + kDefaultErrorRecoveryTokenLookaheadLimit) + : recovery_limit_(recovery_limit), + recovery_attempts_(0), + recovery_token_lookahead_limit_(recovery_token_lookahead_limit) {} -absl::StatusOr Parse(const std::string& expression, - const std::string& description, - const int max_recursion_depth) { - return ParseWithMacros(expression, Macro::AllMacros(), description, - max_recursion_depth); -} + void recover(Parser* recognizer, std::exception_ptr e) override { + checkRecoveryLimit(recognizer); + DefaultErrorStrategy::recover(recognizer, e); + } -absl::StatusOr ParseWithMacros(const std::string& expression, - const std::vector& macros, - const std::string& description, - const int max_recursion_depth) { - auto result = - EnrichedParse(expression, macros, description, max_recursion_depth); - if (result.ok()) { - return result->parsed_expr(); + Token* recoverInline(Parser* recognizer) override { + checkRecoveryLimit(recognizer); + return DefaultErrorStrategy::recoverInline(recognizer); } - return result.status(); -} -absl::StatusOr EnrichedParse( - const std::string& expression, const std::vector& macros, - const std::string& description, const int max_recursion_depth) { - ANTLRInputStream input(expression); - ::cel_grammar::CelLexer lexer(&input); - CommonTokenStream tokens(&lexer); - ::cel_grammar::CelParser parser(&tokens); - ExprRecursionListener listener(max_recursion_depth); - ParserVisitor visitor(description, expression, max_recursion_depth, macros); - - lexer.removeErrorListeners(); - parser.removeErrorListeners(); - lexer.addErrorListener(&visitor); - parser.addErrorListener(&visitor); - parser.addParseListener(&listener); - - // if we were to ignore errors completely: - // std::shared_ptr error_strategy(new BailErrorStrategy()); - // parser.setErrorHandler(error_strategy); - - ::cel_grammar::CelParser::StartContext* root; + // Override the ANTLR implementation to introduce a token lookahead limit as + // this prevents pathologically constructed, yet small (< 16kb) inputs from + // consuming inordinate amounts of compute. + // + // This method is only called on error recovery paths. + void consumeUntil(Parser* recognizer, const IntervalSet& set) override { + size_t ttype = recognizer->getInputStream()->LA(1); + int recovery_search_depth = 0; + while (ttype != Token::EOF && !set.contains(ttype) && + recovery_search_depth++ < recovery_token_lookahead_limit_) { + recognizer->consume(); + ttype = recognizer->getInputStream()->LA(1); + } + // Halt all parsing if the lookahead limit is reached during error recovery. + if (recovery_search_depth == recovery_token_lookahead_limit_) { + throw ParseCancellationException("Unable to find a recovery token"); + } + } + + protected: + std::string escapeWSAndQuote(const std::string& s) const override { + std::string result; + result.reserve(s.size() + 2); + absl::StrAppend(&result, kSingleQuote, s, kSingleQuote); + absl::StrReplaceAll(kStandardReplacements, &result); + return result; + } + + private: + void checkRecoveryLimit(Parser* recognizer) { + if (recovery_attempts_++ >= recovery_limit_) { + std::string too_many_errors = + absl::StrFormat("More than %d parse errors.", recovery_limit_); + recognizer->notifyErrorListeners(too_many_errors); + throw ParseCancellationException(too_many_errors); + } + } + + int recovery_limit_; + int recovery_attempts_; + int recovery_token_lookahead_limit_; +}; + +struct ParseResult { + cel::Expr expr; + cel::SourceInfo source_info; + EnrichedSourceInfo enriched_source_info; +}; + +absl::StatusOr ParseImpl( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options, + std::vector* parse_issues = nullptr) { try { - root = parser.start(); - } catch (ParseCancellationException& e) { - return absl::CancelledError(e.what()); - } catch (std::exception& e) { + CodePointStream input(source.content(), source.description()); + if (input.size() > options.expression_size_codepoint_limit) { + return absl::InvalidArgumentError(absl::StrCat( + "expression size exceeds codepoint limit.", " input size: ", + input.size(), ", limit: ", options.expression_size_codepoint_limit)); + } + CelLexer lexer(&input); + CommonTokenStream tokens(&lexer); + CelParser parser(&tokens); + ExprRecursionListener listener(options.max_recursion_depth); + ParserVisitor visitor( + source, options.max_recursion_depth, registry, options.add_macro_calls, + options.enable_optional_syntax, options.enable_quoted_identifiers, + options.enable_variadic_logical_operators); + + lexer.removeErrorListeners(); + parser.removeErrorListeners(); + lexer.addErrorListener(&visitor); + parser.addErrorListener(&visitor); + parser.addParseListener(&listener); + + // Limit the number of error recovery attempts to prevent bad expressions + // from consuming lots of cpu / memory. + parser.setErrorHandler(std::make_shared( + options.error_recovery_limit, + options.error_recovery_token_lookahead_limit)); + + Expr expr; + try { + expr = ExprFromAny(visitor.visit(parser.start())); + } catch (const ParseCancellationException& e) { + if (visitor.HasErrored()) { + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); + } + return absl::CancelledError(e.what()); + } + + if (visitor.HasErrored()) { + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); + } + + return { + ParseResult{.expr = std::move(expr), + .source_info = visitor.GetSourceInfo(), + .enriched_source_info = visitor.enriched_source_info()}}; + } catch (const std::exception& e) { return absl::AbortedError(e.what()); + } catch (const char* what) { + // ANTLRv4 has historically thrown C string literals. + return absl::AbortedError(what); + } catch (...) { + // We guarantee to never throw and always return a status. + return absl::UnknownError("An unknown exception occurred"); } +} + +class ParserImpl : public cel::Parser { + public: + explicit ParserImpl(const ParserOptions& options, + cel::MacroRegistry macro_registry, + absl::flat_hash_set library_ids) + : options_(options), + macro_registry_(std::move(macro_registry)), + library_ids_(std::move(library_ids)) {} + + absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* parse_issues) const override { + CEL_ASSIGN_OR_RETURN(auto parse_result, + ::google::api::expr::parser::ParseImpl( + source, macro_registry_, options_, parse_issues)); + return std::make_unique(std::move(parse_result.expr), + std::move(parse_result.source_info)); + } + + std::unique_ptr ToBuilder() const override; + + private: + const ParserOptions options_; + const cel::MacroRegistry macro_registry_; + absl::flat_hash_set library_ids_; +}; + +class ParserBuilderImpl : public cel::ParserBuilder { + public: + explicit ParserBuilderImpl(const ParserOptions& options) + : options_(options) {} + + ParserOptions& GetOptions() override { return options_; } + + absl::Status AddMacro(const cel::Macro& macro) override { + for (const auto& existing_macro : macros_) { + if (existing_macro.key() == macro.key()) { + return absl::AlreadyExistsError( + absl::StrCat("macro already exists: ", macro.key())); + } + } + macros_.push_back(macro); + return absl::OkStatus(); + } + + absl::Status AddLibrary(cel::ParserLibrary library) override { + if (!library.id.empty()) { + auto [it, inserted] = library_ids_.insert(library.id); + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("parser library already exists: ", library.id)); + } + } + libraries_.push_back(std::move(library)); + return absl::OkStatus(); + } + + absl::Status AddLibrarySubset(cel::ParserLibrarySubset subset) override { + if (subset.library_id.empty()) { + return absl::InvalidArgumentError("subset must have a library id"); + } + std::string library_id = subset.library_id; + auto [it, inserted] = + library_subsets_.insert({library_id, std::move(subset)}); + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("parser library subset already exists: ", library_id)); + } + return absl::OkStatus(); + } + + absl::StatusOr> Build() override { + using std::swap; + // Save the old configured macros so they aren't affected by applying the + // libraries and can be restored if an error occurs. + std::vector individual_macros; + swap(individual_macros, macros_); + absl::Cleanup cleanup([&] { swap(macros_, individual_macros); }); + + cel::MacroRegistry macro_registry; + + for (const auto& library : libraries_) { + CEL_RETURN_IF_ERROR(library.configure(*this)); + if (!library.id.empty()) { + auto it = library_subsets_.find(library.id); + if (it != library_subsets_.end()) { + const cel::ParserLibrarySubset& subset = it->second; + for (const auto& macro : macros_) { + if (subset.should_include_macro(macro)) { + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(macro)); + } + } + macros_.clear(); + continue; + } + } + + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(macros_)); + macros_.clear(); + } + + absl::flat_hash_set library_ids(library_ids_); - Expr expr = visitor.visit(root).as(); + // Hack to support adding the standard library macros either by option or + // with a library configurer. + if (!options_.disable_standard_macros && !library_ids_.contains("stdlib")) { + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(Macro::AllMacros())); + library_ids.insert("stdlib"); + } + + if (options_.enable_optional_syntax && !library_ids_.contains("optional")) { + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptMapMacro())); + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptFlatMapMacro())); + library_ids.insert("optional"); + } + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(individual_macros)); + return std::make_unique(options_, std::move(macro_registry), + std::move(library_ids)); + } + + private: + friend class ParserImpl; + + ParserOptions options_; + std::vector macros_; + absl::flat_hash_set library_ids_; + std::vector libraries_; + absl::flat_hash_map library_subsets_; +}; + +std::unique_ptr ParserImpl::ToBuilder() const { + auto ins = std::make_unique(options_); + ins->library_ids_ = library_ids_; + ins->macros_ = macro_registry_.ListMacros(); + return ins; +} - if (visitor.hasErrored()) { - return absl::InvalidArgumentError(visitor.errorMessage()); +} // namespace + +absl::StatusOr Parse(absl::string_view expression, + absl::string_view description, + const ParserOptions& options) { + std::vector macros; + if (!options.disable_standard_macros) { + macros = Macro::AllMacros(); + } + if (options.enable_optional_syntax) { + macros.push_back(cel::OptMapMacro()); + macros.push_back(cel::OptFlatMapMacro()); } + return ParseWithMacros(expression, macros, description, options); +} - // root is deleted as part of the parser context +absl::StatusOr ParseWithMacros(absl::string_view expression, + const std::vector& macros, + absl::string_view description, + const ParserOptions& options) { + CEL_ASSIGN_OR_RETURN(auto verbose_parsed_expr, + EnrichedParse(expression, macros, description, options)); + return verbose_parsed_expr.parsed_expr(); +} + +absl::StatusOr EnrichedParse( + absl::string_view expression, const std::vector& macros, + absl::string_view description, const ParserOptions& options) { + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(expression, std::string(description))); + cel::MacroRegistry macro_registry; + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(macros)); + return EnrichedParse(*source, macro_registry, options); +} + +absl::StatusOr EnrichedParse( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options) { + CEL_ASSIGN_OR_RETURN(ParseResult parse_result, + ParseImpl(source, registry, options)); ParsedExpr parsed_expr; - parsed_expr.mutable_expr()->CopyFrom(expr); - parsed_expr.mutable_source_info()->CopyFrom(visitor.sourceInfo()); - auto enriched_source_info = visitor.enrichedSourceInfo(); - return VerboseParsedExpr(parsed_expr, enriched_source_info); + CEL_RETURN_IF_ERROR(cel::ast_internal::ExprToProto( + parse_result.expr, parsed_expr.mutable_expr())); + + CEL_RETURN_IF_ERROR(cel::ast_internal::SourceInfoToProto( + parse_result.source_info, parsed_expr.mutable_source_info())); + return VerboseParsedExpr(std::move(parsed_expr), + std::move(parse_result.enriched_source_info)); +} + +absl::StatusOr Parse( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options) { + CEL_ASSIGN_OR_RETURN(auto verbose_expr, + EnrichedParse(source, registry, options)); + return verbose_expr.parsed_expr(); +} + +} // namespace google::api::expr::parser + +namespace cel { + +// Creates a new parser builder. +// +// Intended for use with the Compiler class, most users should prefer the free +// functions above for independent parsing of expressions. +std::unique_ptr NewParserBuilder(const ParserOptions& options) { + return std::make_unique( + options); } -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel diff --git a/parser/parser.h b/parser/parser.h index 57ad7de8a..4b32c1c42 100644 --- a/parser/parser.h +++ b/parser/parser.h @@ -1,27 +1,50 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// CEL does not support calling the parser during C++ static initialization. +// Callers must ensure the parser is only invoked after C++ static initializers +// are run. Failing to do so is undefined behavior. The current reason for this +// is the parser uses ANTLRv4, which also makes no guarantees about being safe +// with regard to C++ static initialization. As such, neither do we. + #ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ #define THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/types/optional.h" +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/source.h" #include "parser/macro.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser_interface.h" #include "parser/source_factory.h" -#include "absl/status/statusor.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { -constexpr int kDefaultMaxRecursionDepth = 250; +namespace google::api::expr::parser { class VerboseParsedExpr { public: - VerboseParsedExpr(const google::api::expr::v1alpha1::ParsedExpr& parsed_expr, - const EnrichedSourceInfo& enriched_source_info) - : parsed_expr_(parsed_expr), - enriched_source_info_(enriched_source_info) {} + VerboseParsedExpr(cel::expr::ParsedExpr parsed_expr, + EnrichedSourceInfo enriched_source_info) + : parsed_expr_(std::move(parsed_expr)), + enriched_source_info_(std::move(enriched_source_info)) {} - const google::api::expr::v1alpha1::ParsedExpr& parsed_expr() const { + const cel::expr::ParsedExpr& parsed_expr() const { return parsed_expr_; } const EnrichedSourceInfo& enriched_source_info() const { @@ -29,27 +52,51 @@ class VerboseParsedExpr { } private: - google::api::expr::v1alpha1::ParsedExpr parsed_expr_; + cel::expr::ParsedExpr parsed_expr_; EnrichedSourceInfo enriched_source_info_; }; +// See comments at the top of the file for information about usage during C++ +// static initialization. absl::StatusOr EnrichedParse( - const std::string& expression, const std::vector& macros, - const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); - -absl::StatusOr Parse( - const std::string& expression, const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); - -absl::StatusOr ParseWithMacros( - const std::string& expression, const std::vector& macros, - const std::string& description = "", - int max_recursion_depth = kDefaultMaxRecursionDepth); - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google + absl::string_view expression, const std::vector& macros, + absl::string_view description = "", + const ParserOptions& options = ParserOptions()); + +// See comments at the top of the file for information about usage during C++ +// static initialization. +absl::StatusOr Parse( + absl::string_view expression, absl::string_view description = "", + const ParserOptions& options = ParserOptions()); + +// See comments at the top of the file for information about usage during C++ +// static initialization. +absl::StatusOr ParseWithMacros( + absl::string_view expression, const std::vector& macros, + absl::string_view description = "", + const ParserOptions& options = ParserOptions()); + +// See comments at the top of the file for information about usage during C++ +// static initialization. +absl::StatusOr EnrichedParse( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options = ParserOptions()); + +// See comments at the top of the file for information about usage during C++ +// static initialization. +absl::StatusOr Parse( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options = ParserOptions()); + +} // namespace google::api::expr::parser + +namespace cel { +// Creates a new parser builder. +// +// Intended for use with the Compiler class, most users should prefer the free +// functions above for independent parsing of expressions. +std::unique_ptr NewParserBuilder( + const ParserOptions& options = {}); +} // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ diff --git a/parser/parser_benchmarks.cc b/parser/parser_benchmarks.cc new file mode 100644 index 000000000..b05f9b1f5 --- /dev/null +++ b/parser/parser_benchmarks.cc @@ -0,0 +1,282 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/log/absl_check.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "internal/benchmark.h" +#include "internal/testing.h" +#include "parser/macro.h" +#include "parser/options.h" +#include "parser/parser.h" + +namespace google::api::expr::parser { + +namespace { + +using ::absl_testing::IsOk; +using ::testing::Not; + +enum class ParseResult { kSuccess, kError }; + +struct TestInfo { + static TestInfo ErrorCase(absl::string_view expr) { + TestInfo info; + info.expr = expr; + info.result = ParseResult::kError; + return info; + } + // The expression to parse. + std::string expr = ""; + + // The expected result of the parse. + ParseResult result = ParseResult::kSuccess; +}; + +const std::vector& GetTestCases() { + static const std::vector* kInstance = new std::vector{ + // Simple test cases we started with + {"x * 2"}, + {"x * 2u"}, + {"x * 2.0"}, + {"\"\\u2764\""}, + {"\"\u2764\""}, + {"! false"}, + {"-a"}, + {"a.b(5)"}, + {"a[3]"}, + {"SomeMessage{foo: 5, bar: \"xyz\"}"}, + {"[3, 4, 5]"}, + {"{foo: 5, bar: \"xyz\"}"}, + {"a > 5 && a < 10"}, + {"a < 5 || a > 10"}, + TestInfo::ErrorCase("{"), + + // test cases from Go + {"\"A\""}, + {"true"}, + {"false"}, + {"0"}, + {"42"}, + {"0u"}, + {"23u"}, + {"24u"}, + {"0xAu"}, + {"-0xA"}, + {"0xA"}, + {"-1"}, + {"4--4"}, + {"4--4.1"}, + {"b\"abc\""}, + {"23.39"}, + {"!a"}, + {"a"}, + {"a?b:c"}, + {"a || b"}, + {"a || b || c || d || e || f "}, + {"a && b"}, + {"a && b && c && d && e && f && g"}, + {"a && b && c && d || e && f && g && h"}, + {"a + b"}, + {"a - b"}, + {"a * b"}, + {"a / b"}, + {"a % b"}, + {"a in b"}, + {"a == b"}, + {"a != b"}, + {"a > b"}, + {"a >= b"}, + {"a < b"}, + {"a <= b"}, + {"a.b"}, + {"a.b.c"}, + {"a[b]"}, + {"foo{ }"}, + {"foo{ a:b }"}, + {"foo{ a:b, c:d }"}, + {"{}"}, + {"{a:b, c:d}"}, + {"[]"}, + {"[a]"}, + {"[a, b, c]"}, + {"(a)"}, + {"((a))"}, + {"a()"}, + {"a(b)"}, + {"a(b, c)"}, + {"a.b()"}, + {"a.b(c)"}, + {"aaa.bbb(ccc)"}, + + // Parse error tests + TestInfo::ErrorCase("*@a | b"), + TestInfo::ErrorCase("a | b"), + TestInfo::ErrorCase("?"), + TestInfo::ErrorCase("t{>C}"), + + // Macro tests + {"has(m.f)"}, + {"m.exists_one(v, f)"}, + {"m.map(v, f)"}, + {"m.map(v, p, f)"}, + {"m.filter(v, p)"}, + + // Tests from Java parser + {"[] + [1,2,3,] + [4]"}, + {"{1:2u, 2:3u}"}, + {"TestAllTypes{single_int32: 1, single_int64: 2}"}, + + TestInfo::ErrorCase("TestAllTypes(){single_int32: 1, single_int64: 2}"), + {"size(x) == x.size()"}, + TestInfo::ErrorCase("1 + $"), + TestInfo::ErrorCase("1 + 2\n" + "3 +"), + {"\"\\\"\""}, + {"[1,3,4][0]"}, + TestInfo::ErrorCase("1.all(2, 3)"), + {"x[\"a\"].single_int32 == 23"}, + {"x.single_nested_message != null"}, + {"false && !true || false ? 2 : 3"}, + {"b\"abc\" + B\"def\""}, + {"1 + 2 * 3 - 1 / 2 == 6 % 1"}, + {"---a"}, + TestInfo::ErrorCase("1 + +"), + {"\"abc\" + \"def\""}, + TestInfo::ErrorCase("{\"a\": 1}.\"a\""), + {"\"\\xC3\\XBF\""}, + {"\"\\303\\277\""}, + {"\"hi\\u263A \\u263Athere\""}, + {"\"\\U000003A8\\?\""}, + {"\"\\a\\b\\f\\n\\r\\t\\v'\\\"\\\\\\? Legal escapes\""}, + TestInfo::ErrorCase("\"\\xFh\""), + TestInfo::ErrorCase( + "\"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>\""), + {"'😁' in ['😁', '😑', '😦']"}, + {"'\u00ff' in ['\u00ff', '\u00ff', '\u00ff']"}, + {"'\u00ff' in ['\uffff', '\U00100000', '\U0010ffff']"}, + {"'\u00ff' in ['\U00100000', '\uffff', '\U0010ffff']"}, + TestInfo::ErrorCase("'😁' in ['😁', '😑', '😦']\n" + " && in.😁"), + TestInfo::ErrorCase("as"), + TestInfo::ErrorCase("break"), + TestInfo::ErrorCase("const"), + TestInfo::ErrorCase("continue"), + TestInfo::ErrorCase("else"), + TestInfo::ErrorCase("for"), + TestInfo::ErrorCase("function"), + TestInfo::ErrorCase("if"), + TestInfo::ErrorCase("import"), + TestInfo::ErrorCase("in"), + TestInfo::ErrorCase("let"), + TestInfo::ErrorCase("loop"), + TestInfo::ErrorCase("package"), + TestInfo::ErrorCase("namespace"), + TestInfo::ErrorCase("return"), + TestInfo::ErrorCase("var"), + TestInfo::ErrorCase("void"), + TestInfo::ErrorCase("while"), + TestInfo::ErrorCase("[1, 2, 3].map(var, var * var)"), + TestInfo::ErrorCase("[\n\t\r[\n\t\r[\n\t\r]\n\t\r]\n\t\r"), + + // Identifier quoting syntax tests. + {"a.`b`"}, + {"a.`b-c`"}, + {"a.`b c`"}, + {"a.`b/c`"}, + {"a.`b.c`"}, + {"a.`in`"}, + {"A{`b`: 1}"}, + {"A{`b-c`: 1}"}, + {"A{`b c`: 1}"}, + {"A{`b/c`: 1}"}, + {"A{`b.c`: 1}"}, + {"A{`in`: 1}"}, + {"has(a.`b/c`)"}, + // Unsupported quoted identifiers. + TestInfo::ErrorCase("a.`b\tc`"), + TestInfo::ErrorCase("a.`@foo`"), + TestInfo::ErrorCase("a.`$foo`"), + TestInfo::ErrorCase("`a.b`"), + TestInfo::ErrorCase("`a.b`()"), + TestInfo::ErrorCase("foo.`a.b`()"), + // Macro calls tests + {"x.filter(y, y.filter(z, z > 0))"}, + {"has(a.b).filter(c, c)"}, + {"x.filter(y, y.exists(z, has(z.a)) && y.exists(z, has(z.b)))"}, + {"has(a.b).asList().exists(c, c)"}, + TestInfo::ErrorCase("b'\\UFFFFFFFF'"), + {"a.?b[?0] && a[?c]"}, + {"{?'key': value}"}, + {"[?a, ?b]"}, + {"[?a[?b]]"}, + {"Msg{?field: value}"}, + {"m.optMap(v, f)"}, + {"m.optFlatMap(v, f)"}}; + return *kInstance; +} + +class BenchmarkCaseTest : public testing::TestWithParam {}; + +TEST_P(BenchmarkCaseTest, ExpectedResult) { + std::vector macros = Macro::AllMacros(); + macros.push_back(cel::OptMapMacro()); + macros.push_back(cel::OptFlatMapMacro()); + const TestInfo& test_info = GetParam(); + ParserOptions options; + options.enable_optional_syntax = true; + options.enable_quoted_identifiers = true; + + auto result = EnrichedParse(test_info.expr, macros, "", options); + switch (test_info.result) { + case ParseResult::kSuccess: + ASSERT_THAT(result, IsOk()); + break; + case ParseResult::kError: + ASSERT_THAT(result, Not(IsOk())); + break; + } +} + +INSTANTIATE_TEST_SUITE_P(CelParserTest, BenchmarkCaseTest, + testing::ValuesIn(GetTestCases())); + +// This is not a proper microbenchmark, but is used to check for major +// regressions in the ANTLR generated code or concurrency issues. Each benchmark +// iteration parses all of the basic test cases from the unit-tests. +void BM_Parse(benchmark::State& state) { + std::vector macros = Macro::AllMacros(); + macros.push_back(cel::OptMapMacro()); + macros.push_back(cel::OptFlatMapMacro()); + ParserOptions options; + options.enable_optional_syntax = true; + options.enable_quoted_identifiers = true; + for (auto s : state) { + for (const auto& test_case : GetTestCases()) { + auto result = ParseWithMacros(test_case.expr, macros, "", options); + ABSL_DCHECK_EQ(result.ok(), test_case.result == ParseResult::kSuccess); + benchmark::DoNotOptimize(result); + } + } +} + +BENCHMARK(BM_Parse)->ThreadRange(1, std::thread::hardware_concurrency()); + +} // namespace +} // namespace google::api::expr::parser diff --git a/parser/parser_interface.h b/parser/parser_interface.h new file mode 100644 index 000000000..ad6e8ca84 --- /dev/null +++ b/parser/parser_interface.h @@ -0,0 +1,139 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/source.h" +#include "parser/macro.h" +#include "parser/options.h" + +namespace cel { + +class Parser; +class ParserBuilder; + +// Callable for configuring a ParserBuilder. +using ParserBuilderConfigurer = + absl::AnyInvocable; + +struct ParserLibrary { + // Optional identifier to avoid collisions re-adding the same macros. If + // empty, it is not considered for collision detection. + std::string id; + ParserBuilderConfigurer configure; +}; + +// Declares a subset of a parser library. +struct ParserLibrarySubset { + // The id of the library to subset. Only one subset can be applied per + // library id. + // + // Must be non-empty. + std::string library_id; + + using MacroPredicate = absl::AnyInvocable; + MacroPredicate should_include_macro; +}; + +// Interface for building a CEL parser, see comments on `Parser` below. +class ParserBuilder { + public: + virtual ~ParserBuilder() = default; + + // Returns the (mutable) current parser options. + virtual ParserOptions& GetOptions() = 0; + + // Adds a macro to the parser. + // Standard macros should be automatically added based on parser options. + virtual absl::Status AddMacro(const cel::Macro& macro) = 0; + + virtual absl::Status AddLibrary(ParserLibrary library) = 0; + + virtual absl::Status AddLibrarySubset(ParserLibrarySubset subset) = 0; + + // Builds a new parser instance, may error if incompatible macros are added. + virtual absl::StatusOr> Build() = 0; +}; + +// Information about a parse failure. +class ParseIssue { + public: + explicit ParseIssue(std::string message) : message_(std::move(message)) {} + ParseIssue(SourceLocation location, std::string message) + : location_(location), message_(std::move(message)) {} + + ParseIssue(const ParseIssue& other) = default; + ParseIssue& operator=(const ParseIssue& other) = default; + ParseIssue(ParseIssue&& other) = default; + ParseIssue& operator=(ParseIssue&& other) = default; + + SourceLocation location() const { return location_; } + absl::string_view message() const { return message_; } + + private: + SourceLocation location_; + std::string message_; +}; + +// Interface for stateful CEL parser objects for use with a `Compiler` +// (bundled parse and type check). This is not needed for most users: +// prefer using the free functions in `parser.h` for more flexibility. +class Parser { + public: + virtual ~Parser() = default; + + // Parses the given source into a CEL AST. + absl::StatusOr> Parse( + const cel::Source& source) const; + + // Parses the given source into a CEL AST, collecting parse errors in + // `issues`. If `issues` is non-null, it will be cleared and all parse + // issues will be appended to it. + absl::StatusOr> Parse( + const cel::Source& source, std::vector* issues) const; + + // Returns a builder initialized with the configuration of this parser. + virtual std::unique_ptr ToBuilder() const = 0; + + protected: + virtual absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* absl_nullable parse_issues) const = 0; +}; + +inline absl::StatusOr> Parser::Parse( + const cel::Source& source) const { + return ParseImpl(source, nullptr); +} + +inline absl::StatusOr> Parser::Parse( + const cel::Source& source, std::vector* issues) const { + if (issues != nullptr) issues->clear(); + return ParseImpl(source, issues); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ diff --git a/parser/parser_subset_factory.cc b/parser/parser_subset_factory.cc new file mode 100644 index 000000000..fb72a950a --- /dev/null +++ b/parser/parser_subset_factory.cc @@ -0,0 +1,54 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/parser_subset_factory.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "parser/macro.h" +#include "parser/parser_interface.h" + +namespace cel { + +cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( + absl::flat_hash_set macro_names) { + return [macro_names_set = std::move(macro_names)](const Macro& macro) { + return macro_names_set.contains(macro.function()); + }; +} + +cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( + absl::Span macro_names) { + return IncludeMacrosByNamePredicate( + absl::flat_hash_set(macro_names.begin(), macro_names.end())); +} + +cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( + absl::flat_hash_set macro_names) { + return [macro_names_set = std::move(macro_names)](const Macro& macro) { + return !macro_names_set.contains(macro.function()); + }; +} + +cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( + absl::Span macro_names) { + return ExcludeMacrosByNamePredicate( + absl::flat_hash_set(macro_names.begin(), macro_names.end())); +} + +} // namespace cel diff --git a/parser/parser_subset_factory.h b/parser/parser_subset_factory.h new file mode 100644 index 000000000..87ee74f99 --- /dev/null +++ b/parser/parser_subset_factory.h @@ -0,0 +1,41 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_SUBSET_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_PARSER_SUBSET_FACTORY_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "parser/parser_interface.h" + +namespace cel { + +// Predicate that only includes the given macro by name. +cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( + absl::flat_hash_set macro_names); +cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( + absl::Span macro_names); + +// Predicate that excludes the given macros by name. +cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( + absl::flat_hash_set macro_names); +cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( + absl::Span macro_names); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_SUBSET_FACTORY_H_ diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 1e0b35d4b..35f11b413 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1,34 +1,63 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "parser/parser.h" -#include -#include +#include #include +#include #include #include -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "cel/expr/syntax.pb.h" +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/source.h" +#include "internal/testing.h" +#include "parser/macro.h" +#include "parser/options.h" +#include "parser/parser_interface.h" #include "parser/source_factory.h" #include "testutil/expr_printer.h" -namespace google { -namespace api { -namespace expr { -namespace parser { +namespace google::api::expr::parser { + namespace { -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using testing::Not; +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::ConstantKindCase; +using ::cel::ExprKindCase; +using ::cel::test::ExprPrinter; +using ::cel::expr::Expr; +using ::testing::HasSubstr; +using ::testing::Not; struct TestInfo { TestInfo(const std::string& I, const std::string& P, const std::string& E = "", const std::string& L = "", - const std::string& R = "") - : I(I), P(P), E(E), L(L), R(R) {} + const std::string& R = "", const std::string& M = "") + : I(I), P(P), E(E), L(L), R(R), M(M) {} // I contains the input expression to be parsed. std::string I; @@ -45,6 +74,9 @@ struct TestInfo { // R contains the expected enriched source info output of the expression tree. std::string R; + + // M contains the expected macro call output of hte expression tree. + std::string M; }; std::vector test_cases = { @@ -62,7 +94,7 @@ std::vector test_cases = { {"x * 2.0", "_*_(\n" " x^#1:Expr.Ident#,\n" - " 2.^#3:double#\n" + " 2.0^#3:double#\n" ")^#2:Expr.Call#"}, {"\"\\u2764\"", "\"\u2764\"^#1:string#"}, {"\"\u2764\"", "\"\u2764\"^#1:string#"}, @@ -85,9 +117,9 @@ std::vector test_cases = { ")^#2:Expr.Call#"}, {"SomeMessage{foo: 5, bar: \"xyz\"}", "SomeMessage{\n" - " foo:5^#4:int64#^#3:Expr.CreateStruct.Entry#,\n" - " bar:\"xyz\"^#6:string#^#5:Expr.CreateStruct.Entry#\n" - "}^#2:Expr.CreateStruct#"}, + " foo:5^#3:int64#^#2:Expr.CreateStruct.Entry#,\n" + " bar:\"xyz\"^#5:string#^#4:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, {"[3, 4, 5]", "[\n" " 3^#2:int64#,\n" @@ -124,7 +156,8 @@ std::vector test_cases = { {"{", "", "ERROR: :1:2: Syntax error: mismatched input '' expecting " "{'[', " - "'{', '}', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, " + "'{', '}', '(', '.', ',', '-', '!', '\\u003F', 'true', 'false', 'null', " + "NUM_FLOAT, " "NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n | {\n" " | .^"}, @@ -138,6 +171,9 @@ std::vector test_cases = { {"0u", "0u^#1:uint64#"}, {"23u", "23u^#1:uint64#"}, {"24u", "24u^#1:uint64#"}, + {"0xAu", "10u^#1:uint64#"}, + {"-0xA", "-10^#1:int64#"}, + {"0xA", "10^#1:int64#"}, {"-1", "-1^#1:int64#"}, {"4--4", "_-_(\n" @@ -302,16 +338,16 @@ std::vector test_cases = { " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, - {"foo{ }", "foo{}^#2:Expr.CreateStruct#"}, + {"foo{ }", "foo{}^#1:Expr.CreateStruct#"}, {"foo{ a:b }", "foo{\n" - " a:b^#4:Expr.Ident#^#3:Expr.CreateStruct.Entry#\n" - "}^#2:Expr.CreateStruct#"}, + " a:b^#3:Expr.Ident#^#2:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, {"foo{ a:b, c:d }", "foo{\n" - " a:b^#4:Expr.Ident#^#3:Expr.CreateStruct.Entry#,\n" - " c:d^#6:Expr.Ident#^#5:Expr.CreateStruct.Entry#\n" - "}^#2:Expr.CreateStruct#"}, + " a:b^#3:Expr.Ident#^#2:Expr.CreateStruct.Entry#,\n" + " c:d^#5:Expr.Ident#^#4:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, {"{}", "{}^#1:Expr.CreateStruct#"}, {"{a:b, c:d}", "{\n" @@ -367,15 +403,15 @@ std::vector test_cases = { // Parse error tests {"*@a | b", "", - "ERROR: :1:2: Syntax error: token recognition error at: '@'\n" - " | *@a | b\n" - " | .^\n" "ERROR: :1:1: Syntax error: extraneous input '*' expecting {'[', " "'{', " "'(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | *@a | b\n" " | ^\n" + "ERROR: :1:2: Syntax error: token recognition error at: '@'\n" + " | *@a | b\n" + " | .^\n" "ERROR: :1:5: Syntax error: token recognition error at: '| '\n" " | *@a | b\n" " | ....^\n" @@ -396,20 +432,20 @@ std::vector test_cases = { "ERROR: :1:2: Syntax error: mismatched input '' expecting " "{'[', '{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, " "NUM_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}\n | ?\n | .^\n" - "ERROR: :4294967295:0: <> parsetree\n | \n | ^"}, + "ERROR: :4294967295:0: <> parsetree"}, {"t{>C}", "", "ERROR: :1:3: Syntax error: extraneous input '>' expecting {'}', " - "IDENTIFIER}\n | t{>C}\n | ..^\nERROR: :1:5: Syntax error: " + "',', '\\u003F', IDENTIFIER, ESC_IDENTIFIER}\n | t{>C}\n | ..^\nERROR: " + ":1:5: " + "Syntax error: " "mismatched input '}' expecting ':'\n | t{>C}\n | ....^"}, // Macro tests - { - "has(m.f)", - "m^#2:Expr.Ident#.f~test-only~^#4:Expr.Select#", - "", - "m^#2[1,4]#.f~test-only~^#4[1,3]#", - "[1,3,3]^#[2,4,4]^#[3,5,5]^#[4,3,3]", - }, + {"has(m.f)", "m^#2:Expr.Ident#.f~test-only~^#4:Expr.Select#", "", + "m^#2[1,4]#.f~test-only~^#4[1,3]#", "[2,4,4]^#[3,5,5]^#[4,3,3]", + "has(\n" + " m^#2:Expr.Ident#.f^#3:Expr.Select#\n" + ")^#4:has"}, {"m.exists_one(v, f)", "__comprehension__(\n" " // Variable\n" @@ -417,28 +453,30 @@ std::vector test_cases = { " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" - " __result__,\n" + " @result,\n" " // Init\n" " 0^#5:int64#,\n" " // LoopCondition\n" - " _<=_(\n" - " __result__^#7:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#8:Expr.Call#,\n" + " true^#6:bool#,\n" " // LoopStep\n" " _?_:_(\n" " f^#4:Expr.Ident#,\n" " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#10:Expr.Call#,\n" - " __result__^#11:Expr.Ident#\n" - " )^#12:Expr.Call#,\n" + " @result^#7:Expr.Ident#,\n" + " 1^#8:int64#\n" + " )^#9:Expr.Call#,\n" + " @result^#10:Expr.Ident#\n" + " )^#11:Expr.Call#,\n" " // Result\n" " _==_(\n" - " __result__^#13:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#14:Expr.Call#)^#15:Expr.Comprehension#"}, + " @result^#12:Expr.Ident#,\n" + " 1^#13:int64#\n" + " )^#14:Expr.Call#)^#15:Expr.Comprehension#", + "", "", "", + "m^#1:Expr.Ident#.exists_one(\n" + " v^#3:Expr.Ident#,\n" + " f^#4:Expr.Ident#\n" + ")^#15:exists_one"}, {"m.map(v, f)", "__comprehension__(\n" " // Variable\n" @@ -446,20 +484,25 @@ std::vector test_cases = { " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" - " __result__,\n" + " @result,\n" " // Init\n" - " []^#6:Expr.CreateList#,\n" + " []^#5:Expr.CreateList#,\n" " // LoopCondition\n" - " true^#7:bool#,\n" + " true^#6:bool#,\n" " // LoopStep\n" " _+_(\n" - " __result__^#5:Expr.Ident#,\n" + " @result^#7:Expr.Ident#,\n" " [\n" " f^#4:Expr.Ident#\n" " ]^#8:Expr.CreateList#\n" " )^#9:Expr.Call#,\n" " // Result\n" - " __result__^#5:Expr.Ident#)^#10:Expr.Comprehension#"}, + " @result^#10:Expr.Ident#)^#11:Expr.Comprehension#", + "", "", "", + "m^#1:Expr.Ident#.map(\n" + " v^#3:Expr.Ident#,\n" + " f^#4:Expr.Ident#\n" + ")^#11:map"}, {"m.map(v, p, f)", "__comprehension__(\n" " // Variable\n" @@ -467,24 +510,30 @@ std::vector test_cases = { " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" - " __result__,\n" + " @result,\n" " // Init\n" - " []^#7:Expr.CreateList#,\n" + " []^#6:Expr.CreateList#,\n" " // LoopCondition\n" - " true^#8:bool#,\n" + " true^#7:bool#,\n" " // LoopStep\n" " _?_:_(\n" " p^#4:Expr.Ident#,\n" " _+_(\n" - " __result__^#6:Expr.Ident#,\n" + " @result^#8:Expr.Ident#,\n" " [\n" " f^#5:Expr.Ident#\n" " ]^#9:Expr.CreateList#\n" " )^#10:Expr.Call#,\n" - " __result__^#6:Expr.Ident#\n" - " )^#11:Expr.Call#,\n" + " @result^#11:Expr.Ident#\n" + " )^#12:Expr.Call#,\n" " // Result\n" - " __result__^#6:Expr.Ident#)^#12:Expr.Comprehension#"}, + " @result^#13:Expr.Ident#)^#14:Expr.Comprehension#", + "", "", "", + "m^#1:Expr.Ident#.map(\n" + " v^#3:Expr.Ident#,\n" + " p^#4:Expr.Ident#,\n" + " f^#5:Expr.Ident#\n" + ")^#14:map"}, {"m.filter(v, p)", "__comprehension__(\n" " // Variable\n" @@ -492,24 +541,29 @@ std::vector test_cases = { " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" - " __result__,\n" + " @result,\n" " // Init\n" - " []^#6:Expr.CreateList#,\n" + " []^#5:Expr.CreateList#,\n" " // LoopCondition\n" - " true^#7:bool#,\n" + " true^#6:bool#,\n" " // LoopStep\n" " _?_:_(\n" " p^#4:Expr.Ident#,\n" " _+_(\n" - " __result__^#5:Expr.Ident#,\n" + " @result^#7:Expr.Ident#,\n" " [\n" " v^#3:Expr.Ident#\n" " ]^#8:Expr.CreateList#\n" " )^#9:Expr.Call#,\n" - " __result__^#5:Expr.Ident#\n" - " )^#10:Expr.Call#,\n" + " @result^#10:Expr.Ident#\n" + " )^#11:Expr.Call#,\n" " // Result\n" - " __result__^#5:Expr.Ident#)^#11:Expr.Comprehension#"}, + " @result^#12:Expr.Ident#)^#13:Expr.Comprehension#", + "", "", "", + "m^#1:Expr.Ident#.filter(\n" + " v^#3:Expr.Ident#,\n" + " p^#4:Expr.Ident#\n" + ")^#13:filter"}, // Tests from Java parser {"[] + [1,2,3,] + [4]", @@ -533,13 +587,13 @@ std::vector test_cases = { "}^#1:Expr.CreateStruct#"}, {"TestAllTypes{single_int32: 1, single_int64: 2}", "TestAllTypes{\n" - " single_int32:1^#4:int64#^#3:Expr.CreateStruct.Entry#,\n" - " single_int64:2^#6:int64#^#5:Expr.CreateStruct.Entry#\n" - "}^#2:Expr.CreateStruct#"}, + " single_int32:1^#3:int64#^#2:Expr.CreateStruct.Entry#,\n" + " single_int64:2^#5:int64#^#4:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, {"TestAllTypes(){single_int32: 1, single_int64: 2}", "", - "ERROR: :1:13: expected a qualified name\n" + "ERROR: :1:15: Syntax error: mismatched input '{' expecting \n" " | TestAllTypes(){single_int32: 1, single_int64: 2}\n" - " | ............^"}, + " | ..............^"}, {"size(x) == x.size()", "_==_(\n" " size(\n" @@ -574,9 +628,30 @@ std::vector test_cases = { " 0^#6:int64#\n" ")^#5:Expr.Call#"}, {"1.all(2, 3)", "", - "ERROR: :1:7: argument must be a simple name\n" + "ERROR: :1:7: all() variable name must be a simple identifier\n" " | 1.all(2, 3)\n" " | ......^"}, + {"[].all(.x, x)", "", + "ERROR: :1:9: all() variable name must be a simple identifier\n" + " | [].all(.x, x)\n" + " | ........^"}, + {"[].exists(.x, x)", "", + "ERROR: :1:12: exists() variable name must be a simple identifier\n" + " | [].exists(.x, x)\n" + " | ...........^"}, + {"[].exists_one(.x, x)", "", + "ERROR: :1:16: exists_one() variable name must be a simple " + "identifier\n" + " | [].exists_one(.x, x)\n" + " | ...............^"}, + {"[].map(.x, x, x)", "", + "ERROR: :1:9: map() variable name must be a simple identifier\n" + " | [].map(.x, x, x)\n" + " | ........^"}, + {"[].filter(.x, x)", "", + "ERROR: :1:12: filter() variable name must be a simple identifier\n" + " | [].filter(.x, x)\n" + " | ...........^"}, {"x[\"a\"].single_int32 == 23", "_==_(\n" " _[_](\n" @@ -653,8 +728,8 @@ std::vector test_cases = { " \"def\"^#3:string#\n" ")^#2:Expr.Call#"}, {"{\"a\": 1}.\"a\"", "", - "ERROR: :1:10: Syntax error: mismatched input '\"a\"' " - "expecting IDENTIFIER\n" + "ERROR: :1:10: Syntax error: no viable alternative at input " + "'.\"a\"'\n" " | {\"a\": 1}.\"a\"\n" " | .........^"}, {"\"\\xC3\\XBF\"", "\"ÿ\"^#1:string#"}, @@ -662,7 +737,7 @@ std::vector test_cases = { {"\"hi\\u263A \\u263Athere\"", "\"hi☺ ☺there\"^#1:string#"}, {"\"\\U000003A8\\?\"", "\"Ψ?\"^#1:string#"}, {"\"\\a\\b\\f\\n\\r\\t\\v'\\\"\\\\\\? Legal escapes\"", - "\"\\a\\b\\f\\n\\r\\t\\v'\\\"\\? Legal escapes\"^#1:string#"}, + "\"\\x07\\x08\\x0c\\n\\r\\t\\x0b'\\\"\\\\? Legal escapes\"^#1:string#"}, {"\"\\xFh\"", "", "ERROR: :1:1: Syntax error: token recognition error at: '\"\\xFh'\n" " | \"\\xFh\"\n" @@ -699,6 +774,33 @@ std::vector test_cases = { " \"😦\"^#6:string#\n" " ]^#3:Expr.CreateList#\n" ")^#2:Expr.Call#"}, + {"'\u00ff' in ['\u00ff', '\u00ff', '\u00ff']", + "@in(\n" + " \"\u00ff\"^#1:string#,\n" + " [\n" + " \"\u00ff\"^#4:string#,\n" + " \"\u00ff\"^#5:string#,\n" + " \"\u00ff\"^#6:string#\n" + " ]^#3:Expr.CreateList#\n" + ")^#2:Expr.Call#"}, + {"'\u00ff' in ['\uffff', '\U00100000', '\U0010ffff']", + "@in(\n" + " \"\u00ff\"^#1:string#,\n" + " [\n" + " \"\uffff\"^#4:string#,\n" + " \"\U00100000\"^#5:string#,\n" + " \"\U0010ffff\"^#6:string#\n" + " ]^#3:Expr.CreateList#\n" + ")^#2:Expr.Call#"}, + {"'\u00ff' in ['\U00100000', '\uffff', '\U0010ffff']", + "@in(\n" + " \"\u00ff\"^#1:string#,\n" + " [\n" + " \"\U00100000\"^#4:string#,\n" + " \"\uffff\"^#5:string#,\n" + " \"\U0010ffff\"^#6:string#\n" + " ]^#3:Expr.CreateList#\n" + ")^#2:Expr.Call#"}, {"'😁' in ['😁', '😑', '😦']\n" " && in.😁", "", @@ -709,10 +811,10 @@ std::vector test_cases = { " | ......^\n" "ERROR: :2:10: Syntax error: token recognition error at: '😁'\n" " | && in.😁\n" - " | .........^\n" - "ERROR: :2:11: Syntax error: missing IDENTIFIER at ''\n" + " | .........^\n" + "ERROR: :2:11: Syntax error: no viable alternative at input '.'\n" " | && in.😁\n" - " | ..........^"}, + " | ..........^"}, {"as", "", "ERROR: :1:1: reserved identifier: as\n" " | as\n" @@ -797,15 +899,15 @@ std::vector test_cases = { "ERROR: :1:15: reserved identifier: var\n" " | [1, 2, 3].map(var, var * var)\n" " | ..............^\n" + "ERROR: :1:15: map() variable name must be a simple identifier\n" + " | [1, 2, 3].map(var, var * var)\n" + " | ..............^\n" "ERROR: :1:20: reserved identifier: var\n" " | [1, 2, 3].map(var, var * var)\n" " | ...................^\n" "ERROR: :1:26: reserved identifier: var\n" " | [1, 2, 3].map(var, var * var)\n" - " | .........................^\n" - "ERROR: :1:15: argument is not an identifier\n" - " | [1, 2, 3].map(var, var * var)\n" - " | ..............^"}, + " | .........................^"}, {"[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" @@ -814,61 +916,519 @@ std::vector test_cases = { "]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]" "]]]]]]", - "", "Expression recursion limit exceeded. limit: 250"}, + "", "Expression recursion limit exceeded. limit: 32", "", "", ""}, { // Note, the ANTLR parse stack may recurse much more deeply and permit // more detailed expressions than the visitor can recurse over in // practice. "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[['just fine'],[1],[2],[3],[4],[5]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]", - "" // parse output not validated as it is too large. - }}; + "", // parse output not validated as it is too large. + "", + "", + "", + "", + }, + { + "[\n\t\r[\n\t\r[\n\t\r]\n\t\r]\n\t\r", + "", // parse output not validated as it is too large. + "ERROR: :6:3: Syntax error: mismatched input '' expecting " + "{']', ','}\n" + " | \r\n" + " | ..^", + }, -class KindAndIdAdorner : public testutil::ExpressionAdorner { + // Identifier quoting syntax tests. + {"a.`b`", "a^#1:Expr.Ident#.b^#2:Expr.Select#"}, + {"a.`b-c`", "a^#1:Expr.Ident#.b-c^#2:Expr.Select#"}, + {"a.`b c`", "a^#1:Expr.Ident#.b c^#2:Expr.Select#"}, + {"a.`b/c`", "a^#1:Expr.Ident#.b/c^#2:Expr.Select#"}, + {"a.`b.c`", "a^#1:Expr.Ident#.b.c^#2:Expr.Select#"}, + {"a.`in`", "a^#1:Expr.Ident#.in^#2:Expr.Select#"}, + {"A{`b`: 1}", + "A{\n" + " b:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, + {"A{`b-c`: 1}", + "A{\n" + " b-c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, + {"A{`b c`: 1}", + "A{\n" + " b c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, + {"A{`b/c`: 1}", + "A{\n" + " b/c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, + {"A{`b.c`: 1}", + "A{\n" + " b.c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, + {"A{`in`: 1}", + "A{\n" + " in:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" + "}^#1:Expr.CreateStruct#"}, + {"has(a.`b/c`)", "a^#2:Expr.Ident#.b/c~test-only~^#4:Expr.Select#"}, + // Unsupported quoted identifiers. + {"a.`b\tc`", "", + "ERROR: :1:3: Syntax error: token recognition error at: '`b\\t'\n" + " | a.`b c`\n" + " | ..^\n" + "ERROR: :1:7: Syntax error: token recognition error at: '`'\n" + " | a.`b c`\n" + " | ......^"}, + {"a.`@foo`", "", + "ERROR: :1:3: Syntax error: token recognition error at: '`@'\n" + " | a.`@foo`\n" + " | ..^\n" + "ERROR: :1:8: Syntax error: token recognition error at: '`'\n" + " | a.`@foo`\n" + " | .......^"}, + {"a.`$foo`", "", + "ERROR: :1:3: Syntax error: token recognition error at: '`$'\n" + " | a.`$foo`\n" + " | ..^\n" + "ERROR: :1:8: Syntax error: token recognition error at: '`'\n" + " | a.`$foo`\n" + " | .......^"}, + {"`a.b`", "", + "ERROR: :1:1: Syntax error: mismatched input '`a.b`' expecting " + "{'[', '{', " + "'(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " + "NUM_UINT, STRING, " + "BYTES, IDENTIFIER}\n" + " | `a.b`\n" + " | ^"}, + {"`a.b`()", "", + "ERROR: :1:1: Syntax error: extraneous input '`a.b`' expecting " + "{'[', '{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, " + "NUM_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}\n" + " | `a.b`()\n" + " | ^\n" + "ERROR: :1:7: Syntax error: mismatched input ')' expecting {'[', " + "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM" + "_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}\n" + " | `a.b`()\n" + " | ......^"}, + {"foo.`a.b`()", "", + "ERROR: :1:10: Syntax error: mismatched input '(' expecting \n" + " | foo.`a.b`()\n" + " | .........^"}, + + // Macro calls tests + {"x.filter(y, y.filter(z, z > 0))", + "__comprehension__(\n" + " // Variable\n" + " y,\n" + " // Target\n" + " x^#1:Expr.Ident#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " []^#19:Expr.CreateList#,\n" + " // LoopCondition\n" + " true^#20:bool#,\n" + " // LoopStep\n" + " _?_:_(\n" + " __comprehension__(\n" + " // Variable\n" + " z,\n" + " // Target\n" + " y^#4:Expr.Ident#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " []^#10:Expr.CreateList#,\n" + " // LoopCondition\n" + " true^#11:bool#,\n" + " // LoopStep\n" + " _?_:_(\n" + " _>_(\n" + " z^#7:Expr.Ident#,\n" + " 0^#9:int64#\n" + " )^#8:Expr.Call#,\n" + " _+_(\n" + " @result^#12:Expr.Ident#,\n" + " [\n" + " z^#6:Expr.Ident#\n" + " ]^#13:Expr.CreateList#\n" + " )^#14:Expr.Call#,\n" + " @result^#15:Expr.Ident#\n" + " )^#16:Expr.Call#,\n" + " // Result\n" + " @result^#17:Expr.Ident#)^#18:Expr.Comprehension#,\n" + " _+_(\n" + " @result^#21:Expr.Ident#,\n" + " [\n" + " y^#3:Expr.Ident#\n" + " ]^#22:Expr.CreateList#\n" + " )^#23:Expr.Call#,\n" + " @result^#24:Expr.Ident#\n" + " )^#25:Expr.Call#,\n" + " // Result\n" + " @result^#26:Expr.Ident#)^#27:Expr.Comprehension#" + "", + "", "", "", + "x^#1:Expr.Ident#.filter(\n" + " y^#3:Expr.Ident#,\n" + " ^#18:filter#\n" + ")^#27:filter#,\n" + "y^#4:Expr.Ident#.filter(\n" + " z^#6:Expr.Ident#,\n" + " _>_(\n" + " z^#7:Expr.Ident#,\n" + " 0^#9:int64#\n" + " )^#8:Expr.Call#\n" + ")^#18:filter"}, + {"has(a.b).filter(c, c)", + "__comprehension__(\n" + " // Variable\n" + " c,\n" + " // Target\n" + " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " []^#8:Expr.CreateList#,\n" + " // LoopCondition\n" + " true^#9:bool#,\n" + " // LoopStep\n" + " _?_:_(\n" + " c^#7:Expr.Ident#,\n" + " _+_(\n" + " @result^#10:Expr.Ident#,\n" + " [\n" + " c^#6:Expr.Ident#\n" + " ]^#11:Expr.CreateList#\n" + " )^#12:Expr.Call#,\n" + " @result^#13:Expr.Ident#\n" + " )^#14:Expr.Call#,\n" + " // Result\n" + " @result^#15:Expr.Ident#)^#16:Expr.Comprehension#", + "", "", "", + "^#4:has#.filter(\n" + " c^#6:Expr.Ident#,\n" + " c^#7:Expr.Ident#\n" + ")^#16:filter#,\n" + "has(\n" + " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" + ")^#4:has"}, + {"x.filter(y, y.exists(z, has(z.a)) && y.exists(z, has(z.b)))", + "__comprehension__(\n" + " // Variable\n" + " y,\n" + " // Target\n" + " x^#1:Expr.Ident#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " []^#35:Expr.CreateList#,\n" + " // LoopCondition\n" + " true^#36:bool#,\n" + " // LoopStep\n" + " _?_:_(\n" + " _&&_(\n" + " __comprehension__(\n" + " // Variable\n" + " z,\n" + " // Target\n" + " y^#4:Expr.Ident#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " false^#11:bool#,\n" + " // LoopCondition\n" + " @not_strictly_false(\n" + " !_(\n" + " @result^#12:Expr.Ident#\n" + " )^#13:Expr.Call#\n" + " )^#14:Expr.Call#,\n" + " // LoopStep\n" + " _||_(\n" + " @result^#15:Expr.Ident#,\n" + " z^#8:Expr.Ident#.a~test-only~^#10:Expr.Select#\n" + " )^#16:Expr.Call#,\n" + " // Result\n" + " @result^#17:Expr.Ident#)^#18:Expr.Comprehension#,\n" + " __comprehension__(\n" + " // Variable\n" + " z,\n" + " // Target\n" + " y^#19:Expr.Ident#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " false^#26:bool#,\n" + " // LoopCondition\n" + " @not_strictly_false(\n" + " !_(\n" + " @result^#27:Expr.Ident#\n" + " )^#28:Expr.Call#\n" + " )^#29:Expr.Call#,\n" + " // LoopStep\n" + " _||_(\n" + " @result^#30:Expr.Ident#,\n" + " z^#23:Expr.Ident#.b~test-only~^#25:Expr.Select#\n" + " )^#31:Expr.Call#,\n" + " // Result\n" + " @result^#32:Expr.Ident#)^#33:Expr.Comprehension#\n" + " )^#34:Expr.Call#,\n" + " _+_(\n" + " @result^#37:Expr.Ident#,\n" + " [\n" + " y^#3:Expr.Ident#\n" + " ]^#38:Expr.CreateList#\n" + " )^#39:Expr.Call#,\n" + " @result^#40:Expr.Ident#\n" + " )^#41:Expr.Call#,\n" + " // Result\n" + " @result^#42:Expr.Ident#)^#43:Expr.Comprehension#", + "", "", "", + "x^#1:Expr.Ident#.filter(\n" + " y^#3:Expr.Ident#,\n" + " _&&_(\n" + " ^#18:exists#,\n" + " ^#33:exists#\n" + " )^#34:Expr.Call#\n" + ")^#43:filter#,\n" + "y^#19:Expr.Ident#.exists(\n" + " z^#21:Expr.Ident#,\n" + " ^#25:has#\n" + ")^#33:exists#,\n" + "has(\n" + " z^#23:Expr.Ident#.b^#24:Expr.Select#\n" + ")^#25:has#,\n" + "y^#4:Expr.Ident#." + "exists(\n" + " z^#6:Expr.Ident#,\n" + " ^#10:has#\n" + ")^#18:exists#,\n" + "has(\n" + " z^#8:Expr.Ident#.a^#9:Expr.Select#\n" + ")^#10:has"}, + {"has(a.b).asList().exists(c, c)", + "__comprehension__(\n" + " // Variable\n" + " c,\n" + " // Target\n" + " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#.asList()^#5:Expr.Call#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " false^#9:bool#,\n" + " // LoopCondition\n" + " @not_strictly_false(\n" + " !_(\n" + " @result^#10:Expr.Ident#\n" + " )^#11:Expr.Call#\n" + " )^#12:Expr.Call#,\n" + " // LoopStep\n" + " _||_(\n" + " @result^#13:Expr.Ident#,\n" + " c^#8:Expr.Ident#\n" + " )^#14:Expr.Call#,\n" + " // Result\n" + " @result^#15:Expr.Ident#)^#16:Expr.Comprehension#", + "", "", "", + "^#4:has#.asList()^#5:Expr.Call#.exists(\n" + " c^#7:Expr.Ident#,\n" + " c^#8:Expr.Ident#\n" + ")^#16:exists#,\n" + "has(\n" + " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" + ")^#4:has"}, + {"[has(a.b), has(c.d)].exists(e, e)", + "__comprehension__(\n" + " // Variable\n" + " e,\n" + " // Target\n" + " [\n" + " a^#3:Expr.Ident#.b~test-only~^#5:Expr.Select#,\n" + " c^#7:Expr.Ident#.d~test-only~^#9:Expr.Select#\n" + " ]^#1:Expr.CreateList#,\n" + " // Accumulator\n" + " @result,\n" + " // Init\n" + " false^#13:bool#,\n" + " // LoopCondition\n" + " @not_strictly_false(\n" + " !_(\n" + " @result^#14:Expr.Ident#\n" + " )^#15:Expr.Call#\n" + " )^#16:Expr.Call#,\n" + " // LoopStep\n" + " _||_(\n" + " @result^#17:Expr.Ident#,\n" + " e^#12:Expr.Ident#\n" + " )^#18:Expr.Call#,\n" + " // Result\n" + " @result^#19:Expr.Ident#)^#20:Expr.Comprehension#", + "", "", "", + "[\n" + " ^#5:has#,\n" + " ^#9:has#\n" + "]^#1:Expr.CreateList#.exists(\n" + " e^#11:Expr.Ident#,\n" + " e^#12:Expr.Ident#\n" + ")^#20:exists#,\n" + "has(\n" + " c^#7:Expr.Ident#.d^#8:Expr.Select#\n" + ")^#9:has#,\n" + "has(\n" + " a^#3:Expr.Ident#.b^#4:Expr.Select#\n" + ")^#5:has"}, + {"b'\\UFFFFFFFF'", "", + "ERROR: :1:1: Invalid bytes literal: Illegal escape sequence: " + "Unicode escape sequence \\U cannot be used in bytes literals\n | " + "b'\\UFFFFFFFF'\n | ^"}, + {"a.?b[?0] && a[?c]", + "_&&_(\n _[?_](\n _?._(\n a^#1:Expr.Ident#,\n " + "\"b\"^#3:string#\n )^#2:Expr.Call#,\n 0^#5:int64#\n " + ")^#4:Expr.Call#,\n _[?_](\n a^#6:Expr.Ident#,\n " + "c^#8:Expr.Ident#\n )^#7:Expr.Call#\n)^#9:Expr.Call#"}, + {"{?'key': value}", + "{\n " + "?\"key\"^#3:string#:value^#4:Expr.Ident#^#2:Expr.CreateStruct.Entry#\n}^#" + "1:Expr.CreateStruct#"}, + {"[?a, ?b]", + "[\n ?a^#2:Expr.Ident#,\n ?b^#3:Expr.Ident#\n]^#1:Expr.CreateList#"}, + {"[?a[?b]]", + "[\n ?_[?_](\n a^#2:Expr.Ident#,\n b^#4:Expr.Ident#\n " + ")^#3:Expr.Call#\n]^#1:Expr.CreateList#"}, + {"Msg{?field: value}", + "Msg{\n " + "?field:value^#3:Expr.Ident#^#2:Expr.CreateStruct.Entry#\n}^#1:Expr." + "CreateStruct#"}, + {"m.optMap(v, f)", + "_?_:_(\n m^#1:Expr.Ident#.hasValue()^#6:Expr.Call#,\n optional.of(\n " + " __comprehension__(\n // Variable\n #unused,\n // " + "Target\n []^#7:Expr.CreateList#,\n // Accumulator\n v,\n " + " // Init\n m^#5:Expr.Ident#.value()^#8:Expr.Call#,\n // " + "LoopCondition\n false^#9:bool#,\n // LoopStep\n " + "v^#3:Expr.Ident#,\n // Result\n " + "f^#4:Expr.Ident#)^#10:Expr.Comprehension#\n )^#11:Expr.Call#,\n " + "optional.none()^#12:Expr.Call#\n)^#13:Expr.Call#"}, + {"m.optFlatMap(v, f)", + "_?_:_(\n m^#1:Expr.Ident#.hasValue()^#6:Expr.Call#,\n " + "__comprehension__(\n // Variable\n #unused,\n // Target\n " + "[]^#7:Expr.CreateList#,\n // Accumulator\n v,\n // Init\n " + "m^#5:Expr.Ident#.value()^#8:Expr.Call#,\n // LoopCondition\n " + "false^#9:bool#,\n // LoopStep\n v^#3:Expr.Ident#,\n // Result\n " + " f^#4:Expr.Ident#)^#10:Expr.Comprehension#,\n " + "optional.none()^#11:Expr.Call#\n)^#12:Expr.Call#"}}; + +absl::string_view ConstantKind(const cel::Constant& c) { + switch (c.kind_case()) { + case ConstantKindCase::kBool: + return "bool"; + case ConstantKindCase::kInt: + return "int64"; + case ConstantKindCase::kUint: + return "uint64"; + case ConstantKindCase::kDouble: + return "double"; + case ConstantKindCase::kString: + return "string"; + case ConstantKindCase::kBytes: + return "bytes"; + case ConstantKindCase::kNull: + return "NullValue"; + default: + return "unspecified_constant"; + } +} + +absl::string_view ExprKind(const cel::Expr& e) { + switch (e.kind_case()) { + case ExprKindCase::kConstant: + // special cased, this doesn't appear. + return "Expr.Constant"; + case ExprKindCase::kIdentExpr: + return "Expr.Ident"; + case ExprKindCase::kSelectExpr: + return "Expr.Select"; + case ExprKindCase::kCallExpr: + return "Expr.Call"; + case ExprKindCase::kListExpr: + return "Expr.CreateList"; + case ExprKindCase::kMapExpr: + case ExprKindCase::kStructExpr: + return "Expr.CreateStruct"; + case ExprKindCase::kComprehensionExpr: + return "Expr.Comprehension"; + default: + return "unspecified_expr"; + } +} + +class KindAndIdAdorner : public cel::test::ExpressionAdorner { public: - std::string adorn(const Expr& e) const override { + // Use default source_info constructor to make source_info "optional". This + // will prevent macro_calls lookups from interfering with adorning expressions + // that don't need to use macro_calls, such as the parsed AST. + explicit KindAndIdAdorner( + const cel::expr::SourceInfo& source_info = + cel::expr::SourceInfo::default_instance()) + : source_info_(source_info) {} + + std::string Adorn(const cel::Expr& e) const override { + // source_info_ might be empty on non-macro_calls tests + if (source_info_.macro_calls_size() != 0 && + source_info_.macro_calls().contains(e.id())) { + return absl::StrFormat( + "^#%d:%s#", e.id(), + source_info_.macro_calls().at(e.id()).call_expr().function()); + } + if (e.has_const_expr()) { auto& const_expr = e.const_expr(); - auto reflection = const_expr.GetReflection(); - auto oneof = const_expr.GetDescriptor()->FindOneofByName("constant_kind"); - auto field_desc = reflection->GetOneofFieldDescriptor(const_expr, oneof); - auto enum_desc = field_desc->enum_type(); - if (enum_desc) { - return absl::StrFormat("^#%d:%s#", e.id(), nameChain(enum_desc)); - } else { - return absl::StrFormat("^#%d:%s#", e.id(), field_desc->type_name()); - } + return absl::StrCat("^#", e.id(), ":", ConstantKind(const_expr), "#"); } else { - auto reflection = e.GetReflection(); - auto oneof = e.GetDescriptor()->FindOneofByName("expr_kind"); - auto desc = reflection->GetOneofFieldDescriptor(e, oneof)->message_type(); - return absl::StrFormat("^#%d:%s#", e.id(), nameChain(desc)); + return absl::StrCat("^#", e.id(), ":", ExprKind(e), "#"); } } - std::string adorn(const Expr::CreateStruct::Entry& e) const override { + std::string AdornStructField(const cel::StructExprField& e) const override { return absl::StrFormat("^#%d:Expr.CreateStruct.Entry#", e.id()); } - private: - template - std::string nameChain(const T* descriptor) const { - std::list name_chain{descriptor->name()}; - const google::protobuf::Descriptor* desc = descriptor->containing_type(); - while (desc) { - name_chain.push_front(desc->name()); - desc = desc->containing_type(); - } - return absl::StrJoin(name_chain, "."); + std::string AdornMapEntry(const cel::MapExprEntry& e) const override { + return absl::StrFormat("^#%d:Expr.CreateStruct.Entry#", e.id()); } + + private: + const cel::expr::SourceInfo& source_info_; }; -class LocationAdorner : public testutil::ExpressionAdorner { +class LocationAdorner : public cel::test::ExpressionAdorner { public: - LocationAdorner(const google::api::expr::v1alpha1::SourceInfo& source_info) + explicit LocationAdorner(const cel::expr::SourceInfo& source_info) : source_info_(source_info) {} - absl::optional> getLocation(int64_t id) const { + std::string Adorn(const cel::Expr& e) const override { + return LocationToString(e.id()); + } + + std::string AdornStructField(const cel::StructExprField& e) const override { + return LocationToString(e.id()); + } + + std::string AdornMapEntry(const cel::MapExprEntry& e) const override { + return LocationToString(e.id()); + } + + private: + std::string LocationToString(int64_t id) const { + auto loc = GetLocation(id); + if (loc) { + return absl::StrFormat("^#%d[%d,%d]#", id, loc->first, loc->second); + } else { + return absl::StrFormat("^#%d[NO_POS]#", id); + } + } + + absl::optional> GetLocation(int64_t id) const { absl::optional> location; const auto& positions = source_info_.positions(); if (positions.find(id) == positions.end()) { @@ -891,38 +1451,7 @@ class LocationAdorner : public testutil::ExpressionAdorner { return std::make_pair(line, col); } - std::string adorn(const Expr& e) const override { - auto loc = getLocation(e.id()); - if (loc) { - return absl::StrFormat("^#%d[%d,%d]#", e.id(), loc->first, loc->second); - } else { - return absl::StrFormat("^#%d[NO_POS]#", e.id()); - } - } - - std::string adorn(const Expr::CreateStruct::Entry& e) const override { - auto loc = getLocation(e.id()); - if (loc) { - return absl::StrFormat("^#%d[%d,%d]#", e.id(), loc->first, loc->second); - } else { - return absl::StrFormat("^#%d[NO_POS]#", e.id()); - } - } - - private: - template - std::string nameChain(const T* descriptor) const { - std::list name_chain{descriptor->name()}; - const google::protobuf::Descriptor* desc = descriptor->containing_type(); - while (desc) { - name_chain.push_front(desc->name()); - desc = desc->containing_type(); - } - return absl::StrJoin(name_chain, "."); - } - - private: - const google::api::expr::v1alpha1::SourceInfo& source_info_; + const cel::expr::SourceInfo& source_info_; }; std::string ConvertEnrichedSourceInfoToString( @@ -935,46 +1464,449 @@ std::string ConvertEnrichedSourceInfoToString( return absl::StrJoin(offsets, "^#"); } +std::string ConvertMacroCallsToString( + const cel::expr::SourceInfo& source_info) { + KindAndIdAdorner macro_calls_adorner(source_info); + ExprPrinter w(macro_calls_adorner); + // Use a list so we can sort the macro calls ensuring order for appending + std::vector> macro_calls; + for (auto pair : source_info.macro_calls()) { + // Set ID to the map key for the adorner + pair.second.set_id(pair.first); + macro_calls.push_back(pair); + } + // Sort in reverse because the first macro will have the highest id + absl::c_sort(macro_calls, + [](const std::pair& p1, + const std::pair& p2) { + return p1.first > p2.first; + }); + std::string result = ""; + for (const auto& pair : macro_calls) { + result += w.PrintProto(pair.second) += ",\n"; + } + // substring last ",\n" + return result.substr(0, result.size() - 3); +} + class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); + ParserOptions options; + if (!test_info.M.empty()) { + options.add_macro_calls = true; + } + options.enable_optional_syntax = true; + options.enable_quoted_identifiers = true; - auto result = EnrichedParse(test_info.I, Macro::AllMacros()); + std::vector macros = Macro::AllMacros(); + macros.push_back(cel::OptMapMacro()); + macros.push_back(cel::OptFlatMapMacro()); + auto result = EnrichedParse(test_info.I, macros, "", options); if (test_info.E.empty()) { - EXPECT_TRUE(result.ok()); + ASSERT_THAT(result, IsOk()); } else { - EXPECT_FALSE(result.ok()); - EXPECT_EQ(result.status().message(), test_info.E); + EXPECT_THAT(result, Not(IsOk())); + EXPECT_EQ(test_info.E, result.status().message()); } if (!test_info.P.empty()) { KindAndIdAdorner kind_and_id_adorner; - testutil::ExprPrinter w(kind_and_id_adorner); - std::string adorned_string = w.print(result.value().parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string); + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); + EXPECT_EQ(test_info.P, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.L.empty()) { - LocationAdorner location_adorner( - result.value().parsed_expr().source_info()); - testutil::ExprPrinter w(location_adorner); - std::string adorned_string = w.print(result.value().parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string); + LocationAdorner location_adorner(result->parsed_expr().source_info()); + ExprPrinter w(location_adorner); + std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); + EXPECT_EQ(test_info.L, adorned_string) + << result->parsed_expr().ShortDebugString(); + ; } if (!test_info.R.empty()) { - EXPECT_EQ(ConvertEnrichedSourceInfoToString( - result.value().enriched_source_info()), - test_info.R); + EXPECT_EQ(test_info.R, ConvertEnrichedSourceInfoToString( + result->enriched_source_info())); + } + + if (!test_info.M.empty()) { + EXPECT_EQ(test_info.M, ConvertMacroCallsToString( + result.value().parsed_expr().source_info())) + << result->parsed_expr().ShortDebugString(); + ; } } +TEST(ExpressionTest, CompositeExpressionOffsets) { + ParserOptions options; + std::vector macros = Macro::AllMacros(); + + std::string list_expr = "[1, 2]"; + auto list_result = EnrichedParse(list_expr, macros, "", options); + ASSERT_THAT(list_result, IsOk()); + auto list_offsets = list_result->enriched_source_info().offsets(); + EXPECT_EQ(list_offsets.at(1), std::make_pair(0, 5)); + + std::string map_expr = "{'a': 1}"; + auto map_result = EnrichedParse(map_expr, macros, "", options); + ASSERT_THAT(map_result, IsOk()); + auto map_offsets = map_result->enriched_source_info().offsets(); + EXPECT_EQ(map_offsets.at(1), std::make_pair(0, 7)); + + std::string msg_expr = "Msg{f: 1}"; + auto msg_result = EnrichedParse(msg_expr, macros, "", options); + ASSERT_THAT(msg_result, IsOk()); + auto msg_offsets = msg_result->enriched_source_info().offsets(); + EXPECT_EQ(msg_offsets.at(1), std::make_pair(0, 8)); +} + +TEST(ExpressionTest, TsanOom) { + Parse( + "[[a([[???[a[[??[a([[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" + "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[???[" + "a([[????") + .IgnoreError(); +} + +TEST(ExpressionTest, ErrorRecoveryLimits) { + ParserOptions options; + options.error_recovery_limit = 1; + auto result = Parse("......", "", options); + EXPECT_THAT(result, Not(IsOk())); + EXPECT_EQ(result.status().message(), + "ERROR: :1:1: Syntax error: More than 1 parse errors.\n | ......\n " + "| ^\nERROR: :1:2: Syntax error: no viable alternative at input " + "'..'\n | ......\n | .^"); +} + +TEST(ExpressionTest, ExpressionSizeLimit) { + ParserOptions options; + options.expression_size_codepoint_limit = 10; + auto result = Parse("...............", "", options); + EXPECT_THAT(result, Not(IsOk())); + EXPECT_EQ( + result.status().message(), + "expression size exceeds codepoint limit. input size: 15, limit: 10"); +} + +TEST(ExpressionTest, RecursionDepthLongArgList) { + ParserOptions options; + // The particular number here is an implementation detail: the underlying + // visitor will recurse up to 8 times before branching to the create list or + // const steps. The call graph looks something like: + // visit->visitStart->visit->visitExpr->visit->visitOr->visit->visitAnd->visit + // ->visitRelation->visit->visitCalc->visit->visitUnary->visit->visitPrimary + // ->visitCreateList->visit[arg]->visitExpr... + // The expected max depth for create list with an arbitrary number of elements + // is 15. + options.max_recursion_depth = 16; + + EXPECT_THAT(Parse("[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", "", options), IsOk()); +} + +TEST(ExpressionTest, RecursionDepthExceeded) { + ParserOptions options; + // AST visitor will recurse a variable amount depending on the terms used in + // the expression. This check occurs in the business logic converting the raw + // Antlr parse tree into an Expr. There is a separate check (via a custom + // listener) for AST depth while running the antlr generated parser. + options.max_recursion_depth = 6; + auto result = Parse("1 + 2 + 3 + 4 + 5 + 6 + 7", "", options); + + EXPECT_THAT(result, Not(IsOk())); + EXPECT_THAT(result.status().message(), + HasSubstr("Exceeded max recursion depth of 6 when parsing.")); +} + +TEST(ExpressionTest, DisableQuotedIdentifiers) { + ParserOptions options; + options.enable_quoted_identifiers = false; + auto result = Parse("foo.`bar`", "", options); + + EXPECT_THAT(result, Not(IsOk())); + EXPECT_THAT(result.status().message(), + HasSubstr("ERROR: :1:5: unsupported syntax '`'\n" + " | foo.`bar`\n" + " | ....^")); +} + +TEST(ExpressionTest, DisableStandardMacros) { + ParserOptions options; + options.disable_standard_macros = true; + + auto result = Parse("has(foo.bar)", "", options); + + ASSERT_THAT(result, IsOk()); + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.PrintProto(result->expr()); + EXPECT_EQ(adorned_string, + "has(\n" + " foo^#2:Expr.Ident#.bar^#3:Expr.Select#\n" + ")^#1:Expr.Call#") + << adorned_string; +} + +TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { + ParserOptions options; + options.max_recursion_depth = 6; + auto result = Parse("(((1 + 2 + 3 + 4 + (5 + 6))))", "", options); + + EXPECT_THAT(result, IsOk()); +} + +TEST(NewParserBuilderTest, Defaults) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, + cel::NewSource("has(a.b) && [].exists(x, x > 0)")); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + EXPECT_FALSE(ast->IsChecked()); +} + +TEST(NewParserBuilderTest, CustomMacros) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = true; + ASSERT_THAT(builder->AddMacro(cel::HasMacro()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + builder.reset(); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [].map(x, x)")); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + EXPECT_FALSE(ast->IsChecked()); + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + EXPECT_EQ(w.Print(ast->root_expr()), + "_&&_(\n" + " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" + " []^#5:Expr.CreateList#.map(\n" + " x^#7:Expr.Ident#,\n" + " x^#8:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#9:Expr.Call#"); +} + +TEST(NewParserBuilderTest, StandardMacrosNotAddedWithStdlib) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = false; + // Add a fake stdlib to check that we don't try to add the standard macros + // again. Emulates what happens when we add support for subsetting stdlib by + // ids. + ASSERT_THAT(builder->AddLibrary({"stdlib", + [](cel::ParserBuilder& b) { + return b.AddMacro(cel::HasMacro()); + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + builder.reset(); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [].map(x, x)")); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + EXPECT_FALSE(ast->IsChecked()); + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + EXPECT_EQ(w.Print(ast->root_expr()), + "_&&_(\n" + " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" + " []^#5:Expr.CreateList#.map(\n" + " x^#7:Expr.Ident#,\n" + " x^#8:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#9:Expr.Call#"); +} + +TEST(NewParserBuilderTest, ForwardsOptions) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b")); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); + + builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = false; + ASSERT_OK_AND_ASSIGN(parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(source, cel::NewSource("a.?b")); + EXPECT_THAT(parser->Parse(*source), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(NewParserBuilderTest, ToBuilderCopiesConfig) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = true; + builder->GetOptions().disable_standard_macros = true; + ASSERT_THAT(builder->AddLibrary({"custom_lib", + [](cel::ParserBuilder& b) { + return b.AddMacro(cel::HasMacro()); + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + EXPECT_TRUE(derived_builder->GetOptions().enable_optional_syntax); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b && has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + +TEST(NewParserBuilderTest, ToBuilderHandlesStdlibAndOptionalByLibrary) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = true; + builder->GetOptions().enable_optional_syntax = false; + + // Abusing the library ids for testing. Real uses should use subsetting. + ASSERT_THAT( + builder->AddLibrary( + {"stdlib", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + ASSERT_THAT( + builder->AddLibrary( + {"optional", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + // Should be ignored now. + derived_builder->GetOptions().disable_standard_macros = false; + derived_builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + EXPECT_EQ(w.Print(ast->root_expr()), + "has(\n" + " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" + ")^#1:Expr.Call#"); +} + +TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = false; + builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [?a]")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + +struct VariadicLogicalOperatorsTestCase { + std::string input; + std::string expected_adorned_string; +}; + +class VariadicLogicalOperatorsTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalOperatorsTest, Parse) { + const auto& test_case = GetParam(); + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.input)); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.Print(ast->root_expr()); + EXPECT_EQ(adorned_string, test_case.expected_adorned_string); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalOperators, VariadicLogicalOperatorsTest, + testing::Values( + VariadicLogicalOperatorsTestCase{ + .input = "a && b && c && d", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a || b || c || d", + .expected_adorned_string = "_||_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a && b && (c || d || e)", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " _||_(\n" + " c^#4:Expr.Ident#,\n" + " d^#5:Expr.Ident#,\n" + " e^#7:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#3:Expr.Call#"})); + +TEST(ParserTest, ParseFailurePopulatesIssues) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a +", "test.cel")); + std::vector issues; + auto ast_result = parser->Parse(*source, &issues); + EXPECT_THAT(ast_result, Not(IsOk())); + ASSERT_THAT(issues, testing::SizeIs(1)); + EXPECT_THAT(ast_result.status().message(), + HasSubstr("ERROR: test.cel:1:4: Syntax error: mismatched input " + "'' expecting")); + EXPECT_THAT(issues[0].message(), + HasSubstr("Syntax error: mismatched input '' expecting")); + EXPECT_EQ(issues[0].location().line, 1); + // 0-based, but adjusted to 1-based in error message. + EXPECT_EQ(issues[0].location().column, 3); +} + +std::string TestName(const testing::TestParamInfo& test_info) { + std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); + absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); + return name; + return name; +} + INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, - testing::ValuesIn(test_cases)); + testing::ValuesIn(test_cases), TestName); } // namespace -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::parser diff --git a/parser/source_factory.cc b/parser/source_factory.cc deleted file mode 100644 index c8fea90ec..000000000 --- a/parser/source_factory.cc +++ /dev/null @@ -1,544 +0,0 @@ -#include "parser/source_factory.h" - -#include "google/protobuf/struct.pb.h" -#include "absl/memory/memory.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "common/operators.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { -namespace { - -const int kMaxErrorsToReport = 100; - -using common::CelOperator; -using google::api::expr::v1alpha1::Expr; - -} // namespace - -SourceFactory::SourceFactory(const std::string& expression) - : next_id_(1), num_errors_(0) { - calcLineOffsets(expression); -} - -int64_t SourceFactory::id(const antlr4::Token* token) { - int64_t new_id = next_id_; - positions_.emplace( - new_id, SourceLocation{(int32_t)token->getLine(), - (int32_t)token->getCharPositionInLine(), - (int32_t)token->getStopIndex(), line_offsets_}); - next_id_ += 1; - return new_id; -} - -const SourceFactory::SourceLocation& SourceFactory::getSourceLocation( - int64_t id) const { - return positions_.at(id); -} - -const SourceFactory::SourceLocation SourceFactory::noLocation() { - return SourceLocation(-1, -1, -1, {}); -} - -int64_t SourceFactory::id(antlr4::ParserRuleContext* ctx) { - return id(ctx->getStart()); -} - -int64_t SourceFactory::id(const SourceLocation& location) { - int64_t new_id = next_id_; - positions_.emplace(new_id, location); - next_id_ += 1; - return new_id; -} - -int64_t SourceFactory::nextMacroId(int64_t macro_id) { - return id(getSourceLocation(macro_id)); -} - -Expr SourceFactory::newExpr(int64_t id) { - Expr expr; - expr.set_id(id); - return expr; -} - -Expr SourceFactory::newExpr(antlr4::ParserRuleContext* ctx) { - return newExpr(id(ctx)); -} - -Expr SourceFactory::newExpr(const antlr4::Token* token) { - return newExpr(id(token)); -} - -Expr SourceFactory::newGlobalCall(int64_t id, const std::string& function, - const std::vector& args) { - Expr expr = newExpr(id); - auto call_expr = expr.mutable_call_expr(); - call_expr->set_function(function); - std::for_each(args.begin(), args.end(), [&call_expr](const Expr& e) { - call_expr->add_args()->CopyFrom(e); - }); - return expr; -} - -Expr SourceFactory::newGlobalCallForMacro(int64_t macro_id, - const std::string& function, - const std::vector& args) { - return newGlobalCall(nextMacroId(macro_id), function, args); -} - -Expr SourceFactory::newReceiverCall(int64_t id, const std::string& function, - Expr& target, - const std::vector& args) { - Expr expr = newExpr(id); - auto call_expr = expr.mutable_call_expr(); - call_expr->set_function(function); - call_expr->mutable_target()->CopyFrom(target); - std::for_each(args.begin(), args.end(), [&call_expr](const Expr& e) { - call_expr->add_args()->CopyFrom(e); - }); - return expr; -} - -Expr SourceFactory::newIdent(const antlr4::Token* token, - const std::string& ident_name) { - Expr expr = newExpr(token); - expr.mutable_ident_expr()->set_name(ident_name); - return expr; -} - -Expr SourceFactory::newIdentForMacro(int64_t macro_id, - const std::string& ident_name) { - Expr expr = newExpr(nextMacroId(macro_id)); - expr.mutable_ident_expr()->set_name(ident_name); - return expr; -} - -Expr SourceFactory::newSelect( - ::cel_grammar::CelParser::SelectOrCallContext* ctx, Expr& operand, - const std::string& field) { - Expr expr = newExpr(ctx->op); - auto select_expr = expr.mutable_select_expr(); - select_expr->mutable_operand()->CopyFrom(operand); - select_expr->set_field(field); - return expr; -} - -Expr SourceFactory::newPresenceTestForMacro(int64_t macro_id, const Expr& operand, - const std::string& field) { - Expr expr = newExpr(nextMacroId(macro_id)); - auto select_expr = expr.mutable_select_expr(); - select_expr->mutable_operand()->CopyFrom(operand); - select_expr->set_field(field); - select_expr->set_test_only(true); - return expr; -} - -Expr SourceFactory::newObject( - int64_t obj_id, std::string type_name, - const std::vector& entries) { - auto expr = newExpr(obj_id); - auto struct_expr = expr.mutable_struct_expr(); - struct_expr->set_message_name(type_name); - std::for_each(entries.begin(), entries.end(), - [struct_expr](const Expr::CreateStruct::Entry& e) { - struct_expr->add_entries()->CopyFrom(e); - }); - return expr; -} - -Expr::CreateStruct::Entry SourceFactory::newObjectField( - int64_t field_id, const std::string& field, const Expr& value) { - Expr::CreateStruct::Entry entry; - entry.set_id(field_id); - entry.set_field_key(field); - entry.mutable_value()->CopyFrom(value); - return entry; -} - -Expr SourceFactory::newComprehension(int64_t id, const std::string& iter_var, - const Expr& iter_range, - const std::string& accu_var, - const Expr& accu_init, - const Expr& condition, const Expr& step, - const Expr& result) { - Expr expr = newExpr(id); - auto comp_expr = expr.mutable_comprehension_expr(); - comp_expr->set_iter_var(iter_var); - comp_expr->mutable_iter_range()->CopyFrom(iter_range); - comp_expr->set_accu_var(accu_var); - comp_expr->mutable_accu_init()->CopyFrom(accu_init); - comp_expr->mutable_loop_condition()->CopyFrom(condition); - comp_expr->mutable_loop_step()->CopyFrom(step); - comp_expr->mutable_result()->CopyFrom(result); - return expr; -} - -Expr SourceFactory::foldForMacro(int64_t macro_id, const std::string& iter_var, - const Expr& iter_range, - const std::string& accu_var, - const Expr& accu_init, const Expr& condition, - const Expr& step, const Expr& result) { - return newComprehension(nextMacroId(macro_id), iter_var, iter_range, accu_var, - accu_init, condition, step, result); -} - -Expr SourceFactory::newList(int64_t list_id, const std::vector& elems) { - auto expr = newExpr(list_id); - auto list_expr = expr.mutable_list_expr(); - std::for_each(elems.begin(), elems.end(), [list_expr](const Expr& e) { - list_expr->add_elements()->CopyFrom(e); - }); - return expr; -} - -Expr SourceFactory::newQuantifierExprForMacro( - SourceFactory::QuantifierKind kind, int64_t macro_id, Expr* target, - const std::vector& args) { - if (args.empty()) { - return Expr(); - } - if (!args[0].has_ident_expr()) { - auto loc = getSourceLocation(args[0].id()); - return reportError(loc, "argument must be a simple name"); - } - std::string v = args[0].ident_expr().name(); - - // traditional variable name assigned to the fold accumulator variable. - const std::string AccumulatorName = "__result__"; - - auto accu_ident = [this, ¯o_id, &AccumulatorName]() { - return newIdentForMacro(macro_id, AccumulatorName); - }; - - Expr init; - Expr condition; - Expr step; - Expr result; - switch (kind) { - case QUANTIFIER_ALL: - init = newLiteralBoolForMacro(macro_id, true); - condition = newGlobalCallForMacro( - macro_id, CelOperator::NOT_STRICTLY_FALSE, {accu_ident()}); - step = newGlobalCallForMacro(macro_id, CelOperator::LOGICAL_AND, - {accu_ident(), args[1]}); - result = accu_ident(); - break; - - case QUANTIFIER_EXISTS: - init = newLiteralBoolForMacro(macro_id, false); - condition = newGlobalCallForMacro( - macro_id, CelOperator::NOT_STRICTLY_FALSE, - {newGlobalCallForMacro(macro_id, CelOperator::LOGICAL_NOT, - {accu_ident()})}); - step = newGlobalCallForMacro(macro_id, CelOperator::LOGICAL_OR, - {accu_ident(), args[1]}); - result = accu_ident(); - break; - - case QUANTIFIER_EXISTS_ONE: { - Expr zero_expr = newLiteralIntForMacro(macro_id, 0); - Expr one_expr = newLiteralIntForMacro(macro_id, 1); - init = zero_expr; - condition = newGlobalCallForMacro(macro_id, CelOperator::LESS_EQUALS, - {accu_ident(), one_expr}); - step = newGlobalCallForMacro( - macro_id, CelOperator::CONDITIONAL, - {args[1], - newGlobalCallForMacro(macro_id, CelOperator::ADD, - {accu_ident(), one_expr}), - accu_ident()}); - result = newGlobalCallForMacro(macro_id, CelOperator::EQUALS, - {accu_ident(), one_expr}); - break; - } - } - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, - step, result); -} - -Expr SourceFactory::newFilterExprForMacro(int64_t macro_id, Expr* target, - const std::vector& args) { - if (args.empty()) { - return Expr(); - } - if (!args[0].has_ident_expr()) { - auto loc = getSourceLocation(args[0].id()); - return reportError(loc, "argument is not an identifier"); - } - std::string v = args[0].ident_expr().name(); - - // traditional variable name assigned to the fold accumulator variable. - const std::string AccumulatorName = "__result__"; - - Expr filter = args[1]; - Expr accu_expr = newIdentForMacro(macro_id, AccumulatorName); - Expr init = newListForMacro(macro_id, {}); - Expr condition = newLiteralBoolForMacro(macro_id, true); - Expr step = - newGlobalCallForMacro(macro_id, CelOperator::ADD, - {accu_expr, newListForMacro(macro_id, {args[0]})}); - step = newGlobalCallForMacro(macro_id, CelOperator::CONDITIONAL, - {filter, step, accu_expr}); - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, - step, accu_expr); -} - -Expr SourceFactory::newListForMacro(int64_t macro_id, - const std::vector& elems) { - return newList(nextMacroId(macro_id), elems); -} - -Expr SourceFactory::newMap( - int64_t map_id, const std::vector& entries) { - auto expr = newExpr(map_id); - auto struct_expr = expr.mutable_struct_expr(); - std::for_each(entries.begin(), entries.end(), - [struct_expr](const Expr::CreateStruct::Entry& e) { - struct_expr->add_entries()->CopyFrom(e); - }); - return expr; -} - -Expr SourceFactory::newMapForMacro(int64_t macro_id, Expr* target, - const std::vector& args) { - if (args.empty()) { - return Expr(); - } - if (!args[0].has_ident_expr()) { - auto loc = getSourceLocation(args[0].id()); - return reportError(loc, "argument is not an identifier"); - } - std::string v = args[0].ident_expr().name(); - - Expr fn; - Expr filter; - bool has_filter = false; - if (args.size() == 3) { - filter = args[1]; - has_filter = true; - fn = args[2]; - } else { - fn = args[1]; - } - - // traditional variable name assigned to the fold accumulator variable. - const std::string AccumulatorName = "__result__"; - - Expr accu_expr = newIdentForMacro(macro_id, AccumulatorName); - Expr init = newListForMacro(macro_id, {}); - Expr condition = newLiteralBoolForMacro(macro_id, true); - Expr step = newGlobalCallForMacro( - macro_id, CelOperator::ADD, {accu_expr, newListForMacro(macro_id, {fn})}); - if (has_filter) { - step = newGlobalCallForMacro(macro_id, CelOperator::CONDITIONAL, - {filter, step, accu_expr}); - } - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, - step, accu_expr); -} - -Expr::CreateStruct::Entry SourceFactory::newMapEntry(int64_t entry_id, - const Expr& key, - const Expr& value) { - Expr::CreateStruct::Entry entry; - entry.set_id(entry_id); - entry.mutable_map_key()->CopyFrom(key); - entry.mutable_value()->CopyFrom(value); - return entry; -} - -Expr SourceFactory::newLiteralInt(antlr4::ParserRuleContext* ctx, int64_t value) { - Expr expr = newExpr(ctx); - expr.mutable_const_expr()->set_int64_value(value); - return expr; -} - -Expr SourceFactory::newLiteralIntForMacro(int64_t macro_id, int64_t value) { - Expr expr = newExpr(nextMacroId(macro_id)); - expr.mutable_const_expr()->set_int64_value(value); - return expr; -} - -Expr SourceFactory::newLiteralUint(antlr4::ParserRuleContext* ctx, - uint64_t value) { - Expr expr = newExpr(ctx); - expr.mutable_const_expr()->set_uint64_value(value); - return expr; -} - -Expr SourceFactory::newLiteralDouble(antlr4::ParserRuleContext* ctx, - double value) { - Expr expr = newExpr(ctx); - expr.mutable_const_expr()->set_double_value(value); - return expr; -} - -Expr SourceFactory::newLiteralString(antlr4::ParserRuleContext* ctx, - const std::string& s) { - Expr expr = newExpr(ctx); - expr.mutable_const_expr()->set_string_value(s); - return expr; -} - -Expr SourceFactory::newLiteralBytes(antlr4::ParserRuleContext* ctx, - const std::string& b) { - Expr expr = newExpr(ctx); - expr.mutable_const_expr()->set_bytes_value(b); - return expr; -} - -Expr SourceFactory::newLiteralBool(antlr4::ParserRuleContext* ctx, bool b) { - Expr expr = newExpr(ctx); - expr.mutable_const_expr()->set_bool_value(b); - return expr; -} - -Expr SourceFactory::newLiteralBoolForMacro(int64_t macro_id, bool b) { - Expr expr = newExpr(nextMacroId(macro_id)); - expr.mutable_const_expr()->set_bool_value(b); - return expr; -} - -Expr SourceFactory::newLiteralNull(antlr4::ParserRuleContext* ctx) { - Expr expr = newExpr(ctx); - expr.mutable_const_expr()->set_null_value(::google::protobuf::NULL_VALUE); - return expr; -} - -Expr SourceFactory::reportError(antlr4::ParserRuleContext* ctx, - const std::string& msg) { - num_errors_ += 1; - Expr expr = newExpr(ctx); - if (errors_truncated_.size() < kMaxErrorsToReport) { - errors_truncated_.emplace_back(msg, positions_.at(expr.id())); - } - return expr; -} - -Expr SourceFactory::reportError(int32_t line, int32_t col, const std::string& msg) { - num_errors_ += 1; - SourceLocation loc(line, col, /*offset_end=*/-1, line_offsets_); - if (errors_truncated_.size() < kMaxErrorsToReport) { - errors_truncated_.emplace_back(msg, loc); - } - return newExpr(id(loc)); -} - -Expr SourceFactory::reportError(const SourceFactory::SourceLocation& loc, - const std::string& msg) { - num_errors_ += 1; - if (errors_truncated_.size() < kMaxErrorsToReport) { - errors_truncated_.emplace_back(msg, loc); - } - return newExpr(id(loc)); -} - -std::string SourceFactory::errorMessage(const std::string& description, - const std::string& expression) const { - std::vector messages; - std::transform( - errors_truncated_.begin(), errors_truncated_.end(), - std::back_inserter(messages), - [this, description, expression](const SourceFactory::Error& error) { - std::string s = absl::StrFormat("ERROR: %s:%zu:%zu: %s", description, - error.location.line, - // add one to the 0-based column - error.location.col + 1, error.message); - std::string snippet = getSourceLine(error.location.line, expression); - std::string::size_type pos = 0; - while ((pos = snippet.find("\t", pos)) != std::string::npos) { - snippet.replace(pos, 1, " "); - } - std::string src_line = "\n | " + snippet; - std::string ind_line = "\n | "; - for (int i = 0; i < error.location.col; ++i) { - ind_line += "."; - } - ind_line += "^"; - s += src_line + ind_line; - return s; - }); - if (num_errors_ > kMaxErrorsToReport) { - messages.emplace_back(absl::StrCat(num_errors_ - kMaxErrorsToReport, - " more errors were truncated.")); - } - return absl::StrJoin(messages, "\n"); -} - -bool SourceFactory::isReserved(const std::string& ident_name) { - static std::vector reserved_words = { - "as", "break", "const", "continue", "else", "false", "for", - "function", "if", "import", "in", "let", "loop", "package", - "namespace", "null", "return", "true", "var", "void", "while"}; - return std::find(reserved_words.begin(), reserved_words.end(), ident_name) != - reserved_words.end(); -} - -google::api::expr::v1alpha1::SourceInfo SourceFactory::sourceInfo() const { - google::api::expr::v1alpha1::SourceInfo source_info; - source_info.set_location(""); - auto positions = source_info.mutable_positions(); - std::for_each(positions_.begin(), positions_.end(), - [positions](const std::pair& loc) { - positions->insert({loc.first, loc.second.offset}); - }); - std::for_each( - line_offsets_.begin(), line_offsets_.end(), - [&source_info](int32_t offset) { source_info.add_line_offsets(offset); }); - return source_info; -} - -EnrichedSourceInfo SourceFactory::enrichedSourceInfo() const { - std::map> offset; - std::for_each( - positions_.begin(), positions_.end(), - [&offset](const std::pair& loc) { - offset.insert({loc.first, {loc.second.offset, loc.second.offset_end}}); - }); - return EnrichedSourceInfo(offset); -} - -void SourceFactory::calcLineOffsets(const std::string& expression) { - std::vector lines = absl::StrSplit(expression, '\n'); - int offset = 0; - line_offsets_.resize(lines.size()); - for (size_t i = 0; i < lines.size(); ++i) { - offset += lines[i].size() + 1; - line_offsets_[i] = offset; - } -} - -absl::optional SourceFactory::findLineOffset(int32_t line) const { - // note that err.line is 1-based, - // while we need the 0-based index - if (line == 1) { - return 0; - } else if (line > 1 && line <= static_cast(line_offsets_.size())) { - return line_offsets_[line - 2]; - } - return {}; -} - -std::string SourceFactory::getSourceLine(int32_t line, - const std::string& expression) const { - auto char_start = findLineOffset(line); - if (!char_start) { - return ""; - } - auto char_end = findLineOffset(line + 1); - if (char_end) { - return expression.substr(*char_start, *char_end - *char_end - 1); - } else { - return expression.substr(*char_start); - } -} - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google diff --git a/parser/source_factory.h b/parser/source_factory.h index 79d766f45..501e1017a 100644 --- a/parser/source_factory.h +++ b/parser/source_factory.h @@ -1,26 +1,37 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_PARSER_SOURCE_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_PARSER_SOURCE_FACTORY_H_ -#include +#include +#include #include -#include - -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/types/optional.h" -#include "parser/cel_grammar.inc/cel_grammar/CelParser.h" -#include "antlr4-runtime.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { -using google::api::expr::v1alpha1::Expr; +namespace google::api::expr::parser { class EnrichedSourceInfo { public: - EnrichedSourceInfo(const std::map>& offsets) - : offsets_(offsets) {} + explicit EnrichedSourceInfo( + std::map> offsets) + : offsets_(std::move(offsets)) {} + + EnrichedSourceInfo() = default; + EnrichedSourceInfo(const EnrichedSourceInfo& other) = default; + EnrichedSourceInfo& operator=(const EnrichedSourceInfo& other) = default; + EnrichedSourceInfo(EnrichedSourceInfo&& other) = default; + EnrichedSourceInfo& operator=(EnrichedSourceInfo&& other) = default; const std::map>& offsets() const { return offsets_; @@ -31,134 +42,6 @@ class EnrichedSourceInfo { std::map> offsets_; }; -// Provide tools to generate expressions during parsing. -// Keeps track of ID and source location information. -// Shares functionality with //third_party/cel/go/parser/helper.go -class SourceFactory { - public: - struct SourceLocation { - SourceLocation(int32_t line, int32_t col, int32_t offset_end, - const std::vector& line_offsets) - : line(line), col(col), offset_end(offset_end) { - if (line == 1) { - offset = col; - } else if (line > 1) { - offset = line_offsets[line - 2] + col; - } else { - offset = -1; - } - } - int32_t line; - int32_t col; - int32_t offset_end; - int32_t offset; - }; - - struct Error { - Error(std::string message, SourceLocation location) - : message(message), location(location) {} - std::string message; - SourceLocation location; - }; - - enum QuantifierKind { - QUANTIFIER_ALL, - QUANTIFIER_EXISTS, - QUANTIFIER_EXISTS_ONE - }; - - SourceFactory(const std::string& expression); - - int64_t id(const antlr4::Token* token); - int64_t id(antlr4::ParserRuleContext* ctx); - int64_t id(const SourceLocation& location); - - int64_t nextMacroId(int64_t macro_id); - - const SourceLocation& getSourceLocation(int64_t id) const; - - static const SourceLocation noLocation(); - - Expr newExpr(int64_t id); - Expr newExpr(antlr4::ParserRuleContext* ctx); - Expr newExpr(const antlr4::Token* token); - Expr newGlobalCall(int64_t id, const std::string& function, - const std::vector& args); - Expr newGlobalCallForMacro(int64_t macro_id, const std::string& function, - const std::vector& args); - Expr newReceiverCall(int64_t id, const std::string& function, Expr& target, - const std::vector& args); - Expr newIdent(const antlr4::Token* token, const std::string& ident_name); - Expr newIdentForMacro(int64_t macro_id, const std::string& ident_name); - Expr newSelect(::cel_grammar::CelParser::SelectOrCallContext* ctx, - Expr& operand, const std::string& field); - Expr newPresenceTestForMacro(int64_t macro_id, const Expr& operand, - const std::string& field); - Expr newObject(int64_t obj_id, std::string type_name, - const std::vector& entries); - Expr::CreateStruct::Entry newObjectField(int64_t field_id, - const std::string& field, - const Expr& value); - Expr newComprehension(int64_t id, const std::string& iter_var, - const Expr& iter_range, const std::string& accu_var, - const Expr& accu_init, const Expr& condition, - const Expr& step, const Expr& result); - - Expr foldForMacro(int64_t macro_id, const std::string& iter_var, - const Expr& iter_range, const std::string& accu_var, - const Expr& accu_init, const Expr& condition, - const Expr& step, const Expr& result); - Expr newQuantifierExprForMacro(QuantifierKind kind, int64_t macro_id, - Expr* target, const std::vector& args); - Expr newFilterExprForMacro(int64_t macro_id, Expr* target, - const std::vector& args); - - Expr newList(int64_t list_id, const std::vector& elems); - Expr newListForMacro(int64_t macro_id, const std::vector& elems); - Expr newMap(int64_t map_id, - const std::vector& entries); - Expr newMapForMacro(int64_t macro_id, Expr* target, - const std::vector& args); - Expr::CreateStruct::Entry newMapEntry(int64_t entry_id, const Expr& key, - const Expr& value); - Expr newLiteralInt(antlr4::ParserRuleContext* ctx, int64_t value); - Expr newLiteralIntForMacro(int64_t macro_id, int64_t value); - Expr newLiteralUint(antlr4::ParserRuleContext* ctx, uint64_t value); - Expr newLiteralDouble(antlr4::ParserRuleContext* ctx, double value); - Expr newLiteralString(antlr4::ParserRuleContext* ctx, const std::string& s); - Expr newLiteralBytes(antlr4::ParserRuleContext* ctx, const std::string& b); - Expr newLiteralBool(antlr4::ParserRuleContext* ctx, bool b); - Expr newLiteralBoolForMacro(int64_t macro_id, bool b); - Expr newLiteralNull(antlr4::ParserRuleContext* ctx); - - Expr reportError(antlr4::ParserRuleContext* ctx, const std::string& msg); - Expr reportError(int32_t line, int32_t col, const std::string& msg); - Expr reportError(const SourceLocation& loc, const std::string& msg); - - bool isReserved(const std::string& ident_name); - google::api::expr::v1alpha1::SourceInfo sourceInfo() const; - EnrichedSourceInfo enrichedSourceInfo() const; - const std::vector& errors() const { return errors_truncated_; } - std::string errorMessage(const std::string& description, - const std::string& expression) const; - - private: - void calcLineOffsets(const std::string& expression); - absl::optional findLineOffset(int32_t line) const; - std::string getSourceLine(int32_t line, const std::string& expression) const; - - private: - int64_t next_id_; - std::map positions_; - // Truncated at kMaxErrorsToReport. - std::vector errors_truncated_; - int64_t num_errors_; - std::vector line_offsets_; -}; - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::parser #endif // THIRD_PARTY_CEL_CPP_PARSER_SOURCE_FACTORY_H_ diff --git a/parser/standard_macros.cc b/parser/standard_macros.cc new file mode 100644 index 000000000..15069d45b --- /dev/null +++ b/parser/standard_macros.cc @@ -0,0 +1,41 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/standard_macros.h" + +#include "absl/status/status.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel { + +absl::Status RegisterStandardMacros(MacroRegistry& registry, + const ParserOptions& options) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(HasMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(AllMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsOneMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(Map2Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(Map3Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(FilterMacro())); + if (options.enable_optional_syntax) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(OptMapMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(OptFlatMapMacro())); + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/parser/standard_macros.h b/parser/standard_macros.h new file mode 100644 index 000000000..2f3b28563 --- /dev/null +++ b/parser/standard_macros.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_PARSER_STANDARD_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_STANDARD_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel { + +// Registers the standard macros defined by the Common Expression Language. +// https://github.com/google/cel-spec/blob/master/doc/langdef.md#macros +absl::Status RegisterStandardMacros(MacroRegistry& registry, + const ParserOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_PARSER_STANDARD_MACROS_H_ diff --git a/parser/standard_macros_test.cc b/parser/standard_macros_test.cc new file mode 100644 index 000000000..a79390f06 --- /dev/null +++ b/parser/standard_macros_test.cc @@ -0,0 +1,95 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/standard_macros.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/source.h" +#include "internal/testing.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::google::api::expr::parser::EnrichedParse; +using ::testing::HasSubstr; + +struct StandardMacrosTestCase { + std::string expression; + std::string error; +}; + +using StandardMacrosTest = ::testing::TestWithParam; + +TEST_P(StandardMacrosTest, Errors) { + const auto& test_param = GetParam(); + ASSERT_OK_AND_ASSIGN(auto source, NewSource(test_param.expression)); + + ParserOptions options; + options.enable_optional_syntax = true; + + MacroRegistry registry; + ASSERT_THAT(RegisterStandardMacros(registry, options), IsOk()); + + EXPECT_THAT(EnrichedParse(*source, registry, options), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_param.error))); +} + +INSTANTIATE_TEST_SUITE_P( + StandardMacrosTest, StandardMacrosTest, + ::testing::ValuesIn({ + { + .expression = "[].all(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists_one(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].map(__result__, __result__)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].map(__result__, true, __result__)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].filter(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "foo.optMap(__result__, __result__)", + .error = "variable name cannot be __result__", + }, + { + .expression = "foo.optFlatMap(__result__, __result__)", + .error = "variable name cannot be __result__", + }, + })); + +} // namespace +} // namespace cel diff --git a/parser/visitor.cc b/parser/visitor.cc deleted file mode 100644 index 7ba3a0917..000000000 --- a/parser/visitor.cc +++ /dev/null @@ -1,564 +0,0 @@ -#include "parser/visitor.h" - -#include - -#include "google/protobuf/struct.pb.h" -#include "absl/memory/memory.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "common/escaping.h" -#include "common/operators.h" -#include "parser/balancer.h" -#include "parser/source_factory.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { - -using common::CelOperator; -using common::ReverseLookupOperator; - -using ::cel_grammar::CelParser; -using google::api::expr::v1alpha1::Expr; - -ParserVisitor::ParserVisitor(const std::string& description, - const std::string& expression, - const int max_recursion_depth, - const std::vector& macros) - : description_(description), - expression_(expression), - sf_(std::make_shared(expression)), - recursion_depth_(0), - max_recursion_depth_(max_recursion_depth) { - for (const auto& m : macros) { - macros_.emplace(m.macroKey(), m); - } -} - -ParserVisitor::~ParserVisitor() {} - -template ::value>> -T* tree_as(antlr4::tree::ParseTree* tree) { - return dynamic_cast(tree); -} - -antlrcpp::Any ParserVisitor::visit(antlr4::tree::ParseTree* tree) { - recursion_depth_ += 1; - if (recursion_depth_ > max_recursion_depth_) { - return sf_->reportError( - SourceFactory::noLocation(), - absl::StrFormat("Exceeded max recursion depth of %d when parsing.", - max_recursion_depth_)); - } - if (auto* ctx = tree_as(tree)) { - return visitStart(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitExpr(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitConditionalAnd(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitConditionalOr(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitRelation(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitCalc(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitLogicalNot(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitPrimaryExpr(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitMemberExpr(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitSelectOrCall(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitMapInitializerList(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitNegate(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitIndex(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitUnary(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitCreateList(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitCreateMessage(ctx); - } else if (auto* ctx = tree_as(tree)) { - return visitCreateStruct(ctx); - } - - if (tree) { - return sf_->reportError(tree_as(tree), - "unknown parsetree type"); - } - return sf_->reportError(SourceFactory::noLocation(), "<> parsetree"); -} - -antlrcpp::Any ParserVisitor::visitPrimaryExpr( - CelParser::PrimaryExprContext* pctx) { - CelParser::PrimaryContext* primary = pctx->primary(); - if (auto* ctx = tree_as(primary)) { - return visitNested(ctx); - } else if (auto* ctx = - tree_as(primary)) { - return visitIdentOrGlobalCall(ctx); - } else if (auto* ctx = tree_as(primary)) { - return visitCreateList(ctx); - } else if (auto* ctx = tree_as(primary)) { - return visitCreateStruct(ctx); - } else if (auto* ctx = tree_as(primary)) { - return visitConstantLiteral(ctx); - } - return sf_->reportError(pctx, "invalid primary expression"); -} - -antlrcpp::Any ParserVisitor::visitMemberExpr( - CelParser::MemberExprContext* mctx) { - CelParser::MemberContext* member = mctx->member(); - if (auto* ctx = tree_as(member)) { - return visitPrimaryExpr(ctx); - } else if (auto* ctx = tree_as(member)) { - return visitSelectOrCall(ctx); - } else if (auto* ctx = tree_as(member)) { - return visitIndex(ctx); - } else if (auto* ctx = tree_as(member)) { - return visitCreateMessage(ctx); - } - return sf_->reportError(mctx, "unsupported simple expression"); -} - -antlrcpp::Any ParserVisitor::visitStart(CelParser::StartContext* ctx) { - return visit(ctx->expr()); -} - -antlrcpp::Any ParserVisitor::visitExpr(CelParser::ExprContext* ctx) { - auto result = visit(ctx->e); - if (!ctx->op) { - return result; - } - int64_t op_id = sf_->id(ctx->op); - Expr if_true = visit(ctx->e1); - Expr if_false = visit(ctx->e2); - - return globalCallOrMacro(op_id, CelOperator::CONDITIONAL, - {result, if_true, if_false}); -} - -antlrcpp::Any ParserVisitor::visitConditionalOr( - CelParser::ConditionalOrContext* ctx) { - auto result = visit(ctx->e); - if (ctx->ops.empty()) { - return result; - } - ExpressionBalancer b(sf_, CelOperator::LOGICAL_OR, result); - for (size_t i = 0; i < ctx->ops.size(); ++i) { - auto op = ctx->ops[i]; - if (i >= ctx->e1.size()) { - return sf_->reportError(ctx, "unexpected character, wanted '||'"); - } - auto next = visit(ctx->e1[i]).as(); - int64_t op_id = sf_->id(op); - b.addTerm(op_id, next); - } - return b.balance(); -} - -antlrcpp::Any ParserVisitor::visitConditionalAnd( - CelParser::ConditionalAndContext* ctx) { - auto result = visit(ctx->e); - if (ctx->ops.empty()) { - return result; - } - ExpressionBalancer b(sf_, CelOperator::LOGICAL_AND, result); - for (size_t i = 0; i < ctx->ops.size(); ++i) { - auto op = ctx->ops[i]; - if (i >= ctx->e1.size()) { - return sf_->reportError(ctx, "unexpected character, wanted '&&'"); - } - auto next = visit(ctx->e1[i]).as(); - int64_t op_id = sf_->id(op); - b.addTerm(op_id, next); - } - return b.balance(); -} - -antlrcpp::Any ParserVisitor::visitRelation(CelParser::RelationContext* ctx) { - if (ctx->calc()) { - return visit(ctx->calc()); - } - std::string op_text; - if (ctx->op) { - op_text = ctx->op->getText(); - } - auto op = ReverseLookupOperator(op_text); - if (op) { - auto lhs = visit(ctx->relation(0)).as(); - int64_t op_id = sf_->id(ctx->op); - auto rhs = visit(ctx->relation(1)).as(); - return globalCallOrMacro(op_id, *op, {lhs, rhs}); - } - return sf_->reportError(ctx, "operator not found"); -} - -antlrcpp::Any ParserVisitor::visitCalc(CelParser::CalcContext* ctx) { - if (ctx->unary()) { - return visit(ctx->unary()); - } - std::string op_text; - if (ctx->op) { - op_text = ctx->op->getText(); - } - auto op = ReverseLookupOperator(op_text); - if (op) { - auto lhs = visit(ctx->calc(0)).as(); - int64_t op_id = sf_->id(ctx->op); - auto rhs = visit(ctx->calc(1)).as(); - return globalCallOrMacro(op_id, *op, {lhs, rhs}); - } - return sf_->reportError(ctx, "operator not found"); -} - -antlrcpp::Any ParserVisitor::visitUnary(CelParser::UnaryContext* ctx) { - return sf_->newLiteralString(ctx, "<>"); -} - -antlrcpp::Any ParserVisitor::visitLogicalNot( - CelParser::LogicalNotContext* ctx) { - if (ctx->ops.size() % 2 == 0) { - return visit(ctx->member()); - } - int64_t op_id = sf_->id(ctx->ops[0]); - auto target = visit(ctx->member()); - return globalCallOrMacro(op_id, CelOperator::LOGICAL_NOT, {target}); -} - -antlrcpp::Any ParserVisitor::visitNegate(CelParser::NegateContext* ctx) { - if (ctx->ops.size() % 2 == 0) { - return visit(ctx->member()); - } - int64_t op_id = sf_->id(ctx->ops[0]); - auto target = visit(ctx->member()); - return globalCallOrMacro(op_id, CelOperator::NEGATE, {target}); -} - -antlrcpp::Any ParserVisitor::visitSelectOrCall( - CelParser::SelectOrCallContext* ctx) { - auto operand = visit(ctx->member()).as(); - // Handle the error case where no valid identifier is specified. - if (!ctx->id) { - return sf_->newExpr(ctx); - } - auto id = ctx->id->getText(); - if (ctx->open) { - int64_t op_id = sf_->id(ctx->open); - return receiverCallOrMacro(op_id, id, operand, visitList(ctx->args)); - } - return sf_->newSelect(ctx, operand, id); -} - -antlrcpp::Any ParserVisitor::visitIndex(CelParser::IndexContext* ctx) { - auto target = visit(ctx->member()).as(); - int64_t op_id = sf_->id(ctx->op); - auto index = visit(ctx->index).as(); - return globalCallOrMacro(op_id, CelOperator::INDEX, {target, index}); -} - -antlrcpp::Any ParserVisitor::visitCreateMessage( - CelParser::CreateMessageContext* ctx) { - auto target = visit(ctx->member()).as(); - int64_t obj_id = sf_->id(ctx->op); - std::string message_name = extractQualifiedName(ctx, &target); - if (!message_name.empty()) { - auto entries = visitFieldInitializerList(ctx->entries) - .as>(); - return sf_->newObject(obj_id, message_name, entries); - } else { - return sf_->newExpr(obj_id); - } -} - -antlrcpp::Any ParserVisitor::visitFieldInitializerList( - CelParser::FieldInitializerListContext* ctx) { - std::vector res; - if (!ctx || ctx->fields.empty()) { - return res; - } - - res.resize(ctx->fields.size()); - for (size_t i = 0; i < ctx->fields.size(); ++i) { - if (i >= ctx->cols.size() || i >= ctx->values.size()) { - // This is the result of a syntax error detected elsewhere. - return res; - } - const auto& f = ctx->fields[i]; - int64_t init_id = sf_->id(ctx->cols[i]); - auto value = visit(ctx->values[i]).as(); - auto field = sf_->newObjectField(init_id, f->getText(), value); - res[i] = field; - } - - return res; -} - -antlrcpp::Any ParserVisitor::visitIdentOrGlobalCall( - CelParser::IdentOrGlobalCallContext* ctx) { - std::string ident_name; - if (ctx->leadingDot) { - ident_name = "."; - } - if (!ctx->id) { - return sf_->newExpr(ctx); - } - if (sf_->isReserved(ctx->id->getText())) { - return sf_->reportError( - ctx, absl::StrFormat("reserved identifier: %s", ctx->id->getText())); - } - // check if ID is in reserved identifiers - ident_name += ctx->id->getText(); - if (ctx->op) { - int64_t op_id = sf_->id(ctx->op); - return globalCallOrMacro(op_id, ident_name, visitList(ctx->args)); - } - return sf_->newIdent(ctx->id, ident_name); -} - -antlrcpp::Any ParserVisitor::visitNested(CelParser::NestedContext* ctx) { - return visit(ctx->e); -} - -antlrcpp::Any ParserVisitor::visitCreateList( - CelParser::CreateListContext* ctx) { - int64_t list_id = sf_->id(ctx->op); - return sf_->newList(list_id, visitList(ctx->elems)); -} - -std::vector ParserVisitor::visitList(CelParser::ExprListContext* ctx) { - std::vector rv; - if (!ctx) return rv; - std::transform(ctx->e.begin(), ctx->e.end(), std::back_inserter(rv), - [this](CelParser::ExprContext* expr_ctx) { - return visitExpr(expr_ctx).as(); - }); - return rv; -} - -antlrcpp::Any ParserVisitor::visitCreateStruct( - CelParser::CreateStructContext* ctx) { - int64_t struct_id = sf_->id(ctx->op); - std::vector entries; - if (ctx->entries) { - entries = visitMapInitializerList(ctx->entries) - .as>(); - } - return sf_->newMap(struct_id, entries); -} - -antlrcpp::Any ParserVisitor::visitConstantLiteral( - CelParser::ConstantLiteralContext* clctx) { - CelParser::LiteralContext* literal = clctx->literal(); - if (auto* ctx = tree_as(literal)) { - return visitInt(ctx); - } else if (auto* ctx = tree_as(literal)) { - return visitUint(ctx); - } else if (auto* ctx = tree_as(literal)) { - return visitDouble(ctx); - } else if (auto* ctx = tree_as(literal)) { - return visitString(ctx); - } else if (auto* ctx = tree_as(literal)) { - return visitBytes(ctx); - } else if (auto* ctx = tree_as(literal)) { - return visitBoolFalse(ctx); - } else if (auto* ctx = tree_as(literal)) { - return visitBoolTrue(ctx); - } else if (auto* ctx = tree_as(literal)) { - return visitNull(ctx); - } - return sf_->reportError(clctx, "invalid constant literal expression"); -} - -antlrcpp::Any ParserVisitor::visitMapInitializerList( - CelParser::MapInitializerListContext* ctx) { - std::vector res; - if (!ctx || ctx->keys.empty()) { - return res; - } - - res.resize(ctx->cols.size()); - for (size_t i = 0; i < ctx->cols.size(); ++i) { - int64_t col_id = sf_->id(ctx->cols[i]); - auto key = visit(ctx->keys[i]); - auto value = visit(ctx->values[i]); - res[i] = sf_->newMapEntry(col_id, key, value); - } - return res; -} - -antlrcpp::Any ParserVisitor::visitInt(CelParser::IntContext* ctx) { - std::string value; - if (ctx->sign) { - value = ctx->sign->getText(); - } - value += ctx->tok->getText(); - int64_t int_value; - if (absl::SimpleAtoi(value, &int_value)) { - return sf_->newLiteralInt(ctx, int_value); - } else { - return sf_->reportError(ctx, "invalid int literal"); - } -} - -antlrcpp::Any ParserVisitor::visitUint(CelParser::UintContext* ctx) { - std::string value = ctx->tok->getText(); - // trim the 'u' designator included in the uint literal. - if (!value.empty()) { - value.resize(value.size() - 1); - } - uint64_t uint_value; - if (absl::SimpleAtoi(value, &uint_value)) { - return sf_->newLiteralUint(ctx, uint_value); - } else { - return sf_->reportError(ctx, "invalid uint literal"); - } -} - -antlrcpp::Any ParserVisitor::visitDouble(CelParser::DoubleContext* ctx) { - std::string value; - if (ctx->sign) { - value = ctx->sign->getText(); - } - value += ctx->tok->getText(); - double double_value; - if (absl::SimpleAtod(value, &double_value)) { - return sf_->newLiteralDouble(ctx, double_value); - } else { - return sf_->reportError(ctx, "invalid double literal"); - } -} - -antlrcpp::Any ParserVisitor::visitString(CelParser::StringContext* ctx) { - std::string value = unquote(ctx, ctx->tok->getText(), /* is bytes */ false); - return sf_->newLiteralString(ctx, value); -} - -antlrcpp::Any ParserVisitor::visitBytes(CelParser::BytesContext* ctx) { - std::string value = unquote(ctx, ctx->tok->getText().substr(1), - /* is bytes */ true); - return sf_->newLiteralBytes(ctx, value); -} - -antlrcpp::Any ParserVisitor::visitBoolTrue(CelParser::BoolTrueContext* ctx) { - return sf_->newLiteralBool(ctx, true); -} - -antlrcpp::Any ParserVisitor::visitBoolFalse(CelParser::BoolFalseContext* ctx) { - return sf_->newLiteralBool(ctx, false); -} - -antlrcpp::Any ParserVisitor::visitNull(CelParser::NullContext* ctx) { - return sf_->newLiteralNull(ctx); -} - -google::api::expr::v1alpha1::SourceInfo ParserVisitor::sourceInfo() const { - return sf_->sourceInfo(); -} - -EnrichedSourceInfo ParserVisitor::enrichedSourceInfo() const { - return sf_->enrichedSourceInfo(); -} - -void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, - antlr4::Token* offending_symbol, size_t line, - size_t col, const std::string& msg, - std::exception_ptr e) { - sf_->reportError(line, col, "Syntax error: " + msg); -} - -bool ParserVisitor::hasErrored() const { return !sf_->errors().empty(); } - -std::string ParserVisitor::errorMessage() const { - return sf_->errorMessage(description_, expression_); -} - -Expr ParserVisitor::globalCallOrMacro(int64_t expr_id, std::string function, - std::vector args) { - Expr macro_expr; - if (expandMacro(expr_id, function, nullptr, args, ¯o_expr)) { - return macro_expr; - } - - return sf_->newGlobalCall(expr_id, function, args); -} - -Expr ParserVisitor::receiverCallOrMacro(int64_t expr_id, std::string function, - Expr target, std::vector args) { - Expr macro_expr; - if (expandMacro(expr_id, function, &target, args, ¯o_expr)) { - return macro_expr; - } - - return sf_->newReceiverCall(expr_id, function, target, args); -} - -bool ParserVisitor::expandMacro(int64_t expr_id, std::string function, - Expr* target, std::vector args, - Expr* macro_expr) { - std::string macro_key = absl::StrFormat("%s:%d:%s", function, args.size(), - target ? "true" : "false"); - auto m = macros_.find(macro_key); - if (m == macros_.end()) { - std::string var_arg_macro_key = - absl::StrFormat("%s:*:%s", function, target ? "true" : "false"); - m = macros_.find(var_arg_macro_key); - if (m == macros_.end()) { - return false; - } - } - - Expr expr = m->second.expand(sf_, expr_id, target, args); - if (expr.expr_kind_case() != Expr::EXPR_KIND_NOT_SET) { - macro_expr->CopyFrom(expr); - return true; - } - return false; -} - -std::string ParserVisitor::unquote(antlr4::ParserRuleContext* ctx, - const std::string& s, bool is_bytes) { - auto text = unescape(s, is_bytes); - if (!text) { - sf_->reportError(ctx, "failed to unquote"); - return s; - } - return *text; -} - -std::string ParserVisitor::extractQualifiedName(antlr4::ParserRuleContext* ctx, - const Expr* e) { - if (!e) { - return ""; - } - - switch (e->expr_kind_case()) { - case Expr::kIdentExpr: - return e->ident_expr().name(); - case Expr::kSelectExpr: { - auto& s = e->select_expr(); - std::string prefix = extractQualifiedName(ctx, &s.operand()); - if (!prefix.empty()) { - return prefix + "." + s.field(); - } - } break; - default: - break; - } - sf_->reportError(sf_->getSourceLocation(e->id()), - "expected a qualified name"); - return ""; -} - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google diff --git a/parser/visitor.h b/parser/visitor.h deleted file mode 100644 index 7f5cb738a..000000000 --- a/parser/visitor.h +++ /dev/null @@ -1,117 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_PARSER_VISITOR_H_ -#define THIRD_PARTY_CEL_CPP_PARSER_VISITOR_H_ - -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/types/optional.h" -#include "parser/cel_grammar.inc/cel_grammar/CelBaseVisitor.h" -#include "parser/macro.h" -#include "parser/source_factory.h" - -namespace google { -namespace api { -namespace expr { -namespace parser { - -class SourceFactory; - -class ParserVisitor : public ::cel_grammar::CelBaseVisitor, - public antlr4::BaseErrorListener { - public: - ParserVisitor(const std::string& description, const std::string& expression, - const int max_recursion_depth, - const std::vector& macros = {}); - virtual ~ParserVisitor(); - - antlrcpp::Any visit(antlr4::tree::ParseTree* tree) override; - - antlrcpp::Any visitStart( - ::cel_grammar::CelParser::StartContext* ctx) override; - antlrcpp::Any visitExpr(::cel_grammar::CelParser::ExprContext* ctx) override; - antlrcpp::Any visitConditionalOr( - ::cel_grammar::CelParser::ConditionalOrContext* ctx) override; - antlrcpp::Any visitConditionalAnd( - ::cel_grammar::CelParser::ConditionalAndContext* ctx) override; - antlrcpp::Any visitRelation( - ::cel_grammar::CelParser::RelationContext* ctx) override; - antlrcpp::Any visitCalc(::cel_grammar::CelParser::CalcContext* ctx) override; - antlrcpp::Any visitUnary(::cel_grammar::CelParser::UnaryContext* ctx); - antlrcpp::Any visitLogicalNot( - ::cel_grammar::CelParser::LogicalNotContext* ctx) override; - antlrcpp::Any visitNegate( - ::cel_grammar::CelParser::NegateContext* ctx) override; - antlrcpp::Any visitSelectOrCall( - ::cel_grammar::CelParser::SelectOrCallContext* ctx) override; - antlrcpp::Any visitIndex( - ::cel_grammar::CelParser::IndexContext* ctx) override; - antlrcpp::Any visitCreateMessage( - ::cel_grammar::CelParser::CreateMessageContext* ctx) override; - antlrcpp::Any visitFieldInitializerList( - ::cel_grammar::CelParser::FieldInitializerListContext* ctx) override; - antlrcpp::Any visitIdentOrGlobalCall( - ::cel_grammar::CelParser::IdentOrGlobalCallContext* ctx) override; - antlrcpp::Any visitNested( - ::cel_grammar::CelParser::NestedContext* ctx) override; - antlrcpp::Any visitCreateList( - ::cel_grammar::CelParser::CreateListContext* ctx) override; - std::vector visitList( - ::cel_grammar::CelParser::ExprListContext* ctx); - antlrcpp::Any visitCreateStruct( - ::cel_grammar::CelParser::CreateStructContext* ctx) override; - antlrcpp::Any visitConstantLiteral( - ::cel_grammar::CelParser::ConstantLiteralContext* ctx) override; - antlrcpp::Any visitPrimaryExpr( - ::cel_grammar::CelParser::PrimaryExprContext* ctx) override; - antlrcpp::Any visitMemberExpr( - ::cel_grammar::CelParser::MemberExprContext* ctx) override; - - antlrcpp::Any visitMapInitializerList( - ::cel_grammar::CelParser::MapInitializerListContext* ctx) override; - antlrcpp::Any visitInt(::cel_grammar::CelParser::IntContext* ctx) override; - antlrcpp::Any visitUint(::cel_grammar::CelParser::UintContext* ctx) override; - antlrcpp::Any visitDouble( - ::cel_grammar::CelParser::DoubleContext* ctx) override; - antlrcpp::Any visitString( - ::cel_grammar::CelParser::StringContext* ctx) override; - antlrcpp::Any visitBytes( - ::cel_grammar::CelParser::BytesContext* ctx) override; - antlrcpp::Any visitBoolTrue( - ::cel_grammar::CelParser::BoolTrueContext* ctx) override; - antlrcpp::Any visitBoolFalse( - ::cel_grammar::CelParser::BoolFalseContext* ctx) override; - antlrcpp::Any visitNull(::cel_grammar::CelParser::NullContext* ctx) override; - google::api::expr::v1alpha1::SourceInfo sourceInfo() const; - EnrichedSourceInfo enrichedSourceInfo() const; - void syntaxError(antlr4::Recognizer* recognizer, - antlr4::Token* offending_symbol, size_t line, size_t col, - const std::string& msg, std::exception_ptr e) override; - bool hasErrored() const; - - std::string errorMessage() const; - - private: - Expr globalCallOrMacro(int64_t expr_id, std::string function, - std::vector args); - Expr receiverCallOrMacro(int64_t expr_id, std::string function, Expr target, - std::vector args); - bool expandMacro(int64_t expr_id, std::string function, Expr* target, - std::vector args, Expr* macro_expr); - std::string unquote(antlr4::ParserRuleContext* ctx, const std::string& s, - bool is_bytes); - std::string extractQualifiedName(antlr4::ParserRuleContext* ctx, - const Expr* e); - - private: - std::string description_; - std::string expression_; - std::shared_ptr sf_; - std::map macros_; - int recursion_depth_; - const int max_recursion_depth_; -}; - -} // namespace parser -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_PARSER_VISITOR_H_ diff --git a/policy/BUILD b/policy/BUILD new file mode 100644 index 000000000..19195be2b --- /dev/null +++ b/policy/BUILD @@ -0,0 +1,239 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cel_policy", + srcs = [ + "cel_policy.cc", + ], + hdrs = [ + "cel_policy.h", + ], + deps = [ + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "cel_policy_test", + srcs = ["cel_policy_test.cc"], + deps = [ + ":cel_policy", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cel_policy_parser", + srcs = [ + "cel_policy_parse_context.cc", + "cel_policy_parse_result.cc", + ], + hdrs = [ + "cel_policy_parse_context.h", + "cel_policy_parse_result.h", + "cel_policy_parser.h", + ], + deps = [ + ":cel_policy", + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "yaml_policy_parser", + srcs = [ + "yaml_policy_parser.cc", + ], + hdrs = ["yaml_policy_parser.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:source", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_library( + name = "cel_policy_validation_result", + srcs = [ + "cel_policy_validation_result.cc", + ], + hdrs = [ + "cel_policy_validation_result.h", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:ast", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:ast_rewrite", + "//common:constant", + "//common:container", + "//common:decl", + "//common:expr", + "//common:format_type_name", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:type_kind", + "//compiler", + "//internal:status_macros", + "//policy/internal:issue_reporter", + "//policy/internal:optimizer_expr_factory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "yaml_policy_parser_test", + srcs = [ + "test_custom_yaml_policy_parser.cc", + "yaml_policy_parser_test.cc", + ], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":yaml_policy_parser", + "//common:source", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + ":compiler", + ":yaml_policy_parser", + "//common:ast", + "//common:decl", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + "@yaml-cpp", + ], +) diff --git a/policy/cel_policy.cc b/policy/cel_policy.cc new file mode 100644 index 000000000..c2d97edeb --- /dev/null +++ b/policy/cel_policy.cc @@ -0,0 +1,273 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +namespace { + +std::string IdDebugString(CelPolicyElementId id) { + if (id == -1) { + return ""; + } + return absl::StrCat("#", id, "> "); +} + +std::string IndentBlock(absl::string_view text) { + if (text.empty()) { + return ""; + } + std::vector lines; + for (absl::string_view line : absl::StrSplit(text, '\n')) { + if (line.empty()) { + lines.push_back(""); + } else { + lines.push_back(absl::StrCat(" ", line)); + } + } + return absl::StrJoin(lines, "\n"); +} + +} // namespace + +void CelPolicySource::NoteSourcePosition(CelPolicyElementId id, + SourcePosition position) { + source_positions_[id] = position; +} + +std::optional CelPolicySource::GetSourcePosition( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return it->second; +} + +std::optional CelPolicySource::GetSourceLocation( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return policy_source_->GetLocation(it->second); +} + +std::string CelPolicySource::DebugString() const { + std::string result; + + // Sort the source elements in descending order of position + std::vector> sorted_positions; + for (const auto& pair : source_positions_) { + sorted_positions.push_back(pair); + } + std::sort(sorted_positions.begin(), sorted_positions.end(), + [](const auto& a, const auto& b) { + if (a.second == b.second) { + return a.first < b.first; + } + return a.second > b.second; + }); + + result = policy_source_->content().ToString(); + for (const auto& [id, position] : sorted_positions) { + result.insert(position, IdDebugString(id)); + } + return result; +} + +std::string ValueString::DebugString() const { + return absl::StrCat(IdDebugString(id_), "\"", value_, "\""); +} + +std::string Import::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "name: ", name_.DebugString()); + return result; +} + +std::string OutputBlock::DebugString() const { + std::string result; + absl::StrAppend(&result, "output: ", output_.DebugString()); + if (explanation_.has_value()) { + absl::StrAppend(&result, "\nexplanation: ", explanation_->DebugString()); + } + return result; +} + +Match::Match(const Match& other) + : id_(other.id_), condition_(other.condition_) { + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = + std::make_unique(*std::get>(other.result_)); + } +} + +Match& Match::operator=(const Match& other) { + if (this != &other) { + id_ = other.id_; + condition_ = other.condition_; + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = std::make_unique( + *std::get>(other.result_)); + } + } + return *this; +} + +std::string Match::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "match: {\n"); + if (condition_.has_value()) { + absl::StrAppend(&result, " condition: ", condition_->DebugString(), "\n"); + } + if (has_rule()) { + absl::StrAppend(&result, " result:\n", + IndentBlock(IndentBlock(rule().DebugString())), "\n"); + } else { + absl::StrAppend(&result, " result: {\n", + IndentBlock(IndentBlock(output_block().DebugString())), + "\n }\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Variable::DebugString() const { + std::string result; + absl::StrAppend(&result, "variable: {\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + absl::StrAppend(&result, " expression: ", expression_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Rule::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "rule: {\n"); + if (rule_id_.has_value()) { + absl::StrAppend(&result, " rule_id: ", rule_id_->DebugString(), "\n"); + } + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + for (const Variable& variable : variables_) { + absl::StrAppend(&result, IndentBlock(variable.DebugString()), "\n"); + } + for (const Match& match : matches_) { + absl::StrAppend(&result, IndentBlock(match.DebugString()), "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string MetadataValueDebugString(std::any value) { + if (value.type() == typeid(std::monostate)) { + return "null"; + } + if (value.type() == typeid(ValueString)) { + return std::any_cast(value).DebugString(); + } + if (value.type() == typeid(bool)) { + return std::any_cast(value) ? "true" : "false"; + } + if (value.type() == typeid(int)) { + return absl::StrCat(std::any_cast(value)); + } + if (value.type() == typeid(std::string)) { + return std::any_cast(value); + } + return absl::StrCat("typeid: ", value.type().name()); +} + +std::string CelPolicy::DebugString() const { + std::string result; + absl::StrAppend(&result, "CelPolicy{\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, IndentBlock(IndentBlock(source_->DebugString())), + "\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + if (!metadata_.empty()) { + std::vector sorted_keys; + for (const auto& [key, _] : metadata_) { + sorted_keys.push_back(key); + } + std::sort(sorted_keys.begin(), sorted_keys.end()); + + absl::StrAppend(&result, " metadata: {\n"); + for (const auto& key : sorted_keys) { + const auto& value = metadata_.at(key); + absl::StrAppend(&result, " ", key, ": ", + MetadataValueDebugString(value), "\n"); + } + absl::StrAppend(&result, " }\n"); + } + if (!imports_.empty()) { + absl::StrAppend(&result, " imports:\n"); + for (const Import& import : imports_) { + absl::StrAppend(&result, " ", import.DebugString(), "\n"); + } + } + absl::StrAppend(&result, IndentBlock(rule_.DebugString()), "\n"); + absl::StrAppend(&result, "}"); + return result; +} + +} // namespace cel diff --git a/policy/cel_policy.h b/policy/cel_policy.h new file mode 100644 index 000000000..af8f7c977 --- /dev/null +++ b/policy/cel_policy.h @@ -0,0 +1,320 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +using CelPolicyElementId = int32_t; + +class CelPolicySource { + public: + explicit CelPolicySource(cel::SourcePtr policy_source) + : policy_source_(std::move(policy_source)) {} + + const Source* absl_nonnull content() const { return policy_source_.get(); } + + void NoteSourcePosition(CelPolicyElementId id, SourcePosition position); + + std::optional GetSourcePosition(CelPolicyElementId id) const; + + std::optional GetSourceLocation(CelPolicyElementId id) const; + + std::string DebugString() const; + + private: + cel::SourcePtr policy_source_; + absl::flat_hash_map source_positions_; +}; + +class ValueString { + public: + ValueString() : id_(-1) {} + + explicit ValueString(CelPolicyElementId id, absl::string_view value) + : id_(id), value_(value) {} + + CelPolicyElementId id() const { return id_; } + absl::string_view value() const { return value_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + std::string value_; +}; + +class Import { + public: + Import(CelPolicyElementId id, ValueString name) + : id_(id), name_(std::move(name)) {} + CelPolicyElementId id() const { return id_; } + const ValueString& name() const { return name_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + ValueString name_; +}; + +// Defines a variable that can be used in CEL expressions within the policy. +// Variables are evaluated once and stored in the activation context. +class Variable { + public: + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + const ValueString& expression() const { return expression_; } + void set_expression(ValueString expression) { + expression_ = std::move(expression); + } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + + std::string DebugString() const; + + private: + ValueString name_; + ValueString expression_; + std::optional description_; + std::optional display_name_; +}; + +class Rule; + +class OutputBlock { + public: + OutputBlock() = default; + OutputBlock(ValueString output, std::optional explanation) + : output_(std::move(output)), explanation_(std::move(explanation)) {} + + const ValueString& output() const { return output_; } + void set_output(ValueString output) { output_ = std::move(output); } + + const std::optional& explanation() const { return explanation_; } + void set_explanation(ValueString explanation) { + explanation_ = std::move(explanation); + } + + std::string DebugString() const; + + private: + ValueString output_; + std::optional explanation_; +}; + +// Defines a match condition and result. +// If the result is a Rule, it is considered a sub-rule and will be evaluated +// only if the match condition evaluates to true. +class Match { + public: + Match() = default; + Match(const Match& other); + Match& operator=(const Match& other); + + CelPolicyElementId id() const; + void set_id(CelPolicyElementId id); + + bool has_condition() const; + std::optional condition() const; + void set_condition(ValueString condition); + + bool has_output_block() const; + const OutputBlock& output_block() const; + OutputBlock& mutable_output_block(); + + bool has_rule() const; + const Rule& rule() const; + Rule& mutable_rule(); + + void set_result(OutputBlock result); + void set_result(std::unique_ptr result); + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional condition_; + std::variant> result_; +}; + +// Rule is the body of the policy and contains a list of variables and matches. +// Variables are evaluated once and stored in the activation context. +// Matches are evaluated in order and the first match is returned. If the +// match contains a sub-rule, the sub-rule is evaluated only if the match +// condition evaluates to true. +class Rule { + public: + Rule() = default; + Rule(const Rule& other) = default; + + CelPolicyElementId id() const { return id_; } + void set_id(CelPolicyElementId id) { id_ = id; } + + const std::optional& rule_id() const { return rule_id_; } + void set_rule_id(ValueString rule_id) { rule_id_ = std::move(rule_id); } + + const std::optional& description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + const std::vector& variables() const { return variables_; } + std::vector& mutable_variables() { return variables_; } + + const std::vector& matches() const { return matches_; } + std::vector& mutable_matches() { return matches_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional rule_id_; + std::optional description_; + std::vector variables_; + std::vector matches_; +}; + +// CelPolicy is the top-level policy object. +// It contains a source, name, description, display name, imports, and a rule. +// The source is the CEL policy source code. +// The name, description, and display name are metadata about the policy. +// The rule is the main body of the policy. +class CelPolicy { + public: + explicit CelPolicy(std::shared_ptr source) + : source_(std::move(source)) {} + + CelPolicy(const CelPolicy& other) = default; + CelPolicy& operator=(const CelPolicy& other) = default; + + const CelPolicySource* absl_nullable source() const { return source_.get(); } + const std::shared_ptr& source_ptr() const { return source_; } + + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + const absl::flat_hash_map& metadata() const { + return metadata_; + } + absl::flat_hash_map& mutable_metadata() { + return metadata_; + } + const std::vector& imports() const { return imports_; } + std::vector& mutable_imports() { return imports_; } + + const Rule& rule() const { return rule_; } + Rule& mutable_rule() { return rule_; } + + std::string DebugString() const; + + private: + std::shared_ptr source_; + ValueString name_; + std::optional description_; + std::optional display_name_; + absl::flat_hash_map metadata_; + std::vector imports_; + Rule rule_; +}; + +// Implementation details. + +inline CelPolicyElementId Match::id() const { return id_; } +inline void Match::set_id(CelPolicyElementId id) { id_ = id; } + +inline bool Match::has_condition() const { return condition_.has_value(); } + +inline std::optional Match::condition() const { + return condition_; +} + +inline void Match::set_condition(ValueString condition) { + condition_ = std::move(condition); +} + +inline bool Match::has_output_block() const { + return std::holds_alternative(result_); +} + +inline const OutputBlock& Match::output_block() const { + ABSL_DCHECK(std::holds_alternative(result_)); + return std::get(result_); +} + +inline OutputBlock& Match::mutable_output_block() { + if (!std::holds_alternative(result_)) { + result_ = OutputBlock(); + } + return std::get(result_); +} + +inline bool Match::has_rule() const { + return std::holds_alternative>(result_); +} + +inline const Rule& Match::rule() const { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline Rule& Match::mutable_rule() { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline void Match::set_result(OutputBlock result) { + result_ = std::move(result); +} + +inline void Match::set_result(std::unique_ptr result) { + result_ = std::move(result); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ diff --git a/policy/cel_policy_parse_context.cc b/policy/cel_policy_parse_context.cc new file mode 100644 index 000000000..66861d085 --- /dev/null +++ b/policy/cel_policy_parse_context.cc @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_context.h" + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +CelPolicy& CelPolicyParseContext::policy() const { + ABSL_CHECK(policy_ != nullptr) + << "CelPolicyParseContext::policy() called after GetResult()"; + return *policy_; +} + +CelPolicyParseResult CelPolicyParseContext::GetResult() { + if (policy_ != nullptr && issues_.empty()) { + return CelPolicyParseResult(std::move(policy_source_), std::move(policy_), + std::move(issues_)); + } + policy_.reset(); + return CelPolicyParseResult(std::move(policy_source_), nullptr, + std::move(issues_)); +} + +void CelPolicyParseContext::ReportError(CelPolicyElementId element_id, + std::string_view message) { + issues_.push_back(CelPolicyIssue(element_id, std::string(message))); +} + +} // namespace cel diff --git a/policy/cel_policy_parse_context.h b/policy/cel_policy_parse_context.h new file mode 100644 index 000000000..6482fa1ae --- /dev/null +++ b/policy/cel_policy_parse_context.h @@ -0,0 +1,65 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// A mutable context for parsing a CelPolicy. An instance of this class is +// created for each policy parse and is passed to the parser, which is meant to +// be stateless. +// +// Parsers call methods on this class to report issues and populate the policy +// being parsed. Call GetResult() to obtain the resulting CelPolicyParseResult, +// which takes ownership of the parsed policy. Do not use the context after +// calling GetResult(). +class CelPolicyParseContext { + public: + explicit CelPolicyParseContext(std::shared_ptr policy_source) + : policy_source_(std::move(policy_source)), + policy_(std::make_unique(policy_source_)) {} + + CelPolicySource& policy_source() const { return *policy_source_; } + + // Returns the policy being parsed. It should not be used after + // calling GetResult(). + CelPolicy& policy() const; + + // The context should not be used after calling GetResult(). + CelPolicyParseResult GetResult(); + + // Reports an error for the given element with the given error message. + void ReportError(CelPolicyElementId id, std::string_view message); + + CelPolicyElementId next_element_id() { return next_element_id_++; } + + private: + std::shared_ptr policy_source_; + CelPolicyElementId next_element_id_ = 0; + std::vector issues_; + std::unique_ptr policy_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ diff --git a/policy/cel_policy_parse_result.cc b/policy/cel_policy_parse_result.cc new file mode 100644 index 000000000..32d6431bb --- /dev/null +++ b/policy/cel_policy_parse_result.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_result.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { +namespace { + +absl::string_view SeverityString(CelPolicyIssue::Severity severity) { + switch (severity) { + case CelPolicyIssue::Severity::kInformation: + return "INFORMATION"; + case CelPolicyIssue::Severity::kWarning: + return "WARNING"; + case CelPolicyIssue::Severity::kError: + return "ERROR"; + case CelPolicyIssue::Severity::kDeprecated: + return "DEPRECATED"; + default: + return "SEVERITY_UNSPECIFIED"; + } +} + +} // namespace + +std::string CelPolicyIssue::ToDisplayString( + const CelPolicySource* absl_nullable source) const { + SourceLocation location; + std::string description; + std::string snippet; + if (source != nullptr) { + if (relative_position_) { + std::optional base = + source->GetSourcePosition(element_id_); + if (element_id_ == -1) { + base.emplace(0); + } + if (base) { + location = source->content() + ->GetLocation(*base + *relative_position_) + .value_or(SourceLocation{}); + } + } else { + location = + source->GetSourceLocation(element_id_).value_or(SourceLocation{}); + } + description = std::string(source->content()->description()); + snippet = source->content()->DisplayErrorLocation(location); + } + + const int display_column = location.column >= 0 ? location.column + 1 : -1; + + return absl::StrFormat("%s: %s:%d:%d: %s%s", SeverityString(severity_), + description, location.line, display_column, message_, + snippet); +} + +std::string CelPolicyParseResult::FormattedIssues() const { + std::string formatted_issues; + for (const CelPolicyIssue& issue : issues_) { + if (!formatted_issues.empty()) { + absl::StrAppend(&formatted_issues, "\n"); + } + absl::StrAppend(&formatted_issues, issue.ToDisplayString(*policy_source_)); + } + return formatted_issues; +} + +} // namespace cel diff --git a/policy/cel_policy_parse_result.h b/policy/cel_policy_parse_result.h new file mode 100644 index 000000000..2bf80b1ce --- /dev/null +++ b/policy/cel_policy_parse_result.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { + +class CelPolicyIssue { + public: + enum class Severity { kInformation, kDeprecated, kWarning, kError }; + + CelPolicyIssue(CelPolicyElementId element_id, absl::string_view message) + : element_id_(element_id), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, Severity severity, + absl::string_view message) + : element_id_(element_id), severity_(severity), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, Severity severity, + absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + severity_(severity), + message_(message) {} + + std::string ToDisplayString( + const CelPolicySource* absl_nullable source) const; + std::string ToDisplayString(const CelPolicySource& source) const { + return ToDisplayString(&source); + } + + Severity severity() const { return severity_; } + absl::string_view message() const { return message_; } + + private: + CelPolicyElementId element_id_; + std::optional relative_position_; + Severity severity_ = Severity::kError; + std::string message_; +}; + +class CelPolicyParseResult { + public: + explicit CelPolicyParseResult(std::shared_ptr policy_source, + std::unique_ptr policy, + std::vector issues) + : policy_source_(std::move(policy_source)), + policy_(std::move(policy)), + issues_(std::move(issues)) {} + + bool IsValid() const { return policy_ != nullptr; } + + const CelPolicy* absl_nullable GetPolicy() const { return policy_.get(); } + + absl::StatusOr> ReleasePolicy() { + if (policy_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyParseResult is empty. Check for Issues."); + } + return std::move(policy_); + } + + absl::Span GetIssues() const { return issues_; } + + std::string FormattedIssues() const; + + private: + std::shared_ptr policy_source_; + absl_nullable std::unique_ptr policy_; + std::vector issues_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ diff --git a/policy/cel_policy_parser.h b/policy/cel_policy_parser.h new file mode 100644 index 000000000..0a11c9e68 --- /dev/null +++ b/policy/cel_policy_parser.h @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ + +#include "absl/status/status.h" +#include "policy/cel_policy_parse_context.h" + +namespace cel { + +// A policy parser for a given policy format. The type `T` parameter is the +// representation of the input file format, such as `` for YAML. +// +// Parsers are intended to be stateless: all state, including the resulting +// policy and any issues encountered, should be kept in the context passed to +// the `ParsePolicy` method. +template +class CelPolicyParser { + public: + virtual ~CelPolicyParser() = default; + + // Parses the input and populates a CelPolicy in the context. + virtual absl::Status ParsePolicy(CelPolicyParseContext& ctx) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ diff --git a/policy/cel_policy_test.cc b/policy/cel_policy_test.cc new file mode 100644 index 000000000..640247e7f --- /dev/null +++ b/policy/cel_policy_test.cc @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using testing::Field; +using testing::Optional; +using testing::SizeIs; + +TEST(CelPolicyBuilderTest, Build) { + CelPolicyElementId next_id = 1; + ASSERT_OK_AND_ASSIGN(SourcePtr source, NewSource("CEL\n policy\n source")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CelPolicy policy(policy_source); + policy.set_name(ValueString(next_id++, "test_policy")); + policy.set_description(ValueString(next_id++, "test_description")); + policy.set_display_name(ValueString(next_id++, "test_display_name")); + ValueString import1_name = ValueString(next_id++, "test_import1"); + policy.mutable_imports().push_back(Import(next_id++, import1_name)); + ValueString import2_name = ValueString(next_id++, "test_import2"); + policy.mutable_imports().push_back(Import(next_id++, import2_name)); + + Rule& rule = policy.mutable_rule(); + rule.set_id(next_id++); + rule.set_rule_id(ValueString(next_id++, "test_rule_id")); + rule.set_description(ValueString(next_id++, "test_rule_description")); + + Variable variable; + variable.set_name(ValueString(next_id++, "test_variable")); + variable.set_expression(ValueString(next_id++, "test_expression")); + variable.set_description(ValueString(next_id++, "test_variable_description")); + variable.set_display_name( + ValueString(next_id++, "test_variable_display_name")); + + Match match1; + match1.set_id(next_id++); + match1.set_condition(ValueString(next_id++, "test_condition")); + CelPolicyElementId output_id = next_id++; + CelPolicyElementId explanation_id = next_id++; + match1.set_result( + OutputBlock(ValueString(output_id, "test_result"), + ValueString(explanation_id, "test_explanation"))); + + Match match2; + match2.set_id(next_id++); + match2.set_condition(ValueString(next_id++, "test_condition2")); + + auto sub_rule = std::make_unique(); + sub_rule->set_id(next_id++); + sub_rule->set_rule_id(ValueString(next_id++, "sub_rule_id")); + sub_rule->set_description(ValueString(next_id++, "sub_rule_description")); + Match sub_rule_match; + sub_rule_match.set_id(next_id++); + sub_rule_match.set_condition(ValueString(next_id++, "sub_rule_condition")); + sub_rule_match.set_result( + OutputBlock(ValueString(next_id++, "sub_rule_result"), std::nullopt)); + sub_rule->mutable_matches().push_back(sub_rule_match); + + match2.set_result(std::move(sub_rule)); + + rule.mutable_variables().push_back(variable); + rule.mutable_matches().push_back(match1); + rule.mutable_matches().push_back(match2); + + EXPECT_EQ(policy.name().value(), "test_policy"); + ASSERT_TRUE(policy.description().has_value()); + EXPECT_EQ(policy.description()->value(), "test_description"); + ASSERT_TRUE(policy.display_name().has_value()); + EXPECT_EQ(policy.display_name()->value(), "test_display_name"); + + ASSERT_THAT(policy.imports(), SizeIs(2)); + + EXPECT_EQ(policy.imports()[0].name().value(), "test_import1"); + EXPECT_EQ(policy.imports()[1].name().value(), "test_import2"); + ASSERT_TRUE(policy.rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().rule_id()->value(), "test_rule_id"); + ASSERT_TRUE(policy.rule().description().has_value()); + EXPECT_EQ(policy.rule().description()->value(), "test_rule_description"); + + ASSERT_THAT(policy.rule().variables(), SizeIs(1)); + + EXPECT_EQ(policy.rule().variables()[0].name().value(), "test_variable"); + EXPECT_EQ(policy.rule().variables()[0].expression().value(), + "test_expression"); + ASSERT_TRUE(policy.rule().variables()[0].description().has_value()); + EXPECT_EQ(policy.rule().variables()[0].description()->value(), + "test_variable_description"); + ASSERT_TRUE(policy.rule().variables()[0].display_name().has_value()); + EXPECT_EQ(policy.rule().variables()[0].display_name()->value(), + "test_variable_display_name"); + + ASSERT_THAT(policy.rule().matches(), SizeIs(2)); + + EXPECT_EQ(policy.rule().matches()[0].condition().value().value(), + "test_condition"); + ASSERT_TRUE(policy.rule().matches()[0].has_output_block()); + EXPECT_EQ(policy.rule().matches()[0].output_block().output().value(), + "test_result"); + ASSERT_TRUE( + policy.rule().matches()[0].output_block().explanation().has_value()); + EXPECT_EQ(policy.rule().matches()[0].output_block().explanation()->value(), + "test_explanation"); + + EXPECT_EQ(policy.rule().matches()[1].condition().value().value(), + "test_condition2"); + ASSERT_TRUE(policy.rule().matches()[1].has_rule()); + ASSERT_TRUE(policy.rule().matches()[1].rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().rule_id()->value(), + "sub_rule_id"); + ASSERT_TRUE(policy.rule().matches()[1].rule().description().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().description()->value(), + "sub_rule_description"); + ASSERT_THAT(policy.rule().matches()[1].rule().matches(), SizeIs(1)); + EXPECT_EQ(policy.rule() + .matches()[1] + .rule() + .matches()[0] + .condition() + .value() + .value(), + "sub_rule_condition"); + + std::string actual = policy.DebugString(); + EXPECT_EQ(actual, absl::StrReplaceAll(R"(CelPolicy{ + =========================================================== + CEL + policy + source + =========================================================== + name: #1> "test_policy" + description: #2> "test_description" + display_name: #3> "test_display_name" + imports: + #5> name: #4> "test_import1" + #7> name: #6> "test_import2" + #8> rule: { + rule_id: #9> "test_rule_id" + description: #10> "test_rule_description" + variable: { + name: #11> "test_variable" + expression: #12> "test_expression" + description: #13> "test_variable_description" + display_name: #14> "test_variable_display_name" + } + #15> match: { + condition: #16> "test_condition" + result: { + output: #17> "test_result" + explanation: #18> "test_explanation" + } + } + #19> match: { + condition: #20> "test_condition2" + result: + #21> rule: { + rule_id: #22> "sub_rule_id" + description: #23> "sub_rule_description" + #24> match: { + condition: #25> "sub_rule_condition" + result: { + output: #26> "sub_rule_result" + } + } + } + } + } + })", + {{"\n ", "\n"}})); +} + +TEST(CelPolicySourceTest, Build) { + std::string source = + "name: test_policy\n imports:\n - name: test_import\n"; + + ASSERT_OK_AND_ASSIGN(SourcePtr source_ptr, NewSource(source)); + CelPolicySource policy_source(std::move(source_ptr)); + policy_source.NoteSourcePosition(1, source.find("test_policy")); + policy_source.NoteSourcePosition(2, source.find("test_import")); + + EXPECT_THAT(policy_source.GetSourcePosition(1), Optional(6)); + EXPECT_THAT(policy_source.GetSourceLocation(1), + Optional(AllOf(Field(&SourceLocation::line, 1), + Field(&SourceLocation::column, 6)))); + EXPECT_THAT(policy_source.GetSourcePosition(2), Optional(44)); + EXPECT_THAT(policy_source.GetSourceLocation(2), + Optional(AllOf(Field(&SourceLocation::line, 3), + Field(&SourceLocation::column, 13)))); + EXPECT_EQ(policy_source.GetSourcePosition(3), std::nullopt); + EXPECT_EQ(policy_source.GetSourceLocation(3), std::nullopt); + EXPECT_EQ( + policy_source.DebugString(), + "name: #1> test_policy\n imports:\n - name: #2> test_import\n"); +} + +} // namespace +} // namespace cel diff --git a/policy/cel_policy_validation_result.cc b/policy/cel_policy_validation_result.cc new file mode 100644 index 000000000..e257f064c --- /dev/null +++ b/policy/cel_policy_validation_result.cc @@ -0,0 +1,32 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_validation_result.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +std::string CelPolicyValidationResult::FormatIssues() const { + return absl::StrJoin( + issues_, "\n", [this](std::string* out, const CelPolicyIssue& issue) { + absl::StrAppend(out, issue.ToDisplayString(source_.get())); + }); +} + +} // namespace cel diff --git a/policy/cel_policy_validation_result.h b/policy/cel_policy_validation_result.h new file mode 100644 index 000000000..bddb9a3ca --- /dev/null +++ b/policy/cel_policy_validation_result.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// CelPolicyValidationResult holds the result of policy compilation. +// +// Policy compilation/validation errors are captured in issues. +class CelPolicyValidationResult { + public: + CelPolicyValidationResult( + std::unique_ptr ast, std::vector issues, + std::shared_ptr source = nullptr) + : ast_(std::move(ast)), + issues_(std::move(issues)), + source_(std::move(source)) {} + + explicit CelPolicyValidationResult( + std::vector issues, + std::shared_ptr source = nullptr) + : ast_(nullptr), issues_(std::move(issues)), source_(std::move(source)) {} + + // Returns true if validation succeeded and an AST is present. + bool IsValid() const { return ast_ != nullptr; } + + // Returns the AST if validation was successful. + const Ast* absl_nullable GetAst() const { return ast_.get(); } + + // Moves out and returns the AST. + absl::StatusOr> ReleaseAst() { + if (ast_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyValidationResult is empty. Check for CelPolicyIssues."); + } + return std::move(ast_); + } + + // Returns the list of issues encountered during compilation. + absl::Span GetIssues() const { return issues_; } + + // Returns the contained policy source, if any. + const CelPolicySource* absl_nullable GetSource() const { + return source_.get(); + } + + // Returns a formatted error string of the compiled issues. + std::string FormatIssues() const; + + private: + absl_nullable std::unique_ptr ast_; + std::vector issues_; + std::shared_ptr source_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ diff --git a/policy/compiler.cc b/policy/compiler.cc new file mode 100644 index 000000000..7a892447c --- /dev/null +++ b/policy/compiler.cc @@ -0,0 +1,1058 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/format_type_name.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/internal/issue_reporter.h" +#include "policy/internal/optimizer_expr_factory.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +constexpr absl::string_view kCelBlock = "cel.@block"; + +enum class RuleSemantics { + // TODO(b/506179116): will also need "aggregate" or similar concept. + kFirstMatch, + + kNotForUseWithExhaustiveSwitchStatements, +}; + +template +void AbslStringify(Sink& s, RuleSemantics semantics) { + switch (semantics) { + case RuleSemantics::kFirstMatch: + s.Append("first_match"); + return; + default: + s.Append(""); + return; + } +} + +struct EmbeddedAst { + CelPolicyElementId id; + std::unique_ptr ast; +}; + +struct CompiledVariable { + std::string ident; + EmbeddedAst ast; +}; + +struct CompiledOutputBlock { + EmbeddedAst output_ast; + cel::Type result_type; + std::optional explanation_ast; +}; + +struct CompiledRule; + +struct CompiledMatch { + using Production = + std::variant absl_nonnull, + CompiledOutputBlock>; + + CelPolicyElementId id; + std::optional condition; + Production production; +}; + +struct CompiledRule { + CelPolicyElementId id; + std::vector variables; + std::vector matches; + // Not set if cannot be determined. + std::optional result_type; +}; + +std::optional GetOutputType( + const CompiledMatch::Production& production) { + return std::visit( + [](const auto& production) -> std::optional { + if constexpr (std::is_same_v, + CompiledOutputBlock>) { + return production.result_type; + } else if constexpr (std::is_same_v, + std::unique_ptr>) { + return production->result_type; + } + return std::nullopt; + }, + production); +} + +// Internal representation of the compiled policy elements. +// +// This is used for checking the component expression before composing into the +// final AST based on the provided rule semantics. +class IntermediateCompiledPolicy { + public: + CompiledRule& mutable_root_rule() { return root_rule_; } + + const CompiledRule& root_rule() const { return root_rule_; } + + void set_name(absl::string_view name) { name_ = name; } + absl::string_view name() const { return name_; } + void set_display_name(absl::string_view display_name) { + display_name_ = display_name; + } + absl::string_view display_name() const { return display_name_; } + void set_description(absl::string_view description) { + description_ = description; + } + absl::string_view description() const { return description_; } + + void set_semantics(RuleSemantics semantics) { semantics_ = semantics; } + RuleSemantics semantics() const { return semantics_; } + + private: + std::string name_; + std::string display_name_; + std::string description_; + RuleSemantics semantics_ = RuleSemantics::kFirstMatch; + + CompiledRule root_rule_; +}; + +CelPolicyIssue::Severity MapSeverity(cel::TypeCheckIssue::Severity severity) { + switch (severity) { + case cel::TypeCheckIssue::Severity::kError: + return CelPolicyIssue::Severity::kError; + case cel::TypeCheckIssue::Severity::kWarning: + return CelPolicyIssue::Severity::kWarning; + case cel::TypeCheckIssue::Severity::kDeprecated: + return CelPolicyIssue::Severity::kDeprecated; + default: + return CelPolicyIssue::Severity::kError; + } +} + +bool IsWrapperOf(cel::TypeKind wrapper_kind, cel::TypeKind primitive_kind) { + switch (wrapper_kind) { + case cel::TypeKind::kBoolWrapper: + return primitive_kind == cel::TypeKind::kBool; + case cel::TypeKind::kIntWrapper: + return primitive_kind == cel::TypeKind::kInt; + case cel::TypeKind::kUintWrapper: + return primitive_kind == cel::TypeKind::kUint; + case cel::TypeKind::kDoubleWrapper: + return primitive_kind == cel::TypeKind::kDouble; + case cel::TypeKind::kStringWrapper: + return primitive_kind == cel::TypeKind::kString; + case cel::TypeKind::kBytesWrapper: + return primitive_kind == cel::TypeKind::kBytes; + default: + return false; + } +} + +cel::Type FilterSpecialTypes(cel::Type type) { + if (type.IsTypeParam()) { + // Free type param should not appear in the output type, but if it does, + // force it to dyn. + return DynType(); + } + if (type.IsEnum()) { + return IntType{}; + } + if (type.IsError()) { + return DynType(); + } + if (type.IsType()) { + // drop parameters so all type types are compatible. + return TypeType{}; + } + return type; +} + +// Returns true if `from` is assignable to `to`. +// +// Slightly adjusted from the standard routine to cover some edge cases around +// null and wrappers. +// +// TODO(b/522391716): try to standardize assignability checks. +bool OutputTypeIsAssignable(cel::Type from, cel::Type to) { + from = FilterSpecialTypes(from); + to = FilterSpecialTypes(to); + + // Any and dyn are assignable to/from everything. + if (from.kind() == cel::TypeKind::kAny || + from.kind() == cel::TypeKind::kDyn || to.kind() == cel::TypeKind::kAny || + to.kind() == cel::TypeKind::kDyn) { + return true; + } + + // Wrappers auto-unwrap. + if (IsWrapperOf(from.kind(), to.kind()) || + IsWrapperOf(to.kind(), from.kind())) { + return true; + } + + // Null is assignable to anything that is message-like. + if (from.kind() == cel::TypeKind::kNull) { + switch (to.kind()) { + case cel::TypeKind::kNull: + case cel::TypeKind::kStruct: + case cel::TypeKind::kOpaque: + case cel::TypeKind::kTimestamp: + case cel::TypeKind::kDuration: + case cel::TypeKind::kBytesWrapper: + case cel::TypeKind::kBoolWrapper: + case cel::TypeKind::kIntWrapper: + case cel::TypeKind::kUintWrapper: + case cel::TypeKind::kDoubleWrapper: + case cel::TypeKind::kStringWrapper: + return true; + default: + return false; + } + } + + if (from.kind() != to.kind()) { + return false; + } + + if (from.name() != to.name()) { + return false; + } + + if (from.GetParameters().size() != to.GetParameters().size()) { + return false; + } + + for (int i = 0; i < from.GetParameters().size(); ++i) { + if (!OutputTypeIsAssignable(from.GetParameters()[i], + to.GetParameters()[i])) { + return false; + } + } + + return true; +} + +bool OutputTypeIsCompatible(cel::Type from, cel::Type to) { + // We don't handle widening like in a self-contained CEL expression, but + // permit some cases where one type is more specific than the other. + return OutputTypeIsAssignable(from, to) || OutputTypeIsAssignable(to, from); +} + +bool HasErrors(const policy_internal::IssueReporter& issues) { + for (const auto& issue : issues.issues()) { + if (issue.severity() == CelPolicyIssue::Severity::kError) { + return true; + } + } + return false; +} + +// Note on lifetime safety: +// +// The output policy will contain references to types that are owned by the +// arena member of this class. This is safe as long as the policy compiler lives +// as long as the output policies. +class PolicyCompiler { + public: + explicit PolicyCompiler(policy_internal::IssueReporter* issues, + std::unique_ptr base_compiler) + : issues_(*issues), base_compiler_(std::move(base_compiler)) {} + + absl::string_view GetSourceDescription() const { + if (src_ == nullptr) { + return ""; + } + return src_->content()->description(); + } + + void AdaptTypeCheckIssues(CelPolicyElementId id, const ValidationResult& r) { + const Source* source = r.GetSource(); + + for (const auto& iss : r.GetIssues()) { + std::optional offset; + if (source != nullptr) { + offset = source->GetPosition(iss.location()); + } + if (offset.has_value()) { + issues_.ReportOffsetIssue(id, offset.value(), + MapSeverity(iss.severity()), iss.message()); + continue; + } + issues_.ReportIssue(id, MapSeverity(iss.severity()), iss.message()); + } + } + + absl::StatusOr CompileOutputBlock( + const cel::OutputBlock& output_block, const Compiler* env) { + CompiledOutputBlock output; + CEL_ASSIGN_OR_RETURN(auto output_validation, + env->Compile(output_block.output().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.output().id(), output_validation); + + cel::Type result_type = DynType(); + if (output_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, output_validation.ReleaseAst()); + auto root_expr_id = ast->root_expr().id(); + output.output_ast = + EmbeddedAst{output_block.output().id(), std::move(ast)}; + if (auto it = output_validation.GetResolvedTypeMap().find(root_expr_id); + it != output_validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + } + if (output_block.explanation().has_value()) { + CEL_ASSIGN_OR_RETURN(auto explanation_validation, + env->Compile(output_block.explanation()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.explanation()->id(), + explanation_validation); + if (explanation_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, explanation_validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kString) { + issues_.ReportError(output_block.explanation()->id(), + "explanation must evaluate to string"); + } else { + output.explanation_ast = + EmbeddedAst{output_block.explanation()->id(), std::move(ast)}; + } + } + } + output.result_type = result_type; + return output; + } + + absl::Status CompileMatch(const Match& match, const Compiler* env, + CompiledRule* out) { + CompiledMatch c_match; + c_match.id = match.id(); + if (match.condition().has_value()) { + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(match.condition()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(match.condition()->id(), validation); + if (validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kBool) { + issues_.ReportError(match.condition()->id(), + "condition must evaluate to bool"); + } + c_match.condition = + EmbeddedAst{match.condition()->id(), std::move(ast)}; + } + } + + if (match.has_output_block()) { + CEL_ASSIGN_OR_RETURN(c_match.production, + CompileOutputBlock(match.output_block(), env)); + } else if (match.has_rule()) { + auto rule = std::make_unique(); + CEL_RETURN_IF_ERROR(CompileRule(match.rule(), env, rule.get())); + c_match.production = std::move(rule); + } else { + issues_.ReportError(match.id(), "match must specify an output or rule"); + } + out->matches.push_back(std::move(c_match)); + return absl::OkStatus(); + } + + absl::Status CompileRule(const Rule& rule, const cel::Compiler* env, + CompiledRule* out) { + out->id = rule.id(); + std::unique_ptr buf; + + absl::flat_hash_set seen_variables; + for (const auto& variable : rule.variables()) { + std::string name(variable.name().value()); + if (!seen_variables.insert(name).second) { + issues_.ReportError( + variable.expression().id(), + absl::StrCat("overlapping identifier for name 'variables.", name, + "'")); + continue; + } + std::string ident = absl::StrCat("variables.", name); + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(variable.expression().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(variable.expression().id(), validation); + if (!validation.IsValid()) { + continue; + } + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + cel::Type result_type = DynType(); + + if (auto it = validation.GetResolvedTypeMap().find(ast->root_expr().id()); + it != validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + out->variables.push_back(CompiledVariable{ + ident, + EmbeddedAst{variable.expression().id(), std::move(ast)}, + }); + auto next = env->ToBuilder(); + auto status = next->GetCheckerBuilder().AddOrReplaceVariable( + MakeVariableDecl(ident, result_type)); + if (!status.ok()) { + issues_.ReportError(variable.expression().id(), status.message()); + continue; + } + CEL_ASSIGN_OR_RETURN(buf, next->Build()); + env = buf.get(); + } + + std::optional overall_type; + for (const auto& match : rule.matches()) { + CEL_RETURN_IF_ERROR(CompileMatch(match, env, out)); + if (!overall_type.has_value()) { + overall_type = GetOutputType(out->matches.back().production); + continue; + } + + if (std::optional match_type = + GetOutputType(out->matches.back().production); + match_type.has_value()) { + if (!OutputTypeIsCompatible(*match_type, *overall_type)) { + issues_.ReportError( + match.id(), + absl::StrCat("incompatible output types: block has output type ", + FormatTypeName(*match_type), + ", but previous outputs have type ", + FormatTypeName(*overall_type))); + } + } + } + + out->result_type = overall_type; + return absl::OkStatus(); + } + + absl::Status CompilePolicy(const CelPolicy& policy, + IntermediateCompiledPolicy* out) { + src_ = policy.source(); + out->set_semantics(RuleSemantics::kFirstMatch); + out->set_name(policy.name().value()); + out->set_display_name( + policy.display_name().value_or(ValueString{}).value()); + out->set_description(policy.description().value_or(ValueString{}).value()); + + return CompileRule(policy.rule(), base_compiler_.get(), + &out->mutable_root_rule()); + } + + private: + google::protobuf::Arena arena_; + const CelPolicySource* absl_nullable src_; + policy_internal::IssueReporter& issues_; + std::unique_ptr base_compiler_; +}; + +bool IsExhaustive(const CompiledRule& rule); + +class FirstMatchComposer { + public: + FirstMatchComposer(const IntermediateCompiledPolicy& icp, + const Compiler& compiler, + policy_internal::IssueReporter& issues) + : issues_(issues), icp_(icp), compiler_(compiler) {} + + absl::Status Compose(); + + bool success() const { return ast_ != nullptr; } + + std::unique_ptr ReleaseAst() { return std::move(ast_); } + + private: + using VariableScope = absl::flat_hash_map; + + std::optional ResolvePolicyVariable(absl::string_view reference); + + absl::flat_hash_map ResolveBlockIndexes(const Ast& ast); + + bool CheckMatchStructure(const CompiledRule& rule); + + // Returns true if already optional wrapped. + absl::StatusOr ComposeRule(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + // returns true if already optional wrapped. + absl::StatusOr ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr); + + void MapVariables(Ast& ast); + + void ComposeRuleVariables(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + policy_internal::IssueReporter& issues_; + OptimizerExprFactory factory_; + const IntermediateCompiledPolicy& icp_; + const Compiler& compiler_; + std::vector scopes_; + bool optionalize_ = false; + std::unique_ptr ast_; +}; + +absl::Status FirstMatchComposer::Compose() { + ABSL_DCHECK(icp_.semantics() == RuleSemantics::kFirstMatch); + + factory_.mutable_ast().mutable_root_expr() = factory_.NewCall( + "cel.@block", factory_.NewList(), factory_.NewUnspecified()); + auto& block_init_list = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[0]; + auto& insertion_expr = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]; + optionalize_ = !IsExhaustive(icp_.root_rule()); + if (!CheckMatchStructure(icp_.root_rule())) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + bool optional_wrapped, + ComposeRule(icp_.root_rule(), block_init_list, insertion_expr)); + + if (optional_wrapped != optionalize_) { + return absl::InternalError( + "composition failed to handle non-exhaustive rules"); + } + + CEL_ASSIGN_OR_RETURN(cel::ValidationResult result, + compiler_.GetTypeChecker().Check(factory_.ast())); + if (!result.IsValid()) { + for (const auto& iss : result.GetIssues()) { + issues_.ReportError(icp_.root_rule().id, iss.message()); + } + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(ast_, result.ReleaseAst()); + + return absl::OkStatus(); +} + +bool IsTriviallyTrueCondition(const CompiledMatch& match) { + if (!match.condition.has_value() || match.condition->ast == nullptr) { + return true; + } + const cel::Expr& expr = match.condition->ast->root_expr(); + if (expr.has_const_expr()) { + const cel::Constant& const_expr = expr.const_expr(); + if (const_expr.has_bool_value() && const_expr.bool_value()) { + return true; + } + } + return false; +} + +bool IsExhaustive(const CompiledRule& rule); + +bool IsExhaustive(const CompiledMatch& match) { + if (std::holds_alternative(match.production)) { + return true; + } + + const auto* nested_rule_ptr = + std::get_if>(&match.production); + ABSL_DCHECK(nested_rule_ptr != nullptr); + const CompiledRule& nested_rule = **nested_rule_ptr; + return IsExhaustive(nested_rule); +} + +bool IsExhaustive(const CompiledRule& rule) { + if (rule.matches.empty()) { + // Validation should fail, but generalization would be false. + return false; + } + bool has_default = false; + for (const auto& match : rule.matches) { + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + // If this isn't the last match in the rule, it should get flagged + // during validation since it means there are trivially unreachable + // matches. + has_default = true; + } + if (!IsTriviallyTrueCondition(match) && !IsExhaustive(match)) { + // There is a nested rule that might return an optional.none(). + return false; + } + } + // Otherwise, everything in this branch is exhaustive so we can defer + // wrapping. + return has_default; +} + +bool FirstMatchComposer::CheckMatchStructure(const CompiledRule& rule) { + if (rule.matches.empty()) { + issues_.ReportError(rule.id, "rule does not specify match conditions"); + return false; + } + + bool valid = true; + bool seen_trivially_true = false; + + for (const auto& match : rule.matches) { + if (seen_trivially_true) { + if (std::holds_alternative(match.production)) { + issues_.ReportError(match.id, "match creates unreachable outputs"); + } else if (std::holds_alternative>( + match.production)) { + issues_.ReportError(match.id, "rule creates unreachable outputs"); + } + valid = false; + } + + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + seen_trivially_true = true; + } + + if (auto* nested_rule = + std::get_if>(&match.production); + nested_rule != nullptr) { + ABSL_DCHECK(*nested_rule != nullptr); + if (!CheckMatchStructure(**nested_rule)) { + valid = false; + } + } + } + + return valid; +} + +std::optional FirstMatchComposer::ResolvePolicyVariable( + absl::string_view reference) { + for (auto scope_iter = scopes_.rbegin(); scope_iter != scopes_.rend(); + ++scope_iter) { + if (auto it = scope_iter->find(reference); it != scope_iter->end()) { + return it->second; + } + } + return std::nullopt; +} + +class IndexRewrite : public AstRewriterBase { + public: + explicit IndexRewrite(absl::flat_hash_map expr_id_to_index, + OptimizerExprFactory& factory) + : expr_id_to_index_(std::move(expr_id_to_index)), factory_(factory) {} + + bool PreVisitRewrite(Expr& e) override { + if (auto it = expr_id_to_index_.find(e.id()); + it != expr_id_to_index_.end()) { + e.mutable_ident_expr().set_name(absl::StrCat("@index", it->second)); + factory_.RecordReplacement(e.id(), e); + return true; + } + return false; + } + + private: + absl::flat_hash_map expr_id_to_index_; + OptimizerExprFactory& factory_; +}; + +absl::StatusOr FirstMatchComposer::ComposeRule(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + scopes_.emplace_back(); + auto pop_scope = absl::MakeCleanup([this]() { scopes_.pop_back(); }); + ComposeRuleVariables(rule, init, insertion_expr); + Expr* insertion_point = &insertion_expr; + const bool has_default = IsTriviallyTrueCondition(rule.matches.back()); + const bool needs_wrap = !IsExhaustive(rule); + size_t end = rule.matches.size() - (has_default ? 1 : 0); + for (size_t i = 0; i < end; i++) { + const auto& match = rule.matches[i]; + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + return absl::InternalError("detected unreachable match after validation"); + } + + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + + if (!IsTriviallyTrueCondition(match)) { + Ast condition = *match.condition->ast; + MapVariables(condition); + factory_.StartCopyContext(); + auto copy = factory_.Copy(condition.root_expr()); + auto source_info = factory_.RemapSourceInfo(condition.source_info()); + factory_.MergeSourceInfo(source_info); + *insertion_point = factory_.NewCall("_?_:_", std::move(copy)); + insertion_point->mutable_call_expr().mutable_args().push_back( + std::move(production)); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + insertion_point = &insertion_point->mutable_call_expr().add_args(); + continue; + } + + if (!is_wrapped) { + return absl::InternalError( + "composition failed. expected optional wrapped rule but got a plain " + "value"); + } + auto fn = needs_wrap ? "or" : "orValue"; + *insertion_point = factory_.NewMemberCall(fn, std::move(production)); + insertion_point = &insertion_point->mutable_call_expr().add_args(); + } + + if (has_default) { + const auto& match = rule.matches.back(); + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + *insertion_point = std::move(production); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + + return needs_wrap; + } + + // Otherwise, we fell through a non-exhaustive rule. + *insertion_point = factory_.NewCall("optional.none"); + return true; +} + +absl::StatusOr FirstMatchComposer::ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr) { + if (auto* nested_rule = + std::get_if>(&production); + nested_rule != nullptr) { + return ComposeRule(**nested_rule, init, insertion_expr); + } + auto* output = std::get_if(&production); + if (output == nullptr) { + return absl::InternalError("unexpected rule production type"); + } + const EmbeddedAst& output_ast = output->output_ast; + Ast ast = *output_ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + Expr to_insert = factory_.Copy(ast.root_expr()); + auto source_info = factory_.RemapSourceInfo(ast.source_info()); + factory_.MergeSourceInfo(source_info); + insertion_expr = std::move(to_insert); + + return false; +} + +absl::flat_hash_map FirstMatchComposer::ResolveBlockIndexes( + const Ast& ast) { + absl::flat_hash_map out; + for (auto it = ast.reference_map().begin(); it != ast.reference_map().end(); + it++) { + const Reference& ref = it->second; + if (!it->second.overload_id().empty()) { + continue; + } + if (!absl::StartsWith(ref.name(), "variable")) { + continue; + } + if (auto index = ResolvePolicyVariable(ref.name()); index.has_value()) { + out[it->first] = *index; + } + } + return out; +} + +void FirstMatchComposer::MapVariables(Ast& ast) { + absl::flat_hash_map edit_map = ResolveBlockIndexes(ast); + IndexRewrite rewriter(std::move(edit_map), factory_); + AstRewrite(ast.mutable_root_expr(), rewriter); +} + +void FirstMatchComposer::ComposeRuleVariables(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + for (const auto& variable : rule.variables) { + Ast ast = *variable.ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + auto insertion = factory_.Copy(ast.root_expr()); + // TODO(b/506179116): apply the position offsets here. + auto info = factory_.RemapSourceInfo(ast.source_info()); + ABSL_DCHECK(init.has_list_expr()); + int index = init.mutable_list_expr().elements().size(); + init.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(insertion))); + scopes_.back()[variable.ident] = index; + } +} + +bool HasComprehensionParent(const NavigableAstNode& node) { + const NavigableAstNode* curr = &node; + while (curr != nullptr) { + if (curr->node_kind() == NodeKind::kComprehension) { + return true; + } + curr = curr->parent(); + } + return false; +} + +// Unnester implementation. +class Unnester { + public: + Unnester(Ast ast, int height, policy_internal::IssueReporter& issues) + : factory_(std::move(ast)), height_(height), issues_(issues) {} + + // Run the unnesting. + // The class cannot be reused after this is called. + absl::StatusOr Unnest() { + if (height_ > 0) { + CEL_RETURN_IF_ERROR(Slice()); + } + CEL_RETURN_IF_ERROR(Cleanup()); + return std::move(factory_.mutable_ast()); + } + + private: + // The core unnest routine. + absl::Status Slice(); + // Fixup the AST post-unnesting. + absl::Status Cleanup(); + + void ReportErrorAtId(int64_t id, absl::string_view message); + + OptimizerExprFactory factory_; + int height_; + policy_internal::IssueReporter& issues_; +}; + +class UnnestRewriter : public AstRewriterBase { + public: + explicit UnnestRewriter(OptimizerExprFactory& f, Expr& block_list_expr, + absl::Span cuts) + : factory_(f), cuts_(cuts), block_list_expr_(block_list_expr) {} + + bool PostVisitRewrite(Expr& expr) override { + using std::swap; + // Post order so we always see children before parents. + // No need to copy metadata since we're only moving exprs or minting + // new ones. + if (absl::c_contains(cuts_, expr.id())) { + size_t idx = block_list_expr_.list_expr().elements().size(); + Expr value = factory_.NewIdent(absl::StrCat("@index", idx)); + factory_.RecordReplacement(expr.id(), value, /*keep_metadata=*/true); + swap(value, expr); + block_list_expr_.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(value))); + return true; + } + return false; + } + + private: + OptimizerExprFactory& factory_; + absl::Span cuts_; + Expr& block_list_expr_; +}; + +absl::Status Unnester::Slice() { + Expr& root = factory_.mutable_ast().mutable_root_expr(); + if (root.call_expr().function() != kCelBlock || + root.call_expr().args().size() != 2 || + !root.call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + // Two passes, we identify the slice points (bottom up), then cut + // and paste the leaves into the block list. + NavigableAst nav_ast = NavigableAst::Build(factory_.ast().root_expr()); + + ABSL_DCHECK(nav_ast.IdsAreUnique()); + bool can_cut = true; + std::vector cuts; + for (const NavigableAstNode& node : nav_ast.Root().DescendantsPostorder()) { + // Subsequent cuts will be height_ + 1 in the block, indices. Within the + // error margin we specified. + if (node.height() % height_ == 0) { + if (HasComprehensionParent(node)) { + ReportErrorAtId( + node.expr()->id(), + absl::StrCat( + "cannot unnest AST due to comprehension. cannot accommodate " + "height limit of ", + height_)); + can_cut = false; + continue; + } + if (&node == &nav_ast.Root()) { + // If evenly divisible by height, don't cut since it will net a taller + // AST. + continue; + } + cuts.push_back(node.expr()->id()); + } + } + + if (!can_cut || cuts.empty()) { + return absl::OkStatus(); + } + + Expr& block_list_expr = root.mutable_call_expr().mutable_args()[0]; + Expr& insertion_expr = root.mutable_call_expr().mutable_args()[1]; + + UnnestRewriter rewriter(factory_, block_list_expr, cuts); + AstRewrite(insertion_expr, rewriter); + + return absl::OkStatus(); +} + +absl::Status Unnester::Cleanup() { + using std::swap; + + const auto& ast = factory_.ast(); + if (ast.root_expr().call_expr().function() != kCelBlock || + ast.root_expr().call_expr().args().size() != 2 || + !ast.root_expr().call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + if (ast.root_expr().call_expr().args()[0].list_expr().elements().empty()) { + Expr value = std::move(factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]); + factory_.mutable_ast().mutable_root_expr() = std::move(value); + } + + return absl::OkStatus(); +} + +void Unnester::ReportErrorAtId(int64_t id, absl::string_view message) { + int32_t position = 0; + auto it = factory_.ast().source_info().positions().find(id); + if (it != factory_.ast().source_info().positions().end()) { + position = it->second; + } + issues_.ReportError(-1, position, message); +} +} // namespace + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options) { + policy_internal::IssueReporter issues; + if (options.unnesting_height_limit != 0 && + options.unnesting_height_limit < 2) { + return absl::InvalidArgumentError( + "unnesting_height_limit must be at least 2"); + } + auto builder = compiler.ToBuilder(); + ExpressionContainer cont; + for (const auto& import : policy.imports()) { + auto status = cont.AddAbbreviation(import.name().value()); + if (!status.ok()) { + issues.ReportError( + import.name().id(), + absl::StrCat("'", import.name().value(), "': ", status.message())); + } + } + + builder->GetCheckerBuilder().SetExpressionContainer(cont); + CEL_ASSIGN_OR_RETURN(auto base_compiler, builder->Build()); + + PolicyCompiler policy_compiler(&issues, std::move(base_compiler)); + + IntermediateCompiledPolicy icp; + CEL_RETURN_IF_ERROR(policy_compiler.CompilePolicy(policy, &icp)); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + CEL_ASSIGN_OR_RETURN(base_compiler, builder->Build()); + switch (icp.semantics()) { + case RuleSemantics::kFirstMatch: { + FirstMatchComposer composer(icp, *base_compiler, issues); + CEL_RETURN_IF_ERROR(composer.Compose()); + if (!composer.success()) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + auto ast = composer.ReleaseAst(); + Unnester unnester(std::move(*ast), options.unnesting_height_limit, + issues); + CEL_ASSIGN_OR_RETURN(Ast unnested_ast, unnester.Unnest()); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + return CelPolicyValidationResult( + std::make_unique(std::move(unnested_ast)), {}, + policy.source_ptr()); + } + default: + return absl::UnimplementedError( + absl::StrCat("Unsupported RuleSemantics: ", icp.semantics())); + } +} + +} // namespace cel diff --git a/policy/compiler.h b/policy/compiler.h new file mode 100644 index 000000000..0187bd1a2 --- /dev/null +++ b/policy/compiler.h @@ -0,0 +1,50 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_validation_result.h" + +namespace cel { + +struct CompilePolicyOptions { + // If greater than 0, the compiler will attempt to unnest rule branches + // at the specified height. The overall height of the final AST may exceed + // this by a small, fixed margin. + // + // To avoid slicing comprehensions, subexpressions within comprehensions + // are not eligible for unnesting. If the height limit cannot be accommodated, + // an error with code InvalidArgument is returned. + // + // If the AST is converted to proto, even relatively low levels of nesting + // can cause problems in serialization/deserialization. This does not apply + // if the AST is used directly by the runtime. + int unnesting_height_limit = 0; +}; + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +// +// TODO(b/506179116): Implementation in progress. Functionally complete, +// but errors are not consistent with other implementations. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ diff --git a/policy/compiler_test.cc b/policy/compiler_test.cc new file mode 100644 index 000000000..8db494b45 --- /dev/null +++ b/policy/compiler_test.cc @@ -0,0 +1,946 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/types/message_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::IntValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::OptionalValueIsEmpty; +using ::cel::test::StringValueIs; +using ::cel::test::ValueMatcher; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/cel_policy.yaml"; + +absl::StatusOr> BuildTestCompiler() { + CompilerOptions opts; + opts.adapt_parser_errors = true; + opts.parser_options.enable_optional_syntax = true; + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool(), opts)); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", IntType()))); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", IntType()))); + + const google::protobuf::Descriptor* descriptor = + cel::internal::GetSharedTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"); + if (descriptor == nullptr) { + return absl::InternalError("Failed to find TestAllTypes descriptor"); + } + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("spec", cel::MessageType(descriptor)))); + + return builder->Build(); +} + +absl::StatusOr> ParsePolicyFromYaml( + absl::string_view yaml_content) { + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(yaml_content, "test.yaml")); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(auto parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError("Invalid policy YAML structure"); + } + return parse_result.ReleasePolicy(); +} + +TEST(CompilerTest, SmokeTest) { + std::string contents; + std::string test_file = + cel::internal::ResolveRunfilesPath(kTestPolicyFilePath); + auto read_status = cel::internal::GetFileContents(test_file, &contents); + ASSERT_THAT(read_status, IsOk()); + + auto source_or = cel::NewSource(contents, "cel_policy.yaml"); + ASSERT_THAT(source_or.status(), IsOk()); + auto source = *std::move(source_or); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + auto parse_result_or = cel::ParseYamlCelPolicy(policy_source); + ASSERT_THAT(parse_result_or.status(), IsOk()); + auto parse_result = *std::move(parse_result_or); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, VariableOutOfScopeReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: variables.non_existent == 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, ConditionNotBoolReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("condition must evaluate to bool")); +} + +TEST(CompilerTest, InvalidOutputExpressionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: undeclared_var +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, UnreachableMatchAfterTriviallyTrueCondition) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"first"' + - condition: true + output: '"second"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, UnreachableMatchAfterUnconditionalExhaustiveSubRule) { + absl::string_view yaml = R"yaml( +name: dead_branch +rule: + match: + - rule: + match: + - output: 1 + - output: 2 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, RuleWithoutMatchesReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("rule does not specify match conditions")); +} + +TEST(CompilerTest, ExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 15 + output: '"greater than 15"' + - condition: variables.test_var > 5 + output: '"greater than 5"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +TEST(CompilerTest, NonExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 5 + output: '"greater than 5"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, PolicyReferencesEnvInput) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: spec.single_int32 > 10 + output: '"greater than 10"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +struct EvaluationTestCase { + std::string name; + std::string yaml_policy; + struct Input { + int64_t x; + int64_t y; + } input; + ValueMatcher expected_result_matcher; +}; + +class PolicyEvaluationTest : public testing::TestWithParam { +}; + +TEST_P(PolicyEvaluationTest, Evaluate) { + const auto& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(test_case.yaml_policy)); + ASSERT_OK_AND_ASSIGN(auto validation_result, + CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(validation_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, validation_result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + // Set up activation + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(test_case.input.x)); + activation.InsertOrAssignValue("y", cel::IntValue(test_case.input.y)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(cel::Value result, + program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, test_case.expected_result_matcher); +} + +constexpr absl::string_view kEvalPolicyYaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: x > 10 && y > 10 + output: '"both greater than 10"' + - condition: x > 10 + output: '"x greater than 10"' + - condition: y > 10 + output: '"y greater than 10"' + - output: '"default"' +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + PolicyEvaluationTest, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "BothGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 15}, + .expected_result_matcher = StringValueIs("both greater than 10"), + }, + EvaluationTestCase{ + .name = "XGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 5}, + .expected_result_matcher = StringValueIs("x greater than 10"), + }, + EvaluationTestCase{ + .name = "YGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 15}, + .expected_result_matcher = StringValueIs("y greater than 10"), + }, + EvaluationTestCase{ + .name = "Default", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 5}, + .expected_result_matcher = StringValueIs("default"), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNonExhaustivePolicyYaml = R"yaml( +name: nested_rule4 +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x < 3 + output: 1 + - condition: x < 5 + output: 2 + - condition: x < 0 + rule: + match: + - condition: x > -2 + output: 3 + - condition: x > -4 + output: 4 + - output: 5 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NonExhaustivePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals0_FallthroughTopLevel", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEquals2_MatchesFirstNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals6_FallthroughNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 6, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1_MatchesMinus2", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3_MatchesMinus4", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus5_MatchesDefault", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -5, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(5)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNestedVariablePolicyYaml = R"yaml( +name: nested_rule4 +rule: + variables: + - name: i + expression: "1" + - name: j + expression: "2" + match: + - condition: x > 0 + rule: + variables: + - name: k + expression: "3" + match: + - output: "variables.i + variables.j + variables.k" + - condition: x < 0 + rule: + variables: + - name: j + expression: "5" + - name: k + expression: "4" + match: + - output: "variables.i + variables.j + variables.k" + - output: "variables.i + variables.j" +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NestedVariablePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XGreaterThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(6), + }, + EvaluationTestCase{ + .name = "XLessThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(10), + }, + EvaluationTestCase{ + .name = "XEquals0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view + kOptionalChainingUnconditionalSubRuleOptionalParentYaml = R"yaml( +name: optional_chaining +rule: + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 + condition: x < 0 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRuleOptionalParent, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalSubRuleYaml = R"yaml( +name: optional_chaining +rule: + id: r1 + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRule, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(2), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalComplexYaml = R"yaml( +name: optional_chaining +rule: + match: + - condition: x > 0 + rule: + match: + - rule: + match: + - condition: x == 1 + output: 1 + - output: 2 + - rule: + match: + - condition: x == -1 + output: 3 + - condition: x == -2 + output: 4 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalComplex, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus2", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 + - output: 3 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = IntValueIs(2), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: non_exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalNonExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(CompilerTest, ImportsAndAbbreviations) { + absl::string_view yaml = R"yaml( +name: imports_test +imports: + - name: cel.expr.conformance.proto3.TestAllTypes +rule: + match: + - condition: 'spec == TestAllTypes{single_int32: 10}' + output: '"matched"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + auto ast_or = CompilePolicy(*compiler, *policy); + ASSERT_THAT(ast_or, IsOk()); +} + +TEST(CompilerTest, MatchWithoutProductionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match must specify an output or rule")); +} + +int GetAstHeight(const cel::Ast& ast) { + auto nav_ast = cel::NavigableAst::Build(ast.root_expr()); + return nav_ast.Root().height(); +} + +TEST(CompilerTest, UnnestHeightValidation) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 1; + auto status_or = CompilePolicy(*compiler, *policy, options); + EXPECT_THAT(status_or.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "unnesting_height_limit must be at least 2"))); + + options.unnesting_height_limit = 2; + EXPECT_THAT(CompilePolicy(*compiler, *policy, options), IsOk()); +} + +constexpr absl::string_view kDeepPolicyYaml = R"yaml( +name: deep_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x > 1 + rule: + match: + - condition: x > 2 + rule: + match: + - condition: x > 3 + rule: + match: + - condition: x > 4 + rule: + match: + - condition: x > 5 + output: 6 + - output: 5 + - output: 4 + - output: 3 + - output: 2 + - output: 1 + - output: 0 +)yaml"; + +TEST(CompilerTest, UnnestHeightReduction) { + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + // Compile without unnesting + CompilePolicyOptions options_no_unnest; + options_no_unnest.unnesting_height_limit = 0; + ASSERT_OK_AND_ASSIGN(auto result_no_unnest, + CompilePolicy(*compiler, *policy, options_no_unnest)); + ASSERT_TRUE(result_no_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_no_unnest, result_no_unnest.ReleaseAst()); + int height_no_unnest = GetAstHeight(*ast_no_unnest); + + CompilePolicyOptions options_unnest; + options_unnest.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result_unnest, + CompilePolicy(*compiler, *policy, options_unnest)); + ASSERT_TRUE(result_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_unnest, result_unnest.ReleaseAst()); + int height_unnest = GetAstHeight(*ast_unnest); + + EXPECT_EQ(height_no_unnest, 8); + EXPECT_EQ(height_unnest, 5); + EXPECT_LT(height_unnest, height_no_unnest); +} + +TEST(CompilerTest, UnnestComprehensionFailure) { + absl::string_view yaml = R"yaml( +name: comprehension_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: "[1, 2].all(i, i > x)" + output: 1 + - output: 2 + - output: 0 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("cannot unnest AST due to comprehension")); +} + +struct UnnestEvaluationTestCase { + std::string name; + int64_t x; + ValueMatcher expected; +}; + +class UnnestedDeepPolicyEvaluationTest + : public testing::TestWithParam {}; + +TEST_P(UnnestedDeepPolicyEvaluationTest, Evaluate) { + const auto& tc = GetParam(); + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(tc.x)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(cel::Value res, program->Evaluate(&arena, activation)); + + EXPECT_THAT(res, tc.expected); +} + +INSTANTIATE_TEST_SUITE_P( + UnnestedDeepPolicyEvaluation, UnnestedDeepPolicyEvaluationTest, + testing::Values(UnnestEvaluationTestCase{"XEquals6", 6, IntValueIs(6)}, + UnnestEvaluationTestCase{"XEquals5", 5, IntValueIs(5)}, + UnnestEvaluationTestCase{"XEquals4", 4, IntValueIs(4)}, + UnnestEvaluationTestCase{"XEquals3", 3, IntValueIs(3)}, + UnnestEvaluationTestCase{"XEquals2", 2, IntValueIs(2)}, + UnnestEvaluationTestCase{"XEquals1", 1, IntValueIs(1)}, + UnnestEvaluationTestCase{"XEquals0", 0, IntValueIs(0)}, + UnnestEvaluationTestCase{"XEqualsMinus1", -1, + IntValueIs(0)}), + [](const testing::TestParamInfo< + UnnestedDeepPolicyEvaluationTest::ParamType>& info) { + return info.param.name; + }); + +TEST(CompilerTest, UnnestCleanupRunsWhenDisabled) { + // A policy without variables and without nesting. + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 0; // Disabled + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // If cleanup ran, it should have optimized away the trivial `cel.@block`. + // So the root expression should NOT be a call to `cel.@block`. + // It should be just the constant `"ok"`. + auto nav_ast = cel::NavigableAst::Build(ast->root_expr()); + EXPECT_FALSE(nav_ast.Root().expr()->has_call_expr() && + nav_ast.Root().expr()->call_expr().function() == "cel.@block"); + EXPECT_TRUE(nav_ast.Root().expr()->has_const_expr()); +} +} // namespace +} // namespace cel diff --git a/policy/internal/BUILD b/policy/internal/BUILD new file mode 100644 index 000000000..30f43d431 --- /dev/null +++ b/policy/internal/BUILD @@ -0,0 +1,68 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "issue_reporter", + srcs = ["issue_reporter.cc"], + hdrs = ["issue_reporter.h"], + deps = [ + "//common:source", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "optimizer_expr_factory", + srcs = ["optimizer_expr_factory.cc"], + hdrs = ["optimizer_expr_factory.h"], + deps = [ + "//common:ast", + "//common:ast_rewrite", + "//common:ast_traverse", + "//common:ast_visitor_base", + "//common:constant", + "//common:expr", + "//common:expr_factory", + "//common:source", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "optimizer_expr_factory_test", + srcs = ["optimizer_expr_factory_test.cc"], + deps = [ + ":optimizer_expr_factory", + "//common:ast", + "//common:ast_proto", + "//common:ast_rewrite", + "//common:decl", + "//common:expr", + "//common:expr_factory", + "//common:source", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//testutil:expr_printer", + "//tools:cel_unparser", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/policy/internal/issue_reporter.cc b/policy/internal/issue_reporter.cc new file mode 100644 index 000000000..944e687d6 --- /dev/null +++ b/policy/internal/issue_reporter.cc @@ -0,0 +1,45 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/issue_reporter.h" + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel::policy_internal { + +void IssueReporter::ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message) { + issues_.push_back({element, severity, message}); +} + +void IssueReporter::ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, + absl::string_view message) { + issues_.push_back({element, relative_position, severity, message}); +} + +void IssueReporter::ReportError(CelPolicyElementId element, + absl::string_view message) { + ReportIssue(element, Severity::kError, message); +} + +void IssueReporter::ReportError(CelPolicyElementId element, SourcePosition pos, + absl::string_view message) { + ReportOffsetIssue(element, pos, Severity::kError, message); +} + +} // namespace cel::policy_internal diff --git a/policy/internal/issue_reporter.h b/policy/internal/issue_reporter.h new file mode 100644 index 000000000..3f88806ef --- /dev/null +++ b/policy/internal/issue_reporter.h @@ -0,0 +1,57 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel::policy_internal { + +class IssueReporter { + private: + using Severity = CelPolicyIssue::Severity; + + public: + void ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message); + + void ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, absl::string_view message); + + void ReportError(CelPolicyElementId element, absl::string_view message); + void ReportError(CelPolicyElementId element, SourcePosition relative_pos, + absl::string_view message); + + std::vector ReleaseIssues() { + using std::swap; + std::vector out; + swap(out, issues_); + return out; + } + const std::vector& issues() const { return issues_; } + + private: + std::vector issues_; +}; + +} // namespace cel::policy_internal + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ diff --git a/policy/internal/optimizer_expr_factory.cc b/policy/internal/optimizer_expr_factory.cc new file mode 100644 index 000000000..6c89ae958 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.cc @@ -0,0 +1,373 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor_base.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +namespace { + +class MaxIdVisitor final : public AstVisitorBase { + public: + ExprId max_id() const { return max_id_; } + + void PreVisitExpr(const Expr& expr) override { + max_id_ = std::max(max_id_, expr.id()); + } + + void PostVisitExpr(const Expr&) override {} + + void PostVisitStruct(const Expr&, const StructExpr& struct_expr) override { + for (const auto& field : struct_expr.fields()) { + max_id_ = std::max(max_id_, field.id()); + } + } + + void PostVisitMap(const Expr&, const MapExpr& map_expr) override { + for (const auto& entry : map_expr.entries()) { + max_id_ = std::max(max_id_, entry.id()); + } + } + + private: + ExprId max_id_ = 0; +}; + +ExprId GetMaxId(const Expr& expr) { + MaxIdVisitor visitor; + AstTraverse(expr, visitor); + return visitor.max_id(); +} + +ExprId GetMaxId(const Ast& ast) { + ExprId max_id = GetMaxId(ast.root_expr()); + for (const auto& [id, _] : ast.source_info().positions()) { + max_id = std::max(max_id, id); + } + for (const auto& [id, expr] : ast.source_info().macro_calls()) { + max_id = std::max(max_id, id); + max_id = std::max(max_id, GetMaxId(expr)); + } + return max_id; +} + +// Replaces nested macros in a macro_calls expr with reference nodes. +// +// The macro_calls map is used for retaining the original structure of the +// parsed expression before macro expansion. When a macro appears inside another +// macro, the parser will replace the inner macro expr node with an unspecified +// expr with the inner macro's ID in the macro_calls map to save space. +class MakeMacroCallRewrite final : public AstRewriterBase { + public: + explicit MakeMacroCallRewrite(const SourceInfo& source_info) + : source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (source_info_.macro_calls().find(expr.id()) != + source_info_.macro_calls().end()) { + ExprId id = expr.id(); + expr.mutable_kind() = UnspecifiedExpr(); + expr.set_id(id); + return true; + } + return false; + } + + private: + const SourceInfo& source_info_; +}; + +// Updates macro_calls map entries to reflect a replaced expression in the +// main AST. +class ReplaceMacroCallRewrite final : public AstRewriterBase { + public: + ReplaceMacroCallRewrite(ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) + : old_id_(old_id), replacement_(replacement), source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = macro_replacement(); + return true; + } + return false; + } + + Expr macro_replacement() { + if (!macro_replacement_) { + macro_replacement_.emplace(replacement_); + MakeMacroCallRewrite hole_creator(source_info_); + AstRewrite(*macro_replacement_, hole_creator); + } + return *macro_replacement_; + } + + private: + ExprId old_id_; + const Expr& replacement_; + absl::optional macro_replacement_; + const SourceInfo& source_info_; +}; + +void ReplaceSubExpr(Expr& expr, ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) { + ReplaceMacroCallRewrite rewriter(old_id, replacement, source_info); + AstRewrite(expr, rewriter); +} + +class IdRewriter : public AstRewriterBase { + using CopyIdFn = absl::AnyInvocable; + + public: + explicit IdRewriter(CopyIdFn copy_id) : copy_id_(std::move(copy_id)) {} + + // No structure changes just ids. + bool PreVisitRewrite(Expr& expr) override { + expr.set_id(copy_id_(expr.id())); + if (expr.has_struct_expr()) { + for (auto& field : expr.mutable_struct_expr().mutable_fields()) { + field.set_id(copy_id_(field.id())); + } + } else if (expr.has_map_expr()) { + for (auto& entry : expr.mutable_map_expr().mutable_entries()) { + entry.set_id(copy_id_(entry.id())); + } + } + return false; + } + + private: + CopyIdFn copy_id_; +}; + +} // namespace + +OptimizerExprFactory::OptimizerExprFactory(Ast basis) + : ast_(std::move(basis)), next_id_(GetMaxId(ast_) + 1) {} + +OptimizerExprFactory::OptimizerExprFactory() : next_id_(1) {} + +Expr OptimizerExprFactory::Copy(const Expr& expr) { + Expr copied = expr; + IdRewriter rewriter([this](ExprId id) { return CopyId(id); }); + AstRewrite(copied, rewriter); + return copied; +} + +ListExprElement OptimizerExprFactory::Copy(const ListExprElement& element) { + return NewListElement(Copy(element.expr()), element.optional()); +} + +StructExprField OptimizerExprFactory::Copy(const StructExprField& field) { + auto field_id = CopyId(field.id()); + auto field_value = Copy(field.value()); + return NewStructField(field_id, field.name(), std::move(field_value), + field.optional()); +} + +MapExprEntry OptimizerExprFactory::Copy(const MapExprEntry& entry) { + auto entry_id = CopyId(entry.id()); + auto entry_key = Copy(entry.key()); + auto entry_value = Copy(entry.value()); + return NewMapEntry(entry_id, std::move(entry_key), std::move(entry_value), + entry.optional()); +} + +ExprId OptimizerExprFactory::NextId() { return next_id_++; } + +ExprId OptimizerExprFactory::CopyId(ExprId id) { + if (id == 0) { + return 0; + } + auto it = renumbers_.find(id); + if (it != renumbers_.end()) { + return it->second; + } + ExprId new_id = NextId(); + renumbers_[id] = new_id; + return new_id; +} + +SourceInfo OptimizerExprFactory::RemapSourceInfo(const SourceInfo& info, + SourcePosition offset) { + SourceInfo out; + + for (const auto& [old_id, macro_expr] : info.macro_calls()) { + if (auto it = renumbers_.find(old_id); it != renumbers_.end()) { + ExprId new_id = it->second; + out.mutable_macro_calls()[new_id] = Copy(macro_expr); + } + } + + for (const auto& [old_id, new_id] : renumbers_) { + if (auto it = info.positions().find(old_id); it != info.positions().end()) { + out.mutable_positions()[new_id] = it->second + offset; + } + } + + return out; +} + +void OptimizerExprFactory::MergeSourceInfo(const SourceInfo& info) { + auto& target_info = ast_.mutable_source_info(); + + for (const auto& [id, pos] : info.positions()) { + auto [it, inserted] = target_info.mutable_positions().insert({id, pos}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in positions merge"}); + } + } + + for (const auto& [id, expr] : info.macro_calls()) { + auto [it, inserted] = target_info.mutable_macro_calls().insert({id, expr}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in macro calls merge"}); + } + } + + // TODO(b/506179116): need to add some check that we aren't + // introducing incompatible tags. Not possible in the policy compiler right + // now. + for (const auto& ext : info.extensions()) { + auto& target_exts = target_info.mutable_extensions(); + if (!absl::c_linear_search(target_exts, ext)) { + target_exts.push_back(ext); + } + } +} + +void OptimizerExprFactory::RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata) { + auto& source_info = ast_.mutable_source_info(); + if (!keep_metadata) { + source_info.mutable_positions().erase(id); + source_info.mutable_macro_calls().erase(id); + } + + for (auto& [macro_id, macro_expr] : source_info.mutable_macro_calls()) { + ReplaceSubExpr(macro_expr, id, replacement, source_info); + } +} + +Expr OptimizerExprFactory::ReportError(absl::string_view message) { + ExprId id = NextId(); + issues_.push_back(Issue{id, std::string(message)}); + return NewUnspecified(id); +} + +Expr OptimizerExprFactory::ReportErrorAt(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{expr.id(), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::ReportErrorAtCopy(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{CopyId(expr.id()), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::NewUnspecified() { return NewUnspecified(NextId()); } + +Expr OptimizerExprFactory::NewNullConst() { return NewNullConst(NextId()); } + +Expr OptimizerExprFactory::NewBoolConst(bool value) { + return NewBoolConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewIntConst(int64_t value) { + return NewIntConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewUintConst(uint64_t value) { + return NewUintConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewDoubleConst(double value) { + return NewDoubleConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(std::string value) { + return NewBytesConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewBytesConst(absl::string_view value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(const char* value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(std::string value) { + return NewStringConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewStringConst(absl::string_view value) { + return NewStringConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(const char* value) { + return NewStringConst(NextId(), value); +} + +absl::flat_hash_map OptimizerExprFactory::ConsumeRenumbers() { + using std::swap; + absl::flat_hash_map out; + swap(out, renumbers_); + return out; +} + +void OptimizerExprFactory::StartCopyContext() { renumbers_.clear(); } + +const std::vector& OptimizerExprFactory::issues() + const { + return issues_; +} + +const Ast& OptimizerExprFactory::ast() const { return ast_; } + +Ast& OptimizerExprFactory::mutable_ast() { return ast_; } + +absl::string_view OptimizerExprFactory::AccuVarName() { + return ExprFactory::AccuVarName(); +} + +Expr OptimizerExprFactory::NewAccuIdent() { return NewAccuIdent(NextId()); } + +ExprId OptimizerExprFactory::CopyId(const Expr& expr) { + return CopyId(expr.id()); +} + +} // namespace cel diff --git a/policy/internal/optimizer_expr_factory.h b/policy/internal/optimizer_expr_factory.h new file mode 100644 index 000000000..6f63f1485 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.h @@ -0,0 +1,419 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +class ParserMacroExprFactory; +class TestOptimizerExprFactory; + +// `OptimizerExprFactory` is a specialization of `ExprFactory` used for AST +// optimization. It provides utilities for correcting metadata for modified +// ASTs. +class OptimizerExprFactory : protected ExprFactory { + public: + struct Issue { + ExprId location = 0; + std::string message; + }; + + explicit OptimizerExprFactory(Ast basis); + OptimizerExprFactory(); + + protected: + using ExprFactory::IsArrayLike; + using ExprFactory::IsExprLike; + using ExprFactory::IsStringLike; + + template + struct IsRValue + : std::bool_constant< + std::disjunction_v, std::is_same>> {}; + + public: + // Consume the current set of renumberings. + absl::flat_hash_map ConsumeRenumbers(); + + // Starts a new copy context. The current set of renumberings are cleared. + void StartCopyContext(); + + const std::vector& issues() const; + + // Record that a node in the working AST was replaced. This is used to correct + // metadata referencing the old ID. + void RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata = false); + + // Makes a copy of source metadata that is remapped to new expr Ids using + // current renumberings. This is suitable for merging into the main source + // info. + SourceInfo RemapSourceInfo(const SourceInfo& info, SourcePosition offset = 0); + + // Merge a remapped SourceInfo into the current one. + void MergeSourceInfo(const SourceInfo& info); + + const Ast& ast() const; + Ast& mutable_ast(); + + absl::string_view AccuVarName(); + + ABSL_MUST_USE_RESULT Expr Copy(const Expr& expr); + + ABSL_MUST_USE_RESULT ListExprElement Copy(const ListExprElement& element); + + ABSL_MUST_USE_RESULT StructExprField Copy(const StructExprField& field); + + ABSL_MUST_USE_RESULT MapExprEntry Copy(const MapExprEntry& entry); + + ABSL_MUST_USE_RESULT Expr NewUnspecified(); + + ABSL_MUST_USE_RESULT Expr NewNullConst(); + + ABSL_MUST_USE_RESULT Expr NewBoolConst(bool value); + + ABSL_MUST_USE_RESULT Expr NewIntConst(int64_t value); + + ABSL_MUST_USE_RESULT Expr NewUintConst(uint64_t value); + + ABSL_MUST_USE_RESULT Expr NewDoubleConst(double value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(const char* absl_nullable value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(const char* absl_nullable value); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewIdent(Name name); + + ABSL_MUST_USE_RESULT Expr NewAccuIdent(); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewSelect(Operand operand, Field field); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewPresenceTest(Operand operand, Field field); + + template < + typename Function, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args args); + + template < + typename Function, typename Target, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args args); + + using ExprFactory::NewListElement; + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewList(Elements&&... elements); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewList(Elements elements); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT StructExprField NewStructField(Name name, Value value, + bool optional = false); + + template ::value>, + typename = std::enable_if_t< + std::conjunction_v...>>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields&&... fields); + + template < + typename Name, typename Fields, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields fields); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT MapExprEntry NewMapEntry(Key key, Value value, + bool optional = false); + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries&&... entries); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries entries); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result); + + ABSL_MUST_USE_RESULT Expr ReportError(absl::string_view message); + + // Reports an error at the id in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAt(const Expr& expr, + absl::string_view message); + // Reports an error at the mapped id of the copy of expr in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAtCopy(const Expr& expr, + absl::string_view message); + + protected: + ABSL_MUST_USE_RESULT ExprId NextId(); + + ABSL_MUST_USE_RESULT ExprId CopyId(ExprId id); + + ABSL_MUST_USE_RESULT ExprId CopyId(const Expr& expr); + + using ExprFactory::AccuVarName; + using ExprFactory::NewAccuIdent; + using ExprFactory::NewBoolConst; + using ExprFactory::NewBytesConst; + using ExprFactory::NewCall; + using ExprFactory::NewComprehension; + using ExprFactory::NewConst; + using ExprFactory::NewDoubleConst; + using ExprFactory::NewIdent; + using ExprFactory::NewIntConst; + using ExprFactory::NewList; + using ExprFactory::NewMap; + using ExprFactory::NewMapEntry; + using ExprFactory::NewMemberCall; + using ExprFactory::NewNullConst; + using ExprFactory::NewPresenceTest; + using ExprFactory::NewSelect; + using ExprFactory::NewStringConst; + using ExprFactory::NewStruct; + using ExprFactory::NewStructField; + using ExprFactory::NewUintConst; + using ExprFactory::NewUnspecified; + + private: + Ast ast_; + absl::flat_hash_map renumbers_; + std::vector issues_; + + ExprId next_id_ = 1; +}; + +// Implementation details. + +template +Expr OptimizerExprFactory::NewIdent(Name name) { + return NewIdent(NextId(), std::move(name)); +} + +template +Expr OptimizerExprFactory::NewSelect(Operand operand, Field field) { + return NewSelect(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewPresenceTest(Operand operand, Field field) { + return NewPresenceTest(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewCall(NextId(), std::move(function), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args args) { + return NewCall(NextId(), std::move(function), std::move(args)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args args) { + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(args)); +} + +template +Expr OptimizerExprFactory::NewList(Elements&&... elements) { + std::vector array; + array.reserve(sizeof...(Elements)); + (array.push_back(std::forward(elements)), ...); + return NewList(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewList(Elements elements) { + return NewList(NextId(), std::move(elements)); +} + +template +StructExprField OptimizerExprFactory::NewStructField(Name name, Value value, + bool optional) { + return NewStructField(NextId(), std::move(name), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields&&... fields) { + std::vector array; + array.reserve(sizeof...(Fields)); + (array.push_back(std::forward(fields)), ...); + return NewStruct(NextId(), std::move(name), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields fields) { + return NewStruct(NextId(), std::move(name), std::move(fields)); +} + +template +MapExprEntry OptimizerExprFactory::NewMapEntry(Key key, Value value, + bool optional) { + return NewMapEntry(NextId(), std::move(key), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewMap(Entries&&... entries) { + std::vector array; + array.reserve(sizeof...(Entries)); + (array.push_back(std::forward(entries)), ...); + return NewMap(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMap(Entries entries) { + return NewMap(NextId(), std::move(entries)); +} + +template +Expr OptimizerExprFactory::NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_range), + std::move(accu_var), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); +} + +template +Expr OptimizerExprFactory::NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ diff --git a/policy/internal/optimizer_expr_factory_test.cc b/policy/internal/optimizer_expr_factory_test.cc new file mode 100644 index 000000000..1b14b5628 --- /dev/null +++ b/policy/internal/optimizer_expr_factory_test.cc @@ -0,0 +1,570 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/ast_rewrite.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "testutil/expr_printer.h" +#include "tools/cel_unparser.h" + +namespace cel { + +using ::testing::SizeIs; + +// Expose protected members of OptimizerExprFactory for use in tests +// +// These allow setting explicit IDs which is not safe for the optimizing +// factory. +class TestOptimizerExprFactory final : public OptimizerExprFactory { + public: + using OptimizerExprFactory::OptimizerExprFactory; + + using OptimizerExprFactory::NewBoolConst; + using OptimizerExprFactory::NewCall; + using OptimizerExprFactory::NewComprehension; + using OptimizerExprFactory::NewIdent; + using OptimizerExprFactory::NewList; + using OptimizerExprFactory::NewListElement; + using OptimizerExprFactory::NewMap; + using OptimizerExprFactory::NewMapEntry; + using OptimizerExprFactory::NewMemberCall; + using OptimizerExprFactory::NewSelect; + using OptimizerExprFactory::NewStruct; + using OptimizerExprFactory::NewStructField; + using OptimizerExprFactory::NewUnspecified; + using OptimizerExprFactory::NextId; +}; + +namespace { + +class ReplaceExprRewriter final : public AstRewriterBase { + public: + ReplaceExprRewriter(ExprId old_id, const Expr& replacement) + : old_id_(old_id), replacement_(replacement) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = replacement_; + return true; + } + return false; + } + + private: + ExprId old_id_; + const Expr& replacement_; +}; + +void ReplaceExprInTree(Expr& expr, ExprId old_id, const Expr& replacement) { + ReplaceExprRewriter rewriter(old_id, replacement); + AstRewrite(expr, rewriter); +} + +absl::StatusOr> CreateTestCompiler() { + CompilerOptions opts; + opts.parser_options.add_macro_calls = true; + CEL_ASSIGN_OR_RETURN( + auto builder, cel::NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("to_replace", cel::DynType()))); + return builder->Build(); +} + +TEST(OptimizerExprFactory, CopyUnspecified) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); +} + +TEST(OptimizerExprFactory, CopyIdent) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewIdent("foo")), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyConst) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewBoolConst(true)), + factory.NewBoolConst(2, true)); +} + +TEST(OptimizerExprFactory, CopySelect) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewSelect(factory.NewIdent("foo"), "bar")), + factory.NewSelect(3, factory.NewIdent(4, "foo"), "bar")); +} + +TEST(OptimizerExprFactory, CopyCall) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_args; + copied_args.reserve(1); + copied_args.push_back(factory.NewIdent(6, "baz")); + EXPECT_EQ(factory.Copy(factory.NewMemberCall("bar", factory.NewIdent("foo"), + factory.NewIdent("baz"))), + factory.NewMemberCall(4, "bar", factory.NewIdent(5, "foo"), + absl::MakeSpan(copied_args))); +} + +TEST(OptimizerExprFactory, CopyList) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_elements; + copied_elements.reserve(1); + copied_elements.push_back(factory.NewListElement(factory.NewIdent(4, "foo"))); + EXPECT_EQ(factory.Copy(factory.NewList( + factory.NewListElement(factory.NewIdent("foo")))), + factory.NewList(3, absl::MakeSpan(copied_elements))); +} + +TEST(OptimizerExprFactory, CopyStruct) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_fields; + copied_fields.reserve(1); + copied_fields.push_back( + factory.NewStructField(5, "bar", factory.NewIdent(6, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewStruct( + "foo", factory.NewStructField("bar", factory.NewIdent("baz")))), + factory.NewStruct(4, "foo", absl::MakeSpan(copied_fields))); +} + +TEST(OptimizerExprFactory, CopyMap) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_entries; + copied_entries.reserve(1); + copied_entries.push_back(factory.NewMapEntry(6, factory.NewIdent(7, "bar"), + factory.NewIdent(8, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewMap(factory.NewMapEntry( + factory.NewIdent("bar"), factory.NewIdent("baz")))), + factory.NewMap(5, absl::MakeSpan(copied_entries))); +} + +TEST(OptimizerExprFactory, CopyComprehension) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ( + factory.Copy(factory.NewComprehension( + "foo", factory.NewList(), "bar", factory.NewBoolConst(true), + factory.NewIdent("baz"), factory.NewIdent("foo"), + factory.NewIdent("bar"))), + factory.NewComprehension( + 7, "foo", factory.NewList(8, std::vector()), "bar", + factory.NewBoolConst(9, true), factory.NewIdent(10, "baz"), + factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); +} + +TEST(OptimizerExprFactory, RemapSourceInfo) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + info.mutable_positions()[1] = 42; // old ID 1 has position 42 + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to position 42 + 10 = 52 + auto it = remapped.positions().find(2); + ASSERT_NE(it, remapped.positions().end()); + EXPECT_EQ(it->second, 52); +} + +TEST(OptimizerExprFactory, RemapSourceInfoWithMacroCalls) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + // old ID 1 has macro call with ID 3 + info.mutable_macro_calls()[1] = factory.NewIdent("bar"); + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to the copied macro call + // since "bar" has ID 3, Copy(bar) should map ID 3 to ID 4 + + auto it = remapped.macro_calls().find(2); + ASSERT_NE(it, remapped.macro_calls().end()); + + // The macro call should be an Ident with new ID 4 + EXPECT_EQ(it->second.id(), 4); + EXPECT_TRUE(it->second.has_ident_expr()); + EXPECT_EQ(it->second.ident_expr().name(), "bar"); +} + +TEST(OptimizerExprFactory, ReportError) { + TestOptimizerExprFactory factory{Ast()}; + Expr err_expr = factory.ReportError("something went wrong"); + + // err_expr should be unspecified with ID 1 + EXPECT_EQ(err_expr.id(), 1); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with ID 1 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "something went wrong"); +} + +TEST(OptimizerExprFactory, ReportErrorAt) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + Expr err_expr = factory.ReportErrorAtCopy(orig, "error on foo"); + + // err_expr should be unspecified with ID 3 (NextId) + EXPECT_EQ(err_expr.id(), 3); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with mapped ID 2 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 2); + EXPECT_EQ(factory.issues()[0].message, "error on foo"); +} + +TEST(OptimizerExprFactory, MergeSourceInfo) { + // Create a base AST with some source info + SourceInfo base_info; + base_info.set_syntax_version("cel1"); + base_info.set_location("test.cel"); + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + + TestOptimizerExprFactory factory{std::move(base_ast)}; + + // Create a new source info to merge + SourceInfo new_info; + new_info.mutable_positions()[2] = 20; + + factory.MergeSourceInfo(new_info); + + // The merged source info should have both positions + const auto& merged_info = factory.ast().source_info(); + EXPECT_EQ(merged_info.syntax_version(), "cel1"); + EXPECT_EQ(merged_info.location(), "test.cel"); + + auto it1 = merged_info.positions().find(1); + ASSERT_NE(it1, merged_info.positions().end()); + EXPECT_EQ(it1->second, 10); + + auto it2 = merged_info.positions().find(2); + ASSERT_NE(it2, merged_info.positions().end()); + EXPECT_EQ(it2->second, 20); +} + +TEST(OptimizerExprFactory, MergeSourceInfoConflict) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_positions()[1] = 20; // conflicting ID 1 + + factory.MergeSourceInfo(new_info); + + // Should report an error for the conflict + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in positions merge"); +} + +TEST(OptimizerExprFactory, RecordReplacement) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + base_info.mutable_positions()[2] = 20; + + TestOptimizerExprFactory factory{Ast()}; + + // macro_calls[1] maps ID 1 to macro call "bar(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[1] = + factory.NewCall("bar", factory.NewIdent(1, "foo")); + + // macro_calls[2] maps ID 2 to macro call "baz(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[2] = + factory.NewCall("baz", factory.NewIdent(1, "foo")); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory optimizer{std::move(base_ast)}; + + // Record the replacement of ID 1 by a new Ident "replacement" with ID 3 + optimizer.RecordReplacement(1, factory.NewIdent(3, "replacement")); + + const auto& result_info = optimizer.ast().source_info(); + + // 1. ID 1 should be erased from positions + EXPECT_EQ(result_info.positions().find(1), result_info.positions().end()); + EXPECT_NE(result_info.positions().find(2), result_info.positions().end()); + + // 2. ID 1 should be erased from macro_calls keys + EXPECT_EQ(result_info.macro_calls().find(1), result_info.macro_calls().end()); + + // 3. macro_calls[2] should still exist, but its argument referencing ID 1 + // should be replaced with the Ident "replacement" with ID 3 inline + auto it = result_info.macro_calls().find(2); + ASSERT_NE(it, result_info.macro_calls().end()); + + const Expr& macro_expr = it->second; + ASSERT_TRUE(macro_expr.has_call_expr()); + ASSERT_EQ(macro_expr.call_expr().args().size(), 1); + + const Expr& arg = macro_expr.call_expr().args()[0]; + EXPECT_EQ(arg.id(), 3); + EXPECT_TRUE(arg.has_ident_expr()); + EXPECT_EQ(arg.ident_expr().name(), "replacement"); +} + +class IdAdorner : public cel::test::ExpressionAdorner { + public: + std::string Adorn(const cel::Expr& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornStructField(const cel::StructExprField& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornMapEntry(const cel::MapExprEntry& e) const override { + return absl::StrCat("#", e.id()); + } +}; + +TEST(OptimizerExprFactory, UnparseCopiedMacroCall) { + // Arrange: create an template expression and one to inline. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto basis_result, + compiler->Compile("[1].map(x, x + to_replace)")); + ASSERT_TRUE(basis_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto basis_ast, basis_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto copy_result, + compiler->Compile("[1].filter(x, x > 2).size()")); + ASSERT_TRUE(copy_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto copy_ast, copy_result.ReleaseAst()); + + // Locate the "to_replace" IdentExpr node in reference_map + ExprId to_replace_id = 0; + for (const auto& [id, ref] : basis_ast->reference_map()) { + if (ref.name() == "to_replace") { + to_replace_id = id; + break; + } + } + ASSERT_NE(to_replace_id, 0); + + // Act: implement the optimization. + TestOptimizerExprFactory factory{std::move(*basis_ast)}; + Expr copied_expr = factory.Copy(copy_ast->root_expr()); + SourceInfo remapped_info = factory.RemapSourceInfo(copy_ast->source_info()); + factory.MergeSourceInfo(remapped_info); + + ReplaceExprInTree(factory.mutable_ast().mutable_root_expr(), to_replace_id, + copied_expr); + factory.RecordReplacement(to_replace_id, copied_expr); + + // Test AST structure. + EXPECT_EQ( + cel::test::ExprPrinter(IdAdorner()).Print(factory.ast().root_expr()), + R"(__comprehension__( + // Variable + x, + // Target + [ + 1#2 + ]#1, + // Accumulator + @result, + // Init + []#8, + // LoopCondition + true#9, + // LoopStep + _+_( + @result#10, + [ + _+_( + x#5, + __comprehension__( + // Variable + x, + // Target + [ + 1#18 + ]#17, + // Accumulator + @result, + // Init + []#19, + // LoopCondition + true#20, + // LoopStep + _?_:_( + _>_( + x#23, + 2#24 + )#22, + _+_( + @result#26, + [ + x#28 + ]#27 + )#25, + @result#29 + )#21, + // Result + @result#30)#16.size()#15 + )#6 + ]#11 + )#12, + // Result + @result#13)#14)"); + + // Check that the structure is compatible with unparser. + cel::expr::ParsedExpr optimized_parsed; + auto status = AstToParsedExpr(factory.ast(), &optimized_parsed); + ASSERT_THAT(status, absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::string unparsed, + google::api::expr::Unparse(optimized_parsed)); + + EXPECT_EQ(unparsed, "[1].map(x, x + [1].filter(x, x > 2).size())"); + + const CallExpr& call_expr = factory.mutable_ast() + .mutable_source_info() + .mutable_macro_calls()[14] + .mutable_call_expr(); + ASSERT_THAT(call_expr.args(), SizeIs(2)); + ASSERT_THAT(call_expr.args()[1].call_expr().args(), SizeIs(2)); + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].id(), 15); + + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].call_expr().target().id(), + 16); + EXPECT_EQ(call_expr.args()[1] + .call_expr() + .args()[1] + .call_expr() + .target() + .kind_case(), + ExprKindCase::kUnspecifiedExpr); +} + +TEST(OptimizerExprFactory, CopyMultipleAstsWithConsumeRenumbers) { + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto ast1_result, compiler->Compile("[1]")); + ASSERT_TRUE(ast1_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast1, ast1_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto ast2_result, compiler->Compile("2")); + ASSERT_TRUE(ast2_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast2, ast2_result.ReleaseAst()); + + TestOptimizerExprFactory factory{Ast()}; + + Expr copied1 = factory.Copy(ast1->root_expr()); + auto renumbers1 = factory.ConsumeRenumbers(); + + Expr copied2 = factory.Copy(ast2->root_expr()); + auto renumbers2 = factory.ConsumeRenumbers(); + + EXPECT_EQ(renumbers1.size(), 2); + EXPECT_EQ(renumbers2.size(), 1); + + EXPECT_NE(copied1.id(), copied2.id()); + EXPECT_GT(copied2.id(), copied1.id()); +} + +TEST(OptimizerExprFactory, MaxIdVisitorExprKinds) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateTestCompiler()); + + // Expression that covers all the kinds. + ASSERT_OK_AND_ASSIGN(auto source, NewSource(R"cel( + Struct{field : 1} || + {'key' : 'value'} || [1].exists(x, x) || foo(bar))cel")); + ASSERT_OK_AND_ASSIGN(auto ast, compiler->GetParser().Parse(*source)); + + TestOptimizerExprFactory factory{std::move(*ast)}; + + EXPECT_EQ(factory.NextId(), 26); +} + +TEST(OptimizerExprFactory, CopyListElement) { + TestOptimizerExprFactory factory{Ast()}; + ListExprElement orig = factory.NewListElement(factory.NewIdent("foo")); + ListExprElement copied = factory.Copy(orig); + EXPECT_EQ(copied.expr(), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyStructField) { + TestOptimizerExprFactory factory{Ast()}; + StructExprField orig = factory.NewStructField("bar", factory.NewIdent("baz")); + StructExprField copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 3); + EXPECT_EQ(copied.name(), "bar"); + EXPECT_EQ(copied.value(), factory.NewIdent(4, "baz")); +} + +TEST(OptimizerExprFactory, CopyMapEntry) { + TestOptimizerExprFactory factory{Ast()}; + MapExprEntry orig = + factory.NewMapEntry(factory.NewIdent("bar"), factory.NewIdent("baz")); + MapExprEntry copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 4); + EXPECT_EQ(copied.key(), factory.NewIdent(5, "bar")); + EXPECT_EQ(copied.value(), factory.NewIdent(6, "baz")); +} + +TEST(OptimizerExprFactory, MergeSourceInfoMacroConflict) { + SourceInfo base_info; + base_info.mutable_macro_calls()[1] = Expr(); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_macro_calls()[1] = Expr(); + + factory.MergeSourceInfo(new_info); + + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in macro calls merge"); +} + +} // namespace +} // namespace cel diff --git a/policy/test_custom_yaml_policy_parser.cc b/policy/test_custom_yaml_policy_parser.cc new file mode 100644 index 000000000..faced6952 --- /dev/null +++ b/policy/test_custom_yaml_policy_parser.cc @@ -0,0 +1,188 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parser.h" +#include "policy/yaml_policy_parser.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel::internal { + +// TestCustomYamlPolicyParser is used to support unit tests for custom tags +// and custom policy structures. It demonstrates the versatility of the +// cel::YamlPolicyParser framework API by implementing custom tag and block +// parsing without needing to modify the core parser. +class TestCustomYamlPolicyParser : public cel::YamlPolicyParser { + absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const override { + if (tag_name.value() == "name" || tag_name.value() == "description" || + tag_name.value() == "imports") { + return cel::YamlPolicyParser::ParsePolicyTag(ctx, tag_name, node); + } + if (tag_name.value() == "purpose") { + std::optional purpose = + GetValueString(ctx, node, "Policy purpose is not a string"); + if (purpose.has_value()) { + ctx.policy().mutable_metadata()["purpose"] = *purpose; + } + return true; + } + if (tag_name.value() == "version") { + std::optional version = + GetValueString(ctx, node, "Policy version is not a string"); + if (!version.has_value()) { + return true; + } + int version_int; + if (!absl::SimpleAtoi(version->value(), &version_int)) { + ctx.ReportError(version->id(), + absl::StrCat("Policy version is not an integer: ", + version->value())); + return true; + } + ctx.policy().mutable_metadata()["version"] = version_int; + return true; + } + + if (tag_name.value() == "conditions") { + if (!node.IsSequence()) { + ctx.ReportError(tag_name.id(), "Policy 'conditions' is not a sequence"); + return true; + } + for (const YAML::Node& condition : node) { + // Track the number of existing matches before parsing. When ParseMatch + // evaluates an 'else' block, it recursively triggers parsing and adds + // internal inner matches directly to the rule's match vector. + // Inserting the outer match at begin() + size_before ensures that the + // primary outer 'if' condition is always evaluated before its nested + // 'else' fallbacks. + // + // Example: + // if: x > 0 + // then: "positive" + // else: "negative" + // + // The inner "negative" match is parsed and appended to rule.matches() + // by the inner recursive call, before the outer "x > 0" match finishes. + // Inserting at size_before places the "x > 0" match ahead of the inner + // one. + size_t size_before = ctx.policy().rule().matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, + cel::YamlPolicyParser::ParseMatch( + ctx, condition, ctx.policy().mutable_rule())); + ctx.policy().mutable_rule().mutable_matches().insert( + ctx.policy().mutable_rule().mutable_matches().begin() + size_before, + std::move(match)); + } + + return true; + } + return false; + } + + absl::Status ParseThenBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, + Match& match) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'then' is not a string"); + if (val.has_value()) { + OutputBlock output; + output.set_output(*val); + match.set_result(output); + } + } else if (value_node.IsMap()) { + auto nested_rule = std::make_unique(); + CEL_ASSIGN_OR_RETURN( + Match nested_match, + cel::YamlPolicyParser::ParseMatch(ctx, value_node, *nested_rule)); + nested_rule->mutable_matches().insert( + nested_rule->mutable_matches().begin(), std::move(nested_match)); + match.set_result(std::move(nested_rule)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::Status ParseElseBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, Rule& rule) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'else' is not a string"); + if (val.has_value()) { + Match else_match; + else_match.set_id(CollectMetadata(ctx, value_node)); + OutputBlock output; + output.set_output(*val); + else_match.set_result(output); + rule.mutable_matches().push_back(std::move(else_match)); + } + } else if (value_node.IsMap()) { + size_t size_before = rule.matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, cel::YamlPolicyParser::ParseMatch( + ctx, value_node, rule)); + rule.mutable_matches().insert( + rule.mutable_matches().begin() + size_before, std::move(match)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, Match& match, + Rule& rule) const override { + if (tag_name.value() == "if") { + std::optional condition = + GetValueString(ctx, node, "Policy 'if' condition is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "then") { + CEL_RETURN_IF_ERROR(ParseThenBlock(ctx, node, match)); + return true; + } + if (tag_name.value() == "else") { + CEL_RETURN_IF_ERROR(ParseElseBlock(ctx, node, rule)); + return true; + } + return false; + } +}; + +const CelPolicyParser& GetTestCustomYamlPolicyParser() { + static const auto* const parser = new TestCustomYamlPolicyParser(); + return *parser; +} + +} // namespace cel::internal diff --git a/policy/test_util.cc b/policy/test_util.cc new file mode 100644 index 000000000..9fe1e43d1 --- /dev/null +++ b/policy/test_util.cc @@ -0,0 +1,221 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +#include "policy/test_util.h" + +#include +#include +#include +#include + +#include "cel/expr/eval.pb.h" +#include "cel/expr/value.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/status_macros.h" +#include "yaml-cpp/yaml.h" + +namespace cel::test { + +namespace { + +absl::Status YamlToExprValue(const YAML::Node& node, + cel::expr::Value* proto) { + if (node.IsNull()) { + proto->set_null_value(google::protobuf::NULL_VALUE); + return absl::OkStatus(); + } + if (node.IsScalar()) { + // Try bool + try { + proto->set_bool_value(node.as()); + return absl::OkStatus(); + } catch (...) { + } + // Try int64 + try { + int64_t val; + if (YAML::convert::decode(node, val)) { + proto->set_int64_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Try double + try { + double val; + if (YAML::convert::decode(node, val)) { + proto->set_double_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Fallback to string + proto->set_string_value(node.as()); + return absl::OkStatus(); + } + if (node.IsSequence()) { + auto* list = proto->mutable_list_value(); + for (const auto& elem : node) { + CEL_RETURN_IF_ERROR(YamlToExprValue(elem, list->add_values())); + } + return absl::OkStatus(); + } + if (node.IsMap()) { + auto* map_val = proto->mutable_map_value(); + for (auto it = node.begin(); it != node.end(); ++it) { + auto* entry = map_val->add_entries(); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->first, entry->mutable_key())); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->second, entry->mutable_value())); + } + return absl::OkStatus(); + } + return absl::InvalidArgumentError("Unknown YAML node type"); +} + +absl::Status ParseInputValue( + const YAML::Node& node, + cel::expr::conformance::test::InputValue* input_val) { + if (node.IsMap() && node["expr"].IsDefined()) { + input_val->set_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node.IsMap() && node["value"].IsDefined()) { + return YamlToExprValue(node["value"], input_val->mutable_value()); + } + return YamlToExprValue(node, input_val->mutable_value()); +} + +absl::Status ParseTestOutput(const YAML::Node& node, + cel::expr::conformance::test::TestOutput* output) { + if (!node.IsDefined()) { + return absl::InvalidArgumentError("Missing output node"); + } + if (node.IsMap()) { + if (node["expr"].IsDefined()) { + output->set_result_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node["value"].IsDefined()) { + return YamlToExprValue(node["value"], output->mutable_result_value()); + } + if (node["error"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + eval_error->add_errors()->set_message(node["error"].as()); + return absl::OkStatus(); + } + if (node["error_set"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + for (const auto& err : node["error_set"]) { + eval_error->add_errors()->set_message(err.as()); + } + return absl::OkStatus(); + } + if (node["unknown"].IsDefined()) { + auto* unknown = output->mutable_unknown(); + for (const auto& expr_id_node : node["unknown"]) { + unknown->add_exprs(expr_id_node.as()); + } + return absl::OkStatus(); + } + } + return YamlToExprValue(node, output->mutable_result_value()); +} + +absl::StatusOr +ParsePolicyTestSuiteYamlImpl(absl::string_view yaml_content) { + YAML::Node tests_node; + try { + tests_node = YAML::Load(std::string(yaml_content)); + } catch (const std::exception& e) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse YAML: ", e.what())); + } + + cel::expr::conformance::test::TestSuite test_suite; + if (tests_node["description"].IsDefined()) { + test_suite.set_description(tests_node["description"].as()); + } + + YAML::Node sections = tests_node["sections"]; + if (!sections.IsDefined()) { + sections = tests_node["section"]; // support singular format + } + if (!sections.IsDefined()) { + return absl::InvalidArgumentError( + "Missing 'sections' or 'section' in tests YAML"); + } + + for (const auto& section_node : sections) { + auto* section = test_suite.add_sections(); + if (section_node["name"].IsDefined()) { + section->set_name(section_node["name"].as()); + } + if (section_node["description"].IsDefined()) { + section->set_description(section_node["description"].as()); + } + + YAML::Node tests = section_node["tests"]; + if (!tests.IsDefined()) { + tests = section_node["test"]; // support singular format + } + if (!tests.IsDefined()) { + continue; + } + + for (const auto& test_node : tests) { + auto* test_case = section->add_tests(); + if (test_node["name"].IsDefined()) { + test_case->set_name(test_node["name"].as()); + } + if (test_node["description"].IsDefined()) { + test_case->set_description(test_node["description"].as()); + } + if (test_node["context_expr"].IsDefined()) { + test_case->mutable_input_context()->set_context_expr( + test_node["context_expr"].as()); + } + + YAML::Node input_node = test_node["input"]; + if (input_node.IsDefined() && input_node.IsMap()) { + auto* input_map = test_case->mutable_input(); + for (auto it = input_node.begin(); it != input_node.end(); ++it) { + std::string var_name = it->first.as(); + cel::expr::conformance::test::InputValue input_val; + CEL_RETURN_IF_ERROR(ParseInputValue(it->second, &input_val)); + (*input_map)[var_name] = std::move(input_val); + } + } + + YAML::Node output_node = test_node["output"]; + if (output_node.IsDefined()) { + CEL_RETURN_IF_ERROR( + ParseTestOutput(output_node, test_case->mutable_output())); + } + } + } + + return test_suite; +} + +} // namespace + +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content) { + try { + return ParsePolicyTestSuiteYamlImpl(yaml_content); + } catch (...) { + return absl::InvalidArgumentError("Failed to parse YAML"); + } +} + +} // namespace cel::test diff --git a/policy/test_util.h b/policy/test_util.h new file mode 100644 index 000000000..5fe306050 --- /dev/null +++ b/policy/test_util.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "cel/expr/conformance/test/suite.pb.h" + +namespace cel::test { + +// Parses a YAML content representing a policy test suite (tests.yaml) +// and adapts it to the cel.expr.conformance.test.TestSuite protobuf message. +// +// TODO(uncreated-issue/92): Move to the testrunner library. +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ diff --git a/policy/testdata/BUILD b/policy/testdata/BUILD new file mode 100644 index 000000000..10a26fa0b --- /dev/null +++ b/policy/testdata/BUILD @@ -0,0 +1,19 @@ +package( + default_testonly = True, + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "policy_testdata", + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) + +exports_files( + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) diff --git a/policy/testdata/cel_policy.yaml b/policy/testdata/cel_policy.yaml new file mode 100644 index 000000000..010ad8855 --- /dev/null +++ b/policy/testdata/cel_policy.yaml @@ -0,0 +1,42 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Environment: +# spec: TestAllTypes +name: cel_policy +description: A test policy for CEL +display_name: Cel Policy +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +- name: cel.expr.conformance.proto3.TestAllTypes.NestedEnum +rule: + id: test_rule + description: test rule description + variables: + - name: test_var + expression: > + TestAllTypes{single_int64: 10}.single_int64 + match: + - condition: > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + output: | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + explanation: | + "invalid spec, spec is greater than 10" + - condition: > + spec.standalone_enum == NestedEnum.BAR + output: | + "invalid spec, reference to BAR is not allowed" + - condition: spec.single_int64 == variables.test_var + output: '"invalid spec: exactly matches test_var"' + explanation: '"the spec cannot have single_int64 set to a known bad value"' \ No newline at end of file diff --git a/policy/testdata/cel_policy_parser.baseline b/policy/testdata/cel_policy_parser.baseline new file mode 100644 index 000000000..7a6678bfe --- /dev/null +++ b/policy/testdata/cel_policy_parser.baseline @@ -0,0 +1,89 @@ +POLICY SOURCE: cel_policy.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # Environment: + # spec: TestAllTypes + #0> name: #1> cel_policy + #2> description: #3> A test policy for CEL + #4> display_name: #5> Cel Policy + #6> imports: + - #7> name: #8> cel.expr.conformance.proto3.TestAllTypes + - #9> name: #10> cel.expr.conformance.proto3.TestAllTypes.NestedEnum + #11> rule: + #13> #12> id: #14> test_rule + #15> description: #16> test rule description + #17> variables: + - #18> name: #19> test_var + #20> expression: #21> > + TestAllTypes{single_int64: 10}.single_int64 + #22> match: + - #24> #23> condition: #25> > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + #26> output: #27> | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + #28> explanation: #29> | + "invalid spec, spec is greater than 10" + - #31> #30> condition: #32> > + spec.standalone_enum == NestedEnum.BAR + #33> output: #34> | + "invalid spec, reference to BAR is not allowed" + - #36> #35> condition: #37> spec.single_int64 == variables.test_var + #38> output: #39> '"invalid spec: exactly matches test_var"' + #40> explanation: #41> '"the spec cannot have single_int64 set to a known bad value"' + =========================================================== + name: #1> "cel_policy" + description: #3> "A test policy for CEL" + display_name: #5> "Cel Policy" + imports: + #7> name: #8> "cel.expr.conformance.proto3.TestAllTypes" + #9> name: #10> "cel.expr.conformance.proto3.TestAllTypes.NestedEnum" + #12> rule: { + rule_id: #14> "test_rule" + description: #16> "test rule description" + variable: { + name: #19> "test_var" + expression: #21> "TestAllTypes{single_int64: 10}.single_int64 + " + } + #23> match: { + condition: #25> "spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + " + result: { + output: #27> ""invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + " + explanation: #29> ""invalid spec, spec is greater than 10" + " + } + } + #30> match: { + condition: #32> "spec.standalone_enum == NestedEnum.BAR + " + result: { + output: #34> ""invalid spec, reference to BAR is not allowed" + " + } + } + #35> match: { + condition: #37> "spec.single_int64 == variables.test_var" + result: { + output: #39> ""invalid spec: exactly matches test_var"" + explanation: #41> ""the spec cannot have single_int64 set to a known bad value"" + } + } + } +} diff --git a/policy/testdata/custom_policy_format.yaml b/policy/testdata/custom_policy_format.yaml new file mode 100644 index 000000000..a67356906 --- /dev/null +++ b/policy/testdata/custom_policy_format.yaml @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: test +version: 42 +conditions: +- if: spec.single_string == "none" + then: "'zero'" + else: + if: spec.single_string == "integer" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: "'negative integer'" + else: "'not an integer'" diff --git a/policy/testdata/custom_policy_format_parser.baseline b/policy/testdata/custom_policy_format_parser.baseline new file mode 100644 index 000000000..d5b1a2235 --- /dev/null +++ b/policy/testdata/custom_policy_format_parser.baseline @@ -0,0 +1,75 @@ +POLICY SOURCE: custom_policy_format.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> cel_policy_custom_tags + #2> description: #3> A custom policy format + #4> imports: + - #5> name: #6> cel.expr.conformance.proto3.TestAllTypes + #7> purpose: #8> test + #9> version: #10> 42 + #11> conditions: + - #13> #12> if: #14> spec.single_string == "none" + #15> then: #16> "'zero'" + #17> else: + #19> #18> if: #20> spec.single_string == "integer" + #21> then: + #23> #22> if: #24> spec.single_int32 > 0 + #25> then: #26> "'positive integer'" + #27> else: #29> #28> "'negative integer'" + #30> else: #32> #31> "'not an integer'" + + =========================================================== + name: #1> "cel_policy_custom_tags" + description: #3> "A custom policy format" + metadata: { + purpose: #8> "test" + version: 42 + } + imports: + #5> name: #6> "cel.expr.conformance.proto3.TestAllTypes" + rule: { + #12> match: { + condition: #14> "spec.single_string == "none"" + result: { + output: #16> "'zero'" + } + } + #18> match: { + condition: #20> "spec.single_string == "integer"" + result: + rule: { + #22> match: { + condition: #24> "spec.single_int32 > 0" + result: { + output: #26> "'positive integer'" + } + } + #29> match: { + result: { + output: #28> "'negative integer'" + } + } + } + } + #32> match: { + result: { + output: #31> "'not an integer'" + } + } + } +} diff --git a/policy/testdata/custom_policy_format_with_errors.yaml b/policy/testdata/custom_policy_format_with_errors.yaml new file mode 100644 index 000000000..594747c60 --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors.yaml @@ -0,0 +1,33 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: + - testing +version: new +conditions: +- if: + spec.single_string: "none" + then: "'zero'" + else: "'not zero'" +- if: spec.single_string == "number" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: + - ignore +- else: "'negative integer'" + diff --git a/policy/testdata/custom_policy_format_with_errors_parser.baseline b/policy/testdata/custom_policy_format_with_errors_parser.baseline new file mode 100644 index 000000000..978d27bda --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors_parser.baseline @@ -0,0 +1,16 @@ +POLICY SOURCE: custom_policy_format_with_errors.yaml +-------------------------------------------------------------------- +-------------------------------------------------------------------- +PARSER ISSUES: +ERROR: custom_policy_format_with_errors.yaml:19:3: Policy purpose is not a string + | - testing + | ..^ +ERROR: custom_policy_format_with_errors.yaml:20:10: Policy version is not an integer: new + | version: new + | .........^ +ERROR: custom_policy_format_with_errors.yaml:23:5: Policy 'if' condition is not a string + | spec.single_string: "none" + | ....^ +ERROR: custom_policy_format_with_errors.yaml:31:7: Bad syntax in 'if/then' block + | - ignore + | ......^ diff --git a/policy/testdata/nested_rule.yaml b/policy/testdata/nested_rule.yaml new file mode 100644 index 000000000..2b07faa64 --- /dev/null +++ b/policy/testdata/nested_rule.yaml @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: nested_rule +rule: + variables: + - name: "permitted_regions" + expression: "['us', 'uk', 'es']" + match: + - rule: + id: "banned regions" + description: > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + variables: + - name: "banned_regions" + expression: "{'us': false, 'ru': false, 'ir': false}" + match: + - condition: | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + output: "{'banned': true}" + - condition: resource.origin in variables.permitted_regions + output: "{'banned': false}" + - output: "{'banned': true}" + explanation: "'resource is in the banned region ' + resource.origin" \ No newline at end of file diff --git a/policy/testdata/nested_rule_parser.baseline b/policy/testdata/nested_rule_parser.baseline new file mode 100644 index 000000000..128f81bda --- /dev/null +++ b/policy/testdata/nested_rule_parser.baseline @@ -0,0 +1,84 @@ +POLICY SOURCE: nested_rule.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2024 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> nested_rule + #2> rule: + #4> #3> variables: + - #5> name: #6> "permitted_regions" + #7> expression: #8> "['us', 'uk', 'es']" + #9> match: + - #11> #10> rule: + #13> #12> id: #14> "banned regions" + #15> description: #16> > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + #17> variables: + - #18> name: #19> "banned_regions" + #20> expression: #21> "{'us': false, 'ru': false, 'ir': false}" + #22> match: + - #24> #23> condition: #25> | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + #26> output: #27> "{'banned': true}" + - #29> #28> condition: #30> resource.origin in variables.permitted_regions + #31> output: #32> "{'banned': false}" + - #34> #33> output: #35> "{'banned': true}" + #36> explanation: #37> "'resource is in the banned region ' + resource.origin" + =========================================================== + name: #1> "nested_rule" + description: "nested_rule.yaml" + #3> rule: { + variable: { + name: #6> "permitted_regions" + expression: #8> "['us', 'uk', 'es']" + } + #10> match: { + result: + #12> rule: { + rule_id: #14> "banned regions" + description: #16> "determine whether the resource origin is in the banned list. If the region is also in the permitted list, the ban has no effect. + " + variable: { + name: #19> "banned_regions" + expression: #21> "{'us': false, 'ru': false, 'ir': false}" + } + #23> match: { + condition: #25> "resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + " + result: { + output: #27> "{'banned': true}" + } + } + } + } + #28> match: { + condition: #30> "resource.origin in variables.permitted_regions" + result: { + output: #32> "{'banned': false}" + } + } + #33> match: { + result: { + output: #35> "{'banned': true}" + explanation: #37> "'resource is in the banned region ' + resource.origin" + } + } + } +} diff --git a/policy/yaml_policy_parser.cc b/policy/yaml_policy_parser.cc new file mode 100644 index 000000000..c838cff33 --- /dev/null +++ b/policy/yaml_policy_parser.cc @@ -0,0 +1,411 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +CelPolicyElementId YamlPolicyParser::CollectMetadata( + CelPolicyParseContext& ctx, const YAML::Node& node) const { + CelPolicyElementId element_id = ctx.next_element_id(); + if (!node.Mark().is_null()) { + ctx.policy_source().NoteSourcePosition(element_id, node.Mark().pos); + } + return element_id; +} + +std::optional YamlPolicyParser::GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const { + if (!node.IsDefined()) { + // This should never happen since the YAML syntax has already been checked. + return std::nullopt; + } + + CelPolicyElementId id = CollectMetadata(ctx, node); + if (!node.IsScalar()) { + ctx.ReportError(id, error_message); + return std::nullopt; + } + + try { + return ValueString(id, node.as()); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return std::nullopt; + } +} + +absl::Status YamlPolicyParser::ParsePolicy(CelPolicyParseContext& ctx) const { + const Source* source = ctx.policy_source().content(); + if (source == nullptr) { + return absl::OkStatus(); + } + + ctx.policy().set_description(ValueString(-1, source->description())); + std::string text = source->content().ToString(); + YAML::Node node; + try { + node = YAML::Load(text); + } catch (YAML::Exception& e) { + if (!e.mark.is_null()) { + ctx.policy_source().NoteSourcePosition(0, e.mark.pos); + } + ctx.ReportError(0, "Invalid CEL policy YAML syntax"); + return absl::OkStatus(); + } + + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy is not a map"); + return absl::OkStatus(); + } + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, ParsePolicyTag(ctx, *key, value_node)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized top-level policy tag: ", key->value())); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParsePolicyTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node) const { + if (tag_name.value() == "imports") { + CEL_RETURN_IF_ERROR(ParseImports(ctx, node)); + return true; + } + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy 'name' is not a string"); + if (name.has_value()) { + ctx.policy().set_name(*name); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy 'description' is not a string"); + if (description.has_value()) { + ctx.policy().set_description(*description); + } + return true; + } + if (tag_name.value() == "display_name") { + std::optional display_name = + GetValueString(ctx, node, "Policy 'display_name' is not a string"); + if (display_name.has_value()) { + ctx.policy().set_display_name(*display_name); + } + return true; + } + if (tag_name.value() == "rule") { + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, ctx.policy().mutable_rule())); + return true; + } + return false; +} + +absl::Status YamlPolicyParser::ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy 'imports' is not a sequence"); + return absl::OkStatus(); + } + + for (const YAML::Node& import : node) { + CelPolicyElementId import_id = CollectMetadata(ctx, import); + if (!import.IsMap()) { + ctx.ReportError(import_id, "Import is not a map"); + continue; + } + const YAML::Node& name_node = import["name"]; + if (!name_node.IsDefined()) { + ctx.ReportError(import_id, "No 'name' tag in import"); + continue; + } + std::optional import_name = + GetValueString(ctx, name_node, "Import name is not a string"); + if (import_name.has_value()) { + ctx.policy().mutable_imports().push_back(Import(import_id, *import_name)); + } + } + return absl::OkStatus(); +} + +absl::Status YamlPolicyParser::ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy 'rule' is not a map"); + return absl::OkStatus(); + } + rule.set_id(CollectMetadata(ctx, node)); + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy rule tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseRuleTag(ctx, *key, value_node, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy rule tag: ", + key->value())); + } + } + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const { + if (tag_name.value() == "id") { + std::optional rule_id = + GetValueString(ctx, node, "Policy rule 'id' is not a string"); + if (rule_id.has_value()) { + rule.set_rule_id(*rule_id); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy rule 'description' is not a string"); + if (description.has_value()) { + rule.set_description(*description); + } + return true; + } + if (tag_name.value() == "variables") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variables' is not a sequence"); + return true; + } + for (const YAML::Node& variable_node : node) { + CEL_ASSIGN_OR_RETURN(Variable variable, + ParseVariable(ctx, variable_node, rule)); + rule.mutable_variables().push_back(std::move(variable)); + } + return true; + } + if (tag_name.value() == "match") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'match' is not a sequence"); + return true; + } + for (const YAML::Node& match_node : node) { + CEL_ASSIGN_OR_RETURN(Match match, ParseMatch(ctx, match_node, rule)); + rule.mutable_matches().push_back(std::move(match)); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseVariable( + CelPolicyParseContext& ctx, const YAML::Node& node, Rule& rule) const { + Variable variable; + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variable' is not a map"); + return variable; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy variable tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseVariableTag(ctx, *key, value_node, variable)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized policy variable tag: ", key->value())); + } + } + return variable; +} + +absl::StatusOr YamlPolicyParser::ParseVariableTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Variable& variable) const { + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy variable 'name' is not a string"); + if (name.has_value()) { + variable.set_name(*name); + } + return true; + } + if (tag_name.value() == "expression") { + std::optional expression = GetValueString( + ctx, node, "Policy variable 'expression' is not a string"); + if (expression.has_value()) { + variable.set_expression(*expression); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + Match match; + match.set_id(CollectMetadata(ctx, node)); + if (!node.IsMap()) { + ctx.ReportError(match.id(), "Policy rule 'match' is not a map"); + return match; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy match tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseMatchTag(ctx, *key, value_node, match, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy match tag: ", + key->value())); + } + } + + if (match.has_output_block()) { + if (match.output_block().output().value().empty() && + match.output_block().explanation().has_value()) { + ctx.ReportError(match.id(), "Match specifies explanation but no output"); + } + } + + return match; +} + +absl::StatusOr YamlPolicyParser::ParseMatchTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Match& match, Rule& rule) const { + if (tag_name.value() == "condition") { + std::optional condition = + GetValueString(ctx, node, "Policy match 'condition' is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "explanation") { + std::optional explanation = + GetValueString(ctx, node, "Policy match 'explanation' is not a string"); + if (explanation.has_value()) { + if (match.has_rule()) { + ctx.ReportError( + tag_name.id(), + "Cannot specify explanation when a nested rule is present"); + } else { + match.mutable_output_block().set_explanation(*explanation); + } + } + return true; + } + if (tag_name.value() == "output") { + std::optional output = + GetValueString(ctx, node, "Policy match 'output' is not a string"); + if (output.has_value()) { + if (match.has_rule()) { + ctx.ReportError(tag_name.id(), + "Cannot specify output when a nested rule is present"); + } else { + match.mutable_output_block().set_output(*output); + } + } + return true; + } + if (tag_name.value() == "rule") { + if (match.has_output_block()) { + ctx.ReportError(tag_name.id(), + "Cannot specify nested rule when output/explanation is " + "present"); + } + auto nested_rule = std::make_unique(); + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, *nested_rule)); + match.set_result(std::move(nested_rule)); + return true; + } + return false; +} + +const CelPolicyParser& GetDefaultYamlPolicyParser() { + static const auto* const parser = new YamlPolicyParser(); + return *parser; +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source) { + return ParseYamlCelPolicy(std::move(policy_source), + GetDefaultYamlPolicyParser()); +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser) { + CelPolicyParseContext ctx(std::move(policy_source)); + CEL_RETURN_IF_ERROR(parser.ParsePolicy(ctx)); + return ctx.GetResult(); +} + +} // namespace cel diff --git a/policy/yaml_policy_parser.h b/policy/yaml_policy_parser.h new file mode 100644 index 000000000..469209333 --- /dev/null +++ b/policy/yaml_policy_parser.h @@ -0,0 +1,135 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +// A parser for YAML-based CEL policies. +// +// To support additional or alternative YAML elements, subclass +// `YamlPolicyParser` and override specific parsing methods, `Parse*` +class YamlPolicyParser : public CelPolicyParser { + public: + std::optional GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const; + + absl::Status ParsePolicy(CelPolicyParseContext& ctx) const override; + + protected: + // Collects metadata (e.g. source position) for the given YAML node, stores it + // in the context, and returns an ID that can be used to refer to it. + virtual CelPolicyElementId CollectMetadata(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a top-level tag in the policy YAML. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const; + + // Parses the imports section of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a rule element of the policy YAML, which may be the top-level rule + // or a sub-rule of a match. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, Rule& rule) const; + + // Parses a tag in a policy YAML rule. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const; + + // Parses a variable element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariable(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML variable. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariableTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Variable& variable) const; + + // Parses a match element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML match. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Match& match, Rule& rule) const; +}; + +// Returns a default implementation of YamlPolicyParser. +const CelPolicyParser& GetDefaultYamlPolicyParser(); + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser); + +// YAML CelPolicy parser that uses the default format as implemented by +// `YamlPolicyParser`. +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ diff --git a/policy/yaml_policy_parser_test.cc b/policy/yaml_policy_parser_test.cc new file mode 100644 index 000000000..4e7dfc49c --- /dev/null +++ b/policy/yaml_policy_parser_test.cc @@ -0,0 +1,305 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/runfiles.h" +#include "internal/testing.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +namespace internal { +const CelPolicyParser& GetTestCustomYamlPolicyParser(); +} // namespace internal + +namespace { + +using ::absl_testing::IsOk; +using ::testing::HasSubstr; +using ::testing::IsNull; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/"; + +constexpr absl::string_view kBaselineSeparator = + "--------------------------------------------------------------------\n"; + +struct YamlPolicyParserTestCase { + std::string policy_source_file; + std::string baseline_file; + const cel::CelPolicyParser& (*parser_factory)(); +}; + +using YamlPolicyParserTest = testing::TestWithParam; + +TEST_P(YamlPolicyParserTest, Parse) { + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().policy_source_file)); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + std::string baseline; + std::string baseline_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().baseline_file)); + ASSERT_THAT(cel::internal::GetFileContents(baseline_file, &baseline), IsOk()); + baseline = absl::StripAsciiWhitespace(baseline); + + std::ostringstream out; + out << "POLICY SOURCE: " << GetParam().policy_source_file << "\n"; + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, GetParam().policy_source_file)); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + + ASSERT_OK_AND_ASSIGN( + CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source, GetParam().parser_factory())); + + out << kBaselineSeparator; + if (parse_result.IsValid()) { + out << "PARSED POLICY:\n"; + out << parse_result.GetPolicy()->DebugString(); + } else { + ASSERT_THAT(parse_result.GetPolicy(), IsNull()); + out << kBaselineSeparator; + out << "PARSER ISSUES:\n"; + for (const auto& issue : parse_result.GetIssues()) { + out << issue.ToDisplayString(*policy_source) << "\n"; + } + } + + std::string actual(absl::StripAsciiWhitespace(out.str())); + if (actual != baseline) { + // Log the actual result to make it easier to copy/paste into the baseline + // file when updating the tests. + ABSL_LOG(INFO) << "Actual:\n" << actual; + EXPECT_EQ(actual, baseline); + } +} + +INSTANTIATE_TEST_SUITE_P( + Formats, YamlPolicyParserTest, + testing::ValuesIn({ + YamlPolicyParserTestCase{ + .policy_source_file = "cel_policy.yaml", + .baseline_file = "cel_policy_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "nested_rule.yaml", + .baseline_file = "nested_rule_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format.yaml", + .baseline_file = "custom_policy_format_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format_with_errors.yaml", + .baseline_file = "custom_policy_format_with_errors_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + })); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +using YamlPolicyParseErrorTest = testing::TestWithParam; + +TEST_P(YamlPolicyParseErrorTest, YamlSyntaxError) { + const ParseTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(param.yaml, "test")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + EXPECT_THAT(parse_result.FormattedIssues(), HasSubstr(param.expected_error)); +} + +std::vector GetParseTestCases() { + return { + ParseTestCase{ + .yaml = R"yaml( ? [ John, Doe ]: age: 30 )yaml", + .expected_error = "1:22: Invalid CEL policy YAML syntax\n" + " | ? [ John, Doe ]: age: 30 \n" + " | .....................^", + }, + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Policy is not a map\n" + " | invalid yaml \n" + " | .^", + }, + ParseTestCase{ + .yaml = R"yaml( + ? [1, 2, 3] + : "Prime numbers sequence" + )yaml", + .expected_error = "2:23: Policy tag is not a string\n" + " | ? [1, 2, 3]\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: N/A + )yaml", + .expected_error = "2:28: Policy 'imports' is not a sequence\n" + " | imports: N/A\n" + " | ...........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - cel.expr.conformance + )yaml", + .expected_error = "3:21: Import is not a map\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - name: + - cel.expr.conformance + )yaml", + .expected_error = "4:21: Import name is not a string\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: do something + )yaml", + .expected_error = "2:25: Policy 'rule' is not a map\n" + " | rule: do something\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + id: + - 22 + )yaml", + .expected_error = "4:21: Policy rule 'id' is not a string\n" + " | - 22\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + no vars + )yaml", + .expected_error = "4:23: Policy rule 'variables' is not a sequence\n" + " | no vars\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: + foo: bar + )yaml", + .expected_error = "5:25: Policy variable 'name' is not a string\n" + " | foo: bar\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: test_var + expression: + - 22 + )yaml", + .expected_error = + "6:23: Policy variable 'expression' is not a string\n" + " | - 22\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: '\u0041\u00a9\u20ac\U0001f680' + - '\u0041\u00a9\u20ac\U0001f680': name + )yaml", + .expected_error = + "5:23: Unrecognized policy variable tag: " + "\\u0041\\u00a9\\u20ac\\U0001f680\n" + " | - '\\u0041\\u00a9\\u20ac\\U0001f680': " + "name\n" + " | ......................^", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(YamlPolicyParseErrorTest, YamlPolicyParseErrorTest, + ::testing::ValuesIn(GetParseTestCases())); + +TEST(YamlPolicyParserTest, OffsetIssueFormatting) { + // TODO(b/506179116): will need to copy the go implementation in extracting + // the source string from the YAML document instead of the interpreted string + // value to fix up error locations in folded and block literals. + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, "cel_policy.yaml")); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, "cel_policy.yaml")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + CelPolicyElementId name_id = policy->name().id(); + + CelPolicyIssue issue(name_id, 4, CelPolicyIssue::Severity::kError, + "Test error"); + + std::string formatted = issue.ToDisplayString(*policy_source); + + EXPECT_THAT(formatted, HasSubstr("ERROR: cel_policy.yaml:16:11: Test error")); + EXPECT_THAT(formatted, HasSubstr(" | name: cel_policy")); + EXPECT_THAT(formatted, HasSubstr(" | ..........^")); +} + +} // namespace +} // namespace cel diff --git a/protoutil/BUILD b/protoutil/BUILD deleted file mode 100644 index 60ae22e85..000000000 --- a/protoutil/BUILD +++ /dev/null @@ -1,90 +0,0 @@ -# Description -# Libraries for working with protobuffer messages and well-known types. -# -# Uses the namespace google::api::expr::protoutil. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -cc_library( - name = "type_registry", - srcs = ["type_registry.cc"], - hdrs = [ - "type_registry.h", - ], - deps = [ - "//common:macros", - "//common:parent_ref", - "//common:type", - "//common:value", - "//internal:cast", - "//internal:map_impl", - "//internal:proto_util", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "type_registry_test", - srcs = ["type_registry_test.cc"], - deps = [ - ":type_registry", - "//common:value", - "//internal:ref_countable", - "//testutil:util", - "@com_google_googleapis//google/type:money_cc_proto", - "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", - ], -) - -cc_library( - name = "converters", - srcs = [ - "converters.cc", - ], - hdrs = [ - "converters.h", - ], - deps = [ - ":type_registry", - "//common:macros", - "//common:parent_ref", - "//common:value", - "//internal:holder", - "//internal:list_impl", - "//internal:map_impl", - "//internal:proto_util", - "//internal:ref_countable", - "//internal:status_util", - "//internal:types", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "converters_test", - srcs = ["converters_test.cc"], - data = [ - "@com_google_cel_spec//testdata", - ], - deps = [ - ":converters", - ":type_registry", - "//common:value", - "//testutil:test_data_io", - "//testutil:test_data_util", - "//testutil:util", - "@com_google_absl//absl/memory", - "@com_google_cel_spec//testdata:test_data_cc_proto", - "@com_google_cel_spec//testdata:test_value_cc_proto", - "@com_google_googleapis//google/type:money_cc_proto", - "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", - ], -) diff --git a/protoutil/converters.cc b/protoutil/converters.cc deleted file mode 100644 index c66f55a6c..000000000 --- a/protoutil/converters.cc +++ /dev/null @@ -1,264 +0,0 @@ -#include "protoutil/converters.h" - -#include -#include - -#include "google/protobuf/descriptor.h" -#include "google/protobuf/util/message_differencer.h" -#include "absl/container/node_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "common/macros.h" -#include "internal/holder.h" -#include "internal/map_impl.h" -#include "internal/proto_util.h" -#include "internal/ref_countable.h" -#include "internal/status_util.h" - -namespace google { -namespace api { -namespace expr { -namespace protoutil { - -using ::google::api::expr::internal::OkStatus; - -namespace { - -template -bool RegisterFromCall(TypeRegistry* registry) { - return registry->RegisterConstructor( - [](const T& value) { return ValueFrom(value); }); -} - -template -bool RegisterFromPtrCall(TypeRegistry* registry) { - return registry->RegisterConstructor( - [](std::unique_ptr value) { return ValueFrom(std::move(value)); }); -} - -template -bool RegisterForCall(TypeRegistry* registry) { - return registry->RegisterConstructor( - [](const T* value, const common::RefProvider& parent) { - return ValueFor(value, parent); - }); -} - -template -class ListValue final : public common::List { - public: - template - explicit ListValue(Args&&... args) : holder_(std::forward(args)...) {} - - std::size_t size() const override { return holder_.value().values_size(); } - - common::Value Get(std::size_t index) const override { - if (index >= static_cast(holder_.value().values_size())) { - return common::Value::FromError( - internal::OutOfRangeError(index, holder_.value().values_size())); - } - return ValueFor(&holder_.value().values(index), SelfRefProvider()); - } - - google::rpc::Status ForEach( - const std::function& call) - const override { - for (const auto& elem : holder_.value().values()) { - RETURN_IF_STATUS_ERROR(call(ValueFor(&elem, SelfRefProvider()))); - } - return OkStatus(); - } - - bool owns_value() const override { return HolderPolicy::kOwnsValue; } - - private: - internal::Holder holder_; -}; - -template -class Struct final : public common::Map { - public: - template - explicit Struct(Args&&... args) : holder_(std::forward(args)...) {} - - inline std::size_t size() const override { - return holder_.value().fields_size(); - } - - google::rpc::Status ForEach( - const std::function& call) const override { - for (const auto& field : holder_.value().fields()) { - RETURN_IF_STATUS_ERROR( - call(common::Value::ForString(field.first, SelfRefProvider()), - ValueFor(field.second, SelfRefProvider()))); - } - return internal::OkStatus(); - } - - inline bool owns_value() const override { return true; } - - protected: - common::Value GetImpl(const common::Value& key) const override; - - private: - internal::Holder holder_; -}; - -common::Value BuildMapFor(const google::protobuf::Struct* struct_value, - common::ParentRef parent) { - absl::node_hash_map result; - for (const auto& field : struct_value->fields()) { - result.emplace(common::Value::ForString(field.first, parent), - ValueFor(&field.second, parent)); - } - // The keys and values grabbed a ref on parent if needed, so we don't need one - // separately. - return common::Value::MakeMap(std::move(result)); -} - -common::Value BuildMapFrom(google::protobuf::Struct&& struct_value) { - absl::node_hash_map result; - for (auto& fields : *struct_value.mutable_fields()) { - result.emplace(common::Value::FromString(fields.first), - ValueFrom(std::move(fields.second))); - } - return common::Value::MakeMap(std::move(result)); -} - -common::Value BuildMapFrom(const google::protobuf::Struct& struct_value) { - absl::node_hash_map result; - for (const auto& fields : struct_value.fields()) { - result.emplace(common::Value::FromString(fields.first), - ValueFrom(fields.second)); - } - return common::Value::MakeMap(std::move(result)); -} - -} // namespace - -// Converters for google::protobuf::Value. -common::Value ValueFrom(const google::protobuf::Value& value) { - switch (value.kind_case()) { - case google::protobuf::Value::kNullValue: - return common::Value::NullValue(); - case google::protobuf::Value::kBoolValue: - return common::Value::FromBool(value.bool_value()); - case google::protobuf::Value::kNumberValue: - return common::Value::FromDouble(value.number_value()); - case google::protobuf::Value::kStringValue: - return common::Value::FromString(value.string_value()); - case google::protobuf::Value::kStructValue: - return ValueFrom(value.struct_value()); - default: - return common::Value::FromError( - internal::UnimplementedError(absl::StrCat(value.kind_case()))); - } -} - -common::Value ValueFrom(google::protobuf::Value&& value) { - switch (value.kind_case()) { - case google::protobuf::Value::kStructValue: - return ValueFrom(absl::WrapUnique(value.release_struct_value())); - case google::protobuf::Value::kListValue: - return ValueFrom(absl::WrapUnique(value.release_list_value())); - default: - return ValueFrom(value); - } -} - -common::Value ValueFrom(std::unique_ptr value) { - return ValueFrom(std::move(*value)); -} - -common::Value ValueFor(const google::protobuf::Value* value, - common::ParentRef parent) { - switch (value->kind_case()) { - case google::protobuf::Value::kStructValue: - return ValueFor(&value->struct_value(), parent); - case google::protobuf::Value::kListValue: - return ValueFor(&value->list_value(), parent); - default: - return ValueFrom(*value); - } -} - -// Converters for google::protobuf::Struct. -common::Value ValueFrom(const google::protobuf::Struct& value) { - return BuildMapFrom(value); -} - -common::Value ValueFrom(std::unique_ptr value) { - return BuildMapFrom(std::move(*value)); -} -common::Value ValueFor(const google::protobuf::Struct* value, - common::ParentRef parent) { - return BuildMapFor(value, parent); -} - -// Converters for google::protobuf::ListValue -common::Value ValueFrom(const google::protobuf::ListValue& value) { - return common::Value::MakeList>(value); -} -common::Value ValueFrom(std::unique_ptr value) { - return common::Value::MakeList>( - std::move(value)); -} -common::Value ValueFor(const google::protobuf::ListValue* value, - common::ParentRef parent) { - if (!parent) { - return ValueFrom(*value); - } - if (parent->RequiresReference()) { - return common::Value::MakeList>>( - parent->GetRef(), value); - } - return common::Value::MakeList>(value); -} - -// Converters for time/duration. -common::Value ValueFrom(const google::protobuf::Timestamp& value) { - return common::Value::FromTime(internal::DecodeTime(value)); -} - -common::Value ValueFrom(const google::protobuf::Duration& value) { - return common::Value::FromDuration(internal::DecodeDuration(value)); -} - -bool RegisterConvertersWith(TypeRegistry* registry) { - bool success = true; - success &= registry->RegisterConstructor( - common::EnumType(google::protobuf::NullValue_descriptor()), - [](common::EnumType, int32_t) { return common::Value::NullValue(); }); - success &= RegisterFromCall(registry); - success &= RegisterFromPtrCall(registry); - success &= RegisterForCall(registry); - - success &= RegisterFromCall(registry); - success &= RegisterFromPtrCall(registry); - success &= RegisterForCall(registry); - - success &= RegisterFromCall(registry); - success &= RegisterFromPtrCall(registry); - success &= RegisterForCall(registry); - - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - success &= RegisterFromCall(registry); - return success; -} - -} // namespace protoutil -} // namespace expr -} // namespace api -} // namespace google diff --git a/protoutil/converters.h b/protoutil/converters.h deleted file mode 100644 index f0ac61ca1..000000000 --- a/protoutil/converters.h +++ /dev/null @@ -1,86 +0,0 @@ -// Converter functions from common c++ representations to Value. - -#ifndef THIRD_PARTY_CEL_CPP_PROTOUTIL_CONVERTERS_H_ -#define THIRD_PARTY_CEL_CPP_PROTOUTIL_CONVERTERS_H_ - -#include - -#include "google/protobuf/any.pb.h" -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/wrappers.pb.h" -#include "google/protobuf/message.h" -#include "common/parent_ref.h" -#include "common/value.h" -#include "protoutil/type_registry.h" - -namespace google { -namespace api { -namespace expr { -namespace protoutil { - -/** Registers all converter functions with the given type registry. */ -bool RegisterConvertersWith(TypeRegistry* registry); - -// Converters for google::protobuf::Value. -common::Value ValueFrom(const google::protobuf::Value& value); -common::Value ValueFrom(google::protobuf::Value&& value); -common::Value ValueFrom(std::unique_ptr value); -common::Value ValueFor(const google::protobuf::Value* value, - common::ParentRef parent = common::NoParent()); - -// Converters for google::protobuf::Struct. -common::Value ValueFrom(const google::protobuf::Struct& value); -common::Value ValueFrom(std::unique_ptr value); -common::Value ValueFor(const google::protobuf::Struct* value, - common::ParentRef parent = common::NoParent()); - -// Converters for google::protobuf::ListValue -common::Value ValueFrom(const google::protobuf::ListValue& value); -common::Value ValueFrom(std::unique_ptr value); -common::Value ValueFor(const google::protobuf::ListValue* value, - common::ParentRef parent = common::NoParent()); - -// Converters for time/duration. -common::Value ValueFrom(const google::protobuf::Duration& value); -common::Value ValueFrom(const google::protobuf::Timestamp& value); - -// Converters for wrapped values. -inline common::Value ValueFrom(google::protobuf::NullValue value) { - return common::Value::NullValue(); -} -inline common::Value ValueFrom(const google::protobuf::BoolValue& value) { - return common::Value::FromBool(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::Int32Value& value) { - return common::Value::FromInt(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::Int64Value& value) { - return common::Value::FromInt(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::UInt32Value& value) { - return common::Value::FromUInt(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::UInt64Value& value) { - return common::Value::FromUInt(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::FloatValue& value) { - return common::Value::FromDouble(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::DoubleValue& value) { - return common::Value::FromDouble(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::StringValue& value) { - return common::Value::FromString(value.value()); -} -inline common::Value ValueFrom(const google::protobuf::BytesValue& value) { - return common::Value::FromBytes(std::string(value.value())); -} - -} // namespace protoutil -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_PROTO_CONVERTERS_H_ diff --git a/protoutil/converters_test.cc b/protoutil/converters_test.cc deleted file mode 100644 index 952c05b67..000000000 --- a/protoutil/converters_test.cc +++ /dev/null @@ -1,155 +0,0 @@ -#include "protoutil/converters.h" - -#include - -#include "google/protobuf/any.pb.h" -#include "google/type/money.pb.h" -#include "gtest/gtest.h" -#include "absl/memory/memory.h" -#include "common/value.h" -#include "protoutil/type_registry.h" -#include "testutil/test_data_io.h" -#include "testutil/test_data_util.h" -#include "testdata/test_data.pb.h" -#include "testdata/test_value.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace protoutil { - -using testdata::TestValue; - -namespace { - -const TypeRegistry* kReg = []() { - TypeRegistry* registry = new TypeRegistry; - RegisterConvertersWith(registry); - return registry; -}(); - -TEST(ConverterTest, ForProto) { - google::type::Money money; - money.set_nanos(1); - - common::Value value = kReg->ValueFor(&money); - EXPECT_EQ(value.kind(), common::Value::Kind::kObject); - EXPECT_FALSE(value.is_inline()); - EXPECT_TRUE(value.is_value()); - EXPECT_FALSE(value.owns_value()); - EXPECT_EQ(value.object_value().GetMember("nanos").int_value(), 1); - EXPECT_EQ(value.object_value().object_type().full_name(), - "google.type.Money"); -} - -TEST(ConverterTest, FromProto) { - auto money = absl::make_unique(); - money->set_nanos(1); - - common::Value value = kReg->ValueFrom(std::move(money)); - EXPECT_EQ(value.kind(), common::Value::Kind::kObject); - EXPECT_FALSE(value.is_inline()); - EXPECT_TRUE(value.is_value()); - EXPECT_TRUE(value.owns_value()); - EXPECT_EQ(value.object_value().GetMember("nanos").int_value(), 1); - EXPECT_EQ(value.object_value().object_type().full_name(), - "google.type.Money"); -} - -TEST(ConverterTest, FromProto_Any) { - google::type::Money money; - money.set_nanos(1); - - google::protobuf::Any any; - any.PackFrom(money); - - common::Value value = kReg->ValueFrom(any); - EXPECT_EQ(value.kind(), common::Value::Kind::kObject); - EXPECT_FALSE(value.is_inline()); - EXPECT_TRUE(value.is_value()); - EXPECT_TRUE(value.owns_value()); - EXPECT_EQ(value.object_value().GetMember("nanos").int_value(), 1); - EXPECT_EQ(value.object_value().object_type().full_name(), - "google.type.Money"); -} - -TEST(ConverterTest, FromProto_AnyPtr) { - google::type::Money money; - money.set_nanos(1); - - google::protobuf::Any any; - any.PackFrom(money); - - common::Value value = kReg->ValueFor(&any); - EXPECT_EQ(value.kind(), common::Value::Kind::kObject); - EXPECT_FALSE(value.is_inline()); - EXPECT_TRUE(value.is_value()); - EXPECT_TRUE(value.owns_value()); - EXPECT_EQ(value.object_value().GetMember("nanos").int_value(), 1); - EXPECT_EQ(value.object_value().object_type().full_name(), - "google.type.Money"); -} - -class ValueTest : public ::testing::TestWithParam { - public: - ValueTest() { v1beta1::InitValueDifferencer(&v1beta1_differ_); } - - protected: - private: - google::protobuf::util::MessageDifferencer v1beta1_differ_; -}; - -TEST_P(ValueTest, SelfEqual) { - for (const auto& lhs : GetParam().proto()) { - SCOPED_TRACE(lhs.ShortDebugString()); - auto lhs_value = - kReg->ValueFor(&lhs).object_value().GetMember(lhs.value_field_name()); - SCOPED_TRACE(lhs_value); - for (const auto& rhs : GetParam().proto()) { - SCOPED_TRACE(rhs.ShortDebugString()); - auto rhs_value = - kReg->ValueFor(&rhs).object_value().GetMember(rhs.value_field_name()); - SCOPED_TRACE(rhs_value); - EXPECT_EQ(lhs_value, rhs_value); - } - } -} - -INSTANTIATE_TEST_SUITE_P( - UniqueValues, ValueTest, - ::testing::ValuesIn( - testutil::ReadTestData("unique_values").test_values().values()), - testutil::TestDataParamName()); - -class UniqueValueTest - : public ::testing::TestWithParam> { - public: -}; - -TEST_P(UniqueValueTest, NotEqual) { - for (const auto& lhs : GetParam().first.proto()) { - SCOPED_TRACE(lhs.ShortDebugString()); - auto lhs_value = - kReg->ValueFor(&lhs).object_value().GetMember(lhs.value_field_name()); - SCOPED_TRACE(lhs_value); - for (const auto& rhs : GetParam().second.proto()) { - SCOPED_TRACE(rhs.ShortDebugString()); - auto rhs_value = - kReg->ValueFor(&rhs).object_value().GetMember(rhs.value_field_name()); - SCOPED_TRACE(rhs_value); - EXPECT_NE(lhs_value, rhs_value); - } - } -} - -INSTANTIATE_TEST_SUITE_P( - All, UniqueValueTest, - ::testing::ValuesIn(testutil::AllPairs( - testutil::ReadTestData("unique_values").test_values())), - testutil::TestDataParamName()); - -} // namespace -} // namespace protoutil -} // namespace expr -} // namespace api -} // namespace google diff --git a/protoutil/type_registry.cc b/protoutil/type_registry.cc deleted file mode 100644 index ba5c7854a..000000000 --- a/protoutil/type_registry.cc +++ /dev/null @@ -1,643 +0,0 @@ -#include "protoutil/type_registry.h" - -#include "google/protobuf/reflection.h" -#include "google/protobuf/util/message_differencer.h" -#include "absl/container/node_hash_map.h" -#include "common/macros.h" -#include "internal/map_impl.h" -#include "internal/proto_util.h" - -namespace google { -namespace api { -namespace expr { -namespace protoutil { -namespace { - -common::Value MissingCall(const common::ObjectType& type) { - return common::Value::FromError(internal::InternalError( - absl::StrCat("Missing callback for ", type.value()->full_name()))); -} - -absl::string_view FindObjectType(const google::protobuf::Any* value) { - absl::string_view object_type = value->type_url(); - // Unfortunately the proto2 function that does this is internal. - return object_type.substr(object_type.find('/') + 1); -} - -google::protobuf::util::MessageDifferencer& GetDiffer() { - static google::protobuf::util::MessageDifferencer* differ = []() { - auto* comp = new google::protobuf::util::DefaultFieldComparator(); - comp->set_float_comparison(google::protobuf::util::DefaultFieldComparator::EXACT); - comp->set_treat_nan_as_equal(true); - - auto* differ = new google::protobuf::util::MessageDifferencer; - differ->set_field_comparator(comp); - return differ; - }(); - return *differ; -} - -template -class BaseProtoList : public common::List { - public: - explicit BaseProtoList(const common::ParentRef& parent, - const google::protobuf::RepeatedFieldRef& value) - : parent_ref_(parent->GetRef()), value_(value) {} - - inline std::size_t size() const final { return value_.size(); } - inline bool owns_value() const final { return true; } - - protected: - common::ValueRef parent_ref_; - - google::protobuf::RepeatedFieldRef value_; -}; - -template -class ProtoList final : public BaseProtoList { - public: - explicit ProtoList(const common::ParentRef& parent, - const google::protobuf::RepeatedFieldRef& value) - : BaseProtoList(parent, value) {} - - common::Value Get(std::size_t index) const override { - return common::List::GetValue(this->value_.Get(index)); - } - - google::rpc::Status ForEach( - const std::function& call) - const override { - for (const auto& elem : this->value_) { - RETURN_IF_STATUS_ERROR(call(common::List::GetValue(elem))); - } - return internal::OkStatus(); - } -}; - -class BaseProtoRefList : public common::List { - public: - BaseProtoRefList(const common::ParentRef& parent, const google::protobuf::Message* msg, - const google::protobuf::FieldDescriptor* field) - : parent_ref_(parent->GetRef()), msg_(msg), field_(field) {} - - inline std::size_t size() const final { - return msg_->GetReflection()->FieldSize(*msg_, field_); - } - inline bool owns_value() const final { return true; } - - protected: - common::ValueRef parent_ref_; - const google::protobuf::Message* msg_; - const google::protobuf::FieldDescriptor* field_; -}; - -template -class ProtoStrList final : public BaseProtoRefList { - public: - ProtoStrList(const common::ParentRef& parent, const google::protobuf::Message* msg, - const google::protobuf::FieldDescriptor* field) - : BaseProtoRefList(parent, msg, field) {} - - common::Value Get(std::size_t index) const override { - std::string scratch; - const std::string& value = - msg_->GetReflection()->GetRepeatedStringReference(*msg_, field_, index, - &scratch); - if (&value == &scratch) { - return common::Value::From(value); - } - return common::Value::For(value, SelfRefProvider()); - } -}; - -class ProtoMsgList final : public BaseProtoRefList { - public: - ProtoMsgList(const TypeRegistry* reg, const common::ParentRef& parent, - const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* field) - : BaseProtoRefList(parent, msg, field), reg_(reg) {} - - common::Value Get(std::size_t index) const override { - return reg_->ValueFor( - &msg_->GetReflection()->GetRepeatedMessage(*msg_, field_, index), - SelfRefProvider()); - } - - private: - const TypeRegistry* reg_; -}; - -class ProtoEnumList final : public BaseProtoRefList { - public: - ProtoEnumList(const TypeRegistry* reg, const common::ParentRef& parent, - const google::protobuf::Message* msg, - const google::protobuf::FieldDescriptor* field) - : BaseProtoRefList(parent, msg, field), reg_(reg) {} - - common::Value Get(std::size_t index) const override { - return reg_->ValueFrom( - common::EnumType(field_->enum_type()), - msg_->GetReflection()->GetRepeatedEnumValue(*msg_, field_, index)); - } - - private: - const TypeRegistry* reg_; -}; - -/** - * A Object for a google.protobuf.Any that could not be decoded. - */ -template -class UnrecognizedMessageObject final : public common::Object { - public: - template - UnrecognizedMessageObject(T&& value) : holder_(std::forward(value)) {} - - common::Value GetMember(absl::string_view name) const override { - return common::Value::FromError( - internal::UnknownType(object_type().full_name())); - }; - - common::Type object_type() const override { - return common::Type(FindObjectType(&holder_.value())); - } - - void To(google::protobuf::Any* value) const override { - *value = holder_.value(); - } - - bool owns_value() const override { return HolderPolicy::kOwnsValue; } - - google::rpc::Status ForEach( - const std::function& call) const override { - return internal::UnknownType(object_type().full_name()); - } - - protected: - std::size_t ComputeHash() const override { - return internal::Hash(object_type().full_name(), holder_->value()); - } - - bool EqualsImpl(const common::Object& same_type) const final { - const UnrecognizedMessageObject* other = - cast_if(&same_type); - if (other == nullptr) { - return false; - } - return GetDiffer().Equals(holder_.value(), other->holder_.value()); - } - - private: - internal::Holder holder_; -}; - -/** - * A Object class for a proto message. - */ -template -class MessageObject final : public common::Object { - public: - template - explicit MessageObject(const TypeRegistry* registry, Args&&... args) - : registry_(registry), holder_(std::forward(args)...) {} - - template - common::Value MakeList(const google::protobuf::RepeatedFieldRef& value) const { - return common::Value::MakeList>(SelfRefProvider(), - value); - } - - common::Value BuildMapFor(const google::protobuf::FieldDescriptor* field, - const google::protobuf::Message* msg) const { - // Proto maps are represented by a repeated message with two fields - // (key and value) - absl::node_hash_map result; - const google::protobuf::Reflection* refl = msg->GetReflection(); - for (int index = 0; index < refl->FieldSize(*msg, field); ++index) { - const auto& entry = refl->GetRepeatedMessage(*msg, field, index); - const google::protobuf::FieldDescriptor* key_field = - entry.GetDescriptor()->FindFieldByNumber(1); - const google::protobuf::FieldDescriptor* value_field = - entry.GetDescriptor()->FindFieldByNumber(2); - result.emplace(GetFieldValue(key_field, &entry), - GetFieldValue(value_field, &entry)); - } - // The keys and values grabbed a ref on parent if needed, so we don't need - // one separately. - return common::Value::MakeMap(std::move(result)); - } - - common::Value GetFieldValue(const google::protobuf::FieldDescriptor* field, - const google::protobuf::Message* msg) const { - const google::protobuf::Reflection* refl = msg->GetReflection(); - if (field->is_map()) { - return BuildMapFor(field, msg); - } - if (field->is_repeated()) { - switch (field->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: - return MakeList( - refl->GetRepeatedFieldRef(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: - return MakeList( - refl->GetRepeatedFieldRef(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: - return MakeList( - refl->GetRepeatedFieldRef(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: - return common::Value::MakeList( - registry_, SelfRefProvider(), msg, field); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: - return MakeList( - refl->GetRepeatedFieldRef(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: - return MakeList( - refl->GetRepeatedFieldRef(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: - return MakeList( - refl->GetRepeatedFieldRef(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: - return MakeList( - refl->GetRepeatedFieldRef(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: - if (field->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { - return common::Value::MakeList< - ProtoStrList>(SelfRefProvider(), - msg, field); - } else { - return common::Value::MakeList< - ProtoStrList>(SelfRefProvider(), - msg, field); - } - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: - return common::Value::MakeList( - registry_, SelfRefProvider(), msg, field); - default: - return common::Value::FromError(internal::UnimplementedError("")); - } - } - - switch (field->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: - return common::Value::FromBool(refl->GetBool(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: - return common::Value::FromInt(refl->GetInt32(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: - return common::Value::FromInt(refl->GetInt64(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: - return registry_->ValueFrom(common::EnumType(field->enum_type()), - refl->GetEnumValue(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: - return common::Value::FromUInt(refl->GetUInt32(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: - return common::Value::FromUInt(refl->GetUInt64(*msg, field)); - - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: - return common::Value::FromDouble(refl->GetFloat(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: - return common::Value::FromDouble(refl->GetDouble(*msg, field)); - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { - std::string scratch; - const auto& value = refl->GetStringReference(*msg, field, &scratch); - if (field->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { - if (&scratch == &value) { - return common::Value::FromString(value); - } else { - return common::Value::ForString(value, SelfRefProvider()); - } - } else { - if (&scratch == &value) { - return common::Value::FromBytes(value); - } else { - return common::Value::ForBytes(value, SelfRefProvider()); - } - } - } - - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - // Create value that holds on the parent instead of copying the - // message value. - const google::protobuf::Message* sub_msg = &refl->GetMessage(*msg, field); - if (refl->HasField(*msg, field)) { - return registry_->ValueFor(sub_msg, SelfRefProvider()); - } - return registry_->GetDefault(sub_msg); - } - default: - return common::Value::FromError(internal::UnimplementedError("")); - } - } - - common::Value GetMember(absl::string_view name) const override { - std::string str_name(name); - auto* field = holder_->GetDescriptor()->FindFieldByName(str_name); - if (field == nullptr) { - return common::Value::FromError( - internal::NoSuchMember(name, holder_->GetDescriptor()->full_name())); - } - return GetFieldValue(field, &holder_.value()); - } - - common::Type object_type() const override { - return common::Type(common::ObjectType(holder_.value().GetDescriptor())); - } - - void To(google::protobuf::Any* value) const override { - value->PackFrom(holder_.value()); - } - - bool owns_value() const override { return HolderPolicy::kOwnsValue; } - - common::Value ContainsMember(absl::string_view name) const override { - std::string str_name(name); - return common::Value::FromBool( - holder_->GetDescriptor()->FindFieldByName(str_name) != nullptr); - } - - google::rpc::Status ForEach( - const std::function& call) const override { - const google::protobuf::Descriptor* desc = holder_->GetDescriptor(); - for (int i = 0; i < desc->field_count(); ++i) { - const auto* field_desc = desc->field(i); - RETURN_IF_STATUS_ERROR(call(field_desc->name(), - GetFieldValue(field_desc, &holder_.value()))); - } - return internal::OkStatus(); - } - - protected: - bool EqualsImpl(const Object& rhs) const override { - const MessageObject* other = cast_if(&rhs); - if (other == nullptr) { - return false; - } - return GetDiffer().Equals(holder_.value(), other->holder_.value()); - } - - private: - const TypeRegistry* registry_; - internal::Holder holder_; -}; - -template -bool RegisterDefaultImpl(const common::Value& default_value, T* entry) { - assert(default_value.is_value()); - if (entry->default_value.is_value()) { - // Already registered. - return false; - } - entry->default_value = default_value; - return true; -} - -template -bool RegisterImpl(T&& new_ctor, T* existing_ctor) { - assert(new_ctor != nullptr); - if (*existing_ctor != nullptr) { - // Already registered. - return false; - } - *existing_ctor = std::forward(new_ctor); - return true; -} - -} // namespace - -bool TypeRegistry::RegisterDefault(const common::ObjectType& object_type, - const common::Value& default_value) { - return RegisterDefaultImpl(default_value, &object_registry_[object_type]); -} - -bool TypeRegistry::RegisterConstructor( - const common::ObjectType& object_type, - std::function from_ctor) { - return RegisterImpl(std::move(from_ctor), - &object_registry_[object_type].from_ctor); -} - -bool TypeRegistry::RegisterConstructor( - const common::ObjectType& object_type, - std::function from_ctor) { - return RegisterImpl(std::move(from_ctor), - &object_registry_[object_type].from_move_ctor); -} - -bool TypeRegistry::RegisterConstructor( - const common::ObjectType& object_type, - std::function)> from_ctor) { - return RegisterImpl(std::move(from_ctor), - &object_registry_[object_type].from_ptr_ctor); -} - -bool TypeRegistry::RegisterConstructor( - const common::ObjectType& object_type, - std::function for_ctor) { - return RegisterImpl(std::move(for_ctor), - &object_registry_[object_type].for_ctor); -} - -bool TypeRegistry::RegisterConstructor( - const common::ObjectType& object_type, - std::function - for_ctor) { - return RegisterImpl(std::move(for_ctor), - &object_registry_[object_type].for_pnt_ctor); -} - -bool TypeRegistry::RegisterConstructor( - const common::EnumType& enum_type, - std::function from_ctor) { - return RegisterImpl(std::move(from_ctor), - &enum_registry_[enum_type].from_ctor); -} - -common::Value TypeRegistry::GetDefault( - const google::protobuf::Message* default_msg) const { - common::ObjectType type(default_msg->GetDescriptor()); - if (type == common::ObjectType::For()) { - return common::Value::NullValue(); - } - auto itr = object_registry_.find(type); - if (itr != object_registry_.end()) { - return itr->second.default_value.is_value() ? itr->second.default_value - : common::Value::NullValue(); - } - return ValueForUnregistered(default_msg); -} - -common::Value TypeRegistry::ValueFrom(const google::protobuf::Message& value) const { - common::ObjectType type(value.GetDescriptor()); - if (type == common::ObjectType::For()) { - return ValueFromAny(static_cast(value)); - } - - auto entry = GetCalls(type); - if (entry.from_ctor) { - return entry.from_ctor(value); - } - - // Try to fall back on other from calls. - auto ptr = internal::Clone(value); - if (entry.from_ptr_ctor) { - return entry.from_ptr_ctor(std::move(ptr)); - } else if (entry.from_move_ctor) { - return entry.from_move_ctor(std::move(*ptr)); - } - - if (entry.for_ctor || entry.for_pnt_ctor) { - return MissingCall(type); - } - - return ValueFromUnregistered(std::move(ptr)); -} - -common::Value TypeRegistry::ValueFrom(google::protobuf::Message&& value) const { - common::ObjectType type(value.GetDescriptor()); - if (type == common::ObjectType::For()) { - return ValueFromAny(static_cast(value)); - } - - auto entry = GetCalls(type); - if (entry.from_move_ctor) { - return entry.from_move_ctor(std::move(value)); - } - - // Try to fallback on other from calls. - if (entry.from_ctor) { - return entry.from_ctor(value); - } else if (entry.from_ptr_ctor) { - return entry.from_ptr_ctor(internal::Clone(std::move(value))); - } - - if (entry.for_ctor || entry.for_pnt_ctor) { - return MissingCall(type); - } - - return ValueFromUnregistered(internal::Clone(std::move(value))); -} - -common::Value TypeRegistry::ValueFrom( - std::unique_ptr value) const { - common::ObjectType type(value->GetDescriptor()); - if (type == common::ObjectType::For()) { - return ValueFromAny(static_cast(*value)); - } - - auto entry = GetCalls(type); - if (entry.from_ptr_ctor != nullptr) { - return entry.from_ptr_ctor(std::move(value)); - } - - // Try to fallback on other from_* calls. - if (entry.from_move_ctor) { - return entry.from_move_ctor(std::move(*value)); - } else if (entry.from_ctor) { - return entry.from_ctor(*value); - } - - if (entry.for_ctor || entry.for_pnt_ctor) { - return MissingCall(type); - } - - return ValueFromUnregistered(std::move(value)); -} - -common::Value TypeRegistry::ValueFor(const google::protobuf::Message* value, - common::ParentRef parent) const { - if (parent == absl::nullopt) { - return ValueFrom(*value); - } - common::ObjectType type(value->GetDescriptor()); - if (type == common::ObjectType::For()) { - return ValueFromAny(static_cast(*value)); - } - - ObjectRegistryEntry entry = GetCalls(type); - // Try for_pnt_ctor. - if (!parent->RequiresReference() && entry.for_ctor) { - return entry.for_ctor(value); - } else if (entry.for_pnt_ctor) { - return entry.for_pnt_ctor(value, *parent); - } - - // Try to fallback on from_* calls. - if (entry.from_ctor != nullptr) { - return entry.from_ctor(*value); - } else if (entry.from_ptr_ctor != nullptr) { - return entry.from_ptr_ctor(internal::Clone(*value)); - } else if (entry.from_move_ctor != nullptr) { - auto ptr = internal::Clone(*value); - return entry.from_move_ctor(std::move(*ptr)); - } - - if (entry.for_ctor) { - return MissingCall(type); - } - return ValueForUnregistered(value, *parent); -} - -common::Value TypeRegistry::ValueFrom(const common::EnumType& type, - int32_t value) const { - std::function from_ctor; - auto itr = enum_registry_.find(type); - if (itr != enum_registry_.end()) { - from_ctor = itr->second.from_ctor; - } - if (from_ctor) { - return from_ctor(type, value); - } - - return common::Value::FromInt(value); -} - -common::Value TypeRegistry::ValueFromUnregistered( - std::unique_ptr value) const { - return common::Value::MakeObject>( - this, std::move(value)); -} - -common::Value TypeRegistry::ValueForUnregistered( - const google::protobuf::Message* value, common::RefProvider parent) const { - if (parent.RequiresReference()) { - return common::Value::MakeObject>>( - this, parent.GetRef(), value); - } - return common::Value::MakeObject>(this, - value); -} - -common::Value TypeRegistry::ValueFromAny( - const google::protobuf::Any& value) const { - common::Type type(FindObjectType(&value)); - if (!type.is_object()) { - return common::Value::MakeObject>( - value); - } - auto unpacked = type.object_type().Unpack(value); - if (unpacked == nullptr) { - return common::Value::FromError(internal::ParseError(type.full_name())); - } - return ValueFrom(std::move(unpacked)); -} - -TypeRegistry::ObjectRegistryEntry TypeRegistry::GetCalls( - const common::ObjectType& type) const { - TypeRegistry::ObjectRegistryEntry result; - auto itr = object_registry_.find(type); - if (itr != object_registry_.end()) { - result.from_ctor = itr->second.from_ctor; - result.from_ptr_ctor = itr->second.from_ptr_ctor; - result.from_move_ctor = itr->second.from_move_ctor; - result.for_ctor = itr->second.for_ctor; - result.for_pnt_ctor = itr->second.for_pnt_ctor; - } - return result; -} - -} // namespace protoutil -} // namespace expr -} // namespace api -} // namespace google diff --git a/protoutil/type_registry.h b/protoutil/type_registry.h deleted file mode 100644 index 9b5a96da4..000000000 --- a/protoutil/type_registry.h +++ /dev/null @@ -1,336 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_PROTOUTIL_TYPE_REGISTRY_H_ -#define THIRD_PARTY_CEL_CPP_PROTOUTIL_TYPE_REGISTRY_H_ - -#include - -#include "absl/container/node_hash_map.h" -#include "common/parent_ref.h" -#include "common/type.h" -#include "common/value.h" -#include "internal/cast.h" - -namespace google { -namespace api { -namespace expr { -namespace protoutil { - -/** - * A registry for adapting `google::protobuf::Message` to `expr::Value`. - * - * All messages have a default implementation that matches typical protobuf - * codegen. - * - * Constructor functions and default values can be registered to customize - * the behavior of specific protobuf message and enum types. - * - * For messages, several different types of constructor arguments are supported: - * (1) `const proto2:Message &` for copy constructors. - * (2) `google::protobuf::Message &&` for move constructors. - * (3) `std::unique_ptr` for owned ptr constructors. - * (4) `const google::protobuf::Message*' for unowned pointer constructors. In this case - * the argument is guaranteed out live any `Value` returned. Typically used - * for Arena allocated protobufs. - * (5) 'const proto2:Message*, const RefProvider&` for 'view' constructors. In - * this case the message is guaranteed to live at least as long as any - * `ValueRef` retrieved from the provider. Used to support 'view' - * implementations when elements of a `Container` are accessed. - * - * Helper functions are provided to register functions and classes that accept - * concrete message types. For example: - * - * // Register a (1)-style constructor function. - * GOOGLE_CHECK(reg.RegisterConstructor( - * [](const google::type::Money& value) { - * ... - * })); - * - * class Money : public google::api::expr::Object { - * public: - * Money(const google::type::Money& value); - * Money(google::type::Money&& value); - * }; - * // Register both Money constructors and set the default value to 'null'. - * GOOGLE_CHECK(2 == reg.RegisterClass(); - * - * Not all constructor types need to be registered: - * - If the underlying protobuf message is never needed by the returned - * Value, only a (1)-style constructor need be registered. - * - If some of the protobuf message fields can be used directly by the - * returned Value, (2) and (4)-style constructors can be registered to - * reduce memory copying. - * - If the entire protobuf message can be used directly by the returned Value, - * (3) and (4)-style constructors can be registered to reduce memory copying. - * - Register a (5)-style constructor if containers should lazily construct - * Values on access. - * - * At least one of (1), (2), and (3) must be registered for every customized - * type. If this is not the case, and error may be returned instead of a value. - * - * If a default value is not registered for a customized type, 'null' will be - * used. - * - * Callbacks are tried in the following order: - * (1) -> (3) -> (2) - * (2) -> (3) -> (1) - * (3) -> (2) -> (1) - * (4) -> (1) -> (3) -> (2) - * (5) -> (1) -> (3) -> (2) - * - * This class must live longer than any value created from it. - */ -class TypeRegistry { - public: - /** Set the default value, returned when a protobuf field is unset. */ - bool RegisterDefault(const common::ObjectType& object_type, - const common::Value& default_value); - - /** Set the default value, returned when a protobuf field is unset. */ - template - bool RegisterDefault(const common::Value& default_value) { - return RegisterDefault(common::ObjectType(ProtoType::descriptor()), - default_value); - } - - /** Register a copy constructor callable for the given object_type. */ - bool RegisterConstructor( - const common::ObjectType& object_type, - std::function from_ctor); - - /** Register a move constructor callable for the given object_type. */ - bool RegisterConstructor( - const common::ObjectType& object_type, - std::function from_ctor); - - /** Register an owned ptr constructor callable for the given object_type. */ - bool RegisterConstructor( - const common::ObjectType& object_type, - std::function)> from_ctor); - - /** Register an unowned ptr constructor callable for the given object_type. */ - bool RegisterConstructor( - const common::ObjectType& object_type, - std::function for_ctor); - - /** Register a view constructor callable for the given object_type when - * owned by a parent container. */ - bool RegisterConstructor( - const common::ObjectType& object_type, - std::function - for_ctor); - - /** Register a copy constructor callable for the given enum_type. */ - bool RegisterConstructor( - const common::EnumType& enum_type, - std::function from_ctor); - - /** - * Register a constructor for a concrete protobuf message. - * - * The signature of 'call' is used to determin which type of constructor - * callable it is. - * - * @tparam ProtoType the concrete protobuf message class. - * @tparam C the constructor type. Templitzed to allow inline of lambda - * invocation. - * @returns false if a constructor was previously registered. - */ - template - bool RegisterConstructor(C&& call); - - /** - * Register all applicable constructors of implementation class. - * - * Default values is set to 'null' no other default has been registered. - * - * @tparam ProtoType the concrete protobuf message class. - * @tparam Impl the implementation class to register. - * @returns The number of constructors successfully registered. - */ - template - std::size_t RegisterClass(); - - /** - * Register all applicable constructors of implementation class and - * set the default value. - * - * @tparam ProtoType the concrete protobuf message class. - * @tparam Impl the implementation class to register. - * @returns The number of constructors successfully registered. - */ - template - std::size_t RegisterClass(const common::Value& default_value) { - RegisterDefault(common::ObjectType::For(), default_value); - return RegisterClass(); - } - - /** Return the default value for the given default_instance */ - common::Value GetDefault(const google::protobuf::Message* default_msg) const; - - /** Return a value created from the given message. */ - common::Value ValueFrom(const google::protobuf::Message& value) const; - /** Return a value created from the given message. */ - common::Value ValueFrom(google::protobuf::Message&& value) const; - /** Return a value created from the given message. */ - common::Value ValueFrom(std::unique_ptr value) const; - /** Return a value created for the given message. */ - common::Value ValueFor(const google::protobuf::Message* value, - common::ParentRef parent = common::NoParent()) const; - - /** Return a value crated from the given enum value */ - common::Value ValueFrom(const common::EnumType& type, int32_t value) const; - - private: - struct ObjectRegistryEntry { - common::Value default_value = common::Value::FromUnknown(common::Id(-1)); - std::function from_ctor; - std::function from_move_ctor; - std::function)> - from_ptr_ctor; - std::function for_ctor; - std::function - for_pnt_ctor; - }; - - struct EnumRegistryEntry { - std::function from_ctor; - }; - - absl::node_hash_map object_registry_; - absl::node_hash_map enum_registry_; - - common::Value ValueFromUnregistered( - std::unique_ptr value) const; - common::Value ValueForUnregistered( - const google::protobuf::Message* value, - common::RefProvider parent = common::NoParent()) const; - common::Value ValueFromAny(const google::protobuf::Any& value) const; - - ObjectRegistryEntry GetCalls(const common::ObjectType& type) const; -}; - -namespace type_registry_internal { -using internal::general; -using internal::inst_of; -using internal::specialize_for; -using internal::specialize_ifd; -using internal::static_down_cast; - -// Wrap a callable, casting the first argument from M to T. -template -std::function WrapCall(C call) { - return [call](M value, Args&&... args) { - return call(static_down_cast(std::forward(value)), - std::forward(args)...); - }; -} - -// Wrap a constructor, casting the first argument from M to T. -template -std::function WrapCtor( - specialize_for(), inst_of()...))>) { - return [](M value, Args&&... args) { - return common::Value::MakeObject( - static_down_cast(std::forward(value)), - std::forward(args)...); - }; -} - -// No constructor specialization found, return nullptr. -template -std::function WrapCtor(general) { - return nullptr; -} - -template -using reg_if = - specialize_ifd()(inst_of(), inst_of()...))>; - -template -reg_if Register(TypeRegistry* registry, C call) { - return registry->RegisterConstructor( - common::ObjectType(ProtoType::descriptor()), - WrapCall(call)); -} - -template -reg_if> Register(TypeRegistry* registry, C call) { - return registry->RegisterConstructor( - common::ObjectType(ProtoType::descriptor()), - WrapCall, std::unique_ptr>( - call)); -} - -template -reg_if Register(TypeRegistry* registry, C call) { - return registry->RegisterConstructor( - common::ObjectType(ProtoType::descriptor()), - WrapCall(call)); -} - -template -reg_if Register( - TypeRegistry* registry, C call) { - return registry->RegisterConstructor( - common::ObjectType(ProtoType::descriptor()), - WrapCall(call)); -} - -} // namespace type_registry_internal - -template -bool TypeRegistry::RegisterConstructor(C&& call) { - return type_registry_internal::Register(this, - std::forward(call)); -} - -template -std::size_t TypeRegistry::RegisterClass() { - static_assert(!std::is_abstract::value, "class cannot be abstract"); - auto from_ctor = type_registry_internal::WrapCtor( - internal::specialize()); - auto from_move_ctor = - type_registry_internal::WrapCtor( - internal::specialize()); - auto from_ptr_ctor = - type_registry_internal::WrapCtor, - std::unique_ptr>( - internal::specialize()); - auto for_ctor = type_registry_internal::WrapCtor( - internal::specialize()); - auto for_pnt_ctor = - type_registry_internal::WrapCtor( - internal::specialize()); - common::ObjectType type(ProtoType::descriptor()); - std::size_t successes = 0; - if (from_ctor != nullptr && RegisterConstructor(type, from_ctor)) { - ++successes; - } - if (from_move_ctor != nullptr && RegisterConstructor(type, from_move_ctor)) { - ++successes; - } - if (from_ptr_ctor != nullptr && RegisterConstructor(type, from_ptr_ctor)) { - ++successes; - } - if (for_ctor != nullptr && RegisterConstructor(type, for_ctor)) { - ++successes; - } - if (for_pnt_ctor != nullptr && RegisterConstructor(type, for_pnt_ctor)) { - ++successes; - } - return successes; -} - -} // namespace protoutil -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_OBJECT_REGISTRY_H_ diff --git a/protoutil/type_registry_test.cc b/protoutil/type_registry_test.cc deleted file mode 100644 index f56f9c826..000000000 --- a/protoutil/type_registry_test.cc +++ /dev/null @@ -1,303 +0,0 @@ -#include "protoutil/type_registry.h" - -#include "google/protobuf/struct.pb.h" -#include "google/type/money.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "common/value.h" -#include "internal/ref_countable.h" -#include "testutil/util.h" - -namespace google { -namespace api { -namespace expr { -namespace protoutil { -namespace { - -using testutil::EqualsProto; - -// A base object that provides dummy impls for many required overrides. -struct BaseObject : common::Object { - common::Type object_type() const override { - return common::Type(common::ObjectType::For()); - }; - - common::Value GetMember(absl::string_view name) const override { - return common::Value::NullValue(); - } - - inline google::rpc::Status ForEach( - const std::function& call) const final { - return internal::OkStatus(); - } - - /** Serialize the object to protobuf any. */ - void To(google::protobuf::Any* value) const override {} - - bool EqualsImpl(const common::Object& same_type) const override { - return true; - } - - // Expose SelfRef for testing. - common::ParentRef SelfRefProvider() const { - return common::Object::SelfRefProvider(); - } -}; - -// An impl that holds a copy. -struct MoneyCopyImpl final : BaseObject { - explicit MoneyCopyImpl(const google::type::Money& value) - : value(value), moved(false) {} - explicit MoneyCopyImpl(google::type::Money&& value) - : value(std::move(value)), moved(true) {} - - bool owns_value() const override { return true; } - - google::type::Money value; - bool moved; -}; - -// An impl that holds a owned pointer. -struct MoneyOwnedImpl final : BaseObject { - explicit MoneyOwnedImpl(std::unique_ptr value) - : value(std::move(value)) {} - - std::unique_ptr value; - - bool owns_value() const override { return true; } -}; - -// An impl that holds a raw pointer. -struct MoneyUnownedImpl final : BaseObject { - explicit MoneyUnownedImpl(const google::type::Money* value) : value(value) {} - explicit MoneyUnownedImpl(const google::type::Money* value, - const common::RefProvider& parent) - : value(value), parent(parent.GetRef()) {} - - const google::type::Money* value; - common::ValueRef parent; - - bool owns_value() const override { return parent; } -}; - -TEST(TypeRegistry, Copy) { - google::type::Money expected_value; - expected_value.set_nanos(1); - - TypeRegistry reg; - auto result = reg.RegisterClass(); - EXPECT_EQ(2, result); - - auto test_value = [&expected_value](common::Value value, bool moved) { - ASSERT_NE(value.get_if(), nullptr) << value; - // Always owns value. - EXPECT_TRUE(value.owns_value()); - // Should have used move constructor if expected. - EXPECT_EQ(value.get_if()->moved, moved); - // Should have the right value. - EXPECT_THAT(value.get_if()->value, - EqualsProto(expected_value)); - }; - - // by const ref. - test_value(reg.ValueFrom(expected_value), false); - // by move. - test_value(reg.ValueFrom(google::type::Money(expected_value)), true); - // by owned ptr. - test_value( - reg.ValueFrom(absl::make_unique(expected_value)), - true); - // by unowned ptr. - test_value(reg.ValueFor(&expected_value), false); - // by parent owned ptr. - - auto parent = - internal::MakeReffed(google::type::Money(expected_value)); - test_value(reg.ValueFor(&expected_value, parent->SelfRefProvider()), false); -} - -TEST(TypeRegistry, Owned) { - google::type::Money expected_value; - expected_value.set_nanos(1); - MoneyOwnedImpl expected_object( - absl::make_unique(expected_value)); - - TypeRegistry reg; - auto result = reg.RegisterClass(); - EXPECT_EQ(1, result); - - auto test_value = [&expected_value](common::Value value) { - ASSERT_NE(value.get_if(), nullptr) << value; - // Always owns the value. - EXPECT_TRUE(value.owns_value()); - // Should have the right value. - EXPECT_THAT(*value.get_if()->value, - EqualsProto(expected_value)); - }; - - // by const ref. - test_value(reg.ValueFrom(expected_value)); - // by move. - test_value(reg.ValueFrom(google::type::Money(expected_value))); - // by owned ptr. - test_value( - reg.ValueFrom(absl::make_unique(expected_value))); - // by unowned ptr. - test_value(reg.ValueFor(&expected_value)); - // by parent owned ptr. - auto parent = internal::MakeReffed(expected_value); - test_value(reg.ValueFor(&expected_value, parent->SelfRefProvider())); -} - -TEST(TypeRegistry, Unowned) { - google::type::Money expected_value; - expected_value.set_nanos(1); - MoneyUnownedImpl expected_object(&expected_value); - - TypeRegistry reg; - auto result = reg.RegisterClass(); - EXPECT_EQ(2, result); - - auto test_value = [&expected_value](common::Value value, bool has_parent) { - ASSERT_NE(value.get_if(), nullptr) << value; - // Transitively owns the value when it holds a reference to a parent. - EXPECT_EQ(value.owns_value(), has_parent); - EXPECT_EQ(value.get_if()->parent, has_parent); - // Should point to the original value. - EXPECT_EQ(value.get_if()->value, &expected_value); - }; - - common::Value error = common::Value::FromError( - internal::InternalError("Missing callback for google.type.Money")); - - // All From* functions error, as no from constructor is provided. - // by const ref. - EXPECT_EQ(reg.ValueFrom(expected_value), error); - // by move. - EXPECT_EQ(reg.ValueFrom(google::type::Money(expected_value)), error); - // by owned ptr. - EXPECT_EQ( - reg.ValueFrom(absl::make_unique(expected_value)), - error); - - // by unowned ptr. - test_value(reg.ValueFor(&expected_value), false); - // by parent owned ptr. - // Ignores parent because parent doesn't own its value. - auto parent1 = internal::MakeReffed(&expected_value); - EXPECT_FALSE(parent1->owns_value()); - test_value(reg.ValueFor(&expected_value, parent1->SelfRefProvider()), false); - auto parent2 = internal::MakeReffed(expected_value); - EXPECT_TRUE(parent2->owns_value()); - test_value(reg.ValueFor(&expected_value, parent2->SelfRefProvider()), true); -} - -// Test that multiple impls can be registered for the same type. -TEST(TypeRegistry, All) { - TypeRegistry reg; - int result = 0; - result += reg.RegisterClass(); - result += reg.RegisterClass(); - result += reg.RegisterClass(); - EXPECT_EQ(5, result); - - google::type::Money expected_value; - expected_value.set_nanos(1); - MoneyUnownedImpl expected_object(&expected_value); - - // by const ref. - EXPECT_NE(nullptr, reg.ValueFrom(expected_value).get_if()); - // by move. - EXPECT_NE(nullptr, reg.ValueFrom(google::type::Money(expected_value)) - .get_if()); - // by owned ptr. - EXPECT_NE( - nullptr, - reg.ValueFrom(absl::make_unique(expected_value)) - .get_if()); - // by unowned ptr. - EXPECT_NE(nullptr, reg.ValueFor(&expected_value).get_if()); - // by parent owned ptr. - auto parent = internal::MakeReffed(expected_value); - EXPECT_NE(nullptr, reg.ValueFor(&expected_value, parent->SelfRefProvider()) - .get_if()); -} - -// Test that enums can be customized. -TEST(TypeRegistry, Enum) { - TypeRegistry reg; - google::protobuf::Value val; - auto actual = reg.ValueFrom(val).object_value().GetMember("null_value"); - EXPECT_EQ(actual, common::Value::FromInt(0)); - reg.RegisterConstructor( - common::EnumType(google::protobuf::NullValue_descriptor()), - [](common::EnumType, int32_t) { return common::Value::NullValue(); }); - actual = reg.ValueFrom(val).object_value().GetMember("null_value"); - EXPECT_EQ(actual, common::Value::NullValue()); -} - -// Test that unknown types work. -TEST(TypeRegistry, UnknownAny) { - TypeRegistry reg; - google::protobuf::Any unknown; - unknown.set_type_url("bad_type"); - unknown.set_value("hello"); - - auto value = reg.ValueFrom(unknown); - EXPECT_EQ(value.GetType(), common::Value::FromType("bad_type")); - EXPECT_EQ(value.object_value().GetMember("bye"), - common::Value::FromError(internal::UnknownType("bad_type"))); - - // Equal to itself. - EXPECT_EQ(value.hash_code(), reg.ValueFrom(unknown).hash_code()); - EXPECT_EQ(value, reg.ValueFrom(unknown)); - - // Not equal to other. - google::protobuf::Any other_unknown; - other_unknown.set_type_url("bad_type"); - other_unknown.set_value("bye"); - EXPECT_NE(value.hash_code(), reg.ValueFrom(other_unknown).hash_code()); - EXPECT_NE(value, reg.ValueFrom(other_unknown)); - - // Round trips losslessly. - google::protobuf::Any actual; - value.object_value().To(&actual); - EXPECT_THAT(actual, EqualsProto(unknown)); -} - -// Test that known types work. -TEST(TypeRegistry, KnownAny) { - TypeRegistry reg; - google::type::Money money; - money.set_nanos(100); - google::protobuf::Any known; - known.PackFrom(money); - - auto value = reg.ValueFrom(known); - EXPECT_EQ(value.GetType(), common::Value::FromType("google.type.Money")); - EXPECT_EQ(value.object_value().GetMember("nanos"), - common::Value::FromInt(100)); - - // Equal to itself. - EXPECT_EQ(value.hash_code(), reg.ValueFrom(known).hash_code()); - EXPECT_EQ(value, reg.ValueFrom(known)); - - // Not equal to other. - google::protobuf::Any other_known; - money.set_units(10); - other_known.PackFrom(money); - EXPECT_NE(value.hash_code(), reg.ValueFrom(other_known).hash_code()); - EXPECT_NE(value, reg.ValueFrom(other_known)); - - // Round trips losslessly. - google::protobuf::Any actual; - value.object_value().To(&actual); - EXPECT_THAT(actual, EqualsProto(known)); -} - -} // namespace -} // namespace protoutil -} // namespace expr -} // namespace api -} // namespace google diff --git a/runtime/BUILD b/runtime/BUILD new file mode 100644 index 000000000..34ff411a1 --- /dev/null +++ b/runtime/BUILD @@ -0,0 +1,672 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "activation_interface", + hdrs = ["activation_interface.h"], + deps = [ + ":function_overload_reference", + "//base:attributes", + "//common:value", + "//internal:status_macros", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "function_overload_reference", + hdrs = ["function_overload_reference.h"], + deps = [ + ":function", + "//common:function_descriptor", + ], +) + +cc_library( + name = "function_provider", + hdrs = ["function_provider.h"], + deps = [ + ":activation_interface", + ":function_overload_reference", + "//common:function_descriptor", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "activation", + srcs = ["activation.cc"], + hdrs = ["activation.h"], + deps = [ + ":activation_interface", + ":function", + ":function_overload_reference", + "//base:attributes", + "//common:function_descriptor", + "//common:value", + "//internal:status_macros", + "//runtime/internal:attribute_matcher", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "activation_test", + srcs = ["activation_test.cc"], + deps = [ + ":activation", + ":function", + ":function_overload_reference", + "//base:attributes", + "//common:function_descriptor", + "//common:value", + "//common:value_testing", + "//internal:testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "register_function_helper", + hdrs = ["register_function_helper.h"], + deps = + [ + ":function_registry", + "//common:function_descriptor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "function_registry", + srcs = ["function_registry.cc"], + hdrs = ["function_registry.h"], + deps = + [ + ":activation_interface", + ":function", + ":function_overload_reference", + ":function_provider", + "//common:function_descriptor", + "//common:kind", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "function_registry_test", + srcs = ["function_registry_test.cc"], + deps = [ + ":activation", + ":function", + ":function_adapter", + ":function_overload_reference", + ":function_provider", + ":function_registry", + "//common:function_descriptor", + "//common:kind", + "//common:value", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "runtime_options", + hdrs = ["runtime_options.h"], + deps = ["@com_google_absl//absl/base:core_headers"], +) + +cc_library( + name = "type_registry", + srcs = ["type_registry.cc"], + hdrs = ["type_registry.h"], + deps = [ + "//base:data", + "//common:type", + "//common:value", + "//runtime/internal:legacy_runtime_type_provider", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime", + hdrs = ["runtime.h"], + deps = [ + ":activation_interface", + ":runtime_issue", + "//base:ast", + "//base:data", + "//common:native_type", + "//common:value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime_builder", + hdrs = ["runtime_builder.h"], + deps = [ + ":function_registry", + ":runtime", + ":runtime_options", + ":type_registry", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime_builder_factory", + srcs = ["runtime_builder_factory.cc"], + hdrs = ["runtime_builder_factory.h"], + deps = [ + ":runtime_builder", + ":runtime_options", + "//internal:noop_delete", + "//internal:status_macros", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "standard_runtime_builder_factory", + srcs = ["standard_runtime_builder_factory.cc"], + hdrs = ["standard_runtime_builder_factory.h"], + deps = [ + ":runtime_builder", + ":runtime_builder_factory", + ":runtime_options", + ":standard_functions", + "//internal:noop_delete", + "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "standard_runtime_builder_factory_test", + srcs = ["standard_runtime_builder_factory_test.cc"], + deps = [ + ":activation", + ":runtime", + ":runtime_issue", + ":runtime_options", + ":standard_runtime_builder_factory", + "//base:builtins", + "//common:source", + "//common:value", + "//common:value_testing", + "//extensions:bindings_ext", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//parser", + "//parser:macro_registry", + "//parser:standard_macros", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "standard_functions", + srcs = ["standard_functions.cc"], + hdrs = ["standard_functions.h"], + deps = [ + ":function_registry", + ":runtime_options", + "//internal:status_macros", + "//runtime/standard:arithmetic_functions", + "//runtime/standard:comparison_functions", + "//runtime/standard:container_functions", + "//runtime/standard:container_membership_functions", + "//runtime/standard:equality_functions", + "//runtime/standard:logical_functions", + "//runtime/standard:regex_functions", + "//runtime/standard:string_functions", + "//runtime/standard:time_functions", + "//runtime/standard:type_conversion_functions", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "constant_folding", + srcs = ["constant_folding.cc"], + hdrs = ["constant_folding.h"], + deps = [ + ":runtime", + ":runtime_builder", + "//common:typeinfo", + "//eval/compiler:constant_folding", + "//internal:casts", + "//internal:noop_delete", + "//internal:status_macros", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "constant_folding_test", + srcs = ["constant_folding_test.cc"], + deps = [ + ":activation", + ":constant_folding", + ":runtime_builder", + ":runtime_options", + ":standard_runtime_builder_factory", + "//base:function_adapter", + "//common:function_descriptor", + "//common:value", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "regex_precompilation", + srcs = ["regex_precompilation.cc"], + hdrs = ["regex_precompilation.h"], + deps = [ + ":runtime", + ":runtime_builder", + "//common:native_type", + "//eval/compiler:regex_precompilation_optimization", + "//internal:casts", + "//internal:status_macros", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "regex_precompilation_test", + srcs = ["regex_precompilation_test.cc"], + deps = [ + ":activation", + ":constant_folding", + ":regex_precompilation", + ":register_function_helper", + ":runtime_builder", + ":runtime_options", + ":standard_runtime_builder_factory", + "//base:function_adapter", + "//common:value", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "reference_resolver", + srcs = ["reference_resolver.cc"], + hdrs = ["reference_resolver.h"], + deps = [ + ":runtime", + ":runtime_builder", + "//common:native_type", + "//eval/compiler:qualified_reference_resolver", + "//internal:casts", + "//internal:status_macros", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "reference_resolver_test", + srcs = ["reference_resolver_test.cc"], + deps = [ + ":activation", + ":reference_resolver", + ":register_function_helper", + ":runtime_builder", + ":runtime_options", + ":standard_runtime_builder_factory", + "//base:function_adapter", + "//common:value", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime_issue", + hdrs = ["runtime_issue.h"], + deps = ["@com_google_absl//absl/status"], +) + +cc_library( + name = "comprehension_vulnerability_check", + srcs = ["comprehension_vulnerability_check.cc"], + hdrs = ["comprehension_vulnerability_check.h"], + deps = [ + ":runtime", + ":runtime_builder", + "//common:native_type", + "//eval/compiler:comprehension_vulnerability_check", + "//internal:casts", + "//internal:status_macros", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "comprehension_vulnerability_check_test", + srcs = ["comprehension_vulnerability_check_test.cc"], + deps = [ + ":comprehension_vulnerability_check", + ":runtime_builder", + ":runtime_options", + ":standard_runtime_builder_factory", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "function_adapter", + hdrs = ["function_adapter.h"], + deps = [ + ":function", + ":register_function_helper", + "//common:function_descriptor", + "//common:value", + "//internal:status_macros", + "//runtime/internal:function_adapter", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "function_adapter_test", + srcs = ["function_adapter_test.cc"], + deps = [ + ":function", + ":function_adapter", + "//common:function_descriptor", + "//common:kind", + "//common:value", + "//common:value_testing", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "optional_types", + srcs = ["optional_types.cc"], + hdrs = ["optional_types.h"], + deps = [ + ":function_registry", + ":runtime_builder", + ":runtime_options", + "//base:function_adapter", + "//common:casting", + "//common:type", + "//common:value", + "//internal:casts", + "//internal:number", + "//internal:status_macros", + "//runtime/internal:errors", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "optional_types_test", + srcs = ["optional_types_test.cc"], + deps = [ + ":activation", + ":function", + ":optional_types", + ":reference_resolver", + ":runtime", + ":runtime_builder", + ":runtime_options", + ":standard_runtime_builder_factory", + "//common:function_descriptor", + "//common:kind", + "//common:value", + "//common:value_testing", + "//extensions/protobuf:runtime_adapter", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:options", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "function", + hdrs = [ + "function.h", + ], + deps = [ + "//common:value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "memory_safety_test", + srcs = ["memory_safety_test.cc"], + deps = [ + ":activation", + ":constant_folding", + ":function_adapter", + ":optional_types", + ":reference_resolver", + ":regex_precompilation", + ":runtime", + ":runtime_builder", + ":runtime_options", + ":standard_runtime_builder_factory", + "//checker:validation_result", + "//common:decl", + "//common:type", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:any_cc_proto", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "embedder_context", + hdrs = ["embedder_context.h"], + deps = [ + "//common:typeinfo", + "//common:value", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "embedder_context_test", + srcs = ["embedder_context_test.cc"], + deps = [ + ":embedder_context", + "//common:typeinfo", + "//internal:testing", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/runtime/activation.cc b/runtime/activation.cc new file mode 100644 index 000000000..e999f7a02 --- /dev/null +++ b/runtime/activation.cc @@ -0,0 +1,141 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/activation.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "common/function_descriptor.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +absl::StatusOr Activation::FindVariable( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(result != nullptr); + + auto iter = values_.find(name); + if (iter == values_.end()) { + return false; + } + + const ValueEntry& entry = iter->second; + if (entry.provider.has_value()) { + return ProvideValue(name, descriptor_pool, message_factory, arena, result); + } + if (entry.value.has_value()) { + *result = *entry.value; + return true; + } + return false; +} + +absl::StatusOr Activation::ProvideValue( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + absl::MutexLock lock(mutex_); + auto iter = values_.find(name); + ABSL_ASSERT(iter != values_.end()); + ValueEntry& entry = iter->second; + if (entry.value.has_value()) { + *result = *entry.value; + return true; + } + + CEL_ASSIGN_OR_RETURN( + auto provided, + (*entry.provider)(name, descriptor_pool, message_factory, arena)); + if (provided.has_value()) { + entry.value = std::move(provided); + *result = *entry.value; + return true; + } + return false; +} + +std::vector Activation::FindFunctionOverloads( + absl::string_view name) const { + std::vector result; + auto iter = functions_.find(name); + if (iter != functions_.end()) { + const std::vector& overloads = iter->second; + result.reserve(overloads.size()); + for (const auto& overload : overloads) { + result.push_back({*overload.descriptor, *overload.implementation}); + } + } + return result; +} + +bool Activation::InsertOrAssignValue(absl::string_view name, Value value) { + return values_ + .insert_or_assign(name, ValueEntry{std::move(value), absl::nullopt}) + .second; +} + +bool Activation::InsertOrAssignValueProvider(absl::string_view name, + ValueProvider provider) { + return values_ + .insert_or_assign(name, ValueEntry{absl::nullopt, std::move(provider)}) + .second; +} + +bool Activation::InsertFunction(const cel::FunctionDescriptor& descriptor, + std::unique_ptr impl) { + auto& overloads = functions_[descriptor.name()]; + for (auto& overload : overloads) { + if (overload.descriptor->ShapeMatches(descriptor)) { + return false; + } + } + overloads.push_back( + {std::make_unique(descriptor), std::move(impl)}); + return true; +} + +Activation::Activation(Activation&& other) { + using std::swap; + swap(*this, other); +} + +Activation& Activation::operator=(Activation&& other) { + using std::swap; + Activation tmp(std::move(other)); + swap(*this, tmp); + return *this; +} + +} // namespace cel diff --git a/runtime/activation.h b/runtime/activation.h new file mode 100644 index 000000000..8c4fb4073 --- /dev/null +++ b/runtime/activation.h @@ -0,0 +1,184 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/function_descriptor.h" +#include "common/value.h" +#include "runtime/activation_interface.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "runtime/internal/attribute_matcher.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace runtime_internal { +class ActivationAttributeMatcherAccess; +} + +// Thread-compatible implementation of a CEL Activation. +// +// Values can either be provided eagerly or via a provider. +class Activation final : public ActivationInterface { + public: + // Definition for value providers. + using ValueProvider = + absl::AnyInvocable>( + absl::string_view, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull)>; + + Activation() = default; + + // Move only. + Activation(Activation&& other); + + Activation& operator=(Activation&& other); + + // Implements ActivationInterface. + absl::StatusOr FindVariable( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const override; + using ActivationInterface::FindVariable; + + std::vector FindFunctionOverloads( + absl::string_view name) const override; + + absl::Span GetUnknownAttributes() + const override { + return unknown_patterns_; + } + + absl::Span GetMissingAttributes() + const override { + return missing_patterns_; + } + + // Bind a value to a named variable. + // + // Returns false if the entry for name was overwritten. + bool InsertOrAssignValue(absl::string_view name, Value value); + + // Bind a provider to a named variable. The result of the provider may be + // memoized by the activation. + // + // Returns false if the entry for name was overwritten. + bool InsertOrAssignValueProvider(absl::string_view name, + ValueProvider provider); + + void AddUnknownPattern(cel::AttributePattern pattern) { + unknown_patterns_.push_back(std::move(pattern)); + } + + void SetUnknownPatterns(std::vector patterns) { + unknown_patterns_ = std::move(patterns); + } + + void AddMissingPattern(cel::AttributePattern pattern) { + missing_patterns_.push_back(std::move(pattern)); + } + + void SetMissingPatterns(std::vector patterns) { + missing_patterns_ = std::move(patterns); + } + + // Returns true if the function was inserted (no other registered function has + // a matching descriptor). + bool InsertFunction(const cel::FunctionDescriptor& descriptor, + std::unique_ptr impl); + + private: + struct ValueEntry { + // If provider is present, then access must be synchronized to maintain + // thread-compatible semantics for the lazily provided value. + absl::optional value; + absl::optional provider; + }; + + struct FunctionEntry { + std::unique_ptr descriptor; + std::unique_ptr implementation; + }; + + friend class runtime_internal::ActivationAttributeMatcherAccess; + + void SetAttributeMatcher(const runtime_internal::AttributeMatcher* matcher) { + attribute_matcher_ = matcher; + } + + void SetAttributeMatcher( + std::unique_ptr matcher) { + owned_attribute_matcher_ = std::move(matcher); + attribute_matcher_ = owned_attribute_matcher_.get(); + } + + const runtime_internal::AttributeMatcher* absl_nullable GetAttributeMatcher() + const override { + return attribute_matcher_; + } + + friend void swap(Activation& a, Activation& b) { + using std::swap; + swap(a.values_, b.values_); + swap(a.functions_, b.functions_); + swap(a.unknown_patterns_, b.unknown_patterns_); + swap(a.missing_patterns_, b.missing_patterns_); + } + + // Internal getter for provided values. + // Assumes entry for name is present and is a provided value. + // Handles synchronization for caching the provided value. + absl::StatusOr ProvideValue( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; + + // mutex_ used for safe caching of provided variables + mutable absl::Mutex mutex_; + mutable absl::flat_hash_map values_; + + std::vector unknown_patterns_; + std::vector missing_patterns_; + + const runtime_internal::AttributeMatcher* attribute_matcher_ = nullptr; + std::unique_ptr + owned_attribute_matcher_; + + absl::flat_hash_map> functions_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ diff --git a/runtime/activation_interface.h b/runtime/activation_interface.h new file mode 100644 index 000000000..c589468de --- /dev/null +++ b/runtime/activation_interface.h @@ -0,0 +1,109 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_INTERFACE_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_INTERFACE_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "runtime/function_overload_reference.h" +#include "runtime/internal/attribute_matcher.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace runtime_internal { +class ActivationAttributeMatcherAccess; +} // namespace runtime_internal + +// Interface for providing runtime with variable lookups. +// +// Clients should prefer to use one of the concrete implementations provided by +// the CEL library rather than implementing this interface directly. +// TODO(uncreated-issue/40): After finalizing, make this public and add instructions +// for clients to migrate. +class ActivationInterface { + public: + virtual ~ActivationInterface() = default; + + // Find value for a string (possibly qualified) variable name. + virtual absl::StatusOr FindVariable( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; + absl::StatusOr> FindVariable( + absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + Value result; + CEL_ASSIGN_OR_RETURN( + auto found, + FindVariable(name, descriptor_pool, message_factory, arena, &result)); + if (found) { + return result; + } + return absl::nullopt; + } + + // Find a set of context function overloads by name. + virtual std::vector FindFunctionOverloads( + absl::string_view name) const = 0; + + // Return a list of unknown attribute patterns. + // + // If an attribute (select path) encountered during evaluation matches any of + // the patterns, the value will be treated as unknown and propagated in an + // unknown set. + // + // The returned span must remain valid for the duration of any evaluation + // using this this activation. + virtual absl::Span GetUnknownAttributes() + const = 0; + + // Return a list of missing attribute patterns. + // + // If an attribute (select path) encountered during evaluation matches any of + // the patterns, the value will be treated as missing and propagated as an + // error. + // + // The returned span must remain valid for the duration of any evaluation + // using this activation. + virtual absl::Span GetMissingAttributes() + const = 0; + + private: + friend class runtime_internal::ActivationAttributeMatcherAccess; + + // Returns the attribute matcher for this activation. + virtual const runtime_internal::AttributeMatcher* absl_nullable + GetAttributeMatcher() const { + return nullptr; + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_INTERFACE_H_ diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc new file mode 100644 index 000000000..4303116a3 --- /dev/null +++ b/runtime/activation_test.cc @@ -0,0 +1,419 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/activation.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" +#include "common/function_descriptor.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using testing::ElementsAre; +using testing::Eq; +using testing::IsEmpty; +using testing::Optional; +using testing::SizeIs; +using testing::Truly; +using testing::UnorderedElementsAre; + +MATCHER_P(IsIntValue, x, absl::StrCat("is IntValue Handle with value ", x)) { + const Value& handle = arg; + + return handle->Is() && handle.GetInt().NativeValue() == x; +} + +MATCHER_P(AttributePatternMatches, val, "matches AttributePattern") { + const AttributePattern& pattern = arg; + const Attribute& expected = val; + + return pattern.IsMatch(expected) == AttributePattern::MatchType::FULL; +} + +class FunctionImpl : public cel::Function { + public: + FunctionImpl() = default; + + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return NullValue(); + } +}; + +using ActivationTest = common_internal::ValueTest<>; + +TEST_F(ActivationTest, ValueNotFound) { + Activation activation; + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ActivationTest, InsertValue) { + Activation activation; + EXPECT_TRUE(activation.InsertOrAssignValue("var1", IntValue(42))); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); +} + +TEST_F(ActivationTest, InsertValueOverwrite) { + Activation activation; + EXPECT_TRUE(activation.InsertOrAssignValue("var1", IntValue(42))); + EXPECT_FALSE(activation.InsertOrAssignValue("var1", IntValue(0))); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(0)))); +} + +TEST_F(ActivationTest, InsertProvider) { + Activation activation; + + EXPECT_TRUE(activation.InsertOrAssignValueProvider( + "var1", + [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { return IntValue(42); })); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); +} + +TEST_F(ActivationTest, InsertProviderForwardsNotFound) { + Activation activation; + + EXPECT_TRUE(activation.InsertOrAssignValueProvider( + "var1", + [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { return absl::nullopt; })); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST_F(ActivationTest, InsertProviderForwardsStatus) { + Activation activation; + + EXPECT_TRUE(activation.InsertOrAssignValueProvider( + "var1", + [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { return absl::InternalError("test"); })); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + StatusIs(absl::StatusCode::kInternal, "test")); +} + +TEST_F(ActivationTest, ProviderMemoized) { + Activation activation; + int call_count = 0; + + EXPECT_TRUE(activation.InsertOrAssignValueProvider( + "var1", [&call_count](absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { + call_count++; + return IntValue(42); + })); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); + EXPECT_EQ(call_count, 1); +} + +TEST_F(ActivationTest, InsertProviderOverwrite) { + Activation activation; + + EXPECT_TRUE(activation.InsertOrAssignValueProvider( + "var1", + [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { return IntValue(42); })); + EXPECT_FALSE(activation.InsertOrAssignValueProvider( + "var1", + [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { return IntValue(0); })); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(0)))); +} + +TEST_F(ActivationTest, ValuesAndProvidersShareNamespace) { + Activation activation; + bool called = false; + + EXPECT_TRUE(activation.InsertOrAssignValue("var1", IntValue(41))); + EXPECT_TRUE(activation.InsertOrAssignValue("var2", IntValue(41))); + + EXPECT_FALSE(activation.InsertOrAssignValueProvider( + "var1", [&called](absl::string_view name, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) { + called = true; + return IntValue(42); + })); + + EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); + EXPECT_THAT(activation.FindVariable("var2", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(41)))); + EXPECT_TRUE(called); +} + +TEST_F(ActivationTest, SetUnknownAttributes) { + Activation activation; + + activation.SetUnknownPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + + EXPECT_THAT( + activation.GetUnknownAttributes(), + ElementsAre(AttributePatternMatches(Attribute( + "var1", {AttributeQualifier::OfString("field1")})), + AttributePatternMatches(Attribute( + "var1", {AttributeQualifier::OfString("field2")})))); +} + +TEST_F(ActivationTest, ClearUnknownAttributes) { + Activation activation; + + activation.SetUnknownPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + activation.SetUnknownPatterns({}); + + EXPECT_THAT(activation.GetUnknownAttributes(), IsEmpty()); +} + +TEST_F(ActivationTest, SetMissingAttributes) { + Activation activation; + + activation.SetMissingPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + + EXPECT_THAT( + activation.GetMissingAttributes(), + ElementsAre(AttributePatternMatches(Attribute( + "var1", {AttributeQualifier::OfString("field1")})), + AttributePatternMatches(Attribute( + "var1", {AttributeQualifier::OfString("field2")})))); +} + +TEST_F(ActivationTest, ClearMissingAttributes) { + Activation activation; + + activation.SetMissingPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + activation.SetMissingPatterns({}); + + EXPECT_THAT(activation.GetMissingAttributes(), IsEmpty()); +} + +TEST_F(ActivationTest, InsertFunctionOk) { + Activation activation; + + EXPECT_TRUE( + activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kUint}), + std::make_unique())); + EXPECT_TRUE( + activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kInt}), + std::make_unique())); + EXPECT_TRUE( + activation.InsertFunction(FunctionDescriptor("Fn2", false, {Kind::kInt}), + std::make_unique())); + + EXPECT_THAT( + activation.FindFunctionOverloads("Fn"), + UnorderedElementsAre( + Truly([](const FunctionOverloadReference& ref) { + return ref.descriptor.name() == "Fn" && + ref.descriptor.types() == std::vector{Kind::kUint}; + }), + Truly([](const FunctionOverloadReference& ref) { + return ref.descriptor.name() == "Fn" && + ref.descriptor.types() == std::vector{Kind::kInt}; + }))) + << "expected overloads Fn(int), Fn(uint)"; +} + +TEST_F(ActivationTest, InsertFunctionFails) { + Activation activation; + + EXPECT_TRUE( + activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kAny}), + std::make_unique())); + EXPECT_FALSE( + activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kInt}), + std::make_unique())); + + EXPECT_THAT(activation.FindFunctionOverloads("Fn"), + ElementsAre(Truly([](const FunctionOverloadReference& ref) { + return ref.descriptor.name() == "Fn" && + ref.descriptor.types() == std::vector{Kind::kAny}; + }))) + << "expected overload Fn(any)"; +} + +TEST_F(ActivationTest, MoveAssignment) { + Activation moved_from; + + ASSERT_TRUE( + moved_from.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kAny}), + std::make_unique())); + ASSERT_TRUE(moved_from.InsertOrAssignValue("val", IntValue(42))); + + ASSERT_TRUE(moved_from.InsertOrAssignValueProvider( + "val_provided", + [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) + -> absl::StatusOr> { return IntValue(42); })); + moved_from.SetUnknownPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + moved_from.SetMissingPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + + Activation moved_to; + moved_to = std::move(moved_from); + + EXPECT_THAT(moved_to.FindVariable("val", descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); + EXPECT_THAT(moved_to.FindVariable("val_provided", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); + EXPECT_THAT(moved_to.FindFunctionOverloads("Fn"), SizeIs(1)); + EXPECT_THAT(moved_to.GetUnknownAttributes(), SizeIs(2)); + EXPECT_THAT(moved_to.GetMissingAttributes(), SizeIs(2)); + + // moved from value is empty. (well defined but not specified state) + // NOLINTBEGIN(bugprone-use-after-move) + EXPECT_THAT(moved_from.FindVariable("val", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(moved_from.FindVariable("val_provided", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(moved_from.FindFunctionOverloads("Fn"), SizeIs(0)); + EXPECT_THAT(moved_from.GetUnknownAttributes(), SizeIs(0)); + EXPECT_THAT(moved_from.GetMissingAttributes(), SizeIs(0)); + // NOLINTEND(bugprone-use-after-move) +} + +TEST_F(ActivationTest, MoveCtor) { + Activation moved_from; + + ASSERT_TRUE( + moved_from.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kAny}), + std::make_unique())); + ASSERT_TRUE(moved_from.InsertOrAssignValue("val", IntValue(42))); + + ASSERT_TRUE(moved_from.InsertOrAssignValueProvider( + "val_provided", + [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) + -> absl::StatusOr> { return IntValue(42); })); + moved_from.SetUnknownPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + moved_from.SetMissingPatterns( + {AttributePattern("var1", + {AttributeQualifierPattern::OfString("field1")}), + AttributePattern("var1", + {AttributeQualifierPattern::OfString("field2")})}); + + Activation moved_to = std::move(moved_from); + + EXPECT_THAT(moved_to.FindVariable("val", descriptor_pool(), message_factory(), + arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); + EXPECT_THAT(moved_to.FindVariable("val_provided", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Optional(IsIntValue(42)))); + EXPECT_THAT(moved_to.FindFunctionOverloads("Fn"), SizeIs(1)); + EXPECT_THAT(moved_to.GetUnknownAttributes(), SizeIs(2)); + EXPECT_THAT(moved_to.GetMissingAttributes(), SizeIs(2)); + + // moved from value is empty. + // NOLINTBEGIN(bugprone-use-after-move) + EXPECT_THAT(moved_from.FindVariable("val", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(moved_from.FindVariable("val_provided", descriptor_pool(), + message_factory(), arena()), + IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(moved_from.FindFunctionOverloads("Fn"), SizeIs(0)); + EXPECT_THAT(moved_from.GetUnknownAttributes(), SizeIs(0)); + EXPECT_THAT(moved_from.GetMissingAttributes(), SizeIs(0)); + // NOLINTEND(bugprone-use-after-move) +} + +} // namespace +} // namespace cel diff --git a/runtime/comprehension_vulnerability_check.cc b/runtime/comprehension_vulnerability_check.cc new file mode 100644 index 000000000..2ab6657c2 --- /dev/null +++ b/runtime/comprehension_vulnerability_check.cc @@ -0,0 +1,66 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/comprehension_vulnerability_check.h" + +#include "absl/base/macros.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/native_type.h" +#include "eval/compiler/comprehension_vulnerability_check.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" + +namespace cel { + +namespace { + +using ::cel::internal::down_cast; +using ::cel::runtime_internal::RuntimeFriendAccess; +using ::cel::runtime_internal::RuntimeImpl; +using ::google::api::expr::runtime::CreateComprehensionVulnerabilityCheck; + +absl::StatusOr RuntimeImplFromBuilder( + RuntimeBuilder& builder) { + Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); + + if (RuntimeFriendAccess::RuntimeTypeId(runtime) != + NativeTypeId::For()) { + return absl::UnimplementedError( + "constant folding only supported on the default cel::Runtime " + "implementation."); + } + + RuntimeImpl& runtime_impl = down_cast(runtime); + + return &runtime_impl; +} + +} // namespace + +absl::Status EnableComprehensionVulnerabiltyCheck( + cel::RuntimeBuilder& builder) { + CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, + RuntimeImplFromBuilder(builder)); + ABSL_ASSERT(runtime_impl != nullptr); + runtime_impl->expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/runtime/comprehension_vulnerability_check.h b/runtime/comprehension_vulnerability_check.h new file mode 100644 index 000000000..0b7b18dd7 --- /dev/null +++ b/runtime/comprehension_vulnerability_check.h @@ -0,0 +1,35 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_COMPREHENSION_VULNERABILITY_CHECK_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_COMPREHENSION_VULNERABILITY_CHECK_H_ + +#include "absl/status/status.h" +#include "runtime/runtime_builder.h" + +namespace cel { + +// Enable a check for memory vulnerabilities within comprehension +// sub-expressions. +// +// Note: This flag is not necessary if you are only using Core CEL macros. +// +// Consider enabling this feature when using custom comprehensions, and +// absolutely enable the feature when using hand-written ASTs for +// comprehension expressions. +// +// This check is not exhaustive and shouldn't be used with deeply nested ASTs. +absl::Status EnableComprehensionVulnerabiltyCheck(RuntimeBuilder& builder); +} // namespace cel +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_COMPREHENSION_VULNERABILITY_CHECK_H_ diff --git a/runtime/comprehension_vulnerability_check_test.cc b/runtime/comprehension_vulnerability_check_test.cc new file mode 100644 index 000000000..ba9c7572a --- /dev/null +++ b/runtime/comprehension_vulnerability_check_test.cc @@ -0,0 +1,155 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/comprehension_vulnerability_check.h" + +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::extensions::ProtobufRuntimeAdapter; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::protobuf::TextFormat; +using ::testing::HasSubstr; + +constexpr absl::string_view kVulnerableExpr = R"pb( + expr { + id: 1 + comprehension_expr { + iter_var: "unused" + accu_var: "accu" + result { + id: 2 + ident_expr { name: "accu" } + } + accu_init { + id: 11 + list_expr { + elements { + id: 12 + const_expr { int64_value: 0 } + } + } + } + loop_condition { + id: 13 + const_expr { bool_value: true } + } + loop_step { + id: 3 + call_expr { + function: "_+_" + args { + id: 4 + ident_expr { name: "accu" } + } + args { + id: 5 + ident_expr { name: "accu" } + } + } + } + iter_range { + id: 6 + list_expr { + elements { + id: 7 + const_expr { int64_value: 0 } + } + elements { + id: 8 + const_expr { int64_value: 0 } + } + elements { + id: 9 + const_expr { int64_value: 0 } + } + elements { + id: 10 + const_expr { int64_value: 0 } + } + } + } + } + } +)pb"; + +TEST(ComprehensionVulnerabilityCheck, EnabledVulnerable) { + RuntimeOptions runtime_options; + ASSERT_OK_AND_ASSIGN( + RuntimeBuilder builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), + runtime_options)); + ASSERT_OK(EnableComprehensionVulnerabiltyCheck(builder)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + ASSERT_TRUE(TextFormat::ParseFromString(kVulnerableExpr, &expr)); + + EXPECT_THAT( + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Comprehension contains memory exhaustion vulnerability"))); +} + +TEST(ComprehensionVulnerabilityCheck, EnabledNotVulnerable) { + RuntimeOptions runtime_options; + ASSERT_OK_AND_ASSIGN( + RuntimeBuilder builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), + runtime_options)); + ASSERT_OK(EnableComprehensionVulnerabiltyCheck(builder)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("[0, 0, 0, 0].map(x, x + 1)")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), IsOk()); +} + +TEST(ComprehensionVulnerabilityCheck, DisabledVulnerable) { + RuntimeOptions runtime_options; + ASSERT_OK_AND_ASSIGN( + RuntimeBuilder builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), + runtime_options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + ASSERT_TRUE(TextFormat::ParseFromString(kVulnerableExpr, &expr)); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), IsOk()); +} + +} // namespace +} // namespace cel diff --git a/runtime/constant_folding.cc b/runtime/constant_folding.cc new file mode 100644 index 000000000..2d14154dc --- /dev/null +++ b/runtime/constant_folding.cc @@ -0,0 +1,158 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/constant_folding.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/typeinfo.h" +#include "eval/compiler/constant_folding.h" +#include "internal/casts.h" +#include "internal/noop_delete.h" +#include "internal/status_macros.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { +namespace { + +using ::cel::internal::down_cast; +using ::cel::runtime_internal::RuntimeFriendAccess; +using ::cel::runtime_internal::RuntimeImpl; + +absl::StatusOr RuntimeImplFromBuilder( + RuntimeBuilder& builder ABSL_ATTRIBUTE_LIFETIME_BOUND) { + Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); + if (RuntimeFriendAccess::RuntimeTypeId(runtime) != TypeId()) { + return absl::UnimplementedError( + "constant folding only supported on the default cel::Runtime " + "implementation."); + } + return down_cast(&runtime); +} + +absl::Status EnableConstantFoldingImpl( + RuntimeBuilder& builder, absl_nullable std::shared_ptr arena, + absl_nullable std::shared_ptr message_factory) { + CEL_ASSIGN_OR_RETURN(RuntimeImpl* absl_nonnull runtime_impl, + RuntimeImplFromBuilder(builder)); + if (arena != nullptr) { + runtime_impl->environment().KeepAlive(arena); + } + if (message_factory != nullptr) { + runtime_impl->environment().KeepAlive(message_factory); + } + runtime_impl->expr_builder().AddProgramOptimizer( + runtime_internal::CreateConstantFoldingOptimizer( + std::move(arena), std::move(message_factory))); + return absl::OkStatus(); +} + +} // namespace + +absl::Status EnableConstantFolding(RuntimeBuilder& builder) { + return EnableConstantFoldingImpl(builder, nullptr, nullptr); +} + +absl::Status EnableConstantFolding(RuntimeBuilder& builder, + google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + nullptr); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl_nonnull std::shared_ptr arena) { + ABSL_DCHECK(arena != nullptr); + return EnableConstantFoldingImpl(builder, std::move(arena), nullptr); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + google::protobuf::MessageFactory* absl_nonnull message_factory) { + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, nullptr, + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl_nonnull std::shared_ptr message_factory) { + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl(builder, nullptr, + std::move(message_factory)); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, + google::protobuf::MessageFactory* absl_nonnull message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, + absl_nonnull std::shared_ptr message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + std::move(message_factory)); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, + google::protobuf::MessageFactory* absl_nonnull message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, std::move(arena), + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, + absl_nonnull std::shared_ptr message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl(builder, std::move(arena), + std::move(message_factory)); +} + +} // namespace cel::extensions diff --git a/runtime/constant_folding.h b/runtime/constant_folding.h new file mode 100644 index 000000000..27a87f8cd --- /dev/null +++ b/runtime/constant_folding.h @@ -0,0 +1,69 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "runtime/runtime_builder.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +// Enable constant folding in the runtime being built. +// +// Constant folding eagerly evaluates sub-expressions with all constant inputs +// at plan time to simplify the resulting program. User functions are executed +// if they are eagerly bound. +// +// The provided, the `google::protobuf::Arena` must outlive the resulting runtime +// and any program it creates. Otherwise the runtime will create one as needed +// during planning for each program, unless one is explicitly provided during +// planning. +// +// The provided, the `google::protobuf::MessageFactory` must outlive the resulting runtime +// and any program it creates. Otherwise the runtime will create one as needed +// and use it for all planning and the resulting programs created from the +// runtime, unless one is explicitly provided during planning or evaluation. +absl::Status EnableConstantFolding(RuntimeBuilder& builder); +absl::Status EnableConstantFolding(RuntimeBuilder& builder, + google::protobuf::Arena* absl_nonnull arena); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + google::protobuf::MessageFactory* absl_nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl_nonnull std::shared_ptr message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, + google::protobuf::MessageFactory* absl_nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, + absl_nonnull std::shared_ptr message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, + google::protobuf::MessageFactory* absl_nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, + absl_nonnull std::shared_ptr message_factory); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ diff --git a/runtime/constant_folding_test.cc b/runtime/constant_folding_test.cc new file mode 100644 index 000000000..c59d5602a --- /dev/null +++ b/runtime/constant_folding_test.cc @@ -0,0 +1,228 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/constant_folding.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "base/function_adapter.h" +#include "common/function_descriptor.h" +#include "common/value.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::HasSubstr; + +using ValueMatcher = testing::Matcher; + +struct TestCase { + std::string name; + std::string expression; + ValueMatcher result_matcher; + absl::Status status; +}; + +MATCHER_P(IsIntValue, expected, "") { + const Value& value = arg; + return value->Is() && value.GetInt().NativeValue() == expected; +} + +MATCHER_P(IsBoolValue, expected, "") { + const Value& value = arg; + return value->Is() && value.GetBool().NativeValue() == expected; +} + +MATCHER_P(IsErrorValue, expected_substr, "") { + const Value& value = arg; + return value->Is() && + absl::StrContains(value.GetError().NativeValue().message(), + expected_substr); +} + +class ConstantFoldingExtTest : public testing::TestWithParam {}; + +TEST_P(ConstantFoldingExtTest, Runner) { + google::protobuf::Arena arena; + RuntimeOptions options; + const TestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + auto status = BinaryFunctionAdapter, const StringValue&, + const StringValue&>:: + RegisterGlobalOverload( + "prepend", + [](const StringValue& value, const StringValue& prefix) { + return StringValue( + absl::StrCat(prefix.ToString(), value.ToString())); + }, + builder.function_registry()); + ASSERT_THAT(status, IsOk()); + + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expression)); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + Activation activation; + + auto result = program->Evaluate(&arena, activation); + if (test_case.status.ok()) { + ASSERT_OK_AND_ASSIGN(Value value, std::move(result)); + + EXPECT_THAT(value, test_case.result_matcher); + return; + } + + EXPECT_THAT(result.status(), StatusIs(test_case.status.code(), + HasSubstr(test_case.status.message()))); +} + +INSTANTIATE_TEST_SUITE_P( + Cases, ConstantFoldingExtTest, + testing::ValuesIn(std::vector{ + {"sum", "1 + 2 + 3", IsIntValue(6)}, + {"list_create", "[1, 2, 3, 4].filter(x, x < 4).size()", IsIntValue(3)}, + {"string_concat", "('12' + '34' + '56' + '78' + '90').size()", + IsIntValue(10)}, + {"comprehension", "[1, 2, 3, 4].exists(x, x in [4, 5, 6, 7])", + IsBoolValue(true)}, + {"nested_comprehension", + "[1, 2, 3, 4].exists(x, [1, 2, 3, 4].all(y, y <= x))", + IsBoolValue(true)}, + {"runtime_error", "[1, 2, 3, 4].exists(x, ['4'].all(y, y <= x))", + IsErrorValue("No matching overloads")}, + {"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", IsIntValue(2)}, + {"custom_function", "prepend('def', 'abc') == 'abcdef'", + IsBoolValue(true)}}), + + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(ConstantFoldingExtTest, LazyFunctionNotFolded) { + google::protobuf::Arena arena; + RuntimeOptions options; + + ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + int call_count = 0; + using FunctionAdapter = + BinaryFunctionAdapter, const StringValue&, + const StringValue&>; + auto fn = FunctionAdapter::WrapFunction( + [&call_count](const StringValue& value, const StringValue& prefix) { + call_count++; + return StringValue(absl::StrCat(prefix.ToString(), value.ToString())); + }); + FunctionDescriptor descriptor = FunctionAdapter::CreateDescriptor( + "lazy_prepend", /*receiver_style=*/false); + ASSERT_THAT(builder.function_registry().RegisterLazyFunction(descriptor), + IsOk()); + + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("lazy_prepend('def', 'abc') == 'abcdef'")); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + EXPECT_EQ(call_count, 0); + Activation activation; + activation.InsertFunction(descriptor, std::move(fn)); + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 1); + EXPECT_THAT(result, IsBoolValue(true)); + + ASSERT_OK_AND_ASSIGN(result, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 2); + EXPECT_THAT(result, IsBoolValue(true)); +} + +TEST(ConstantFoldingExtTest, ContextualFunctionNotFolded) { + google::protobuf::Arena arena; + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + int call_count = 0; + + auto status = BinaryFunctionAdapter< + absl::StatusOr, const StringValue&, + const StringValue&>::Register("contextual_prepend", + /*receiver_style=*/false, + [&call_count](const StringValue& value, + const StringValue& prefix) { + call_count++; + return StringValue(absl::StrCat( + prefix.ToString(), value.ToString())); + }, + builder.function_registry(), + {/*.is_strict=*/true, + /*is_contextual=*/true}); + ASSERT_THAT(status, IsOk()); + + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("contextual_prepend('def', 'abc') == 'abcdef'")); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + EXPECT_EQ(call_count, 0); + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 1); + EXPECT_THAT(value, IsBoolValue(true)); + + ASSERT_OK_AND_ASSIGN(value, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 2); + EXPECT_THAT(value, IsBoolValue(true)); +} + +} // namespace +} // namespace cel::extensions diff --git a/runtime/embedder_context.h b/runtime/embedder_context.h new file mode 100644 index 000000000..49407882e --- /dev/null +++ b/runtime/embedder_context.h @@ -0,0 +1,147 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/types/optional.h" +#include "common/typeinfo.h" +#include "common/value.h" + +namespace cel { + +// EmbedderContext is used to package custom content defined by the embedder +// during CEL evaluation. The custom content is indexed by type. Value types +// are returned as absl::optional where T is the value type. Pointer types +// are returned as T*. +// +// The content values must be trivially copyable and have a size <= 16 bytes. +// These are typically pointers or small value types (e.g. primitives, enums). +// +// An all zero memory value is used to represent an empty value. The caller +// must provide some way to disambiguate if that is a meaningfully distinct +// value from nullopt / nullptr. +// +// Scope is used to provide a distinction between multiple usages of CEL in the +// same binary. +class EmbedderContext { + public: + template + static EmbedderContext From(Args... args); + + // Convenience using a default scope. + template + static EmbedderContext From(Args... args) { + return From(args...); + } + + template + std::enable_if_t, absl::optional> Get() const; + + template + std::enable_if_t, T> Get() const; + + template + std::enable_if_t, absl::optional> Get() const { + return Get(); + } + + template + std::enable_if_t, T> Get() const { + return Get(); + } + + private: + template + void Set(T arg, Ts... args); + + template + void Set() {} + + absl::InlinedVector values_; + // These are included to check for bad accesses in debug mode. + absl::InlinedVector type_ids_; + TypeInfo scope_; +}; + +template +void EmbedderContext::Set(Arg arg, Args... args) { + using IndexType = std::decay_t; + size_t index = TypeIdInSet::template IndexFor(); + if (index >= values_.size()) { + values_.resize(index + 1, cel::CustomValueContent::Zero()); + type_ids_.resize(index + 1); + } + values_[index] = cel::CustomValueContent::From(arg); + type_ids_[index] = cel::TypeId(); + Set(args...); +} + +template +std::enable_if_t, absl::optional> +EmbedderContext::Get() const { + ABSL_DCHECK_EQ(cel::TypeId(), scope_) + << "EmbedderContext::Get wrong scope"; + using IndexType = std::decay_t; + size_t index = TypeIdInSet::template IndexFor(); + if (index >= values_.size()) { + return absl::nullopt; + } + + const auto& content = values_[index]; + if (content.IsZero()) return absl::nullopt; + + ABSL_DCHECK_EQ(type_ids_.size(), values_.size()); + ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId()) + << "EmbedderContext::Get wrong type id"; + + return content.To(); +} + +template +std::enable_if_t, T> EmbedderContext::Get() const { + ABSL_DCHECK_EQ(cel::TypeId(), scope_) + << "EmbedderContext::Get wrong scope"; + using IndexType = std::decay_t; + size_t index = TypeIdInSet::template IndexFor(); + if (index >= values_.size()) { + return nullptr; + } + + const auto& content = values_[index]; + if (content.IsZero()) return nullptr; + + ABSL_DCHECK_EQ(type_ids_.size(), values_.size()); + ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId()) + << "EmbedderContext::Get wrong type id"; + + return content.To(); +} + +template +EmbedderContext EmbedderContext::From(Args... args) { + EmbedderContext context; + context.scope_ = TypeId(); + context.Set(args...); + return context; +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ diff --git a/runtime/embedder_context_test.cc b/runtime/embedder_context_test.cc new file mode 100644 index 000000000..d8cbbb736 --- /dev/null +++ b/runtime/embedder_context_test.cc @@ -0,0 +1,93 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/embedder_context.h" + +#include + +#include "absl/types/optional.h" +#include "common/typeinfo.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::Optional; + +TEST(EmbedderContextTest, From) { + struct TestScope {}; + EmbedderContext context = EmbedderContext::From(int64_t{42}); + EXPECT_THAT((context.Get()), Optional(42)); + EXPECT_EQ((context.Get()), absl::nullopt); + + EmbedderContext context2 = EmbedderContext::From(uint64_t{42}); + EXPECT_THAT((context2.Get()), Optional(42)); + EXPECT_EQ((context2.Get()), absl::nullopt); + + // Side effect, but checking that we keep a dense range. + EXPECT_EQ(cel::TypeIdInSet::Size(), 2); +} + +TEST(EmbedderContextTest, FromOutOfLine) { + struct TestScope {}; + EmbedderContext context = + EmbedderContext::From(int64_t{42}, uint64_t{43}, double{44}); + + EXPECT_THAT((context.Get()), Optional(42)); + EXPECT_THAT((context.Get()), Optional(43)); + EXPECT_THAT((context.Get()), Optional(44)); + EXPECT_EQ((context.Get()), absl::nullopt); + + // Note: Referencing a type not intended to be stored will still reserve a + // slot in the TypeIdInSet. + EXPECT_EQ(cel::TypeIdInSet::Size(), 4); +} + +TEST(EmbedderContextTest, FromPtrs) { + struct TestScope {}; + struct TestPointee { + } foo; + int64_t pointee2; + + EmbedderContext context = EmbedderContext::From( + &foo, const_cast(&pointee2)); + EXPECT_EQ((context.Get()), &pointee2); + EXPECT_EQ((context.Get()), &foo); + + EmbedderContext context2 = EmbedderContext::From(&foo); + EXPECT_EQ((context2.Get()), nullptr); + EXPECT_EQ((context2.Get()), &foo); + + // Note: const int* not the same as int*. + EXPECT_EQ(cel::TypeIdInSet::Size(), 3); +} + +TEST(EmbedderContextTest, FromDefaultScope) { + EmbedderContext context = EmbedderContext::From(int64_t{42}); + EXPECT_THAT((context.Get()), Optional(42)); + EXPECT_EQ((context.Get()), absl::nullopt); +} + +// These death assertions are only enabled when compiled in debug mode. +// Caller is responsible for adequately testing since we're limited in what +// we can statically check due to the type-erasure. +TEST(EmbedderContextDeathTest, GetWithWrongScope) { + struct TestScope {}; + EmbedderContext context = EmbedderContext::From(int64_t{42}); + EXPECT_DEBUG_DEATH( + { context.Get(); }, "EmbedderContext::Get wrong scope"); +} + +} // namespace +} // namespace cel diff --git a/runtime/function.h b/runtime/function.h new file mode 100644 index 000000000..a2c842f81 --- /dev/null +++ b/runtime/function.h @@ -0,0 +1,115 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_ + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +class EmbedderContext; + +// Interface for extension functions. +// +// The host for the CEL environment may provide implementations to define custom +// extension functions. +// +// The runtime expects functions to be deterministic and side-effect free. +class Function { + public: + virtual ~Function() = default; + + // Context for the function invocation. + // + // Collects evaluation state that may be needed for the function to operate. + // + // The function implementation should not retain a reference to the context + // object beyond the duration of the function call or modify the InvokeContext + // itself. + class InvokeContext { + public: + InvokeContext( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + const EmbedderContext* absl_nullable embedder_context = nullptr) + : descriptor_pool_(descriptor_pool), + message_factory_(message_factory), + arena_(arena), + embedder_context_(embedder_context) {} + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { + return descriptor_pool_; + } + + google::protobuf::MessageFactory* absl_nonnull message_factory() const { + return message_factory_; + } + + google::protobuf::Arena* absl_nonnull arena() const { return arena_; } + + const EmbedderContext* absl_nullable embedder_context() const { + return embedder_context_; + } + + void set_embedder_context( + const EmbedderContext* absl_nullable embedder_context) { + embedder_context_ = embedder_context; + } + + private: + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; + google::protobuf::MessageFactory* absl_nonnull message_factory_; + google::protobuf::Arena* absl_nonnull arena_; + const EmbedderContext* absl_nullable embedder_context_; + }; + + ABSL_DEPRECATED("Use the InvokeContext overload instead.") + inline absl::StatusOr Invoke( + absl::Span args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + // Attempt to evaluate an extension function based on the runtime arguments + // during the evaluation of a CEL expression. + // + // A non-ok status is interpreted as an unrecoverable error in evaluation ( + // e.g. data corruption). This stops evaluation and is propagated immediately. + // + // A cel::ErrorValue typed result is considered a recoverable error and + // follows CEL's logical short-circuiting behavior. + virtual absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const = 0; +}; + +absl::StatusOr Function::Invoke( + absl::Span args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + InvokeContext context(descriptor_pool, message_factory, arena); + return Invoke(args, context); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_ diff --git a/runtime/function_adapter.h b/runtime/function_adapter.h new file mode 100644 index 000000000..62932a027 --- /dev/null +++ b/runtime/function_adapter.h @@ -0,0 +1,830 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Definitions for template helpers to wrap C++ functions as CEL extension +// function implementations. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/function_descriptor.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "runtime/function.h" +#include "runtime/internal/function_adapter.h" +#include "runtime/register_function_helper.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace runtime_internal { + +template +struct AdaptedTypeTraits { + using AssignableType = T; + + static T ToArg(AssignableType v) { return v; } +}; + +// Specialization for cref parameters without forcing a temporary copy of the +// underlying handle argument. +template <> +struct AdaptedTypeTraits { + using AssignableType = const Value*; + + static std::reference_wrapper ToArg(AssignableType v) { + return *v; + } +}; + +template <> +struct AdaptedTypeTraits { + using AssignableType = const StringValue*; + + static std::reference_wrapper ToArg(AssignableType v) { + return *v; + } +}; + +template <> +struct AdaptedTypeTraits { + using AssignableType = const BytesValue*; + + static std::reference_wrapper ToArg(AssignableType v) { + return *v; + } +}; + +// Partial specialization for other cases. +// +// These types aren't referenceable since they aren't actually +// represented as alternatives in the underlying variant. +// +// This still requires an implicit copy and corresponding ref-count increase. +template +struct AdaptedTypeTraits { + using AssignableType = T; + + static T ToArg(AssignableType v) { return v; } +}; + +template +struct AdaptHelperImpl { + template + static absl::Status Apply(absl::Span input, T& output) { + static_assert(sizeof...(Args) > 0); + static_assert(std::tuple_size_v == sizeof...(Args)); + CEL_RETURN_IF_ERROR(ValueToAdaptedVisitor{input[I]}(&std::get(output))); + if constexpr (I == sizeof...(Args) - 1) { + return absl::OkStatus(); + } else { + CEL_RETURN_IF_ERROR( + (AdaptHelperImpl::template Apply(input, output))); + } + return absl::OkStatus(); + } +}; + +template +struct AdaptHelper { + template + static absl::Status Apply(absl::Span input, T& output) { + return AdaptHelperImpl<0, Args...>::template Apply(input, output); + } +}; + +template +struct ToArgsImpl { + template + struct El { + using type = T; + constexpr static size_t index = I; + }; + + template + struct ZipHolder { + template + static ResultType ToArgs(Op&& op, const TupleType& argbuffer, + const Function::InvokeContext& context) { + return std::forward(op)( + runtime_internal::AdaptedTypeTraits::ToArg( + std::get(argbuffer))..., + context); + } + }; + + template + static ZipHolder...> MakeZip(const std::index_sequence&) { + return ZipHolder...>{}; + } +}; + +template +struct ToArgsHelper { + template + static ResultType Apply(Op&& op, const TupleType& argbuffer, + const Function::InvokeContext& context) { + using Impl = ToArgsImpl; + using Zip = decltype(Impl::MakeZip(std::index_sequence_for{})); + return Zip::template ToArgs(std::forward(op), argbuffer, + context); + } +}; + +} // namespace runtime_internal + +// Adapter class for generating CEL extension functions from a one argument +// function. +// +// See documentation for Binary Function adapter for general recommendations. +// +// Example Usage: +// double Invert(ValueManager&, double x) { +// return 1 / x; +// } +// +// { +// std::unique_ptr builder; +// +// CEL_RETURN_IF_ERROR( +// builder->GetRegistry()->Register( +// UnaryFunctionAdapter::CreateDescriptor("inv", +// /*receiver_style=*/false), +// UnaryFunctionAdapter::WrapFunction(&Invert))); +// } +// // example CEL expression +// inv(4) == 1/4 [true] +template +class NullaryFunctionAdapter + : public RegisterHelper> { + public: + using FunctionType = + absl::AnyInvocable; + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + template + static std::enable_if_t< + std::is_invocable_v, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction([function = std::forward(function)]( + const Function::InvokeContext& context) -> T { + return function(context.descriptor_pool(), context.message_factory(), + context.arena()); + }); + } + + template + static std::enable_if_t, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction([function = std::forward(function)]( + const Function::InvokeContext& context) -> T { + return function(); + }); + } + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, receiver_style, {}, options); + } + + private: + class UnaryFunctionImpl : public Function { + public: + explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + absl::StatusOr Invoke( + absl::Span args, + const Function::InvokeContext& context) const final { + if (args.size() != 0) { + return absl::InvalidArgumentError( + "unexpected number of arguments for nullary function"); + } + + if constexpr (std::is_same_v || + std::is_same_v>) { + return fn_(context); + } else { + T result = fn_(context); + + return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); + } + } + + private: + FunctionType fn_; + }; +}; + +// Adapter class for generating CEL extension functions from a one argument +// function. +// +// See documentation for Binary Function adapter for general recommendations. +// +// Example Usage: +// double Invert(ValueManager&, double x) { +// return 1 / x; +// } +// +// { +// std::unique_ptr builder; +// +// CEL_RETURN_IF_ERROR( +// builder->GetRegistry()->Register( +// UnaryFunctionAdapter::CreateDescriptor("inv", +// /*receiver_style=*/false), +// UnaryFunctionAdapter::WrapFunction(&Invert))); +// } +// // example CEL expression +// inv(4) == 1/4 [true] +template +class UnaryFunctionAdapter : public RegisterHelper> { + public: + using FunctionType = + absl::AnyInvocable; + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + template + static std::enable_if_t< + std::is_invocable_v, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction( + [function = std::forward(function)]( + U arg1, const Function::InvokeContext& context) -> T { + return function(arg1, context.descriptor_pool(), + context.message_factory(), context.arena()); + }); + } + + template + static std::enable_if_t, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction( + [function = std::forward(function)]( + U arg1, const Function::InvokeContext& context) -> T { + return function(arg1); + }); + } + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict) { + return CreateDescriptor( + name, receiver_style, + FunctionDescriptorOptions{is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, receiver_style, + {runtime_internal::AdaptedKind()}, options); + } + + private: + class UnaryFunctionImpl : public Function { + public: + explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + absl::StatusOr Invoke( + absl::Span args, + const Function::InvokeContext& context) const final { + using ArgTraits = runtime_internal::AdaptedTypeTraits; + if (args.size() != 1) { + return absl::InvalidArgumentError( + "unexpected number of arguments for unary function"); + } + typename ArgTraits::AssignableType arg1; + + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); + if constexpr (std::is_same_v || + std::is_same_v>) { + return fn_(ArgTraits::ToArg(arg1), context); + } else { + T result = fn_(ArgTraits::ToArg(arg1), context); + + return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); + } + } + + private: + FunctionType fn_; + }; +}; + +// Adapter class for generating CEL extension functions from a two argument +// function. Generates an implementation of the cel::Function interface that +// calls the function to wrap. +// +// Extension functions must distinguish between recoverable errors (error that +// should participate in CEL's error pruning) and unrecoverable errors (a non-ok +// absl::Status that stops evaluation). The function to wrap may return +// StatusOr to propagate a Status, or return a Value with an Error +// value to introduce a CEL error. +// +// To introduce an extension function that may accept any kind of CEL value as +// an argument, the wrapped function should use a Value parameter and +// check the type of the argument at evaluation time. +// +// Supported CEL to C++ type mappings: +// bool -> bool +// double -> double +// uint -> uint64_t +// int -> int64_t +// timestamp -> absl::Time +// duration -> absl::Duration +// +// Complex types may be referred to by cref or value. +// To return these, users should return a Value. +// any/dyn -> Value, const Value& +// string -> StringValue | const StringValue& +// bytes -> BytesValue | const BytesValue& +// list -> ListValue | const ListValue& +// map -> MapValue | const MapValue& +// struct -> StructValue | const StructValue& +// null -> NullValue | const NullValue& +// +// To intercept error and unknown arguments, users must use a non-strict +// overload with all arguments typed as any and check the kind of the +// Value argument. +// +// Example Usage: +// double SquareDifference(ValueManager&, double x, double y) { +// return x * x - y * y; +// } +// +// { +// RuntimeBuilder builder; +// // Initialize Expression builder with built-ins as needed. +// +// CEL_RETURN_IF_ERROR( +// builder.function_registry().Register( +// BinaryFunctionAdapter::CreateDescriptor( +// "sq_diff", /*receiver_style=*/false), +// BinaryFunctionAdapter::WrapFunction( +// &SquareDifference))); +// +// +// // Alternative shorthand +// // See RegisterHelper (template base class) for details. +// // runtime/register_function_helper.h +// auto status = BinaryFunctionAdapter:: +// RegisterGlobalOverload( +// "sq_diff", +// &SquareDifference, +// builder.function_registry()); +// CEL_RETURN_IF_ERROR(status); +// } +// +// example CEL expression: +// sq_diff(4, 3) == 7 [true] +// +template +class BinaryFunctionAdapter + : public RegisterHelper> { + public: + using FunctionType = + absl::AnyInvocable; + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + template + static std::enable_if_t< + std::is_invocable_v, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction( + [function = std::forward(function)]( + U arg1, V arg2, const Function::InvokeContext& context) -> T { + return function(arg1, arg2, context.descriptor_pool(), + context.message_factory(), context.arena()); + }); + } + + template + static std::enable_if_t, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction( + [function = std::forward(function)]( + U arg1, V arg2, const Function::InvokeContext& context) -> T { + return function(arg1, arg2); + }); + } + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, receiver_style, + {runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind()}, + options); + } + + private: + class BinaryFunctionImpl : public Function { + public: + explicit BinaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + absl::StatusOr Invoke( + absl::Span args, + const Function::InvokeContext& context) const final { + using Arg1Traits = runtime_internal::AdaptedTypeTraits; + using Arg2Traits = runtime_internal::AdaptedTypeTraits; + if (args.size() != 2) { + return absl::InvalidArgumentError( + "unexpected number of arguments for binary function"); + } + typename Arg1Traits::AssignableType arg1; + typename Arg2Traits::AssignableType arg2; + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[1]}(&arg2)); + + if constexpr (std::is_same_v || + std::is_same_v>) { + return fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), context); + } else { + T result = + fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), context); + + return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); + } + } + + private: + BinaryFunctionAdapter::FunctionType fn_; + }; +}; + +template +class TernaryFunctionAdapter + : public RegisterHelper> { + public: + using FunctionType = + absl::AnyInvocable; + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + template + static std::enable_if_t< + std::is_invocable_v< + F, U, V, W, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull>, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction([function = std::forward(function)]( + U arg1, V arg2, W arg3, + const Function::InvokeContext& context) -> T { + return function(arg1, arg2, arg3, context.descriptor_pool(), + context.message_factory(), context.arena()); + }); + } + + template + static std::enable_if_t, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction([function = std::forward(function)]( + U arg1, V arg2, W arg3, + const Function::InvokeContext& context) -> T { + return function(arg1, arg2, arg3); + }); + } + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict) { + return CreateDescriptor( + name, receiver_style, + FunctionDescriptorOptions{is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { + return FunctionDescriptor( + name, receiver_style, + {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind()}, + options); + } + + private: + class TernaryFunctionImpl : public Function { + public: + explicit TernaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + absl::StatusOr Invoke( + absl::Span args, + const Function::InvokeContext& context) const final { + using Arg1Traits = runtime_internal::AdaptedTypeTraits; + using Arg2Traits = runtime_internal::AdaptedTypeTraits; + using Arg3Traits = runtime_internal::AdaptedTypeTraits; + if (args.size() != 3) { + return absl::InvalidArgumentError( + "unexpected number of arguments for ternary function"); + } + typename Arg1Traits::AssignableType arg1; + typename Arg2Traits::AssignableType arg2; + typename Arg3Traits::AssignableType arg3; + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[1]}(&arg2)); + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[2]}(&arg3)); + + if constexpr (std::is_same_v || + std::is_same_v>) { + return fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), + Arg3Traits::ToArg(arg3), context); + } else { + T result = fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), + Arg3Traits::ToArg(arg3), context); + + return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); + } + } + + private: + TernaryFunctionAdapter::FunctionType fn_; + }; +}; + +template +class QuaternaryFunctionAdapter + : public RegisterHelper> { + public: + using FunctionType = + absl::AnyInvocable; + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + template + static std::enable_if_t< + std::is_invocable_v< + F, U, V, W, X, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull>, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction([function = std::forward(function)]( + U arg1, V arg2, W arg3, X arg4, + const Function::InvokeContext& context) -> T { + return function(arg1, arg2, arg3, arg4, context.descriptor_pool(), + context.message_factory(), context.arena()); + }); + } + + template + static std::enable_if_t, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction([function = std::forward(function)]( + U arg1, V arg2, W arg3, X arg4, + const Function::InvokeContext& context) -> T { + return function(arg1, arg2, arg3, arg4); + }); + } + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { + return FunctionDescriptor( + name, receiver_style, + {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind()}, + options); + } + + private: + class QuaternaryFunctionImpl : public Function { + public: + explicit QuaternaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + absl::StatusOr Invoke( + absl::Span args, + const Function::InvokeContext& context) const final { + using Arg1Traits = runtime_internal::AdaptedTypeTraits; + using Arg2Traits = runtime_internal::AdaptedTypeTraits; + using Arg3Traits = runtime_internal::AdaptedTypeTraits; + using Arg4Traits = runtime_internal::AdaptedTypeTraits; + if (args.size() != 4) { + return absl::InvalidArgumentError( + "unexpected number of arguments for quaternary function"); + } + typename Arg1Traits::AssignableType arg1; + typename Arg2Traits::AssignableType arg2; + typename Arg3Traits::AssignableType arg3; + typename Arg4Traits::AssignableType arg4; + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[1]}(&arg2)); + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[2]}(&arg3)); + CEL_RETURN_IF_ERROR( + runtime_internal::ValueToAdaptedVisitor{args[3]}(&arg4)); + + if constexpr (std::is_same_v || + std::is_same_v>) { + return fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), + Arg3Traits::ToArg(arg3), Arg4Traits::ToArg(arg4), context); + } else { + T result = + fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), + Arg3Traits::ToArg(arg3), Arg4Traits::ToArg(arg4), context); + + return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); + } + } + + private: + QuaternaryFunctionAdapter::FunctionType fn_; + }; +}; + +// Primary template for n-ary adapter. +template +class NaryFunctionAdapter; + +template +class NaryFunctionAdapter : public NullaryFunctionAdapter {}; + +template +class NaryFunctionAdapter : public UnaryFunctionAdapter {}; + +template +class NaryFunctionAdapter : public BinaryFunctionAdapter {}; + +template +class NaryFunctionAdapter + : public TernaryFunctionAdapter {}; + +template +class NaryFunctionAdapter + : public QuaternaryFunctionAdapter {}; + +// N-ary function adapter. +// +// Prefer using one of the specific count adapters above for readability and +// better error messages. +template +class NaryFunctionAdapter + : public RegisterHelper> { + public: + using FunctionType = + absl::AnyInvocable; + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, receiver_style, + {runtime_internal::AdaptedKind()...}, + options); + } + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + template + static std::enable_if_t< + std::is_invocable_v< + F, Args..., const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull>, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction( + [function = std::forward(function)]( + Args... args, const Function::InvokeContext& context) -> T { + return function(args..., context.descriptor_pool(), + context.message_factory(), context.arena()); + }); + } + + template + static std::enable_if_t, + std::unique_ptr> + WrapFunction(F&& function) { + return WrapFunction( + [function = std::forward(function)]( + Args... args, const Function::InvokeContext& context) -> T { + return function(args...); + }); + } + + private: + class NaryFunctionImpl : public Function { + private: + using ArgBuffer = std::tuple< + typename runtime_internal::AdaptedTypeTraits::AssignableType...>; + + public: + explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + absl::StatusOr Invoke( + absl::Span args, + const Function::InvokeContext& context) const final { + if (args.size() != sizeof...(Args)) { + return absl::InvalidArgumentError( + absl::StrCat("unexpected number of arguments for ", sizeof...(Args), + "-ary function")); + } + ArgBuffer arg_buffer; + CEL_RETURN_IF_ERROR( + runtime_internal::AdaptHelper::Apply(args, arg_buffer)); + if constexpr (std::is_same_v || + std::is_same_v>) { + return runtime_internal::ToArgsHelper::template Apply( + fn_, arg_buffer, context); + } else { + T result = runtime_internal::ToArgsHelper::template Apply( + fn_, arg_buffer, context); + return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); + } + } + + private: + FunctionType fn_; + }; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ diff --git a/runtime/function_adapter_test.cc b/runtime/function_adapter_test.cc new file mode 100644 index 000000000..910020fdf --- /dev/null +++ b/runtime/function_adapter_test.cc @@ -0,0 +1,864 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/function_adapter.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/testing.h" +#include "runtime/function.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::IsEmpty; + +class FunctionAdapterTest : public common_internal::ValueTest<> { + using Base = common_internal::ValueTest<>; + + public: + FunctionAdapterTest() + : Base(), test_context_(descriptor_pool(), message_factory(), arena()) {} + + const Function::InvokeContext& test_invoke_context() const { + return test_context_; + } + + protected: + cel::Function::InvokeContext test_context_; +}; + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionOldOverload) { + using FunctionAdapter = UnaryFunctionAdapter; + + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](const StringValue& x, + const Function::InvokeContext& context) -> StringValue { + std::string buf; + absl::string_view s = x.ToStringView(&buf); + buf = absl::StrCat("pre_", s); + return StringValue::From(std::move(buf), context.arena()); + }); + + std::vector args{StringValue::Wrap(absl::string_view("foo"), arena())}; + ASSERT_OK_AND_ASSIGN( + auto result, + wrapped->Invoke(args, descriptor_pool(), message_factory(), arena())); + + EXPECT_THAT(result, test::StringValueIs("pre_foo")); + ASSERT_OK_AND_ASSIGN(result, wrapped->Invoke(args, test_invoke_context())); + + EXPECT_THAT(result, test::StringValueIs("pre_foo")); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionInt) { + using FunctionAdapter = UnaryFunctionAdapter; + + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](int64_t x) -> int64_t { return x + 2; }); + + std::vector args{IntValue(40)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetInt().NativeValue(), 42); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionDouble) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](double x) -> double { return x * 2; }); + + std::vector args{DoubleValue(40.0)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetDouble().NativeValue(), 80.0); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionUint) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](uint64_t x) -> uint64_t { return x - 2; }); + + std::vector args{UintValue(44)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetUint().NativeValue(), 42); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionBool) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](bool x) -> bool { return !x; }); + + std::vector args{BoolValue(true)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetBool().NativeValue(), false); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionTimestamp) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](absl::Time x) -> absl::Time { return x + absl::Minutes(1); }); + + std::vector args; + args.emplace_back() = TimestampValue(absl::UnixEpoch()); + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetTimestamp().NativeValue(), + absl::UnixEpoch() + absl::Minutes(1)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionDuration) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](absl::Duration x) -> absl::Duration { return x + absl::Seconds(2); }); + + std::vector args; + args.emplace_back() = DurationValue(absl::Seconds(6)); + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetDuration().NativeValue(), absl::Seconds(8)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionString) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](const StringValue& x) -> StringValue { + return StringValue("pre_" + x.ToString()); + }); + + std::vector args; + args.emplace_back() = StringValue("string"); + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetString().ToString(), "pre_string"); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionBytes) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](const BytesValue& x) -> BytesValue { + return BytesValue("pre_" + x.ToString()); + }); + + std::vector args; + args.emplace_back() = BytesValue("bytes"); + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetBytes().ToString(), "pre_bytes"); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionAny) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](const Value& x) -> uint64_t { return x.GetUint().NativeValue() - 2; }); + + std::vector args{UintValue(44)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetUint().NativeValue(), 42); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionReturnError) { + using FunctionAdapter = UnaryFunctionAdapter; + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](uint64_t x) -> Value { + return ErrorValue(absl::InvalidArgumentError("test_error")); + }); + + std::vector args{UintValue(44)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, "test_error")); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionPropagateStatus) { + using FunctionAdapter = + UnaryFunctionAdapter, uint64_t>; + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](uint64_t x) -> absl::StatusOr { + // Returning a status directly stops CEL evaluation and + // immediately returns. + return absl::InternalError("test_error"); + }); + + std::vector args{UintValue(44)}; + EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInternal, "test_error")); +} + +TEST_F(FunctionAdapterTest, + UnaryFunctionAdapterWrapFunctionReturnStatusOrValue) { + using FunctionAdapter = + UnaryFunctionAdapter, uint64_t>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](uint64_t x) -> absl::StatusOr { return x; }); + + std::vector args{UintValue(44)}; + ASSERT_OK_AND_ASSIGN(Value result, + wrapped->Invoke(args, test_invoke_context())); + EXPECT_EQ(result.GetUint().NativeValue(), 44); +} + +TEST_F(FunctionAdapterTest, + UnaryFunctionAdapterWrapFunctionWrongArgCountError) { + using FunctionAdapter = + UnaryFunctionAdapter, uint64_t>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](uint64_t x) -> absl::StatusOr { return 42; }); + + std::vector args{UintValue(44), UintValue(43)}; + EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + "unexpected number of arguments for unary function")); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionWrongArgTypeError) { + using FunctionAdapter = + UnaryFunctionAdapter, uint64_t>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](uint64_t x) -> absl::StatusOr { return 42; }); + + std::vector args{DoubleValue(44)}; + EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected uint value"))); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorInt) { + FunctionDescriptor desc = + UnaryFunctionAdapter, int64_t>::CreateDescriptor( + "Increment", false); + + EXPECT_EQ(desc.name(), "Increment"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kInt64)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDouble) { + FunctionDescriptor desc = + UnaryFunctionAdapter, double>::CreateDescriptor( + "Mult2", true); + + EXPECT_EQ(desc.name(), "Mult2"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_TRUE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kDouble)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorUint) { + FunctionDescriptor desc = + UnaryFunctionAdapter, uint64_t>::CreateDescriptor( + "Increment", false); + + EXPECT_EQ(desc.name(), "Increment"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kUint64)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBool) { + FunctionDescriptor desc = + UnaryFunctionAdapter, bool>::CreateDescriptor( + "Not", false); + + EXPECT_EQ(desc.name(), "Not"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kBool)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorTimestamp) { + FunctionDescriptor desc = + UnaryFunctionAdapter, absl::Time>::CreateDescriptor( + "AddMinute", false); + + EXPECT_EQ(desc.name(), "AddMinute"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kTimestamp)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDuration) { + FunctionDescriptor desc = + UnaryFunctionAdapter, + absl::Duration>::CreateDescriptor("AddFiveSeconds", + false); + + EXPECT_EQ(desc.name(), "AddFiveSeconds"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kDuration)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorString) { + FunctionDescriptor desc = + UnaryFunctionAdapter, + StringValue>::CreateDescriptor("Prepend", false); + + EXPECT_EQ(desc.name(), "Prepend"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kString)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBytes) { + FunctionDescriptor desc = + UnaryFunctionAdapter, BytesValue>::CreateDescriptor( + "Prepend", false); + + EXPECT_EQ(desc.name(), "Prepend"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kBytes)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorAny) { + FunctionDescriptor desc = + UnaryFunctionAdapter, Value>::CreateDescriptor( + "Increment", false); + + EXPECT_EQ(desc.name(), "Increment"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny)); +} + +TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorNonStrict) { + FunctionDescriptor desc = + UnaryFunctionAdapter, Value>::CreateDescriptor( + "Increment", false, + /*is_strict=*/false); + + EXPECT_EQ(desc.name(), "Increment"); + EXPECT_FALSE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionInt) { + using FunctionAdapter = BinaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](int64_t x, int64_t y) -> int64_t { return x + y; }); + + std::vector args{IntValue(21), IntValue(21)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetInt().NativeValue(), 42); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionDouble) { + using FunctionAdapter = BinaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](double x, double y) -> double { return x * y; }); + + std::vector args{DoubleValue(40.0), DoubleValue(2.0)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetDouble().NativeValue(), 80.0); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionUint) { + using FunctionAdapter = BinaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](uint64_t x, uint64_t y) -> uint64_t { return x - y; }); + + std::vector args{UintValue(44), UintValue(2)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetUint().NativeValue(), 42); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionBool) { + using FunctionAdapter = BinaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](bool x, bool y) -> bool { return x != y; }); + + std::vector args{BoolValue(false), BoolValue(true)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetBool().NativeValue(), true); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionTimestamp) { + using FunctionAdapter = + BinaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](absl::Time x, absl::Time y) -> absl::Time { return x > y ? x : y; }); + + std::vector args; + args.emplace_back() = TimestampValue(absl::UnixEpoch() + absl::Seconds(1)); + args.emplace_back() = TimestampValue(absl::UnixEpoch() + absl::Seconds(2)); + + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetTimestamp().NativeValue(), + absl::UnixEpoch() + absl::Seconds(2)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionDuration) { + using FunctionAdapter = + BinaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](absl::Duration x, absl::Duration y) -> absl::Duration { + return x > y ? x : y; + }); + + std::vector args; + args.emplace_back() = DurationValue(absl::Seconds(5)); + args.emplace_back() = DurationValue(absl::Seconds(2)); + + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetDuration().NativeValue(), absl::Seconds(5)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionString) { + using FunctionAdapter = + BinaryFunctionAdapter, const StringValue&, + const StringValue&>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](const StringValue& x, + const StringValue& y) -> absl::StatusOr { + return StringValue(x.ToString() + y.ToString()); + }); + + std::vector args; + args.emplace_back() = StringValue("abc"); + args.emplace_back() = StringValue("def"); + + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetString().ToString(), "abcdef"); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionBytes) { + using FunctionAdapter = + BinaryFunctionAdapter, const BytesValue&, + const BytesValue&>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](const BytesValue& x, + const BytesValue& y) -> absl::StatusOr { + return BytesValue(x.ToString() + y.ToString()); + }); + + std::vector args; + args.emplace_back() = BytesValue("abc"); + args.emplace_back() = BytesValue("def"); + + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetBytes().ToString(), "abcdef"); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionAny) { + using FunctionAdapter = BinaryFunctionAdapter; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](const Value& x, const Value& y) -> uint64_t { + return x.GetUint().NativeValue() - + static_cast(y.GetDouble().NativeValue()); + }); + + std::vector args{UintValue(44), DoubleValue(2)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetUint().NativeValue(), 42); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionReturnError) { + using FunctionAdapter = BinaryFunctionAdapter; + std::unique_ptr wrapped = + FunctionAdapter::WrapFunction([](int64_t x, uint64_t y) -> Value { + return ErrorValue(absl::InvalidArgumentError("test_error")); + }); + + std::vector args{IntValue(44), UintValue(44)}; + ASSERT_OK_AND_ASSIGN(auto result, + wrapped->Invoke(args, test_invoke_context())); + + ASSERT_TRUE(result->Is()); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, "test_error")); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionPropagateStatus) { + using FunctionAdapter = + BinaryFunctionAdapter, int64_t, uint64_t>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](int64_t, uint64_t x) -> absl::StatusOr { + // Returning a status directly stops CEL evaluation and + // immediately returns. + return absl::InternalError("test_error"); + }); + + std::vector args{IntValue(43), UintValue(44)}; + EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInternal, "test_error")); +} + +TEST_F(FunctionAdapterTest, + BinaryFunctionAdapterWrapFunctionWrongArgCountError) { + using FunctionAdapter = + BinaryFunctionAdapter, uint64_t, double>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](uint64_t x, double y) -> absl::StatusOr { return 42; }); + + std::vector args{UintValue(44)}; + EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + "unexpected number of arguments for binary function")); +} + +TEST_F(FunctionAdapterTest, + BinaryFunctionAdapterWrapFunctionWrongArgTypeError) { + using FunctionAdapter = + BinaryFunctionAdapter, uint64_t, uint64_t>; + std::unique_ptr wrapped = FunctionAdapter::WrapFunction( + [](int64_t x, int64_t y) -> absl::StatusOr { return 42; }); + + std::vector args{DoubleValue(44), DoubleValue(44)}; + EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected uint value"))); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorInt) { + FunctionDescriptor desc = + BinaryFunctionAdapter, int64_t, + int64_t>::CreateDescriptor("Add", false); + + EXPECT_EQ(desc.name(), "Add"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kInt64, Kind::kInt64)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorDouble) { + FunctionDescriptor desc = + BinaryFunctionAdapter, double, + double>::CreateDescriptor("Mult", true); + + EXPECT_EQ(desc.name(), "Mult"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_TRUE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kDouble, Kind::kDouble)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorUint) { + FunctionDescriptor desc = + BinaryFunctionAdapter, uint64_t, + uint64_t>::CreateDescriptor("Add", false); + + EXPECT_EQ(desc.name(), "Add"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kUint64, Kind::kUint64)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorBool) { + FunctionDescriptor desc = + BinaryFunctionAdapter, bool, + bool>::CreateDescriptor("Xor", false); + + EXPECT_EQ(desc.name(), "Xor"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kBool, Kind::kBool)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorTimestamp) { + FunctionDescriptor desc = + BinaryFunctionAdapter, absl::Time, + absl::Time>::CreateDescriptor("Max", false); + + EXPECT_EQ(desc.name(), "Max"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kTimestamp, Kind::kTimestamp)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorDuration) { + FunctionDescriptor desc = + BinaryFunctionAdapter, absl::Duration, + absl::Duration>::CreateDescriptor("Max", false); + + EXPECT_EQ(desc.name(), "Max"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kDuration, Kind::kDuration)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorString) { + FunctionDescriptor desc = + BinaryFunctionAdapter, StringValue, + StringValue>::CreateDescriptor("Concat", false); + + EXPECT_EQ(desc.name(), "Concat"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kString, Kind::kString)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorBytes) { + FunctionDescriptor desc = + BinaryFunctionAdapter, BytesValue, + BytesValue>::CreateDescriptor("Concat", false); + + EXPECT_EQ(desc.name(), "Concat"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kBytes, Kind::kBytes)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorAny) { + FunctionDescriptor desc = + BinaryFunctionAdapter, Value, + Value>::CreateDescriptor("Add", false); + EXPECT_EQ(desc.name(), "Add"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny, Kind::kAny)); +} + +TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorNonStrict) { + FunctionDescriptor desc = + BinaryFunctionAdapter, Value, + Value>::CreateDescriptor("Add", false, false); + EXPECT_EQ(desc.name(), "Add"); + EXPECT_FALSE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny, Kind::kAny)); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor0Args) { + FunctionDescriptor desc = + NullaryFunctionAdapter>::CreateDescriptor( + "ZeroArgs", false); + + EXPECT_EQ(desc.name(), "ZeroArgs"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), IsEmpty()); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction0Args) { + std::unique_ptr fn = + NullaryFunctionAdapter>::WrapFunction( + []() { return StringValue("abc"); }); + + ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke({}, descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetString().ToString(), "abc"); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor3Args) { + FunctionDescriptor desc = TernaryFunctionAdapter< + absl::StatusOr, int64_t, bool, + const StringValue&>::CreateDescriptor("MyFormatter", false); + + EXPECT_EQ(desc.name(), "MyFormatter"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), + ElementsAre(Kind::kInt64, Kind::kBool, Kind::kString)); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3Args) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, + const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val) + -> absl::StatusOr { + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString())); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + args.emplace_back() = StringValue("abcd"); + ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke(args, descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetString().ToString(), "42_false_abcd"); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3ArgsBadArgType) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, + const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val) + -> absl::StatusOr { + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString())); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + args.emplace_back() = TimestampValue(absl::UnixEpoch()); + EXPECT_THAT(fn->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected string value"))); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3ArgsBadArgCount) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, + const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val) + -> absl::StatusOr { + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString())); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + EXPECT_THAT(fn->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of arguments"))); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor5Args) { + FunctionDescriptor desc = + NaryFunctionAdapter, int64_t, bool, + const StringValue&, int64_t, + int64_t>::CreateDescriptor("MyFormatter", false); + + EXPECT_EQ(desc.name(), "MyFormatter"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), + ElementsAre(Kind::kInt64, Kind::kBool, Kind::kString, + Kind::kInt64, Kind::kInt64)); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5Args) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, const StringValue&, int64_t, + int64_t>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val, + int64_t extra_arg, + int64_t extra_arg2) -> absl::StatusOr { + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString(), "_", extra_arg, + "_", extra_arg2)); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + args.emplace_back() = StringValue("abcd"); + args.push_back(IntValue(123)); + args.push_back(IntValue(456)); + ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke(args, descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetString().ToString(), "42_false_abcd_123_456"); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5ArgsBadArgType) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, const StringValue&, int64_t, + int64_t>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val, + int64_t extra_arg, + int64_t extra_arg2) -> absl::StatusOr { + static_cast(extra_arg); + static_cast(extra_arg2); + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString())); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + args.emplace_back() = TimestampValue(absl::UnixEpoch()); + args.push_back(IntValue(123)); + args.push_back(IntValue(456)); + EXPECT_THAT(fn->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected string value"))); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5ArgsBadArgCount) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, const StringValue&, int64_t, + int64_t>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val, + int64_t extra_arg, + int64_t extra_arg2) -> absl::StatusOr { + static_cast(extra_arg); + static_cast(extra_arg2); + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString())); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + EXPECT_THAT(fn->Invoke(args, test_invoke_context()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of arguments"))); +} + +} // namespace +} // namespace cel diff --git a/runtime/function_overload_reference.h b/runtime/function_overload_reference.h new file mode 100644 index 000000000..f27e1ff74 --- /dev/null +++ b/runtime/function_overload_reference.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_OVERLOAD_REFERENCE_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_OVERLOAD_REFERENCE_H_ + +#include "common/function_descriptor.h" +#include "runtime/function.h" + +namespace cel { + +// Represents a view to a single overload for a function. +// +// Clients must take care to not persist instances beyond the lifetime of the +// owning object. +struct FunctionOverloadReference { + const FunctionDescriptor& descriptor; + const Function& implementation; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_OVERLOAD_REFERENCE_H_ diff --git a/runtime/function_provider.h b/runtime/function_provider.h new file mode 100644 index 000000000..679d7f159 --- /dev/null +++ b/runtime/function_provider.h @@ -0,0 +1,46 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_PROVIDER_H_ + +#include "absl/status/statusor.h" +#include "common/function_descriptor.h" +#include "runtime/activation_interface.h" +#include "runtime/function_overload_reference.h" + +namespace cel::runtime_internal { + +// Interface for providers of lazily bound functions. +// +// Lazily bound functions may have an implementation that is dependent on the +// evaluation context (as represented by the Activation). +class FunctionProvider { + public: + virtual ~FunctionProvider() = default; + + // Returns a reference to a function implementation based on the provided + // Activation. Given the same activation, this should return the same Function + // instance. The cel::FunctionOverloadReference is assumed to be stable for + // the life of the Activation. + // + // An empty optional result is interpreted as no matching overload. + virtual absl::StatusOr> GetFunction( + const FunctionDescriptor& descriptor, + const ActivationInterface& activation) const = 0; +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_PROVIDER_H_ diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc new file mode 100644 index 000000000..59f267255 --- /dev/null +++ b/runtime/function_registry.cc @@ -0,0 +1,263 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/function_registry.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "runtime/activation_interface.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_provider.h" + +namespace cel { +namespace { + +// Impl for simple provider that looks up functions in an activation function +// registry. +class ActivationFunctionProviderImpl + : public cel::runtime_internal::FunctionProvider { + public: + ActivationFunctionProviderImpl() = default; + + absl::StatusOr> GetFunction( + const cel::FunctionDescriptor& descriptor, + const cel::ActivationInterface& activation) const override { + std::vector overloads = + activation.FindFunctionOverloads(descriptor.name()); + + std::optional matching_overload = absl::nullopt; + + for (const auto& overload : overloads) { + if (overload.descriptor.ShapeMatches(descriptor)) { + if (matching_overload.has_value()) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Couldn't resolve function."); + } + matching_overload.emplace(overload); + } + } + + return matching_overload; + } +}; + +// Create a CelFunctionProvider that just looks up the functions inserted in the +// Activation. This is a convenience implementation for a simple, common +// use-case. +std::unique_ptr +CreateActivationFunctionProvider() { + return std::make_unique(); +} + +} // namespace + +absl::Status FunctionRegistry::Register( + const cel::FunctionDescriptor& descriptor, + std::unique_ptr implementation) { + if (DescriptorRegistered(descriptor)) { + return absl::Status( + absl::StatusCode::kAlreadyExists, + "CelFunction with specified parameters already registered"); + } + if (!ValidateNonStrictOverload(descriptor)) { + return absl::Status(absl::StatusCode::kAlreadyExists, + "Only one overload is allowed for non-strict function"); + } + + auto& overloads = functions_[descriptor.name()]; + overloads.static_overloads.push_back( + StaticFunctionEntry(descriptor, std::move(implementation))); + return absl::OkStatus(); +} + +absl::Status FunctionRegistry::RegisterLazyFunction( + const cel::FunctionDescriptor& descriptor) { + if (DescriptorRegistered(descriptor)) { + return absl::Status( + absl::StatusCode::kAlreadyExists, + "CelFunction with specified parameters already registered"); + } + if (!ValidateNonStrictOverload(descriptor)) { + return absl::Status(absl::StatusCode::kAlreadyExists, + "Only one overload is allowed for non-strict function"); + } + auto& overloads = functions_[descriptor.name()]; + + overloads.lazy_overloads.push_back( + LazyFunctionEntry(descriptor, CreateActivationFunctionProvider())); + + return absl::OkStatus(); +} + +std::vector +FunctionRegistry::FindStaticOverloads(absl::string_view name, + bool receiver_style, + absl::Span types) const { + std::vector matched_funcs; + + auto overloads = functions_.find(name); + if (overloads == functions_.end()) { + return matched_funcs; + } + + for (const auto& overload : overloads->second.static_overloads) { + if (overload.descriptor->ShapeMatches(receiver_style, types)) { + matched_funcs.push_back({*overload.descriptor, *overload.implementation}); + } + } + + return matched_funcs; +} + +std::vector +FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name, + bool receiver_style, + size_t arity) const { + std::vector matched_funcs; + + auto overloads = functions_.find(name); + if (overloads == functions_.end()) { + return matched_funcs; + } + + for (const auto& overload : overloads->second.static_overloads) { + if (overload.descriptor->receiver_style() == receiver_style && + overload.descriptor->types().size() == arity) { + matched_funcs.push_back({*overload.descriptor, *overload.implementation}); + } + } + + return matched_funcs; +} + +std::vector FunctionRegistry::FindLazyOverloads( + absl::string_view name, bool receiver_style, + absl::Span types) const { + std::vector matched_funcs; + + auto overloads = functions_.find(name); + if (overloads == functions_.end()) { + return matched_funcs; + } + + for (const auto& entry : overloads->second.lazy_overloads) { + if (entry.descriptor->ShapeMatches(receiver_style, types)) { + matched_funcs.push_back({*entry.descriptor, *entry.function_provider}); + } + } + + return matched_funcs; +} + +std::vector +FunctionRegistry::FindLazyOverloadsByArity(absl::string_view name, + bool receiver_style, + size_t arity) const { + std::vector matched_funcs; + + auto overloads = functions_.find(name); + if (overloads == functions_.end()) { + return matched_funcs; + } + + for (const auto& entry : overloads->second.lazy_overloads) { + if (entry.descriptor->receiver_style() == receiver_style && + entry.descriptor->types().size() == arity) { + matched_funcs.push_back({*entry.descriptor, *entry.function_provider}); + } + } + + return matched_funcs; +} + +absl::node_hash_map> +FunctionRegistry::ListFunctions() const { + absl::node_hash_map> + descriptor_map; + + for (const auto& entry : functions_) { + std::vector descriptors; + const RegistryEntry& function_entry = entry.second; + descriptors.reserve(function_entry.static_overloads.size() + + function_entry.lazy_overloads.size()); + for (const auto& entry : function_entry.static_overloads) { + descriptors.push_back(entry.descriptor.get()); + } + for (const auto& entry : function_entry.lazy_overloads) { + descriptors.push_back(entry.descriptor.get()); + } + descriptor_map[entry.first] = std::move(descriptors); + } + + return descriptor_map; +} + +bool FunctionRegistry::DescriptorRegistered( + const cel::FunctionDescriptor& descriptor) const { + auto overloads = functions_.find(descriptor.name()); + if (overloads == functions_.end()) { + return false; + } + const RegistryEntry& entry = overloads->second; + for (const auto& static_ovl : entry.static_overloads) { + if (static_ovl.descriptor->ShapeMatches(descriptor)) { + return true; + } + } + for (const auto& lazy_ovl : entry.lazy_overloads) { + if (lazy_ovl.descriptor->ShapeMatches(descriptor)) { + return true; + } + } + return false; +} + +bool FunctionRegistry::ValidateNonStrictOverload( + const cel::FunctionDescriptor& descriptor) const { + auto overloads = functions_.find(descriptor.name()); + if (overloads == functions_.end()) { + return true; + } + const RegistryEntry& entry = overloads->second; + if (!descriptor.is_strict()) { + // If the newly added overload is a non-strict function, we require that + // there are no other overloads, which is not possible here. + return false; + } + // If the newly added overload is a strict function, we need to make sure + // that no previous overloads are registered non-strict. If the list of + // overload is not empty, we only need to check the first overload. This is + // because if the first overload is strict, other overloads must also be + // strict by the rule. + return (entry.static_overloads.empty() || + entry.static_overloads[0].descriptor->is_strict()) && + (entry.lazy_overloads.empty() || + entry.lazy_overloads[0].descriptor->is_strict()); +} + +} // namespace cel diff --git a/runtime/function_registry.h b/runtime/function_registry.h new file mode 100644 index 000000000..6a227978d --- /dev/null +++ b/runtime/function_registry.h @@ -0,0 +1,160 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_provider.h" + +namespace cel { + +// FunctionRegistry manages binding builtin or custom CEL functions to +// implementations. +// +// The registry is consulted during program planning to tie overload candidates +// to the CEL function in the AST getting planned. +// +// The registry takes ownership of the cel::Function objects -- the registry +// must outlive any program planned using it. +// +// This class is move-only. +class FunctionRegistry { + public: + // Represents a single overload for a lazily provided function. + struct LazyOverload { + const cel::FunctionDescriptor& descriptor; + const cel::runtime_internal::FunctionProvider& provider; + }; + + FunctionRegistry() = default; + + // Move-only + FunctionRegistry(FunctionRegistry&&) = default; + FunctionRegistry& operator=(FunctionRegistry&&) = default; + + // Register a function implementation for the given descriptor. + // Function registration should be performed prior to CelExpression creation. + absl::Status Register(const cel::FunctionDescriptor& descriptor, + std::unique_ptr implementation); + + // Register a lazily provided function. + // Internally, the registry binds a FunctionProvider that provides an overload + // at evaluation time by resolving against the overloads provided by an + // implementation of cel::ActivationInterface. + absl::Status RegisterLazyFunction(const cel::FunctionDescriptor& descriptor); + + // Find subset of cel::Function implementations that match overload conditions + // As types may not be available during expression compilation, + // further narrowing of this subset will happen at evaluation stage. + // + // name - the name of CEL function (as distinct from overload ID); + // receiver_style - indicates whether function has receiver style; + // types - argument types. If type is not known during compilation, + // cel::Kind::kAny should be passed. + // + // Results refer to underlying registry entries by reference. Results are + // invalid after the registry is deleted. + std::vector FindStaticOverloads( + absl::string_view name, bool receiver_style, + absl::Span types) const; + + std::vector FindStaticOverloadsByArity( + absl::string_view name, bool receiver_style, size_t arity) const; + + // Find subset of cel::Function providers that match overload conditions. + // As types may not be available during expression compilation, + // further narrowing of this subset will happen at evaluation stage. + // + // name - the name of CEL function (as distinct from overload ID); + // receiver_style - indicates whether function has receiver style; + // types - argument types. If type is not known during compilation, + // cel::Kind::kAny should be passed. + // + // Results refer to underlying registry entries by reference. Results are + // invalid after the registry is deleted. + std::vector FindLazyOverloads( + absl::string_view name, bool receiver_style, + absl::Span types) const; + + std::vector FindLazyOverloadsByArity(absl::string_view name, + bool receiver_style, + size_t arity) const; + + // Retrieve list of registered function descriptors. This includes both + // static and lazy functions. + absl::node_hash_map> + ListFunctions() const; + + private: + struct StaticFunctionEntry { + StaticFunctionEntry(const cel::FunctionDescriptor& descriptor, + std::unique_ptr impl) + : descriptor(std::make_unique(descriptor)), + implementation(std::move(impl)) {} + + // Extra indirection needed to preserve pointer stability for the + // descriptors. + std::unique_ptr descriptor; + std::unique_ptr implementation; + }; + + struct LazyFunctionEntry { + LazyFunctionEntry( + const cel::FunctionDescriptor& descriptor, + std::unique_ptr provider) + : descriptor(std::make_unique(descriptor)), + function_provider(std::move(provider)) {} + + // Extra indirection needed to preserve pointer stability for the + // descriptors. + std::unique_ptr descriptor; + std::unique_ptr function_provider; + }; + + struct RegistryEntry { + std::vector static_overloads; + std::vector lazy_overloads; + }; + + // Returns whether the descriptor is registered either as a lazy function or + // as a static function. + bool DescriptorRegistered(const cel::FunctionDescriptor& descriptor) const; + + // Returns true if after adding this function, the rule "a non-strict + // function should have only a single overload" will be preserved. + bool ValidateNonStrictOverload( + const cel::FunctionDescriptor& descriptor) const; + + // indexed by function name (not type checker overload id). + absl::flat_hash_map functions_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc new file mode 100644 index 000000000..53916777a --- /dev/null +++ b/runtime/function_registry_test.cc @@ -0,0 +1,302 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/function_registry.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "common/value.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_provider.h" + +namespace cel { + +namespace { + +using ::absl_testing::StatusIs; +using ::cel::runtime_internal::FunctionProvider; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::SizeIs; +using ::testing::Truly; + +class ConstIntFunction : public cel::Function { + public: + static cel::FunctionDescriptor MakeDescriptor() { + return {"ConstFunction", false, {}}; + } + + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(42); + } +}; + +TEST(FunctionRegistryTest, InsertAndRetrieveLazyFunction) { + cel::FunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; + FunctionRegistry registry; + Activation activation; + ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); + + const auto descriptors = + registry.FindLazyOverloads("LazyFunction", false, {}); + EXPECT_THAT(descriptors, SizeIs(1)); +} + +// Confirm that lazy and static functions share the same descriptor space: +// i.e. you can't insert both a lazy function and a static function for the same +// descriptors. +TEST(FunctionRegistryTest, LazyAndStaticFunctionShareDescriptorSpace) { + FunctionRegistry registry; + cel::FunctionDescriptor desc = ConstIntFunction::MakeDescriptor(); + ASSERT_OK(registry.RegisterLazyFunction(desc)); + + absl::Status status = registry.Register(ConstIntFunction::MakeDescriptor(), + std::make_unique()); + EXPECT_FALSE(status.ok()); +} + +TEST(FunctionRegistryTest, FindStaticOverloadsReturns) { + FunctionRegistry registry; + cel::FunctionDescriptor desc = ConstIntFunction::MakeDescriptor(); + ASSERT_OK(registry.Register(desc, std::make_unique())); + + std::vector overloads = + registry.FindStaticOverloads(desc.name(), false, {}); + + EXPECT_THAT(overloads, + ElementsAre(Truly( + [](const cel::FunctionOverloadReference& overload) -> bool { + return overload.descriptor.name() == "ConstFunction"; + }))) + << "Expected single ConstFunction()"; +} + +TEST(FunctionRegistryTest, ListFunctions) { + cel::FunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; + FunctionRegistry registry; + + ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); + EXPECT_OK(registry.Register(ConstIntFunction::MakeDescriptor(), + std::make_unique())); + + auto registered_functions = registry.ListFunctions(); + + EXPECT_THAT(registered_functions, SizeIs(2)); + EXPECT_THAT(registered_functions["LazyFunction"], SizeIs(1)); + EXPECT_THAT(registered_functions["ConstFunction"], SizeIs(1)); +} + +TEST(FunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { + FunctionRegistry registry; + Activation activation; + cel::FunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; + EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); + + auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); + ASSERT_THAT(providers, SizeIs(1)); + const FunctionProvider& provider = providers[0].provider; + ASSERT_OK_AND_ASSIGN( + std::optional func, + provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, + activation)); + + EXPECT_EQ(func, absl::nullopt); +} + +TEST(FunctionRegistryTest, DefaultLazyProviderReturnsImpl) { + FunctionRegistry registry; + Activation activation; + EXPECT_OK(registry.RegisterLazyFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kAny}))); + EXPECT_TRUE(activation.InsertFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kInt}), + UnaryFunctionAdapter::WrapFunction( + [](int64_t x) { return 2 * x; }))); + EXPECT_TRUE(activation.InsertFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kDouble}), + UnaryFunctionAdapter::WrapFunction( + [](double x) { return 2 * x; }))); + + auto providers = + registry.FindLazyOverloads("LazyFunction", false, {Kind::kInt}); + ASSERT_THAT(providers, SizeIs(1)); + const FunctionProvider& provider = providers[0].provider; + ASSERT_OK_AND_ASSIGN( + std::optional func, + provider.GetFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kInt}), activation)); + + ASSERT_TRUE(func.has_value()); + EXPECT_EQ(func->descriptor.name(), "LazyFunction"); + EXPECT_EQ(func->descriptor.types(), std::vector{cel::Kind::kInt64}); +} + +TEST(FunctionRegistryTest, DefaultLazyProviderAmbiguousOverload) { + FunctionRegistry registry; + Activation activation; + EXPECT_OK(registry.RegisterLazyFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kAny}))); + EXPECT_TRUE(activation.InsertFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kInt}), + UnaryFunctionAdapter::WrapFunction( + [](int64_t x) { return 2 * x; }))); + EXPECT_TRUE(activation.InsertFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kDouble}), + UnaryFunctionAdapter::WrapFunction( + [](double x) { return 2 * x; }))); + + auto providers = + registry.FindLazyOverloads("LazyFunction", false, {Kind::kInt}); + ASSERT_THAT(providers, SizeIs(1)); + const FunctionProvider& provider = providers[0].provider; + + EXPECT_THAT( + provider.GetFunction( + FunctionDescriptor("LazyFunction", false, {Kind::kAny}), activation), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Couldn't resolve function"))); +} + +TEST(FunctionRegistryTest, CanRegisterNonStrictFunction) { + { + FunctionRegistry registry; + cel::FunctionDescriptor descriptor("NonStrictFunction", + /*receiver_style=*/false, {Kind::kAny}, + /*is_strict=*/false); + ASSERT_OK( + registry.Register(descriptor, std::make_unique())); + EXPECT_THAT( + registry.FindStaticOverloads("NonStrictFunction", false, {Kind::kAny}), + SizeIs(1)); + } + { + FunctionRegistry registry; + cel::FunctionDescriptor descriptor("NonStrictLazyFunction", + /*receiver_style=*/false, {Kind::kAny}, + /*is_strict=*/false); + EXPECT_OK(registry.RegisterLazyFunction(descriptor)); + EXPECT_THAT(registry.FindLazyOverloads("NonStrictLazyFunction", false, + {Kind::kAny}), + SizeIs(1)); + } +} + +using NonStrictTestCase = std::tuple; +using NonStrictRegistrationFailTest = testing::TestWithParam; + +TEST_P(NonStrictRegistrationFailTest, + IfOtherOverloadExistsRegisteringNonStrictFails) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + FunctionRegistry registry; + cel::FunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, {Kind::kAny}, + /*is_strict=*/true); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK( + registry.Register(descriptor, std::make_unique())); + } + cel::FunctionDescriptor new_descriptor("OverloadedFunction", + /*receiver_style=*/false, + {Kind::kAny, Kind::kAny}, + /*is_strict=*/false); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = + registry.Register(new_descriptor, std::make_unique()); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("Only one overload"))); +} + +TEST_P(NonStrictRegistrationFailTest, + IfOtherNonStrictExistsRegisteringStrictFails) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + FunctionRegistry registry; + cel::FunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, {Kind::kAny}, + /*is_strict=*/false); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK( + registry.Register(descriptor, std::make_unique())); + } + cel::FunctionDescriptor new_descriptor("OverloadedFunction", + /*receiver_style=*/false, + {Kind::kAny, Kind::kAny}, + /*is_strict=*/true); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = + registry.Register(new_descriptor, std::make_unique()); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("Only one overload"))); +} + +TEST_P(NonStrictRegistrationFailTest, CanRegisterStrictFunctionsWithoutLimit) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + FunctionRegistry registry; + cel::FunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, {Kind::kAny}, + /*is_strict=*/true); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK( + registry.Register(descriptor, std::make_unique())); + } + cel::FunctionDescriptor new_descriptor("OverloadedFunction", + /*receiver_style=*/false, + {Kind::kAny, Kind::kAny}, + /*is_strict=*/true); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = + registry.Register(new_descriptor, std::make_unique()); + } + EXPECT_OK(status); +} + +INSTANTIATE_TEST_SUITE_P(NonStrictRegistrationFailTest, + NonStrictRegistrationFailTest, + testing::Combine(testing::Bool(), testing::Bool())); + +} // namespace + +} // namespace cel diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD new file mode 100644 index 000000000..1223ff6d1 --- /dev/null +++ b/runtime/internal/BUILD @@ -0,0 +1,226 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package( + # Internals for cel/runtime. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "runtime_friend_access", + hdrs = ["runtime_friend_access.h"], + deps = [ + "//common:native_type", + "//runtime", + "//runtime:runtime_builder", + ], +) + +cc_library( + name = "runtime_env", + srcs = ["runtime_env.cc"], + hdrs = ["runtime_env.h"], + deps = [ + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//internal:noop_delete", + "//internal:well_known_types", + "//runtime:function_registry", + "//runtime:type_registry", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime_impl", + srcs = ["runtime_impl.cc"], + hdrs = ["runtime_impl.h"], + deps = [ + ":runtime_env", + "//base:ast", + "//base:data", + "//common:native_type", + "//common:value", + "//eval/compiler:flat_expr_builder", + "//eval/eval:attribute_trail", + "//eval/eval:comprehension_slots", + "//eval/eval:direct_expression_step", + "//eval/eval:evaluator_core", + "//internal:casts", + "//internal:status_macros", + "//internal:well_known_types", + "//runtime", + "//runtime:activation_interface", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime:type_registry", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "convert_constant", + srcs = ["convert_constant.cc"], + hdrs = ["convert_constant.h"], + deps = [ + "//common:allocator", + "//common:ast", + "//common:constant", + "//common:value", + "//eval/internal:errors", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "errors", + srcs = ["errors.cc"], + hdrs = ["errors.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "issue_collector", + hdrs = ["issue_collector.h"], + deps = [ + "//runtime:runtime_issue", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "issue_collector_test", + srcs = ["issue_collector_test.cc"], + deps = [ + ":issue_collector", + "//internal:testing", + "//runtime:runtime_issue", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "function_adapter", + hdrs = [ + "function_adapter.h", + ], + deps = [ + "//common:casting", + "//common:kind", + "//common:value", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "function_adapter_test", + srcs = ["function_adapter_test.cc"], + deps = [ + ":function_adapter", + "//common:kind", + "//common:value", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "runtime_env_testing", + testonly = True, + srcs = ["runtime_env_testing.cc"], + hdrs = ["runtime_env_testing.h"], + deps = [ + ":runtime_env", + "//internal:noop_delete", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "legacy_runtime_type_provider", + hdrs = ["legacy_runtime_type_provider.h"], + deps = [ + "//eval/public/structs:protobuf_descriptor_type_provider", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime_type_provider", + srcs = ["runtime_type_provider.cc"], + hdrs = ["runtime_type_provider.h"], + deps = [ + "//common:type", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "attribute_matcher", + hdrs = ["attribute_matcher.h"], + deps = ["//base:attributes"], +) + +cc_library( + name = "activation_attribute_matcher_access", + srcs = ["activation_attribute_matcher_access.cc"], + hdrs = ["activation_attribute_matcher_access.h"], + deps = [ + ":attribute_matcher", + "//eval/public:activation", + "//runtime:activation", + "@com_google_absl//absl/base:nullability", + ], +) diff --git a/runtime/internal/activation_attribute_matcher_access.cc b/runtime/internal/activation_attribute_matcher_access.cc new file mode 100644 index 000000000..7d358ba23 --- /dev/null +++ b/runtime/internal/activation_attribute_matcher_access.cc @@ -0,0 +1,61 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/activation_attribute_matcher_access.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "eval/public/activation.h" +#include "runtime/activation.h" +#include "runtime/internal/attribute_matcher.h" + +namespace cel::runtime_internal { + +void ActivationAttributeMatcherAccess::SetAttributeMatcher( + google::api::expr::runtime::Activation& activation, + const AttributeMatcher* matcher) { + activation.SetAttributeMatcher(matcher); +} + +void ActivationAttributeMatcherAccess::SetAttributeMatcher( + google::api::expr::runtime::Activation& activation, + std::unique_ptr matcher) { + activation.SetAttributeMatcher(std::move(matcher)); +} + +const AttributeMatcher* absl_nullable +ActivationAttributeMatcherAccess::GetAttributeMatcher( + const google::api::expr::runtime::BaseActivation& activation) { + return activation.GetAttributeMatcher(); +} + +void ActivationAttributeMatcherAccess::SetAttributeMatcher( + Activation& activation, const AttributeMatcher* matcher) { + activation.SetAttributeMatcher(matcher); +} + +void ActivationAttributeMatcherAccess::SetAttributeMatcher( + Activation& activation, std::unique_ptr matcher) { + activation.SetAttributeMatcher(std::move(matcher)); +} + +const AttributeMatcher* absl_nullable +ActivationAttributeMatcherAccess::GetAttributeMatcher( + const ActivationInterface& activation) { + return activation.GetAttributeMatcher(); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/activation_attribute_matcher_access.h b/runtime/internal/activation_attribute_matcher_access.h new file mode 100644 index 000000000..2741be692 --- /dev/null +++ b/runtime/internal/activation_attribute_matcher_access.h @@ -0,0 +1,60 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ACTIVATION_MATCHER_ACCESS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ACTIVATION_MATCHER_ACCESS_H_ + +#include + +#include "absl/base/nullability.h" +#include "runtime/internal/attribute_matcher.h" + +namespace google::api::expr::runtime { +class Activation; +class BaseActivation; +} // namespace google::api::expr::runtime + +namespace cel { +class Activation; +class ActivationInterface; +} // namespace cel + +namespace cel::runtime_internal { + +class ActivationAttributeMatcherAccess { + public: + static void SetAttributeMatcher( + google::api::expr::runtime::Activation& activation, + const AttributeMatcher* matcher); + + static void SetAttributeMatcher( + google::api::expr::runtime::Activation& activation, + std::unique_ptr matcher); + + static const AttributeMatcher* absl_nullable GetAttributeMatcher( + const google::api::expr::runtime::BaseActivation& activation); + + static void SetAttributeMatcher(Activation& activation, + const AttributeMatcher* matcher); + + static void SetAttributeMatcher( + Activation& activation, std::unique_ptr matcher); + + static const AttributeMatcher* absl_nullable GetAttributeMatcher( + const ActivationInterface& activation); +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ACTIVATION_MATCHER_ACCESS_H_ diff --git a/runtime/internal/attribute_matcher.h b/runtime/internal/attribute_matcher.h new file mode 100644 index 000000000..a168b714c --- /dev/null +++ b/runtime/internal/attribute_matcher.h @@ -0,0 +1,48 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ATTRIBUTE_MATCHER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ATTRIBUTE_MATCHER_H_ + +#include "base/attribute.h" + +namespace cel::runtime_internal { + +// Interface for matching unknown and missing attributes against the +// observed attribute trail at runtime. +class AttributeMatcher { + public: + using MatchResult = cel::AttributePattern::MatchType; + + virtual ~AttributeMatcher() = default; + + // Checks whether the attribute trail matches any unknown patterns. + // Used to identify and collect referenced unknowns in an UnknownValue. + virtual MatchResult CheckForUnknown(const Attribute& attr + [[maybe_unused]]) const { + return MatchResult::NONE; + }; + + // Checks whether the attribute trail matches any missing patterns. + // Used to identify missing attributes, and report an error if referenced + // directly. + virtual MatchResult CheckForMissing(const Attribute& attr + [[maybe_unused]]) const { + return MatchResult::NONE; + }; +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ATTRIBUTE_MATCHER_H_ diff --git a/runtime/internal/convert_constant.cc b/runtime/internal/convert_constant.cc new file mode 100644 index 000000000..33f382858 --- /dev/null +++ b/runtime/internal/convert_constant.cc @@ -0,0 +1,78 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/convert_constant.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "common/allocator.h" +#include "common/constant.h" +#include "common/value.h" +#include "eval/internal/errors.h" + +namespace cel::runtime_internal { +namespace { +using ::cel::Constant; + +struct ConvertVisitor { + Allocator<> allocator; + + absl::StatusOr operator()(std::monostate) { + return absl::InvalidArgumentError("unspecified constant"); + } + absl::StatusOr operator()(std::nullptr_t) { return NullValue(); } + absl::StatusOr operator()(bool value) { return BoolValue(value); } + absl::StatusOr operator()(int64_t value) { + return IntValue(value); + } + absl::StatusOr operator()(uint64_t value) { + return UintValue(value); + } + absl::StatusOr operator()(double value) { + return DoubleValue(value); + } + absl::StatusOr operator()(const cel::StringConstant& value) { + return StringValue(allocator, value); + } + absl::StatusOr operator()(const cel::BytesConstant& value) { + return BytesValue(allocator, value); + } + absl::StatusOr operator()(const absl::Duration duration) { + if (duration >= kDurationHigh || duration <= kDurationLow) { + return ErrorValue(*DurationOverflowError()); + } + return UnsafeDurationValue(duration); + } + absl::StatusOr operator()(const absl::Time timestamp) { + return UnsafeTimestampValue(timestamp); + } +}; + +} // namespace + +// Converts an Ast constant into a runtime value, managed according to the +// given value factory. +// +// A status maybe returned if value creation fails. +absl::StatusOr ConvertConstant(const Constant& constant, + Allocator<> allocator) { + return absl::visit(ConvertVisitor{allocator}, constant.constant_kind()); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/convert_constant.h b/runtime/internal/convert_constant.h new file mode 100644 index 000000000..f1ac0c850 --- /dev/null +++ b/runtime/internal/convert_constant.h @@ -0,0 +1,39 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_CONVERT_CONSTANT_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_CONVERT_CONSTANT_H_ + +#include "absl/status/statusor.h" +#include "common/allocator.h" +#include "common/ast.h" +#include "common/value.h" + +namespace cel::runtime_internal { + +// Adapt AST constant to a Value. +// +// Underlying data is copied for string types to keep the program independent +// from the input AST. +// +// The evaluator assumes most ast constants are valid so unchecked ValueManager +// methods are used. +// +// A status may still be returned if value creation fails according to +// value_factory's policy. +absl::StatusOr ConvertConstant(const Constant& constant, + Allocator<> allocator); + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_CONVERT_CONSTANT_H_ diff --git a/runtime/internal/errors.cc b/runtime/internal/errors.cc new file mode 100644 index 000000000..5d86fd5d7 --- /dev/null +++ b/runtime/internal/errors.cc @@ -0,0 +1,69 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "runtime/internal/errors.h" + +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace cel::runtime_internal { + +const absl::Status* DurationOverflowError() { + static const auto* const kDurationOverflow = new absl::Status( + absl::StatusCode::kInvalidArgument, "Duration is out of range"); + return kDurationOverflow; +} + +absl::Status CreateNoSuchKeyError(absl::string_view key) { + return absl::NotFoundError(absl::StrCat(kErrNoSuchKey, " : ", key)); +} + +absl::Status CreateNoMatchingOverloadError(absl::string_view fn) { + return absl::UnknownError( + absl::StrCat(kErrNoMatchingOverload, fn.empty() ? "" : " : ", fn)); +} + +absl::Status CreateNoSuchFieldError(absl::string_view field) { + return absl::Status( + absl::StatusCode::kNotFound, + absl::StrCat(kErrNoSuchField, field.empty() ? "" : " : ", field)); +} + +absl::Status CreateMissingAttributeError( + absl::string_view missing_attribute_path) { + absl::Status result = absl::InvalidArgumentError( + absl::StrCat(kErrMissingAttribute, missing_attribute_path)); + result.SetPayload(kPayloadUrlMissingAttributePath, + absl::Cord(missing_attribute_path)); + return result; +} + +absl::Status CreateInvalidMapKeyTypeError(absl::string_view key_type) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid map key type: '", key_type, "'")); +} + +absl::Status CreateUnknownFunctionResultError(absl::string_view help_message) { + absl::Status result = absl::UnavailableError( + absl::StrCat("Unknown function result: ", help_message)); + result.SetPayload(kPayloadUrlUnknownFunctionResult, absl::Cord("true")); + return result; +} + +absl::Status CreateError(absl::string_view message, absl::StatusCode code) { + return absl::Status(code, message); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/errors.h b/runtime/internal/errors.h new file mode 100644 index 000000000..b5d6ad745 --- /dev/null +++ b/runtime/internal/errors.h @@ -0,0 +1,71 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Factories and constants for well-known CEL errors. +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ERRORS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ERRORS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" + +namespace cel::runtime_internal { + +constexpr absl::string_view kErrNoMatchingOverload = + "No matching overloads found"; +constexpr absl::string_view kErrNoSuchField = "no_such_field"; +constexpr absl::string_view kErrNoSuchKey = "Key not found in map"; +// Error name for MissingAttributeError indicating that evaluation has +// accessed an attribute whose value is undefined. go/terminal-unknown +constexpr absl::string_view kErrMissingAttribute = "MissingAttributeError: "; +constexpr absl::string_view kPayloadUrlMissingAttributePath = + "missing_attribute_path"; +constexpr absl::string_view kPayloadUrlUnknownFunctionResult = + "cel_is_unknown_function_result"; + +// Exclusive bounds for valid duration values. +constexpr absl::Duration kDurationHigh = absl::Seconds(315576000001); +constexpr absl::Duration kDurationLow = absl::Seconds(-315576000001); + +const absl::Status* DurationOverflowError(); + +// At runtime, no matching overload could be found for a function invocation. +absl::Status CreateNoMatchingOverloadError(absl::string_view fn); + +// No such field for struct access. +absl::Status CreateNoSuchFieldError(absl::string_view field); + +// No such key for map access. +absl::Status CreateNoSuchKeyError(absl::string_view key); + +// Invalid key type used for map index. +absl::Status CreateInvalidMapKeyTypeError(absl::string_view key_type); + +// A missing attribute was accessed. Attributes may be declared as missing to +// they are not well defined at evaluation time. +absl::Status CreateMissingAttributeError( + absl::string_view missing_attribute_path); + +// Function result is unknown. The evaluator may convert this to an +// UnknownValue if enabled. +absl::Status CreateUnknownFunctionResultError(absl::string_view help_message); + +// The default error type uses absl::StatusCode::kUnknown. In general, a more +// specific error should be used. +absl::Status CreateError(absl::string_view message, + absl::StatusCode code = absl::StatusCode::kUnknown); + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ERRORS_H_ diff --git a/runtime/internal/function_adapter.h b/runtime/internal/function_adapter.h new file mode 100644 index 000000000..9b191e577 --- /dev/null +++ b/runtime/internal/function_adapter.h @@ -0,0 +1,232 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Definitions for implementation details of the function adapter utility. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_FUNCTION_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_FUNCTION_ADAPTER_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "common/casting.h" +#include "common/kind.h" +#include "common/value.h" + +namespace cel::runtime_internal { + +// Helper for triggering static asserts in an unspecialized template overload. +template +struct UnhandledType : std::false_type {}; + +// Adapts the type param Type to the appropriate Kind. +// A static assertion fails if the provided type does not map to a cel::Value +// kind. +template +constexpr Kind AdaptedKind() { + static_assert(UnhandledType::value, + "Unsupported primitive type to cel::Kind conversion"); + return Kind::kNotForUseWithExhaustiveSwitchStatements; +} + +template <> +constexpr Kind AdaptedKind() { + return Kind::kInt64; +} + +template <> +constexpr Kind AdaptedKind() { + return Kind::kUint64; +} + +template <> +constexpr Kind AdaptedKind() { + return Kind::kDouble; +} + +template <> +constexpr Kind AdaptedKind() { + return Kind::kBool; +} + +template <> +constexpr Kind AdaptedKind() { + return Kind::kTimestamp; +} + +template <> +constexpr Kind AdaptedKind() { + return Kind::kDuration; +} + +// Value types without a generic C++ type representation can be referenced by +// cref or value of the cel::*Value type. +#define VALUE_ADAPTED_KIND_OVL(value_type, kind) \ + template <> \ + constexpr Kind AdaptedKind() { \ + return kind; \ + } \ + \ + template <> \ + constexpr Kind AdaptedKind() { \ + return kind; \ + } + +VALUE_ADAPTED_KIND_OVL(Value, Kind::kAny); +VALUE_ADAPTED_KIND_OVL(StringValue, Kind::kString); +VALUE_ADAPTED_KIND_OVL(BytesValue, Kind::kBytes); +VALUE_ADAPTED_KIND_OVL(StructValue, Kind::kStruct); +VALUE_ADAPTED_KIND_OVL(MapValue, Kind::kMap); +VALUE_ADAPTED_KIND_OVL(ListValue, Kind::kList); +VALUE_ADAPTED_KIND_OVL(NullValue, Kind::kNullType); +VALUE_ADAPTED_KIND_OVL(OpaqueValue, Kind::kOpaque); +VALUE_ADAPTED_KIND_OVL(TypeValue, Kind::kType); + +#undef VALUE_ADAPTED_KIND_OVL + +// Adapt a Value to its corresponding argument type in a wrapped c++ +// function. +struct ValueToAdaptedVisitor { + absl::Status operator()(int64_t* out) const { + if (!input.IsInt()) { + return absl::InvalidArgumentError("expected int value"); + } + *out = input.GetInt().NativeValue(); + return absl::OkStatus(); + } + + absl::Status operator()(uint64_t* out) const { + if (!input.IsUint()) { + return absl::InvalidArgumentError("expected uint value"); + } + *out = input.GetUint().NativeValue(); + return absl::OkStatus(); + } + + absl::Status operator()(double* out) const { + if (!input.IsDouble()) { + return absl::InvalidArgumentError("expected double value"); + } + *out = input.GetDouble().NativeValue(); + return absl::OkStatus(); + } + + absl::Status operator()(bool* out) const { + if (!input.IsBool()) { + return absl::InvalidArgumentError("expected bool value"); + } + *out = input.GetBool().NativeValue(); + return absl::OkStatus(); + } + + absl::Status operator()(absl::Time* out) const { + if (!input.IsTimestamp()) { + return absl::InvalidArgumentError("expected timestamp value"); + } + *out = input.GetTimestamp().ToTime(); + return absl::OkStatus(); + } + + absl::Status operator()(absl::Duration* out) const { + if (!input.IsDuration()) { + return absl::InvalidArgumentError("expected duration value"); + } + *out = input.GetDuration().ToDuration(); + return absl::OkStatus(); + } + + absl::Status operator()(Value* out) const { + *out = input; + return absl::OkStatus(); + } + + absl::Status operator()(const Value** out) const { + *out = &input; + return absl::OkStatus(); + } + + template + absl::Status operator()(T* out) const { + if (!InstanceOf>(input)) { + return absl::InvalidArgumentError( + absl::StrCat("expected ", ValueKindToString(T::kKind), " value")); + } + *out = Cast>(input); + return absl::OkStatus(); + } + + template + absl::Status operator()(T** out) const { + if (!InstanceOf>(input)) { + return absl::InvalidArgumentError( + absl::StrCat("expected ", ValueKindToString(T::kKind), " value")); + } + static_assert(std::is_lvalue_reference_v< + decltype(Cast>(input))>, + "expected l-value reference return type for Cast."); + *out = &Cast>(input); + return absl::OkStatus(); + } + + const Value& input; +}; + +// Adapts the return value of a wrapped C++ function to its corresponding +// Value representation. +struct AdaptedToValueVisitor { + absl::StatusOr operator()(int64_t in) { return IntValue(in); } + + absl::StatusOr operator()(uint64_t in) { return UintValue(in); } + + absl::StatusOr operator()(double in) { return DoubleValue(in); } + + absl::StatusOr operator()(bool in) { return BoolValue(in); } + + absl::StatusOr operator()(absl::Time in) { + // Type matching may have already occurred. It's too late to change up the + // type and return an error. + return TimestampValue(in); + } + + absl::StatusOr operator()(absl::Duration in) { + // Type matching may have already occurred. It's too late to change up the + // type and return an error. + return DurationValue(in); + } + + absl::StatusOr operator()(Value in) { return in; } + + template + absl::StatusOr operator()(T in) { + return in; + } + + // Special case for StatusOr return value -- wrap the underlying value if + // present, otherwise return the status. + template + absl::StatusOr operator()(absl::StatusOr wrapped) { + if (!wrapped.ok()) { + return std::move(wrapped).status(); + } + return this->operator()(std::move(wrapped).value()); + } +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_FUNCTION_ADAPTER_H_ diff --git a/runtime/internal/function_adapter_test.cc b/runtime/internal/function_adapter_test.cc new file mode 100644 index 000000000..643f08090 --- /dev/null +++ b/runtime/internal/function_adapter_test.cc @@ -0,0 +1,319 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/function_adapter.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "common/kind.h" +#include "common/value.h" +#include "internal/testing.h" + +namespace cel::runtime_internal { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; + +static_assert(AdaptedKind() == Kind::kInt, "int adapts to int64_t"); +static_assert(AdaptedKind() == Kind::kUint, + "uint adapts to uint64_t"); +static_assert(AdaptedKind() == Kind::kDouble, + "double adapts to double"); +static_assert(AdaptedKind() == Kind::kBool, "bool adapts to bool"); +static_assert(AdaptedKind() == Kind::kTimestamp, + "timestamp adapts to absl::Time"); +static_assert(AdaptedKind() == Kind::kDuration, + "duration adapts to absl::Duration"); +// Handle types. +static_assert(AdaptedKind() == Kind::kAny, "any adapts to Value"); +static_assert(AdaptedKind() == Kind::kString, + "string adapts to String"); +static_assert(AdaptedKind() == Kind::kBytes, + "bytes adapts to Bytes"); +static_assert(AdaptedKind() == Kind::kStruct, + "struct adapts to StructValue"); +static_assert(AdaptedKind() == Kind::kList, + "list adapts to ListValue"); +static_assert(AdaptedKind() == Kind::kMap, "map adapts to MapValue"); +static_assert(AdaptedKind() == Kind::kNullType, + "null adapts to NullValue"); +static_assert(AdaptedKind() == Kind::kAny, + "any adapts to const Value&"); +static_assert(AdaptedKind() == Kind::kString, + "string adapts to const String&"); +static_assert(AdaptedKind() == Kind::kBytes, + "bytes adapts to const Bytes&"); +static_assert(AdaptedKind() == Kind::kStruct, + "struct adapts to const StructValue&"); +static_assert(AdaptedKind() == Kind::kList, + "list adapts to const ListValue&"); +static_assert(AdaptedKind() == Kind::kMap, + "map adapts to const MapValue&"); +static_assert(AdaptedKind() == Kind::kNullType, + "null adapts to const NullValue&"); + +class ValueToAdaptedVisitorTest : public ::testing::Test {}; + +TEST_F(ValueToAdaptedVisitorTest, Int) { + Value v = cel::IntValue(10); + + int64_t out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out, 10); +} + +TEST_F(ValueToAdaptedVisitorTest, IntWrongKind) { + Value v = cel::UintValue(10); + + int64_t out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected int value")); +} + +TEST_F(ValueToAdaptedVisitorTest, Uint) { + Value v = cel::UintValue(11); + + uint64_t out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out, 11); +} + +TEST_F(ValueToAdaptedVisitorTest, UintWrongKind) { + Value v = cel::IntValue(11); + + uint64_t out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected uint value")); +} + +TEST_F(ValueToAdaptedVisitorTest, Double) { + Value v = cel::DoubleValue(12.0); + + double out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out, 12.0); +} + +TEST_F(ValueToAdaptedVisitorTest, DoubleWrongKind) { + Value v = cel::UintValue(10); + + double out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected double value")); +} + +TEST_F(ValueToAdaptedVisitorTest, Bool) { + Value v = cel::BoolValue(false); + + bool out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out, false); +} + +TEST_F(ValueToAdaptedVisitorTest, BoolWrongKind) { + Value v = cel::UintValue(10); + + bool out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected bool value")); +} + +TEST_F(ValueToAdaptedVisitorTest, Timestamp) { + Value v = cel::TimestampValue(absl::UnixEpoch() + absl::Seconds(1)); + + absl::Time out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out, absl::UnixEpoch() + absl::Seconds(1)); +} + +TEST_F(ValueToAdaptedVisitorTest, TimestampWrongKind) { + Value v = cel::UintValue(10); + + absl::Time out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected timestamp value")); +} + +TEST_F(ValueToAdaptedVisitorTest, Duration) { + Value v = cel::DurationValue(absl::Seconds(5)); + + absl::Duration out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out, absl::Seconds(5)); +} + +TEST_F(ValueToAdaptedVisitorTest, DurationWrongKind) { + Value v = cel::UintValue(10); + + absl::Duration out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected duration value")); +} + +TEST_F(ValueToAdaptedVisitorTest, String) { + Value v = cel::StringValue("string"); + + StringValue out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out.ToString(), "string"); +} + +TEST_F(ValueToAdaptedVisitorTest, StringWrongKind) { + Value v = cel::UintValue(10); + + StringValue out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected string value")); +} + +TEST_F(ValueToAdaptedVisitorTest, Bytes) { + Value v = cel::BytesValue("bytes"); + + BytesValue out; + ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); + + EXPECT_EQ(out.ToString(), "bytes"); +} + +TEST_F(ValueToAdaptedVisitorTest, BytesWrongKind) { + Value v = cel::UintValue(10); + + BytesValue out; + EXPECT_THAT( + ValueToAdaptedVisitor{v}(&out), + StatusIs(absl::StatusCode::kInvalidArgument, "expected bytes value")); +} + +class AdaptedToValueVisitorTest : public ::testing::Test {}; + +TEST_F(AdaptedToValueVisitorTest, Int) { + int64_t value = 10; + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsInt()); + EXPECT_EQ(result.GetInt().NativeValue(), 10); +} + +TEST_F(AdaptedToValueVisitorTest, Double) { + double value = 10; + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsDouble()); + EXPECT_EQ(result.GetDouble().NativeValue(), 10.0); +} + +TEST_F(AdaptedToValueVisitorTest, Uint) { + uint64_t value = 10; + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsUint()); + EXPECT_EQ(result.GetUint().NativeValue(), 10); +} + +TEST_F(AdaptedToValueVisitorTest, Bool) { + bool value = true; + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.GetBool().NativeValue(), true); +} + +TEST_F(AdaptedToValueVisitorTest, Timestamp) { + absl::Time value = absl::UnixEpoch() + absl::Seconds(10); + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsTimestamp()); + EXPECT_EQ(result.GetTimestamp().ToTime(), + absl::UnixEpoch() + absl::Seconds(10)); +} + +TEST_F(AdaptedToValueVisitorTest, Duration) { + absl::Duration value = absl::Seconds(5); + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsDuration()); + EXPECT_EQ(result.GetDuration().ToDuration(), absl::Seconds(5)); +} + +TEST_F(AdaptedToValueVisitorTest, String) { + StringValue value = cel::StringValue("str"); + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsString()); + EXPECT_EQ(result.GetString().ToString(), "str"); +} + +TEST_F(AdaptedToValueVisitorTest, Bytes) { + BytesValue value = cel::BytesValue("bytes"); + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsBytes()); + EXPECT_EQ(result.GetBytes().ToString(), "bytes"); +} + +TEST_F(AdaptedToValueVisitorTest, StatusOrValue) { + absl::StatusOr value = 10; + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); + + ASSERT_TRUE(result.IsInt()); + EXPECT_EQ(result.GetInt().NativeValue(), 10); +} + +TEST_F(AdaptedToValueVisitorTest, StatusOrError) { + absl::StatusOr value = absl::InternalError("test_error"); + + EXPECT_THAT(AdaptedToValueVisitor{}(value).status(), + StatusIs(absl::StatusCode::kInternal, "test_error")); +} + +TEST_F(AdaptedToValueVisitorTest, Any) { + auto handle = cel::ErrorValue(absl::InternalError("test_error")); + + ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(handle)); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kInternal, "test_error")); +} + +} // namespace +} // namespace cel::runtime_internal diff --git a/runtime/internal/issue_collector.h b/runtime/internal/issue_collector.h new file mode 100644 index 000000000..e3a294d4f --- /dev/null +++ b/runtime/internal/issue_collector.h @@ -0,0 +1,64 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ISSUE_COLLECTOR_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ISSUE_COLLECTOR_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "runtime/runtime_issue.h" + +namespace cel::runtime_internal { + +// IssueCollector collects issues and reports absl::Status according to the +// configured severity limit. +class IssueCollector { + public: + // Args: + // severity: inclusive limit for issues to return as non-ok absl::Status. + explicit IssueCollector(RuntimeIssue::Severity severity_limit) + : severity_limit_(severity_limit) {} + + // move-only. + IssueCollector(const IssueCollector&) = delete; + IssueCollector& operator=(const IssueCollector&) = delete; + IssueCollector(IssueCollector&&) = default; + IssueCollector& operator=(IssueCollector&&) = default; + + // Collect an Issue. + // Returns a status according to the IssueCollector's policy and the given + // Issue. + // The Issue is always added to issues, regardless of whether AddIssue returns + // a non-ok status. + absl::Status AddIssue(RuntimeIssue issue) { + issues_.push_back(std::move(issue)); + if (issues_.back().severity() >= severity_limit_) { + return issues_.back().ToStatus(); + } + return absl::OkStatus(); + } + + absl::Span issues() const { return issues_; } + std::vector ExtractIssues() { return std::move(issues_); } + + private: + RuntimeIssue::Severity severity_limit_; + std::vector issues_; +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ISSUE_COLLECTOR_H_ diff --git a/runtime/internal/issue_collector_test.cc b/runtime/internal/issue_collector_test.cc new file mode 100644 index 000000000..c7caaaf9c --- /dev/null +++ b/runtime/internal/issue_collector_test.cc @@ -0,0 +1,94 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "runtime/internal/issue_collector.h" + +#include "absl/status/status.h" +#include "internal/testing.h" +#include "runtime/runtime_issue.h" + +namespace cel::runtime_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::ElementsAre; +using ::testing::Truly; + +template +bool ApplyMatcher(Matcher m, const T& t) { + return static_cast>(m).Matches(t); +} + +TEST(IssueCollector, CollectsIssues) { + IssueCollector issue_collector(RuntimeIssue::Severity::kError); + + EXPECT_THAT(issue_collector.AddIssue( + RuntimeIssue::CreateError(absl::InvalidArgumentError("e1"))), + StatusIs(absl::StatusCode::kInvalidArgument, "e1")); + ASSERT_OK(issue_collector.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError("w1"), + RuntimeIssue::ErrorCode::kNoMatchingOverload))); + + EXPECT_THAT( + issue_collector.issues(), + ElementsAre( + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kError && + issue.error_code() == RuntimeIssue::ErrorCode::kOther && + ApplyMatcher( + StatusIs(absl::StatusCode::kInvalidArgument, "e1"), + issue.ToStatus()); + }), + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kWarning && + issue.error_code() == + RuntimeIssue::ErrorCode::kNoMatchingOverload && + ApplyMatcher( + StatusIs(absl::StatusCode::kInvalidArgument, "w1"), + issue.ToStatus()); + }))); +} + +TEST(IssueCollector, ReturnsStatusAtLimit) { + IssueCollector issue_collector(RuntimeIssue::Severity::kWarning); + + EXPECT_THAT(issue_collector.AddIssue( + RuntimeIssue::CreateError(absl::InvalidArgumentError("e1"))), + StatusIs(absl::StatusCode::kInvalidArgument, "e1")); + + EXPECT_THAT(issue_collector.AddIssue(RuntimeIssue::CreateWarning( + absl::InvalidArgumentError("w1"), + RuntimeIssue::ErrorCode::kNoMatchingOverload)), + StatusIs(absl::StatusCode::kInvalidArgument, "w1")); + + EXPECT_THAT( + issue_collector.issues(), + ElementsAre( + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kError && + issue.error_code() == RuntimeIssue::ErrorCode::kOther && + ApplyMatcher( + StatusIs(absl::StatusCode::kInvalidArgument, "e1"), + issue.ToStatus()); + }), + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kWarning && + issue.error_code() == + RuntimeIssue::ErrorCode::kNoMatchingOverload && + ApplyMatcher( + StatusIs(absl::StatusCode::kInvalidArgument, "w1"), + issue.ToStatus()); + }))); +} +} // namespace +} // namespace cel::runtime_internal diff --git a/runtime/internal/legacy_runtime_type_provider.h b/runtime/internal/legacy_runtime_type_provider.h new file mode 100644 index 000000000..503a79b46 --- /dev/null +++ b/runtime/internal/legacy_runtime_type_provider.h @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ + +#include "absl/base/nullability.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +class LegacyRuntimeTypeProvider final + : public google::api::expr::runtime::ProtobufDescriptorProvider { + public: + LegacyRuntimeTypeProvider( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nullable message_factory) + : google::api::expr::runtime::ProtobufDescriptorProvider( + descriptor_pool, message_factory) {} +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ diff --git a/runtime/internal/runtime_env.cc b/runtime/internal/runtime_env.cc new file mode 100644 index 000000000..fe5b47330 --- /dev/null +++ b/runtime/internal/runtime_env.cc @@ -0,0 +1,73 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_env.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/synchronization/mutex.h" +#include "internal/noop_delete.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +RuntimeEnv::KeepAlives::~KeepAlives() { + while (!deque.empty()) { + deque.pop_back(); + } +} + +google::protobuf::MessageFactory* absl_nonnull RuntimeEnv::MutableMessageFactory() const { + google::protobuf::MessageFactory* absl_nullable shared_message_factory = + message_factory_ptr.load(std::memory_order_relaxed); + if (shared_message_factory != nullptr) { + return shared_message_factory; + } + absl::MutexLock lock(message_factory_mutex); + shared_message_factory = message_factory_ptr.load(std::memory_order_relaxed); + if (shared_message_factory == nullptr) { + if (descriptor_pool.get() == google::protobuf::DescriptorPool::generated_pool()) { + // Using the generated descriptor pool, just use the generated message + // factory. + message_factory = std::shared_ptr( + google::protobuf::MessageFactory::generated_factory(), + internal::NoopDeleteFor()); + } else { + auto dynamic_message_factory = + std::make_shared(); + // Ensure we do not delegate to the generated factory, if the default + // every changes. We prefer being hermetic. + dynamic_message_factory->SetDelegateToGeneratedFactory(false); + message_factory = std::move(dynamic_message_factory); + } + shared_message_factory = message_factory.get(); + message_factory_ptr.store(shared_message_factory, + std::memory_order_seq_cst); + } + return shared_message_factory; +} + +void RuntimeEnv::KeepAlive(std::shared_ptr keep_alive) { + if (keep_alive == nullptr) { + return; + } + keep_alives.deque.push_back(std::move(keep_alive)); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_env.h b/runtime/internal/runtime_env.h new file mode 100644 index 000000000..cb9d9b93d --- /dev/null +++ b/runtime/internal/runtime_env.h @@ -0,0 +1,134 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "internal/well_known_types.h" +#include "runtime/function_registry.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +// Shared state used by the runtime during creation, configuration, planning, +// and evaluation. Passed around via `std::shared_ptr`. +// +// TODO(uncreated-issue/66): Make this a class. +struct RuntimeEnv final { + explicit RuntimeEnv(absl_nonnull std::shared_ptr + descriptor_pool, + absl_nullable std::shared_ptr + message_factory = nullptr) + : descriptor_pool(std::move(descriptor_pool)), + message_factory(std::move(message_factory)), + legacy_type_registry(this->descriptor_pool.get(), + this->message_factory.get()), + type_registry(legacy_type_registry.InternalGetModernRegistry()), + function_registry(legacy_function_registry.InternalGetRegistry()) { + if (this->message_factory != nullptr) { + message_factory_ptr.store(this->message_factory.get(), + std::memory_order_seq_cst); + } + } + + // Not copyable or moveable. + RuntimeEnv(const RuntimeEnv&) = delete; + RuntimeEnv(RuntimeEnv&&) = delete; + RuntimeEnv& operator=(const RuntimeEnv&) = delete; + RuntimeEnv& operator=(RuntimeEnv&&) = delete; + + // Ideally the environment would already be initialized, but things are a bit + // awkward. This should only be called once immediately after construction. + absl::Status Initialize() { + return well_known_types.Initialize(descriptor_pool.get()); + } + + bool IsInitialized() const { return well_known_types.IsInitialized(); } + + ABSL_ATTRIBUTE_UNUSED + const absl_nonnull std::shared_ptr + descriptor_pool; + + private: + // These fields deal with a message factory that is lazily initialized as + // needed. This might be called during the planning phase of an expression or + // during evaluation. We want the ability to get the message factory when it + // is already created to be cheap, so we use an atomic and a mutex for the + // slow path. + // + // Do not access any of these fields directly, use member functions. + mutable absl::Mutex message_factory_mutex; + mutable absl_nullable std::shared_ptr message_factory + ABSL_GUARDED_BY(message_factory_mutex); + // std::atomic> is not really a simple atomic, so we + // avoid it. + mutable std::atomic + message_factory_ptr = nullptr; + + struct KeepAlives final { + KeepAlives() = default; + + ~KeepAlives(); + + // Not copyable or moveable. + KeepAlives(const KeepAlives&) = delete; + KeepAlives(KeepAlives&&) = delete; + KeepAlives& operator=(const KeepAlives&) = delete; + KeepAlives& operator=(KeepAlives&&) = delete; + + std::deque> deque; + }; + + KeepAlives keep_alives; + + public: + // Because of legacy shenanigans, we use shared_ptr here. For legacy, this is + // an unowned shared_ptr (a noop deleter) pointing to the modern equivalent + // which is a member of the legacy variant. + google::api::expr::runtime::CelTypeRegistry legacy_type_registry; + google::api::expr::runtime::CelFunctionRegistry legacy_function_registry; + TypeRegistry& type_registry; + FunctionRegistry& function_registry; + + well_known_types::Reflection well_known_types; + + google::protobuf::MessageFactory* absl_nonnull MutableMessageFactory() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + // Not thread safe. Adds `keep_alive` to a list owned by this environment + // and ensures it survives at least as long as this environment. Keep alives + // are released in reverse order of their registration. This mimics normal + // destructor rules of members. + // + // IMPORTANT: This should only be when building the runtime, and not after. + void KeepAlive(std::shared_ptr keep_alive); +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ diff --git a/runtime/internal/runtime_env_testing.cc b/runtime/internal/runtime_env_testing.cc new file mode 100644 index 000000000..6de4fffcf --- /dev/null +++ b/runtime/internal/runtime_env_testing.cc @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_env_testing.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "internal/noop_delete.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/internal/runtime_env.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +absl_nonnull std::shared_ptr NewTestingRuntimeEnv() { + auto env = std::make_shared( + internal::GetSharedTestingDescriptorPool(), + std::shared_ptr( + internal::GetTestingMessageFactory(), + internal::NoopDeleteFor())); + ABSL_CHECK_OK(env->Initialize()); // Crash OK + return env; +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_env_testing.h b/runtime/internal/runtime_env_testing.h new file mode 100644 index 000000000..71b2096cd --- /dev/null +++ b/runtime/internal/runtime_env_testing.h @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ + +#include + +#include "absl/base/nullability.h" +#include "runtime/internal/runtime_env.h" + +namespace cel::runtime_internal { + +absl_nonnull std::shared_ptr NewTestingRuntimeEnv(); + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ diff --git a/runtime/internal/runtime_friend_access.h b/runtime/internal/runtime_friend_access.h new file mode 100644 index 000000000..715f95550 --- /dev/null +++ b/runtime/internal/runtime_friend_access.h @@ -0,0 +1,45 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_FRIEND_ACCESS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_FRIEND_ACCESS_H_ + +#include "common/native_type.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" + +namespace cel::runtime_internal { + +// Provide accessors for friend-visibility internal runtime details. +// +// CEL supported runtime extensions need implementation specific details to work +// correctly. We restrict access to prevent external usages since we don't +// guarantee stability on the implementation details. +class RuntimeFriendAccess { + public: + // Access underlying runtime instance. + static Runtime& GetMutableRuntime(RuntimeBuilder& builder) { + return builder.runtime(); + } + + // Return the internal type_id for the runtime instance for checked down + // casting. + static NativeTypeId RuntimeTypeId(Runtime& runtime) { + return runtime.GetNativeTypeId(); + } +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_RUNTIME_EXTENSIONS_FRIEND_ACCESS_H_ diff --git a/runtime/internal/runtime_impl.cc b/runtime/internal/runtime_impl.cc new file mode 100644 index 000000000..92d097b2c --- /dev/null +++ b/runtime/internal/runtime_impl.cc @@ -0,0 +1,159 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "runtime/internal/runtime_impl.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "base/ast.h" +#include "base/type_provider.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "runtime/activation_interface.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel::runtime_internal { +namespace { + +using ::google::api::expr::runtime::AttributeTrail; +using ::google::api::expr::runtime::ComprehensionSlots; +using ::google::api::expr::runtime::DirectExpressionStep; +using ::google::api::expr::runtime::ExecutionFrameBase; +using ::google::api::expr::runtime::FlatExpression; +using ::google::api::expr::runtime::WrappedDirectStep; + +class ProgramImpl final : public TraceableProgram { + public: + using EvaluationListener = TraceableProgram::EvaluationListener; + ProgramImpl( + const std::shared_ptr& environment, + FlatExpression impl) + : environment_(environment), impl_(std::move(impl)) {} + + absl::StatusOr TraceImpl( + const ActivationInterface& activation, + EvaluationListener evaluation_listener, google::protobuf::Arena* absl_nonnull arena, + const EvaluateOptions& options) const override { + ABSL_DCHECK(arena != nullptr); + auto state = + impl_.MakeEvaluatorState(environment_->descriptor_pool.get(), + options.message_factory != nullptr + ? options.message_factory + : environment_->MutableMessageFactory(), + arena); + return impl_.EvaluateWithCallback(activation, options.embedder_context, + std::move(evaluation_listener), state); + } + + const TypeProvider& GetTypeProvider() const override { + return environment_->type_registry.GetComposedTypeProvider(); + } + + private: + // Keep the Runtime environment alive while programs reference it. + std::shared_ptr environment_; + FlatExpression impl_; +}; + +class RecursiveProgramImpl final : public TraceableProgram { + public: + using EvaluationListener = TraceableProgram::EvaluationListener; + RecursiveProgramImpl( + const std::shared_ptr& environment, + FlatExpression impl, const DirectExpressionStep* absl_nonnull root) + : environment_(environment), impl_(std::move(impl)), root_(root) {} + + absl::StatusOr TraceImpl( + const ActivationInterface& activation, + EvaluationListener evaluation_listener, google::protobuf::Arena* absl_nonnull arena, + const EvaluateOptions& options) const override { + ABSL_DCHECK(arena != nullptr); + ComprehensionSlots slots(impl_.comprehension_slots_size()); + ExecutionFrameBase frame(activation, std::move(evaluation_listener), + impl_.options(), GetTypeProvider(), + environment_->descriptor_pool.get(), + options.message_factory != nullptr + ? options.message_factory + : environment_->MutableMessageFactory(), + arena, options.embedder_context, slots); + + Value result; + AttributeTrail attribute; + CEL_RETURN_IF_ERROR(root_->Evaluate(frame, result, attribute)); + + return result; + } + + const TypeProvider& GetTypeProvider() const override { + return environment_->type_registry.GetComposedTypeProvider(); + } + + private: + // Keep the Runtime environment alive while programs reference it. + std::shared_ptr environment_; + FlatExpression impl_; + const DirectExpressionStep* absl_nonnull root_; +}; + +} // namespace + +absl::StatusOr> RuntimeImpl::CreateProgram( + std::unique_ptr ast, + const Runtime::CreateProgramOptions& options) const { + return CreateTraceableProgram(std::move(ast), options); +} + +absl::StatusOr> +RuntimeImpl::CreateTraceableProgram( + std::unique_ptr ast, + const Runtime::CreateProgramOptions& options) const { + CEL_ASSIGN_OR_RETURN(auto flat_expr, expr_builder_.CreateExpressionImpl( + std::move(ast), options.issues)); + + // Special case if the program is fully recursive. + // + // This implementation avoids unnecessary allocs at evaluation time which + // improves performance notably for small expressions. + if (expr_builder_.options().max_recursion_depth != 0 && + !flat_expr.subexpressions().empty() && + // mainline expression is exactly one recursive step. + flat_expr.subexpressions().front().size() == 1 && + flat_expr.subexpressions().front().front()->GetNativeTypeId() == + NativeTypeId::For()) { + const DirectExpressionStep* root = + internal::down_cast( + flat_expr.subexpressions().front().front().get()) + ->wrapped(); + return std::make_unique(environment_, + std::move(flat_expr), root); + } + + return std::make_unique(environment_, std::move(flat_expr)); +} + +bool TestOnly_IsRecursiveImpl(const Program* program) { + return dynamic_cast(program) != nullptr; +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_impl.h b/runtime/internal/runtime_impl.h new file mode 100644 index 000000000..7c5d445f9 --- /dev/null +++ b/runtime/internal/runtime_impl.h @@ -0,0 +1,125 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_IMPL_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "base/ast.h" +#include "base/type_provider.h" +#include "common/native_type.h" +#include "eval/compiler/flat_expr_builder.h" +#include "internal/well_known_types.h" +#include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +class RuntimeImpl : public Runtime { + public: + using Environment = RuntimeEnv; + + RuntimeImpl(absl_nonnull std::shared_ptr environment, + const RuntimeOptions& options) + : environment_(std::move(environment)), + expr_builder_(environment_, options) { + ABSL_DCHECK(environment_->well_known_types.IsInitialized()); + } + + TypeRegistry& type_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return environment_->type_registry; + } + const TypeRegistry& type_registry() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return environment_->type_registry; + } + + FunctionRegistry& function_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return environment_->function_registry; + } + const FunctionRegistry& function_registry() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return environment_->function_registry; + } + + const well_known_types::Reflection& well_known_types() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return environment_->well_known_types; + } + + Environment& environment() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *environment_; + } + const Environment& environment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *environment_; + } + + // implement Runtime + absl::StatusOr> CreateProgram( + std::unique_ptr ast, + const Runtime::CreateProgramOptions& options) const final; + + absl::StatusOr> CreateTraceableProgram( + std::unique_ptr ast, + const Runtime::CreateProgramOptions& options) const override; + + const TypeProvider& GetTypeProvider() const override { + return environment_->type_registry.GetComposedTypeProvider(); + } + + const google::protobuf::DescriptorPool* absl_nonnull GetDescriptorPool() + const override { + return environment_->descriptor_pool.get(); + } + + google::protobuf::MessageFactory* absl_nonnull GetMessageFactory() const override { + return environment_->MutableMessageFactory(); + } + + // exposed for extensions access + google::api::expr::runtime::FlatExprBuilder& expr_builder() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return expr_builder_; + } + + private: + NativeTypeId GetNativeTypeId() const override { + return NativeTypeId::For(); + } + // Note: this is mutable, but should only be accessed in a const context after + // building is complete. + // + // This is used to keep alive the registries while programs reference them. + std::shared_ptr environment_; + google::api::expr::runtime::FlatExprBuilder expr_builder_; +}; + +// Exposed for testing to validate program is recursively planned. +// +// Uses dynamic_casts to test. +bool TestOnly_IsRecursiveImpl(const Program* program); + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_IMPL_H_ diff --git a/runtime/internal/runtime_type_provider.cc b/runtime/internal/runtime_type_provider.cc new file mode 100644 index 000000000..40f5ff575 --- /dev/null +++ b/runtime/internal/runtime_type_provider.cc @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_type_provider.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "common/value.h" +#include "common/values/value_builder.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +absl::Status RuntimeTypeProvider::RegisterType(const OpaqueType& type) { + auto insertion = types_.insert(std::pair{type.name(), Type(type)}); + if (!insertion.second) { + return absl::AlreadyExistsError( + absl::StrCat("type already registered: ", insertion.first->first)); + } + return absl::OkStatus(); +} + +absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( + absl::string_view name) const { + auto type = FindWellKnownType(name); + if (type.has_value()) { + return type; + } + const auto* desc = descriptor_pool_->FindMessageTypeByName(name); + if (desc != nullptr) { + return MessageType(desc); + } + + if (const auto it = types_.find(name); it != types_.end()) { + return it->second; + } + return absl::nullopt; +} + +absl::StatusOr> +RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, + absl::string_view value) const { + auto enum_constant = FindWellKnownTypeEnumConstant(type, value); + if (enum_constant.has_value()) { + return enum_constant; + } + const google::protobuf::EnumDescriptor* enum_desc = + descriptor_pool_->FindEnumTypeByName(type); + if (enum_desc == nullptr) { + return absl::nullopt; + } + + // Note: we don't support strong enum typing at this time so only the fully + // qualified enum values are meaningful, so we don't provide any signal if the + // enum type is found but can't match the value name. + const google::protobuf::EnumValueDescriptor* value_desc = + enum_desc->FindValueByName(value); + if (value_desc == nullptr) { + return absl::nullopt; + } + + return TypeIntrospector::EnumConstant{ + EnumType(enum_desc), enum_desc->full_name(), value_desc->name(), + value_desc->number()}; +} + +absl::StatusOr> +RuntimeTypeProvider::FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const { + auto field = FindWellKnownTypeFieldByName(type, name); + if (field.has_value()) { + return field; + } + const auto* desc = descriptor_pool_->FindMessageTypeByName(type); + if (desc == nullptr) { + return absl::nullopt; + } + const auto* field_desc = desc->FindFieldByName(name); + if (field_desc == nullptr) { + field_desc = descriptor_pool_->FindExtensionByPrintableName(desc, name); + if (field_desc == nullptr) { + return absl::nullopt; + } + } + return MessageTypeField(field_desc); +} + +absl::StatusOr +RuntimeTypeProvider::NewValueBuilder( + absl::string_view name, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + return common_internal::NewValueBuilder(arena, descriptor_pool_, + message_factory, name); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_type_provider.h b/runtime/internal/runtime_type_provider.h new file mode 100644 index 000000000..3f418af4d --- /dev/null +++ b/runtime/internal/runtime_type_provider.h @@ -0,0 +1,63 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_reflector.h" +#include "common/value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +class RuntimeTypeProvider final : public TypeReflector { + public: + explicit RuntimeTypeProvider( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) + : descriptor_pool_(descriptor_pool) {} + + absl::Status RegisterType(const OpaqueType& type); + + absl::StatusOr NewValueBuilder( + absl::string_view name, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override; + + protected: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override; + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const override; + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const override; + + private: + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; + absl::flat_hash_map types_; +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ diff --git a/runtime/memory_safety_test.cc b/runtime/memory_safety_test.cc new file mode 100644 index 000000000..a60b4ce60 --- /dev/null +++ b/runtime/memory_safety_test.cc @@ -0,0 +1,1082 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Tests for memory safety using the CEL Evaluator. +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/constant_folding.h" +#include "runtime/function_adapter.h" +#include "runtime/optional_types.h" +#include "runtime/reference_resolver.h" +#include "runtime/regex_precompilation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::cel::expr::conformance::proto3::NestedTestAllTypes; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::test::StringValueIs; +using ::cel::test::ValueMatcher; +using ::google::protobuf::Any; +using ::testing::Not; + +struct TestCase { + std::string name; + std::string expression; + absl::flat_hash_map> + activation; + test::ValueMatcher expected_matcher; + bool reference_resolver_enabled = false; +}; + +enum Options { kDefault, kExhaustive, kFoldConstants }; + +using ParamType = std::tuple; + +absl::StatusOr> CreateCompiler() { + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection< + cel::expr::conformance::proto3::NestedTestAllTypes>(); + + CEL_ASSIGN_OR_RETURN( + std::unique_ptr b, + NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); + CEL_RETURN_IF_ERROR(b->AddLibrary(StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(b->AddLibrary(OptionalCompilerLibrary())); + b->GetCheckerBuilder().set_container("cel.expr.conformance.proto3"); + auto& cb = b->GetCheckerBuilder(); + CEL_RETURN_IF_ERROR(cb.AddVariable(MakeVariableDecl("bool_var", BoolType()))); + CEL_RETURN_IF_ERROR( + cb.AddVariable(MakeVariableDecl("string_var", StringType()))); + CEL_RETURN_IF_ERROR( + cb.AddVariable(MakeVariableDecl("condition", BoolType()))); + CEL_RETURN_IF_ERROR(cb.AddVariable(MakeVariableDecl( + "nested_test_all_types", MessageType(NestedTestAllTypes::descriptor())))); + + CEL_RETURN_IF_ERROR(cb.AddFunction( + MakeFunctionDecl("IsPrivate", MakeOverloadDecl("IsPrivate_string", + BoolType(), StringType())) + .value())); + CEL_RETURN_IF_ERROR(cb.AddFunction( + MakeFunctionDecl( + "net.IsPrivate", + MakeOverloadDecl("net_IsPrivate_string", BoolType(), StringType())) + .value())); + + return b->Build(); +} + +const Compiler& GetCompiler() { + static const Compiler* compiler = []() { + auto compiler = CreateCompiler(); + ABSL_QCHECK_OK(compiler.status()); + return compiler->release(); + }(); + return *compiler; +} + +std::string TestCaseName(const testing::TestParamInfo& param_info) { + const ParamType& param = param_info.param; + absl::string_view opt; + switch (std::get<1>(param)) { + case Options::kDefault: + opt = "default"; + break; + case Options::kExhaustive: + opt = "exhaustive"; + break; + case Options::kFoldConstants: + opt = "opt"; + break; + } + + return absl::StrCat(std::get<0>(param).name, "_", opt); +} + +bool IsPrivateIpv4Impl(const StringValue& addr) { + // Implementation for demonstration, this is simple but incomplete and + // brittle. + std::string buf; + return absl::StartsWith(addr.ToStringView(&buf), "192.168.") || + absl::StartsWith(addr.ToStringView(&buf), "10."); +} + +absl::StatusOr> ConfigureRuntimeImpl( + bool resolve_references, Options evaluation_options) { + RuntimeOptions options; + switch (evaluation_options) { + case Options::kDefault: + options.short_circuiting = true; + break; + case Options::kExhaustive: + options.short_circuiting = false; + break; + case Options::kFoldConstants: + options.enable_comprehension_list_append = true; + options.short_circuiting = true; + break; + } + options.enable_qualified_type_identifiers = resolve_references; + options.container = "cel.expr.conformance.proto3"; + CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + if (resolve_references) { + CEL_RETURN_IF_ERROR(EnableReferenceResolver( + runtime_builder, ReferenceResolverEnabled::kAlways)); + CEL_RETURN_IF_ERROR(extensions::EnableOptionalTypes(runtime_builder)); + } + if (evaluation_options == Options::kFoldConstants) { + CEL_RETURN_IF_ERROR(extensions::EnableConstantFolding(runtime_builder)); + CEL_RETURN_IF_ERROR(extensions::EnableRegexPrecompilation(runtime_builder)); + } + + auto s = UnaryFunctionAdapter::Register( + "IsPrivate", false, &IsPrivateIpv4Impl, + runtime_builder.function_registry()); + CEL_RETURN_IF_ERROR(s); + s.Update(UnaryFunctionAdapter::Register( + "net.IsPrivate", false, &IsPrivateIpv4Impl, + runtime_builder.function_registry())); + CEL_RETURN_IF_ERROR(s); + + return std::move(runtime_builder).Build(); +} + +class EvaluatorMemorySafetyTest : public testing::TestWithParam { + public: + EvaluatorMemorySafetyTest() = default; + + protected: + const TestCase& GetTestCase() { return std::get<0>(GetParam()); } + + absl::StatusOr> ConfigureRuntime() { + return ConfigureRuntimeImpl(GetTestCase().reference_resolver_enabled, + std::get<1>(GetParam())); + } +}; + +void InitActivation(const TestCase& test_case, google::protobuf::Arena& arena, + Activation& activation) { + for (const auto& [key, value] : test_case.activation) { + if (absl::holds_alternative(value)) { + activation.InsertOrAssignValue(key, std::get(value)); + } else { + // Note: This assumes that the TestCase is valid for the given TEST. + // Changes to the activation map will invalidate the pointer to message + // that gets wrapped here. + activation.InsertOrAssignValue( + key, Value::WrapMessageUnsafe( + &std::get(value), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + } + } +} + +TEST_P(EvaluatorMemorySafetyTest, Basic) { + const auto& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntime()); + + ASSERT_OK_AND_ASSIGN(ValidationResult validation, + GetCompiler().Compile(test_case.expression)); + + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + Activation activation; + google::protobuf::Arena arena; + InitActivation(test_case, arena, activation); + absl::StatusOr got = program->Evaluate(&arena, activation); + + EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); +} + +TEST_P(EvaluatorMemorySafetyTest, ProgramSafeAfterRuntimeDestroyed) { + const auto& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntime()); + + ASSERT_OK_AND_ASSIGN(ValidationResult validation, + GetCompiler().Compile(test_case.expression)); + + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + Activation activation; + google::protobuf::Arena arena; + InitActivation(test_case, arena, activation); + runtime.reset(); + absl::StatusOr got = program->Evaluate(&arena, activation); + EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); +} + +// Helper for making an eternal string value without looking like a memory leak. +Value MakeStringValue(absl::string_view str) { + static absl::NoDestructor kArena; + return StringValue::Wrap(str, kArena.get()); +} + +NestedTestAllTypes MakeNestedTestAllTypes(absl::string_view textproto) { + NestedTestAllTypes msg; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(textproto, &msg)); + return msg; +} + +MATCHER_P(ParsedProtoStructEquals, expected, "") { + const cel::StructValue& got = arg; + if (!got.IsParsedMessage()) { + return false; + } + auto& msg = got.GetParsedMessage(); + auto cmp = absl::WrapUnique(msg->New()); + if (!google::protobuf::TextFormat::ParseFromString(expected, cmp.get())) { + *result_listener << "Failed to parse expected proto"; + return false; + } + return google::protobuf::util::MessageDifferencer::Equals(*msg, *cmp); +} + +INSTANTIATE_TEST_SUITE_P( + Expression, EvaluatorMemorySafetyTest, + testing::Combine( + testing::ValuesIn(std::vector{ + { + "bool", + "(true && false) || bool_var || string_var == 'test_str'", + {{"bool_var", BoolValue(false)}, + {"string_var", MakeStringValue("test_str")}}, + test::BoolValueIs(true), + }, + { + "const_str", + "condition ? 'left_hand_string' : 'right_hand_string'", + {{"condition", BoolValue(false)}}, + test::StringValueIs("right_hand_string"), + }, + { + "long_const_string", + "condition ? 'left_hand_string' : " + "'long_right_hand_string_0123456789'", + {{"condition", BoolValue(false)}}, + test::StringValueIs("long_right_hand_string_0123456789"), + }, + {"optional_of_long_const_string", + "condition ? optional.of('lhs_short') : " + "optional.of('long_right_hand_string_0123456789')", + {{"condition", BoolValue(false)}}, + test::OptionalValueIs( + test::StringValueIs("long_right_hand_string_0123456789")), + // optional.of is a namespaced function. + /*enable_reference_resolver=*/true}, + { + "computed_string", + "(condition ? 'a.b' : 'b.c') + '.d.e.f'", + {{"condition", BoolValue(false)}}, + test::StringValueIs("b.c.d.e.f"), + }, + { + "regex", + R"('192.168.128.64'.matches(r'^192\.168\.[0-2]?[0-9]?[0-9]\.[0-2]?[0-9]?[0-9]') )", + {}, + test::BoolValueIs(true), + }, + { + "list_create", + "[1, 2, 3, 4, 5, 6][3] == 4", + {}, + test::BoolValueIs(true), + }, + { + "list_create_strings", + "['1', '2', '3', '4', '5', '6'][2] == '3'", + {}, + test::BoolValueIs(true), + }, + { + "map_create", + "{'1': 'one', '2': 'two'}['2']", + {}, + test::StringValueIs("two"), + }, + { + "struct_create", + R"cel( + NestedTestAllTypes{ + child: NestedTestAllTypes{ + payload: TestAllTypes{ + repeated_int32: [1, 2, 3] + } + }, + payload: TestAllTypes{ + repeated_string: ["foo", "bar", "baz"] + } + })cel", + {}, + test::StructValueIs(ParsedProtoStructEquals(R"pb( + child { payload { repeated_int32: [ 1, 2, 3 ] } } + payload { repeated_string: [ "foo", "bar", "baz" ] } + )pb")), + }, + {"extension_function", + "IsPrivate('8.8.8.8')", + {}, + test::BoolValueIs(false), + /*enable_reference_resolver=*/false}, + {"namespaced_function", + "net.IsPrivate('192.168.0.1')", + {}, + test::BoolValueIs(true), + /*enable_reference_resolver=*/true}, + { + "comprehension", + "['abc', 'def', 'ghi', 'jkl'].exists(el, el == 'mno')", + {}, + test::BoolValueIs(false), + }, + { + "comprehension_complex", + "['a' + 'b' + 'c', 'd' + 'ef', 'g' + 'hi', 'j' + 'kl']" + ".exists(el, el.startsWith('g'))", + {}, + test::BoolValueIs(true), + }, + TestCase{ + "unsafe_message_access", + "nested_test_all_types.child.payload", + {{"nested_test_all_types", + MakeNestedTestAllTypes(R"pb(child { + payload { single_int32: 1 } + })pb")}}, + test::StructValueIs( + ParsedProtoStructEquals(R"pb(single_int32: 1)pb")), + }, + TestCase{ + "unsafe_message_access_repeated_field", + "nested_test_all_types.payload.repeated_int32.size() == 3", + {{"nested_test_all_types", + MakeNestedTestAllTypes(R"pb(payload { + repeated_int32: 1 + repeated_int32: 2 + repeated_int32: 3 + })pb")}}, + test::BoolValueIs(true), + }, + TestCase{ + "unsafe_message_access_repeated_field_index", + "nested_test_all_types.payload.repeated_int32[1] == 2", + {{"nested_test_all_types", + MakeNestedTestAllTypes(R"pb(payload { + repeated_int32: 1 + repeated_int32: 2 + repeated_int32: 3 + })pb")}}, + test::BoolValueIs(true), + }, + TestCase{ + "unsafe_message_access_map_field", + "nested_test_all_types.payload.map_int32_string.size() == 2", + {{"nested_test_all_types", + MakeNestedTestAllTypes( + R"pb(payload { + map_int32_string { key: 1 value: "foo" } + map_int32_string { key: 2 value: "bar" } + })pb")}}, + test::BoolValueIs(true), + }, + TestCase{ + "unsafe_message_access_map_field_index", + "nested_test_all_types.payload.map_int32_string[1] == 'foo'", + {{"nested_test_all_types", + MakeNestedTestAllTypes( + R"pb(payload { + map_int32_string { key: 1 value: "foo" } + map_int32_string { key: 2 value: "bar" } + })pb")}}, + test::BoolValueIs(true), + }, + TestCase{ + "unsafe_message_access_string_field", + "nested_test_all_types.payload.single_string == 'foo'", + {{"nested_test_all_types", MakeNestedTestAllTypes( + R"pb(payload { + single_string: "foo" + })pb")}}, + test::BoolValueIs(true), + }, + TestCase{ + "unsafe_message_access_assign", + "NestedTestAllTypes{payload: " + "nested_test_all_types.child.payload}", + {{"nested_test_all_types", + MakeNestedTestAllTypes(R"pb(child { + payload { single_int32: 1 } + })pb")}}, + test::StructValueIs(ParsedProtoStructEquals(R"pb(payload { + single_int32: + 1 + })pb")), + }, + TestCase{ + "unsafe_message_access_assign_repeated_field", + "TestAllTypes{repeated_int32: " + "nested_test_all_types.payload.repeated_int32}", + {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( + payload { repeated_int32: [ 1, 2, 3 ] } + )pb")}}, + test::StructValueIs(ParsedProtoStructEquals( + R"pb(repeated_int32: [ 1, 2, 3 ])pb")), + }, + TestCase{ + "unsafe_message_access_assign_map_field", + "TestAllTypes{map_int32_string: " + "nested_test_all_types.payload.map_int32_string}", + {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( + payload { + map_int32_string { key: 1 value: "foo" } + map_int32_string { key: 2 value: "bar" } + } + )pb")}}, + test::StructValueIs(ParsedProtoStructEquals( + R"pb(map_int32_string { key: 1 value: "foo" } + map_int32_string { key: 2 value: "bar" })pb")), + }, + TestCase{ + "unsafe_message_access_assign_string_field", + "TestAllTypes{single_string: " + "nested_test_all_types.payload.single_string}", + {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( + payload { + single_string: 'foo is a long string that is not inlined abcdef' + } + )pb")}}, + test::StructValueIs(ParsedProtoStructEquals( + R"pb(single_string: 'foo is a long string that is not inlined abcdef')pb")), + }, + TestCase{ + "unsafe_message_access_assign_bytes_field", + "TestAllTypes{single_bytes: " + "nested_test_all_types.payload.single_bytes}", + {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( + payload { + single_bytes: 'foo is a long string that is not inlined abcdef' + } + )pb")}}, + test::StructValueIs(ParsedProtoStructEquals( + R"pb(single_bytes: 'foo is a long string that is not inlined abcdef')pb")), + }, + TestCase{ + "unsafe_message_access_assign_from_repeated_string_field", + "TestAllTypes{single_string: " + "nested_test_all_types.payload.repeated_string[0]}", + {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( + payload { + repeated_string: 'foo is a long string that is not inlined abcdef' + } + )pb")}}, + test::StructValueIs(ParsedProtoStructEquals( + R"pb(single_string: 'foo is a long string that is not inlined abcdef')pb")), + }, + TestCase{ + "unsafe_message_access_assign_from_map_string_field", + "TestAllTypes{single_string: " + "nested_test_all_types.payload.map_int32_string[1]}", + {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( + payload { + map_int32_string { + key: 1 + value: "foo is a long string that is not inlined abcdef" + } + } + )pb")}}, + test::StructValueIs(ParsedProtoStructEquals( + R"pb(single_string: "foo is a long string that is not inlined abcdef")pb")), + }, + }), + testing::Values(Options::kDefault, Options::kExhaustive, + Options::kFoldConstants)), + &TestCaseName); + +MATCHER_P(IsSameInstance, expected, "") { + return std::mem_fn(&ParsedMessageValue::operator->)(&arg) == expected; +} + +// Returns true if the string value is backed by the same instance as the +// expected string. Note: this only applies for string values that are too big +// to be inlined in the StringValue and not represented as a absl::Cord. +MATCHER_P(IsSameStringInstance, expected, "") { + const StringValue& got = arg; + std::string buf; + absl::string_view got_view = got.ToStringView(&buf); + bool result = + got_view.data() == expected.data() && got_view.size() == expected.size(); + if (!result) { + *result_listener << absl::StrFormat("got: %p, wanted: %p", got_view.data(), + expected.data()); + } + return result; +} + +class ViewTypesMemorySafetyTest : public testing::TestWithParam { + protected: + Options EvaluationOptions() { return GetParam(); } +}; + +// Test cases demonstrating how inputs as views are handled. +TEST_P(ViewTypesMemorySafetyTest, WrappedMessage) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + child { payload { repeated_int32: [ 1, 2, 3 ] } } + payload { repeated_string: [ "foo", "bar", "baz" ] } + )pb"; + + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "condition ? nested_test_all_types : NestedTestAllTypes{}")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes* proto = + NestedTestAllTypes::default_instance().New(&arena); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); + Activation activation; + activation.InsertOrAssignValue("condition", BoolValue(true)); + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is the input message. + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, + test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); + EXPECT_EQ(result_msg->GetArena(), &arena); + EXPECT_THAT(result_msg, IsSameInstance(proto)); +} + +// Test cases demonstrating how inputs as views are handled. +TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFields) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + child { payload { repeated_int32: [ 1, 2, 3 ] } } + payload { repeated_string: [ "foo", "bar", "baz" ] } + )pb"; + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile("nested_test_all_types.child.payload")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes* proto = + NestedTestAllTypes::default_instance().New(&arena); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); + Activation activation; + activation.InsertOrAssignValue("condition", BoolValue(true)); + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of a sub-message in the input. + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals( + "repeated_int32: [ 1, 2, 3 ]"))); + EXPECT_EQ(result_msg->GetArena(), &arena); + EXPECT_THAT(result_msg, IsSameInstance(&(proto->child().payload()))); +} + +TEST_P(ViewTypesMemorySafetyTest, WrappedMessageDifferentArena) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + child { payload { repeated_int32: [ 1, 2, 3 ] } } + payload { repeated_string: [ "foo", "bar", "baz" ] } + )pb"; + + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "condition ? nested_test_all_types : NestedTestAllTypes{}")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + google::protobuf::Arena other_arena; + NestedTestAllTypes* proto = + NestedTestAllTypes::default_instance().New(&other_arena); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); + Activation activation; + activation.InsertOrAssignValue("condition", BoolValue(true)); + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is a copy of the input message. + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, + test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); + EXPECT_EQ(result_msg->GetArena(), &arena); + EXPECT_THAT(result_msg, Not(IsSameInstance(proto))); +} + +TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFromAny) { + // Arrange: create the runtime. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + child { payload { repeated_int32: [ 1, 2, 3 ] } } + payload { repeated_string: [ "foo", "bar", "baz" ] } + )pb"; + + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "condition ? nested_test_all_types : NestedTestAllTypes{}")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); + Any any; + any.PackFrom(proto); + Activation activation; + activation.InsertOrAssignValue("condition", BoolValue(true)); + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessage(&any, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + + // Assert + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, + test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); + EXPECT_EQ(result_msg->GetArena(), &arena); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageDifferentArena) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + child { payload { repeated_int32: [ 1, 2, 3 ] } } + payload { repeated_string: [ "foo", "bar", "baz" ] } + )pb"; + + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "condition ? nested_test_all_types : NestedTestAllTypes{}")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + // The unsafe version will alias the input message, so caller must ensure + // the input outlives the use of the `Value` rather than assuming it + // is managed by the evaluation arena. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); + Activation activation; + activation.InsertOrAssignValue("condition", BoolValue(true)); + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of the input message. + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, + test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); + EXPECT_EQ(result_msg->GetArena(), nullptr); + EXPECT_THAT(result_msg, IsSameInstance(&proto)); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageFields) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + child { payload { repeated_int32: [ 1, 2, 3 ] } } + payload { repeated_string: [ "foo", "bar", "baz" ] } + )pb"; + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile("nested_test_all_types.child.payload")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); + Activation activation; + activation.InsertOrAssignValue("condition", BoolValue(true)); + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of a sub-message in the input. + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals( + "repeated_int32: [ 1, 2, 3 ]"))); + EXPECT_EQ(result_msg->GetArena(), nullptr); + EXPECT_THAT(result_msg, IsSameInstance(&(proto.child().payload()))); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageRepeatedField) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + payload { repeated_nested_message: { bb: 42 } } + )pb"; + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "nested_test_all_types.payload.repeated_nested_message[0]")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); + Activation activation; + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of a sub-message in the input. + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, + test::StructValueIs(ParsedProtoStructEquals("bb: 42"))); + EXPECT_EQ(result_msg->GetArena(), nullptr); + EXPECT_THAT(result_msg, + IsSameInstance(&(proto.payload().repeated_nested_message(0)))); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageMapField) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "nested_test_all_types.payload.map_string_message['foo']")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb( + payload { + map_string_message: { + key: "foo" + value: { bb: 42 } + } + map_string_message: { + key: "baz" + value: { bb: 43 } + } + })pb", + &proto)); + Activation activation; + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of a sub-message in the input. + ASSERT_TRUE(result.IsParsedMessage()); + const ParsedMessageValue& result_msg = result.GetParsedMessage(); + EXPECT_THAT(result_msg, + test::StructValueIs(ParsedProtoStructEquals(R"pb(bb: 42)pb"))); + EXPECT_THAT( + result_msg, + IsSameInstance(&(proto.payload().map_string_message().at("foo")))); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageStringFields) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + child { payload { single_string: "foo that is too big to be inlined..." } } + )pb"; + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "nested_test_all_types.child.payload.single_string")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); + Activation activation; + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of a sub-message in the input. + ASSERT_TRUE(result.IsString()); + const StringValue& result_string = result.GetString(); + EXPECT_THAT(result_string, + StringValueIs("foo that is too big to be inlined...")); + EXPECT_THAT(result_string, IsSameStringInstance(absl::string_view( + proto.child().payload().single_string()))); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageRepeatedStringField) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + payload { repeated_string: "foo that is too big to be inlined..." } + )pb"; + ASSERT_OK_AND_ASSIGN(ValidationResult validation, + GetCompiler().Compile( + "nested_test_all_types.payload.repeated_string[0]")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); + Activation activation; + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of a sub-message in the input. + ASSERT_TRUE(result.IsString()); + const StringValue& result_string = result.GetString(); + EXPECT_THAT(result_string, + StringValueIs("foo that is too big to be inlined...")); + EXPECT_THAT(result_string, IsSameStringInstance(absl::string_view( + proto.payload().repeated_string(0)))); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageMapStringField) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + constexpr absl::string_view kProtoValue = R"pb( + payload { + map_string_string: { + key: "foo" + value: "bar that is too big to be inlined..." + } + })pb"; + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "nested_test_all_types.payload.map_string_string['foo']")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); + Activation activation; + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: the result is an alias of a sub-message in the input. + ASSERT_TRUE(result.IsString()); + const StringValue& result_string = result.GetString(); + EXPECT_THAT(result_string, + StringValueIs("bar that is too big to be inlined...")); + EXPECT_THAT(result_string, + IsSameStringInstance(absl::string_view( + proto.payload().map_string_string().at("foo")))); +} + +TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageStringFieldAssign) { + // Arrange: create the runtime and expression. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + ConfigureRuntimeImpl(false, EvaluationOptions())); + ASSERT_OK_AND_ASSIGN( + ValidationResult validation, + GetCompiler().Compile( + "TestAllTypes{single_string: " + "nested_test_all_types.child.payload.single_string}.single_string")); + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + // Act: wrap the message and evaluate the expression. + google::protobuf::Arena arena; + NestedTestAllTypes proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + child { + payload { single_string: "foo that is too big to be inlined..." } + })pb", + &proto)); + Activation activation; + activation.InsertOrAssignValue( + "nested_test_all_types", + Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + &arena)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + // Assert: check that the result is not tied to the alias. + // This is not a safe assumption generally, but making sure that the runtime + // is making a defensive copy when building a message assumed to be on the + // arena. Callers cannot safely assume this for arbitrary expressions. + proto.Clear(); + ASSERT_TRUE(result.IsString()); + const StringValue& result_string = result.GetString(); + EXPECT_THAT(result_string, + StringValueIs("foo that is too big to be inlined...")); + EXPECT_THAT(result_string, Not(IsSameStringInstance(absl::string_view( + proto.child().payload().single_string())))); +} + +INSTANTIATE_TEST_SUITE_P(Cases, ViewTypesMemorySafetyTest, + testing::Values(Options::kDefault, + Options::kExhaustive, + Options::kFoldConstants), + [](const testing::TestParamInfo& info) { + switch (info.param) { + case Options::kDefault: + return "default"; + case Options::kExhaustive: + return "exhaustive"; + case Options::kFoldConstants: + return "opt"; + } + }); + +} // namespace +} // namespace cel diff --git a/runtime/optional_types.cc b/runtime/optional_types.cc new file mode 100644 index 000000000..6678a05ed --- /dev/null +++ b/runtime/optional_types.cc @@ -0,0 +1,387 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/optional_types.h" + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/function_adapter.h" +#include "common/casting.h" +#include "common/type.h" +#include "common/value.h" +#include "internal/casts.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/internal/errors.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { + +namespace { + +Value OptionalOf(const Value& value, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return OptionalValue::Of(value, arena); +} + +Value OptionalNone() { return OptionalValue::None(); } + +Value OptionalOfNonZeroValue( + const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (value.IsZeroValue()) { + return OptionalNone(); + } + return OptionalOf(value, descriptor_pool, message_factory, arena); +} + +absl::StatusOr OptionalGetValue(const OpaqueValue& opaque_value) { + if (auto optional_value = opaque_value.AsOptional(); optional_value) { + return optional_value->Value(); + } + return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("value")}; +} + +absl::StatusOr OptionalHasValue(const OpaqueValue& opaque_value) { + if (auto optional_value = opaque_value.AsOptional(); optional_value) { + return BoolValue{optional_value->HasValue()}; + } + return ErrorValue{ + runtime_internal::CreateNoMatchingOverloadError("hasValue")}; +} + +absl::StatusOr SelectOptionalFieldStruct( + const StructValue& struct_value, const StringValue& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::string field_name; + auto field_name_view = key.NativeString(field_name); + CEL_ASSIGN_OR_RETURN(auto has_field, + struct_value.HasFieldByName(field_name_view)); + if (!has_field) { + return OptionalValue::None(); + } + CEL_ASSIGN_OR_RETURN( + auto field, struct_value.GetFieldByName(field_name_view, descriptor_pool, + message_factory, arena)); + return OptionalValue::Of(std::move(field), arena); +} + +absl::StatusOr SelectOptionalFieldMap( + const MapValue& map, const StringValue& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + absl::optional value; + CEL_ASSIGN_OR_RETURN(value, + map.Find(key, descriptor_pool, message_factory, arena)); + if (value) { + return OptionalValue::Of(std::move(*value), arena); + } + return OptionalValue::None(); +} + +absl::StatusOr SelectOptionalField( + const OpaqueValue& opaque_value, const StringValue& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (auto optional_value = opaque_value.AsOptional(); optional_value) { + if (!optional_value->HasValue()) { + return OptionalValue::None(); + } + auto container = optional_value->Value(); + if (auto map_value = container.AsMap(); map_value) { + return SelectOptionalFieldMap(*map_value, key, descriptor_pool, + message_factory, arena); + } + if (auto struct_value = container.AsStruct(); struct_value) { + return SelectOptionalFieldStruct(*struct_value, key, descriptor_pool, + message_factory, arena); + } + } + return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("_[?_]")}; +} + +absl::StatusOr MapOptIndexOptionalValue( + const MapValue& map, const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + absl::optional value; + if (auto double_key = cel::As(key); double_key) { + // Try int/uint. + auto number = internal::Number::FromDouble(double_key->NativeValue()); + if (number.LosslessConvertibleToInt()) { + CEL_ASSIGN_OR_RETURN(value, + map.Find(IntValue{number.AsInt()}, descriptor_pool, + message_factory, arena)); + if (value) { + return OptionalValue::Of(std::move(*value), arena); + } + } + if (number.LosslessConvertibleToUint()) { + CEL_ASSIGN_OR_RETURN(value, + map.Find(UintValue{number.AsUint()}, descriptor_pool, + message_factory, arena)); + if (value) { + return OptionalValue::Of(std::move(*value), arena); + } + } + } else { + CEL_ASSIGN_OR_RETURN( + value, map.Find(key, descriptor_pool, message_factory, arena)); + if (value) { + return OptionalValue::Of(std::move(*value), arena); + } + if (auto int_key = key.AsInt(); int_key && int_key->NativeValue() >= 0) { + CEL_ASSIGN_OR_RETURN( + value, + map.Find(UintValue{static_cast(int_key->NativeValue())}, + descriptor_pool, message_factory, arena)); + if (value) { + return OptionalValue::Of(std::move(*value), arena); + } + } else if (auto uint_key = key.AsUint(); + uint_key && + uint_key->NativeValue() <= + static_cast(std::numeric_limits::max())) { + CEL_ASSIGN_OR_RETURN( + value, + map.Find(IntValue{static_cast(uint_key->NativeValue())}, + descriptor_pool, message_factory, arena)); + if (value) { + return OptionalValue::Of(std::move(*value), arena); + } + } + } + return OptionalValue::None(); +} + +absl::StatusOr ListOptIndexOptionalInt( + const ListValue& list, int64_t key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(auto list_size, list.Size()); + if (key < 0 || static_cast(key) >= list_size) { + return OptionalValue::None(); + } + CEL_ASSIGN_OR_RETURN(auto element, + list.Get(static_cast(key), descriptor_pool, + message_factory, arena)); + return OptionalValue::Of(std::move(element), arena); +} + +absl::StatusOr OptionalOptIndexOptionalValue( + const OpaqueValue& opaque_value, const Value& key, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (auto optional_value = As(opaque_value); optional_value) { + if (!optional_value->HasValue()) { + return OptionalValue::None(); + } + auto container = optional_value->Value(); + if (auto map_value = cel::As(container); map_value) { + return MapOptIndexOptionalValue(*map_value, key, descriptor_pool, + message_factory, arena); + } + if (auto list_value = cel::As(container); list_value) { + if (auto int_value = cel::As(key); int_value) { + return ListOptIndexOptionalInt(*list_value, int_value->NativeValue(), + descriptor_pool, message_factory, arena); + } + } + } + return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("_[?_]")}; +} + +absl::StatusOr ListFirst(const cel::ListValue& list, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + if (size == 0) { + return Value(OptionalValue::None()); + } + CEL_ASSIGN_OR_RETURN(Value value, + list.Get(0, descriptor_pool, message_factory, arena)); + return Value(OptionalValue::Of(std::move(value), arena)); +} + +absl::StatusOr ListLast(const cel::ListValue& list, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); + if (size == 0) { + return Value(OptionalValue::None()); + } + CEL_ASSIGN_OR_RETURN(Value value, + list.Get(static_cast(size) - 1, descriptor_pool, + message_factory, arena)); + return Value(OptionalValue::Of(std::move(value), arena)); +} + +absl::StatusOr ListUnwrapOpt( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + auto builder = NewListValueBuilder(arena); + CEL_ASSIGN_OR_RETURN(auto list_size, list.Size()); + builder->Reserve(list_size); + + absl::Status status = list.ForEach( + [&](const Value& value) -> absl::StatusOr { + if (auto optional_value = value.AsOptional(); optional_value) { + if (optional_value->HasValue()) { + CEL_RETURN_IF_ERROR(builder->Add(optional_value->Value())); + } + } else { + return absl::InvalidArgumentError(absl::StrFormat( + "optional.unwrap() expected a list(optional(T)), but %s " + "was found in the list.", + value.GetTypeName())); + } + return true; + }, + descriptor_pool, message_factory, arena); + if (!status.ok()) { + return ErrorValue(status); + } + return std::move(*builder).Build(); +} + +absl::Status RegisterOptionalTypeFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + if (!options.enable_qualified_type_identifiers) { + return absl::FailedPreconditionError( + "optional_type requires " + "RuntimeOptions.enable_qualified_type_identifiers"); + } + if (!options.enable_heterogeneous_equality) { + return absl::FailedPreconditionError( + "optional_type requires RuntimeOptions.enable_heterogeneous_equality"); + } + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("optional.of", + false), + UnaryFunctionAdapter::WrapFunction(&OptionalOf))); + CEL_RETURN_IF_ERROR( + registry.Register(UnaryFunctionAdapter::CreateDescriptor( + "optional.ofNonZeroValue", false), + UnaryFunctionAdapter::WrapFunction( + &OptionalOfNonZeroValue))); + CEL_RETURN_IF_ERROR(registry.Register( + NullaryFunctionAdapter::CreateDescriptor("optional.none", false), + NullaryFunctionAdapter::WrapFunction(&OptionalNone))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, + OpaqueValue>::CreateDescriptor("value", true), + UnaryFunctionAdapter, OpaqueValue>::WrapFunction( + &OptionalGetValue))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, + OpaqueValue>::CreateDescriptor("hasValue", true), + UnaryFunctionAdapter, OpaqueValue>::WrapFunction( + &OptionalHasValue))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, StructValue, + StringValue>::CreateDescriptor("_?._", false), + BinaryFunctionAdapter, StructValue, StringValue>:: + WrapFunction(&SelectOptionalFieldStruct))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, MapValue, + StringValue>::CreateDescriptor("_?._", false), + BinaryFunctionAdapter, MapValue, StringValue>:: + WrapFunction(&SelectOptionalFieldMap))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, OpaqueValue, + StringValue>::CreateDescriptor("_?._", false), + BinaryFunctionAdapter, OpaqueValue, + StringValue>::WrapFunction(&SelectOptionalField))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, MapValue, + Value>::CreateDescriptor("_[?_]", false), + BinaryFunctionAdapter, MapValue, + Value>::WrapFunction(&MapOptIndexOptionalValue))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, ListValue, + int64_t>::CreateDescriptor("_[?_]", false), + BinaryFunctionAdapter, ListValue, + int64_t>::WrapFunction(&ListOptIndexOptionalInt))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, OpaqueValue, + Value>::CreateDescriptor("_[?_]", false), + BinaryFunctionAdapter, OpaqueValue, Value>:: + WrapFunction(&OptionalOptIndexOptionalValue))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "optional.unwrap", false), + UnaryFunctionAdapter, ListValue>::WrapFunction( + &ListUnwrapOpt))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "unwrapOpt", true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + &ListUnwrapOpt))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "first", true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + &ListFirst))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "last", true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + &ListLast))); + return absl::OkStatus(); +} + +} // namespace + +absl::Status EnableOptionalTypes(RuntimeBuilder& builder) { + auto& runtime = cel::internal::down_cast( + runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder)); + CEL_RETURN_IF_ERROR(RegisterOptionalTypeFunctions( + builder.function_registry(), runtime.expr_builder().options())); + CEL_RETURN_IF_ERROR(builder.type_registry().RegisterType(OptionalType())); + runtime.expr_builder().enable_optional_types(); + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/runtime/optional_types.h b/runtime/optional_types.h new file mode 100644 index 000000000..7c8087175 --- /dev/null +++ b/runtime/optional_types.h @@ -0,0 +1,152 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_OPTIONAL_TYPES_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_OPTIONAL_TYPES_H_ + +#include "absl/status/status.h" +#include "runtime/runtime_builder.h" + +namespace cel::extensions { + +// EnableOptionalTypes enable support for optional syntax and types in CEL. +// +// The optional value type makes it possible to express whether variables have +// been provided, whether a result has been computed, and in the future whether +// an object field path, map key value, or list index has a value. +// +// # Syntax Changes +// +// OptionalTypes are unlike other CEL extensions because they modify the CEL +// syntax itself, notably through the use of a `?` preceding a field name or +// index value. +// +// ## Field Selection +// +// The optional syntax in field selection is denoted as `obj.?field`. In other +// words, if a field is set, return `optional.of(obj.field)“, else +// `optional.none()`. The optional field selection is viral in the sense that +// after the first optional selection all subsequent selections or indices +// are treated as optional, i.e. the following expressions are equivalent: +// +// obj.?field.subfield +// obj.?field.?subfield +// +// ## Indexing +// +// Similar to field selection, the optional syntax can be used in index +// expressions on maps and lists: +// +// list[?0] +// map[?key] +// +// ## Optional Field Setting +// +// When creating map or message literals, if a field may be optionally set +// based on its presence, then placing a `?` before the field name or key +// will ensure the type on the right-hand side must be optional(T) where T +// is the type of the field or key-value. +// +// The following returns a map with the key expression set only if the +// subfield is present, otherwise an empty map is created: +// +// {?key: obj.?field.subfield} +// +// ## Optional Element Setting +// +// When creating list literals, an element in the list may be optionally added +// when the element expression is preceded by a `?`: +// +// [a, ?b, ?c] // return a list with either [a], [a, b], [a, b, c], or [a, c] +// +// # Optional.Of +// +// Create an optional(T) value of a given value with type T. +// +// optional.of(10) +// +// # Optional.OfNonZeroValue +// +// Create an optional(T) value of a given value with type T if it is not a +// zero-value. A zero-value the default empty value for any given CEL type, +// including empty protobuf message types. If the value is empty, the result +// of this call will be optional.none(). +// +// optional.ofNonZeroValue([1, 2, 3]) // optional(list(int)) +// optional.ofNonZeroValue([]) // optional.none() +// optional.ofNonZeroValue(0) // optional.none() +// optional.ofNonZeroValue("") // optional.none() +// +// # Optional.None +// +// Create an empty optional value. +// +// # HasValue +// +// Determine whether the optional contains a value. +// +// optional.of(b'hello').hasValue() // true +// optional.ofNonZeroValue({}).hasValue() // false +// +// # Value +// +// Get the value contained by the optional. If the optional does not have a +// value, the result will be a CEL error. +// +// optional.of(b'hello').value() // b'hello' +// optional.ofNonZeroValue({}).value() // error +// +// # Or +// +// If the value on the left-hand side is optional.none(), the optional value +// on the right hand side is returned. If the value on the left-hand set is +// valued, then it is returned. This operation is short-circuiting and will +// only evaluate as many links in the `or` chain as are needed to return a +// non-empty optional value. +// +// obj.?field.or(m[?key]) +// l[?index].or(obj.?field.subfield).or(obj.?other) +// +// # OrValue +// +// Either return the value contained within the optional on the left-hand side +// or return the alternative value on the right hand side. +// +// m[?key].orValue("none") +// +// # OptMap +// +// Apply a transformation to the optional's underlying value if it is not empty +// and return an optional typed result based on the transformation. The +// transformation expression type must return a type T which is wrapped into +// an optional. +// +// msg.?elements.optMap(e, e.size()).orValue(0) +// +// # OptFlatMap +// +// Introduced in version: 1 +// +// Apply a transformation to the optional's underlying value if it is not empty +// and return the result. The transform expression must return an optional(T) +// rather than type T. This can be useful when dealing with zero values and +// conditionally generating an empty or non-empty result in ways which cannot +// be expressed with `optMap`. +// +// msg.?elements.optFlatMap(e, e[?0]) // return the first element if present. +absl::Status EnableOptionalTypes(RuntimeBuilder& builder); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_OPTIONAL_TYPES_H_ diff --git a/runtime/optional_types_test.cc b/runtime/optional_types_test.cc new file mode 100644 index 000000000..455e51988 --- /dev/null +++ b/runtime/optional_types_test.cc @@ -0,0 +1,459 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/optional_types.h" + +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/function.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::extensions::ProtobufRuntimeAdapter; +using ::cel::test::BoolValueIs; +using ::cel::test::IntValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::OptionalValueIsEmpty; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::parser::ParserOptions; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::TestWithParam; + +MATCHER_P(MatchesOptionalReceiver1, name, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + + std::vector types{Kind::kOpaque}; + return descriptor.name() == name && descriptor.receiver_style() == true && + descriptor.types() == types; +} + +MATCHER_P2(MatchesOptionalReceiver2, name, kind, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + + std::vector types{Kind::kOpaque, kind}; + return descriptor.name() == name && descriptor.receiver_style() == true && + descriptor.types() == types; +} + +MATCHER_P2(MatchesOptionalSelect, kind1, kind2, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + + std::vector types{kind1, kind2}; + return descriptor.name() == "_?._" && descriptor.receiver_style() == false && + descriptor.types() == types; +} + +MATCHER_P2(MatchesOptionalIndex, kind1, kind2, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + + std::vector types{kind1, kind2}; + return descriptor.name() == "_[?_]" && descriptor.receiver_style() == false && + descriptor.types() == types; +} + +TEST(EnableOptionalTypes, HeterogeneousEqualityRequired) { + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), + RuntimeOptions{.enable_qualified_type_identifiers = true, + .enable_heterogeneous_equality = false})); + EXPECT_THAT(EnableOptionalTypes(builder), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST(EnableOptionalTypes, QualifiedTypeIdentifiersRequired) { + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), + RuntimeOptions{.enable_qualified_type_identifiers = false, + .enable_heterogeneous_equality = true})); + EXPECT_THAT(EnableOptionalTypes(builder), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST(EnableOptionalTypes, PreconditionsSatisfied) { + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), + RuntimeOptions{.enable_qualified_type_identifiers = true, + .enable_heterogeneous_equality = true})); + EXPECT_THAT(EnableOptionalTypes(builder), IsOk()); +} + +TEST(EnableOptionalTypes, Functions) { + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), + RuntimeOptions{.enable_qualified_type_identifiers = true, + .enable_heterogeneous_equality = true})); + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + + EXPECT_THAT(builder.function_registry().FindStaticOverloads("hasValue", true, + {Kind::kOpaque}), + ElementsAre(MatchesOptionalReceiver1("hasValue"))); + + EXPECT_THAT(builder.function_registry().FindStaticOverloads("value", true, + {Kind::kOpaque}), + ElementsAre(MatchesOptionalReceiver1("value"))); + + EXPECT_THAT(builder.function_registry().FindStaticOverloads( + "_?._", false, {Kind::kStruct, Kind::kString}), + ElementsAre(MatchesOptionalSelect(Kind::kStruct, Kind::kString))); + EXPECT_THAT(builder.function_registry().FindStaticOverloads( + "_?._", false, {Kind::kMap, Kind::kString}), + ElementsAre(MatchesOptionalSelect(Kind::kMap, Kind::kString))); + EXPECT_THAT(builder.function_registry().FindStaticOverloads( + "_?._", false, {Kind::kOpaque, Kind::kString}), + ElementsAre(MatchesOptionalSelect(Kind::kOpaque, Kind::kString))); + + EXPECT_THAT(builder.function_registry().FindStaticOverloads( + "_[?_]", false, {Kind::kMap, Kind::kAny}), + ElementsAre(MatchesOptionalIndex(Kind::kMap, Kind::kAny))); + EXPECT_THAT(builder.function_registry().FindStaticOverloads( + "_[?_]", false, {Kind::kList, Kind::kInt}), + ElementsAre(MatchesOptionalIndex(Kind::kList, Kind::kInt))); + EXPECT_THAT(builder.function_registry().FindStaticOverloads( + "_[?_]", false, {Kind::kOpaque, Kind::kAny}), + ElementsAre(MatchesOptionalIndex(Kind::kOpaque, Kind::kAny))); +} + +struct EvaluateResultTestCase { + std::string name; + std::string expression; + test::ValueMatcher value_matcher; + + template + friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { + sink.Append(tc.name); + } +}; + +class OptionalTypesTest + : public TestWithParam> { + public: + const EvaluateResultTestCase& GetTestCase() { + return std::get<0>(GetParam()); + } + + bool EnableShortCircuiting() { return std::get<1>(GetParam()); } +}; + +TEST_P(OptionalTypesTest, RecursivePlan) { + RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + opts.max_recursion_depth = -1; + opts.short_circuiting = EnableShortCircuiting(); + + const EvaluateResultTestCase& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_OK(EnableOptionalTypes(builder)); + ASSERT_OK( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(test_case.expression, "", + ParserOptions{.enable_optional_syntax = true})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, test_case.value_matcher) << test_case.expression; +} + +TEST_P(OptionalTypesTest, Defaults) { + RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + opts.short_circuiting = EnableShortCircuiting(); + const EvaluateResultTestCase& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_OK(EnableOptionalTypes(builder)); + ASSERT_OK( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(test_case.expression, "", + ParserOptions{.enable_optional_syntax = true})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, test_case.value_matcher) << test_case.expression; +} + +INSTANTIATE_TEST_SUITE_P( + Basic, OptionalTypesTest, + testing::Combine( + testing::ValuesIn(std::vector{ + {"optional_none_hasValue", "optional.none().hasValue()", + BoolValueIs(false)}, + {"optional_of_hasValue", "optional.of(0).hasValue()", + BoolValueIs(true)}, + {"optional_ofNonZeroValue_hasValue", + "optional.ofNonZeroValue(0).hasValue()", BoolValueIs(false)}, + {"optional_or_absent", + "optional.ofNonZeroValue(0).or(optional.ofNonZeroValue(0))", + OptionalValueIsEmpty()}, + {"optional_or_present", "optional.of(1).or(optional.none())", + OptionalValueIs(IntValueIs(1))}, + {"optional_orValue_absent", "optional.ofNonZeroValue(0).orValue(1)", + IntValueIs(1)}, + {"optional_orValue_present", "optional.of(1).orValue(2)", + IntValueIs(1)}, + {"list_of_optional", "[optional.of(1)][0].orValue(1)", + IntValueIs(1)}, + {"list_unwrap_empty", "optional.unwrap([]) == []", + BoolValueIs(true)}, + {"list_unwrap_empty_optional_none", + "optional.unwrap([optional.none(), optional.none()]) == []", + BoolValueIs(true)}, + {"list_unwrap_three_elements", + "optional.unwrap([optional.of(42), optional.none(), " + "optional.of(\"a\")]) == [42, \"a\"]", + BoolValueIs(true)}, + {"list_unwrap_no_none", + "optional.unwrap([optional.of(42), optional.of(\"a\")]) == [42, " + "\"a\"]", + BoolValueIs(true)}, + {"list_unwrapOpt_empty", "[].unwrapOpt() == []", BoolValueIs(true)}, + {"list_unwrapOpt_empty_optional_none", + "[optional.none(), optional.none()].unwrapOpt() == []", + BoolValueIs(true)}, + {"list_unwrapOpt_three_elements", + "[optional.of(42), optional.none(), " + "optional.of(\"a\")].unwrapOpt() == [42, \"a\"]", + BoolValueIs(true)}, + {"list_unwrapOpt_no_none", + "[optional.of(42), optional.of(\"a\")].unwrapOpt() == [42, \"a\"]", + BoolValueIs(true)}, + {"list_first", "[1, 2, 3].first()", OptionalValueIs(IntValueIs(1))}, + {"list_first_empty", "[].first()", OptionalValueIsEmpty()}, + {"list_last", "[1, 2, 3].last()", OptionalValueIs(IntValueIs(3))}, + {"list_last_empty", "[].last()", OptionalValueIsEmpty()}, + }), + /*enable_short_circuiting*/ testing::Bool())); + +class UnreachableFunction final : public cel::Function { + public: + explicit UnreachableFunction(int64_t* count) : count_(count) {} + + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + ++(*count_); + return ErrorValue(absl::CancelledError()); + } + + private: + int64_t* const count_; +}; + +TEST(OptionalTypesTest, ErrorShortCircuiting) { + RuntimeOptions opts{.enable_qualified_type_identifiers = true}; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + int64_t unreachable_count = 0; + + ASSERT_OK(EnableOptionalTypes(builder)); + ASSERT_OK( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); + ASSERT_OK(builder.function_registry().Register( + cel::FunctionDescriptor("unreachable", false, {}), + std::make_unique(&unreachable_count))); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("optional.of(1 / 0).orValue(unreachable())", "", + ParserOptions{.enable_optional_syntax = true})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + EXPECT_EQ(unreachable_count, 0); + ASSERT_TRUE(result->Is()) << result->DebugString(); + EXPECT_THAT(result.GetError().NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("divide by zero"))); +} + +TEST(OptionalTypesTest, CreateList_TypeConversionError) { + RuntimeOptions opts{.enable_qualified_type_identifiers = true}; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("[?foo]", "", + ParserOptions{.enable_optional_syntax = true})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + activation.InsertOrAssignValue("foo", IntValue(1)); + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(result.IsError()) << result.DebugString(); + EXPECT_THAT(result.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type conversion error"))); +} + +TEST(OptionalTypesTest, CreateMap_TypeConversionError) { + RuntimeOptions opts{.enable_qualified_type_identifiers = true}; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("{?1: foo}", "", + ParserOptions{.enable_optional_syntax = true})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + activation.InsertOrAssignValue("foo", IntValue(1)); + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(result.IsError()) << result.DebugString(); + EXPECT_THAT(result.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type conversion error"))); +} + +TEST(OptionalTypesTest, CreateStruct_KeyTypeConversionError) { + RuntimeOptions opts{.enable_qualified_type_identifiers = true}; + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + auto builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + Parse("cel.expr.conformance.proto2.TestAllTypes{?single_int32: foo}", + "", ParserOptions{.enable_optional_syntax = true})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + activation.InsertOrAssignValue("foo", IntValue(1)); + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(result.IsError()) << result.DebugString(); + EXPECT_THAT(result.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type conversion error"))); +} + +} // namespace +} // namespace cel::extensions diff --git a/runtime/reference_resolver.cc b/runtime/reference_resolver.cc new file mode 100644 index 000000000..8cb14598a --- /dev/null +++ b/runtime/reference_resolver.cc @@ -0,0 +1,77 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/reference_resolver.h" + +#include "absl/base/macros.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/native_type.h" +#include "eval/compiler/qualified_reference_resolver.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" + +namespace cel { +namespace { + +using ::cel::internal::down_cast; +using ::cel::runtime_internal::RuntimeFriendAccess; +using ::cel::runtime_internal::RuntimeImpl; + +absl::StatusOr RuntimeImplFromBuilder(RuntimeBuilder& builder) { + Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); + + if (RuntimeFriendAccess::RuntimeTypeId(runtime) != + NativeTypeId::For()) { + return absl::UnimplementedError( + "regex precompilation only supported on the default cel::Runtime " + "implementation."); + } + + RuntimeImpl& runtime_impl = down_cast(runtime); + + return &runtime_impl; +} + +google::api::expr::runtime::ReferenceResolverOption Convert( + ReferenceResolverEnabled enabled) { + switch (enabled) { + case ReferenceResolverEnabled::kCheckedExpressionOnly: + return google::api::expr::runtime::ReferenceResolverOption::kCheckedOnly; + case ReferenceResolverEnabled::kAlways: + return google::api::expr::runtime::ReferenceResolverOption::kAlways; + } + ABSL_LOG(FATAL) << "unsupported ReferenceResolverEnabled enumerator: " + << static_cast(enabled); +} + +} // namespace + +absl::Status EnableReferenceResolver(RuntimeBuilder& builder, + ReferenceResolverEnabled enabled) { + CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, + RuntimeImplFromBuilder(builder)); + ABSL_ASSERT(runtime_impl != nullptr); + + runtime_impl->expr_builder().AddAstTransform( + NewReferenceResolverExtension(Convert(enabled))); + return absl::OkStatus(); +} + +} // namespace cel diff --git a/runtime/reference_resolver.h b/runtime/reference_resolver.h new file mode 100644 index 000000000..8eb144040 --- /dev/null +++ b/runtime/reference_resolver.h @@ -0,0 +1,46 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_REFERENCE_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_REFERENCE_RESOLVER_H_ + +#include "absl/status/status.h" +#include "runtime/runtime_builder.h" + +namespace cel { + +enum class ReferenceResolverEnabled { kCheckedExpressionOnly, kAlways }; + +// Enables expression rewrites to normalize the AST representation of +// references to qualified names of enum constants, variables and functions. +// +// For parse-only expressions, this is only able to disambiguate functions based +// on registered overloads in the runtime. +// +// Note: This may require making a deep copy of the input expression in order to +// apply the rewrites. +// +// Applied adjustments: +// - for dot-qualified variable names represented as select operations, +// replaces select operations with an identifier. +// - for dot-qualified functions, replaces receiver call with a global +// function call. +// - for compile time constants (such as enum values), inlines the constant +// value as a literal. +absl::Status EnableReferenceResolver(RuntimeBuilder& builder, + ReferenceResolverEnabled enabled); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_REFERENCE_RESOLVER_H_ diff --git a/runtime/reference_resolver_test.cc b/runtime/reference_resolver_test.cc new file mode 100644 index 000000000..398799e13 --- /dev/null +++ b/runtime/reference_resolver_test.cc @@ -0,0 +1,364 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "runtime/reference_resolver.h" + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/register_function_helper.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::cel::extensions::ProtobufRuntimeAdapter; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; + +using ::google::api::expr::parser::Parse; + +using ::absl_testing::StatusIs; +using ::testing::HasSubstr; + +TEST(ReferenceResolver, ResolveQualifiedFunctions) { + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + ASSERT_OK( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); + + absl::Status status = + RegisterHelper>:: + RegisterGlobalOverload( + "com.example.Exp", + [](int64_t base, int64_t exp) -> int64_t { + int64_t result = 1; + for (int64_t i = 0; i < exp; ++i) { + result *= base; + } + return result; + }, + builder.function_registry()); + ASSERT_OK(status); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("com.example.Exp(2, 3) == 8")); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value->Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); +} + +TEST(ReferenceResolver, ResolveQualifiedFunctionsCheckedOnly) { + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + ASSERT_OK(EnableReferenceResolver( + builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); + + absl::Status status = + RegisterHelper>:: + RegisterGlobalOverload( + "com.example.Exp", + [](int64_t base, int64_t exp) -> int64_t { + int64_t result = 1; + for (int64_t i = 0; i < exp; ++i) { + result *= base; + } + return result; + }, + builder.function_registry()); + ASSERT_OK(status); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("com.example.Exp(2, 3) == 8")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, parsed_expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("No overloads provided"))); +} + +// com.example.x + com.example.y +constexpr absl::string_view kIdentifierExpression = R"pb( + reference_map: { + key: 3 + value: { name: "com.example.x" } + } + reference_map: { + key: 4 + value: { overload_id: "add_int64" } + } + reference_map: { + key: 7 + value: { name: "com.example.y" } + } + type_map: { + key: 3 + value: { primitive: INT64 } + } + type_map: { + key: 4 + value: { primitive: INT64 } + } + type_map: { + key: 7 + value: { primitive: INT64 } + } + source_info: { + location: "" + line_offsets: 30 + positions: { key: 1 value: 0 } + positions: { key: 2 value: 3 } + positions: { key: 3 value: 11 } + positions: { key: 4 value: 14 } + positions: { key: 5 value: 16 } + positions: { key: 6 value: 19 } + positions: { key: 7 value: 27 } + } + expr: { + id: 4 + call_expr: { + function: "_+_" + args: { + id: 3 + # compilers typically already apply this rewrite, but older saved + # expressions might preserve the original parse. + select_expr { + operand { + id: 8 + select_expr { + operand: { + id: 9 + ident_expr { name: "com" } + } + field: "example" + } + } + field: "x" + } + } + args: { + id: 7 + ident_expr: { name: "com.example.y" } + } + } + })pb"; + +TEST(ReferenceResolver, ResolveQualifiedIdentifiers) { + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + ASSERT_OK(EnableReferenceResolver( + builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + CheckedExpr checked_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kIdentifierExpression, + &checked_expr)); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, checked_expr)); + + google::protobuf::Arena arena; + Activation activation; + + activation.InsertOrAssignValue("com.example.x", IntValue(3)); + activation.InsertOrAssignValue("com.example.y", IntValue(4)); + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value->Is()); + EXPECT_EQ(value.GetInt().NativeValue(), 7); +} + +TEST(ReferenceResolver, ResolveQualifiedIdentifiersSkipParseOnly) { + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + ASSERT_OK(EnableReferenceResolver( + builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + CheckedExpr checked_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kIdentifierExpression, + &checked_expr)); + + // Discard type-check information + Expr unchecked_expr = checked_expr.expr(); + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, checked_expr.expr())); + + google::protobuf::Arena arena; + Activation activation; + + activation.InsertOrAssignValue("com.example.x", IntValue(3)); + activation.InsertOrAssignValue("com.example.y", IntValue(4)); + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value->Is()); + EXPECT_THAT(value.GetError().NativeValue(), + StatusIs(absl::StatusCode::kUnknown, HasSubstr("\"com\""))); +} + +// cel.expr.conformance.proto2.GlobalEnum.GAZ == 2 +constexpr absl::string_view kEnumExpr = R"pb( + reference_map: { + key: 8 + value: { + name: "cel.expr.conformance.proto2.GlobalEnum.GAZ" + value: { int64_value: 2 } + } + } + reference_map: { + key: 9 + value: { overload_id: "equals" } + } + type_map: { + key: 8 + value: { primitive: INT64 } + } + type_map: { + key: 9 + value: { primitive: BOOL } + } + type_map: { + key: 10 + value: { primitive: INT64 } + } + source_info: { + location: "" + line_offsets: 1 + line_offsets: 64 + line_offsets: 77 + positions: { key: 1 value: 13 } + positions: { key: 2 value: 19 } + positions: { key: 3 value: 23 } + positions: { key: 4 value: 28 } + positions: { key: 5 value: 33 } + positions: { key: 6 value: 36 } + positions: { key: 7 value: 43 } + positions: { key: 8 value: 54 } + positions: { key: 9 value: 59 } + positions: { key: 10 value: 62 } + } + expr: { + id: 9 + call_expr: { + function: "_==_" + args: { + id: 8 + ident_expr: { name: "cel.expr.conformance.proto2.GlobalEnum.GAZ" } + } + args: { + id: 10 + const_expr: { int64_value: 2 } + } + } + })pb"; + +TEST(ReferenceResolver, ResolveEnumConstants) { + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + ASSERT_OK(EnableReferenceResolver( + builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + CheckedExpr checked_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kEnumExpr, &checked_expr)); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, checked_expr)); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value->Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); +} + +TEST(ReferenceResolver, ResolveEnumConstantsSkipParseOnly) { + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + ASSERT_OK(EnableReferenceResolver( + builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + CheckedExpr checked_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kEnumExpr, &checked_expr)); + + Expr unchecked_expr = checked_expr.expr(); + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, unchecked_expr)); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value->Is()); + EXPECT_THAT( + value.GetError().NativeValue(), + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("\"cel.expr.conformance.proto2.GlobalEnum.GAZ\""))); +} + +} // namespace +} // namespace cel diff --git a/runtime/regex_precompilation.cc b/runtime/regex_precompilation.cc new file mode 100644 index 000000000..236715f94 --- /dev/null +++ b/runtime/regex_precompilation.cc @@ -0,0 +1,65 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/regex_precompilation.h" + +#include "absl/base/macros.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/native_type.h" +#include "eval/compiler/regex_precompilation_optimization.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" + +namespace cel::extensions { +namespace { + +using ::cel::internal::down_cast; +using ::cel::runtime_internal::RuntimeFriendAccess; +using ::cel::runtime_internal::RuntimeImpl; +using ::google::api::expr::runtime::CreateRegexPrecompilationExtension; + +absl::StatusOr RuntimeImplFromBuilder(RuntimeBuilder& builder) { + Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); + + if (RuntimeFriendAccess::RuntimeTypeId(runtime) != + NativeTypeId::For()) { + return absl::UnimplementedError( + "regex precompilation only supported on the default cel::Runtime " + "implementation."); + } + + RuntimeImpl& runtime_impl = down_cast(runtime); + + return &runtime_impl; +} + +} // namespace + +absl::Status EnableRegexPrecompilation(RuntimeBuilder& builder) { + CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, + RuntimeImplFromBuilder(builder)); + ABSL_ASSERT(runtime_impl != nullptr); + + runtime_impl->expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension( + runtime_impl->expr_builder().options().regex_max_program_size)); + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/runtime/regex_precompilation.h b/runtime/regex_precompilation.h new file mode 100644 index 000000000..b02493f4d --- /dev/null +++ b/runtime/regex_precompilation.h @@ -0,0 +1,32 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_REGEX_PRECOMPILATION_FOLDING_H_ +#define THIRD_PARTY_CEL_CPP_REGEX_PRECOMPILATION_FOLDING_H_ + +#include "absl/status/status.h" +#include "runtime/runtime_builder.h" + +namespace cel::extensions { + +// Enable regular expression precompilation. +// +// Attempts to precompile regular expression patterns that are known to be +// constant in 'match' calls. If an invalid pattern is encountered, expression +// planning will fail instead of returning a program. +absl::Status EnableRegexPrecompilation(RuntimeBuilder& builder); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_REGEX_PRECOMPILATION_FOLDING_H_ diff --git a/runtime/regex_precompilation_test.cc b/runtime/regex_precompilation_test.cc new file mode 100644 index 000000000..85b47ef45 --- /dev/null +++ b/runtime/regex_precompilation_test.cc @@ -0,0 +1,192 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/regex_precompilation.h" + +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/constant_folding.h" +#include "runtime/register_function_helper.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::_; +using ::testing::HasSubstr; + +using ValueMatcher = testing::Matcher; + +struct TestCase { + std::string name; + std::string expression; + ValueMatcher result_matcher; + absl::Status create_status; +}; + +MATCHER_P(IsIntValue, expected, "") { + const Value& value = arg; + return value->Is() && value.GetInt().NativeValue() == expected; +} + +MATCHER_P(IsBoolValue, expected, "") { + const Value& value = arg; + return value->Is() && value.GetBool().NativeValue() == expected; +} + +MATCHER_P(IsErrorValue, expected_substr, "") { + const Value& value = arg; + return value->Is() && + absl::StrContains(value.GetError().NativeValue().message(), + expected_substr); +} + +class RegexPrecompilationTest : public testing::TestWithParam {}; + +TEST_P(RegexPrecompilationTest, Basic) { + RuntimeOptions options; + const TestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + auto status = RegisterHelper, const StringValue&, const StringValue&>>:: + RegisterGlobalOverload( + "prepend", + [](const StringValue& value, const StringValue& prefix) { + return StringValue( + absl::StrCat(prefix.ToString(), value.ToString())); + }, + builder.function_registry()); + ASSERT_THAT(status, IsOk()); + + ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expression)); + + auto program_or = + ProtobufRuntimeAdapter::CreateProgram(*runtime, parsed_expr); + if (!test_case.create_status.ok()) { + ASSERT_THAT(program_or.status(), + StatusIs(test_case.create_status.code(), + HasSubstr(test_case.create_status.message()))); + return; + } + + ASSERT_OK_AND_ASSIGN(auto program, std::move(program_or)); + + google::protobuf::Arena arena; + Activation activation; + activation.InsertOrAssignValue("string_var", + StringValue(&arena, "string_var")); + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_THAT(value, test_case.result_matcher); +} + +TEST_P(RegexPrecompilationTest, WithConstantFolding) { + RuntimeOptions options; + const TestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + + auto status = RegisterHelper, const StringValue&, const StringValue&>>:: + RegisterGlobalOverload( + "prepend", + [](const StringValue& value, const StringValue& prefix) { + return StringValue( + absl::StrCat(prefix.ToString(), value.ToString())); + }, + builder.function_registry()); + ASSERT_THAT(status, IsOk()); + + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expression)); + + auto program_or = + ProtobufRuntimeAdapter::CreateProgram(*runtime, parsed_expr); + if (!test_case.create_status.ok()) { + ASSERT_THAT(program_or.status(), + StatusIs(test_case.create_status.code(), + HasSubstr(test_case.create_status.message()))); + return; + } + + ASSERT_OK_AND_ASSIGN(auto program, std::move(program_or)); + google::protobuf::Arena arena; + Activation activation; + activation.InsertOrAssignValue("string_var", + StringValue(&arena, "string_var")); + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_THAT(value, test_case.result_matcher); +} + +INSTANTIATE_TEST_SUITE_P( + Cases, RegexPrecompilationTest, + testing::ValuesIn(std::vector{ + {"matches_receiver", R"(string_var.matches(r's\w+_var'))", + IsBoolValue(true)}, + {"matches_receiver_false", R"(string_var.matches(r'string_var\d+'))", + IsBoolValue(false)}, + {"matches_global_true", R"(matches(string_var, r's\w+_var'))", + IsBoolValue(true)}, + {"matches_global_false", R"(matches(string_var, r'string_var\d+'))", + IsBoolValue(false)}, + {"matches_bad_re2_expression", "matches('123', r'(?& info) { + return info.param.name; + }); + +} // namespace +} // namespace cel::extensions diff --git a/runtime/register_function_helper.h b/runtime/register_function_helper.h new file mode 100644 index 000000000..8cc133abc --- /dev/null +++ b/runtime/register_function_helper.h @@ -0,0 +1,99 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_REGISTER_FUNCTION_HELPER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_REGISTER_FUNCTION_HELPER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/function_descriptor.h" +#include "runtime/function_registry.h" +namespace cel { + +// Helper class for performing registration with function adapter. +// +// Usage: +// +// auto status = RegisterHelper> +// ::RegisterGlobalOverload( +// '_<_', +// [](int64_t x, int64_t y) -> bool {return x < y}, +// registry); +// +// if (!status.ok) return status; +// +// Note: if using this with status macros (*RETURN_IF_ERROR), an extra set of +// parentheses is needed around the multi-argument template specifier. +template +class RegisterHelper { + public: + // Generic registration for an adapted function. Prefer using one of the more + // specific Register* functions. + template + static absl::Status Register(absl::string_view name, bool receiver_style, + FunctionT&& fn, FunctionRegistry& registry, + bool strict) { + return registry.Register( + AdapterT::CreateDescriptor(name, receiver_style, strict), + AdapterT::WrapFunction(std::forward(fn))); + } + + template + static absl::Status Register(absl::string_view name, bool receiver_style, + FunctionT&& fn, FunctionRegistry& registry, + FunctionDescriptorOptions options = {}) { + return registry.Register( + AdapterT::CreateDescriptor(name, receiver_style, options), + AdapterT::WrapFunction(std::forward(fn))); + } + + // Registers a global overload (.e.g. size() ) + template + static absl::Status RegisterGlobalOverload(absl::string_view name, + FunctionT&& fn, + FunctionRegistry& registry) { + return Register(name, /*receiver_style=*/false, std::forward(fn), + registry); + } + + // Registers a member overload (.e.g. .size()) + template + static absl::Status RegisterMemberOverload(absl::string_view name, + FunctionT&& fn, + FunctionRegistry& registry) { + return Register(name, /*receiver_style=*/true, std::forward(fn), + registry); + } + + // Registers a non-strict overload. + // + // Non-strict functions may receive errors or unknown values as arguments, + // and must correctly propagate them. + // + // Most extension functions should prefer 'strict' overloads where the + // evaluator handles unknown and error propagation. + template + static absl::Status RegisterNonStrictOverload(absl::string_view name, + FunctionT&& fn, + FunctionRegistry& registry) { + return Register(name, /*receiver_style=*/false, std::forward(fn), + registry, /*strict=*/false); + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_REGISTER_FUNCTION_HELPER_H_ diff --git a/runtime/runtime.h b/runtime/runtime.h new file mode 100644 index 000000000..2db39b0e3 --- /dev/null +++ b/runtime/runtime.h @@ -0,0 +1,229 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Interfaces for runtime concepts. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/ast.h" +#include "base/type_provider.h" +#include "common/native_type.h" +#include "common/value.h" +#include "runtime/activation_interface.h" +#include "runtime/runtime_issue.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +namespace runtime_internal { +class RuntimeFriendAccess; +} // namespace runtime_internal + +class EmbedderContext; + +// Options for the Program::Evaluate call. +struct EvaluateOptions { + // Optional message factory to use for the duration of the Evaluate call. + // If unset, a default message factory will be provided by the runtime. + google::protobuf::MessageFactory* absl_nullable message_factory = nullptr; + + // Optional embedder context to use for the duration of the Evaluate call. + // This is used to access custom data in extension functions. + // This is only propagated to functions that are marked as context sensitive. + const EmbedderContext* absl_nullable embedder_context = nullptr; +}; + +// Representation of an evaluable CEL expression. +// +// See Runtime below for creating new programs. +class Program { + public: + virtual ~Program() = default; + + // Evaluate the program. + // + // Non-recoverable errors (i.e. outside of CEL's notion of an error) are + // returned as a non-ok absl::Status. These are propagated immediately and do + // not participate in CEL's notion of error handling. + // + // CEL errors are represented as result with an Ok status and a held + // cel::ErrorValue result. + // + // Activation manages instances of variables available in the cel expression's + // environment. + // + // The arena will be used to as necessary to allocate values and must outlive + // the returned value, as must this program. + // + // For consistency, users should use the same arena to create values + // in the activation and for Program evaluation. + absl::StatusOr Evaluate( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation, + const EvaluateOptions& options = {}) const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return EvaluateImpl(activation, arena, options); + } + + ABSL_DEPRECATED("Use the EvaluateOptions overload instead.") + absl::StatusOr Evaluate( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nullable message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation) const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return EvaluateImpl(activation, arena, {message_factory}); + } + + virtual const TypeProvider& GetTypeProvider() const = 0; + + protected: + virtual absl::StatusOr EvaluateImpl( + const ActivationInterface& activation, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + const EvaluateOptions& options) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; +}; + +// Representation for a traceable CEL expression. +// +// Implementations provide an additional Trace method that evaluates the +// expression and invokes a callback allowing callers to inspect intermediate +// state during evaluation. +class TraceableProgram : public Program { + public: + // EvaluationListener may be provided to an EvaluateWithCallback call to + // inspect intermediate values during evaluation. + // + // The callback is called on after every program step that corresponds + // to an AST expression node. The value provided is the top of the value + // stack, corresponding to the result of evaluating the given sub expression. + // + // A returning a non-ok status stops evaluation and forwards the error. + using EvaluationListener = absl::AnyInvocable; + + using Program::Evaluate; + + // Evaluate the Program plan with a Listener. + // + // The given callback will be invoked after evaluating any program step + // that corresponds to an AST node in the planned CEL expression. + // + // If the callback returns a non-ok status, evaluation stops and the Status + // is forwarded as the result of the EvaluateWithCallback call. + absl::StatusOr Trace( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation, + EvaluationListener evaluation_listener, + const EvaluateOptions& options = {}) const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return TraceImpl(activation, std::move(evaluation_listener), arena, + options); + } + + ABSL_DEPRECATED("Use the EvaluateOptions overload instead.") + absl::StatusOr Trace( + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + google::protobuf::MessageFactory* absl_nullable message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation, + EvaluationListener evaluation_listener) const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return TraceImpl(activation, std::move(evaluation_listener), arena, + {message_factory}); + } + + protected: + absl::StatusOr EvaluateImpl(const ActivationInterface& activation, + google::protobuf::Arena* absl_nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const EvaluateOptions& options) const + ABSL_ATTRIBUTE_LIFETIME_BOUND override { + return TraceImpl(activation, nullptr, arena, options); + } + + virtual absl::StatusOr TraceImpl( + const ActivationInterface& activation, + EvaluationListener evaluation_listener, + google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + const EvaluateOptions& options) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; +}; + +// Interface for a CEL runtime. +// +// Manages the state necessary to generate Programs. +// +// Runtime instances should be created from a RuntimeBuilder rather than +// instantiated directly. +// +// Implementations provided by CEL will be thread-compatible, but write +// operations on the underlying environment (TypeRegistry, FunctionRegistry) or +// on the implementation via down casting must be synchronized by the caller and +// may invalidate any Programs created from the Runtime. +class Runtime { + public: + struct CreateProgramOptions { + // Optional output for collecting issues encountered while planning. + // If non-null, vector is cleared and encountered issues are added. + std::vector* issues = nullptr; + }; + + virtual ~Runtime() = default; + + absl::StatusOr> CreateProgram( + std::unique_ptr ast) const { + return CreateProgram(std::move(ast), CreateProgramOptions{}); + } + + virtual absl::StatusOr> CreateProgram( + std::unique_ptr ast, + const CreateProgramOptions& options) const = 0; + + absl::StatusOr> CreateTraceableProgram( + std::unique_ptr ast) const { + return CreateTraceableProgram(std::move(ast), CreateProgramOptions{}); + } + + virtual absl::StatusOr> + CreateTraceableProgram(std::unique_ptr ast, + const CreateProgramOptions& options) const = 0; + + virtual const TypeProvider& GetTypeProvider() const = 0; + + virtual const google::protobuf::DescriptorPool* absl_nonnull GetDescriptorPool() + const = 0; + + virtual google::protobuf::MessageFactory* absl_nonnull GetMessageFactory() const = 0; + + private: + friend class runtime_internal::RuntimeFriendAccess; + + virtual NativeTypeId GetNativeTypeId() const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_H_ diff --git a/runtime/runtime_builder.h b/runtime/runtime_builder.h new file mode 100644 index 000000000..ff1db7b82 --- /dev/null +++ b/runtime/runtime_builder.h @@ -0,0 +1,101 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "runtime/function_registry.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Forward declare for friend access to avoid requiring a link dependency on the +// standard implementation and some extensions. +namespace runtime_internal { +class RuntimeFriendAccess; +} // namespace runtime_internal + +class RuntimeBuilder; +absl::StatusOr CreateRuntimeBuilder( + absl_nonnull std::shared_ptr, + const RuntimeOptions&); + +// RuntimeBuilder provides mutable accessors to configure a new runtime. +// +// Instances of this class are consumed when built. +class RuntimeBuilder { + public: + // Move-only + RuntimeBuilder(const RuntimeBuilder&) = delete; + RuntimeBuilder& operator=(const RuntimeBuilder&) = delete; + RuntimeBuilder(RuntimeBuilder&&) = default; + RuntimeBuilder& operator=(RuntimeBuilder&&) = default; + + TypeRegistry& type_registry() { + ABSL_DCHECK(runtime_ != nullptr); + return *type_registry_; + } + + FunctionRegistry& function_registry() { + ABSL_DCHECK(runtime_ != nullptr); + return *function_registry_; + } + + // Return the built runtime. + // + // The builder is left in an undefined state after this call and cannot be + // reused. + absl::StatusOr> Build() && { + return std::move(runtime_); + } + + private: + friend class runtime_internal::RuntimeFriendAccess; + friend absl::StatusOr CreateRuntimeBuilder( + absl_nonnull std::shared_ptr, + const RuntimeOptions&); + + // Constructor for a new runtime builder. + // + // It's assumed that the type registry and function registry are managed by + // the runtime. + // + // CEL users should use one of the factory functions for a new builder. + // See standard_runtime_builder_factory.h and runtime_builder_factory.h + RuntimeBuilder(TypeRegistry& type_registry, + FunctionRegistry& function_registry, + std::unique_ptr runtime) + : type_registry_(&type_registry), + function_registry_(&function_registry), + runtime_(std::move(runtime)) {} + + Runtime& runtime() { return *runtime_; } + + TypeRegistry* type_registry_; + FunctionRegistry* function_registry_; + std::unique_ptr runtime_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_H_ diff --git a/runtime/runtime_builder_factory.cc b/runtime/runtime_builder_factory.cc new file mode 100644 index 000000000..f5e760c0b --- /dev/null +++ b/runtime/runtime_builder_factory.cc @@ -0,0 +1,68 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/runtime_builder_factory.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "internal/noop_delete.h" +#include "internal/status_macros.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +using ::cel::runtime_internal::RuntimeEnv; +using ::cel::runtime_internal::RuntimeImpl; + +absl::StatusOr CreateRuntimeBuilder( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const RuntimeOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + return CreateRuntimeBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + options); +} + +absl::StatusOr CreateRuntimeBuilder( + absl_nonnull std::shared_ptr descriptor_pool, + const RuntimeOptions& options) { + // TODO(uncreated-issue/57): and internal API for adding extensions that need to + // downcast to the runtime impl. + // TODO(uncreated-issue/56): add API for attaching an issue listener (replacing the + // vector overloads). + ABSL_DCHECK(descriptor_pool != nullptr); + auto environment = std::make_shared(std::move(descriptor_pool)); + CEL_RETURN_IF_ERROR(environment->Initialize()); + auto runtime_impl = + std::make_unique(std::move(environment), options); + runtime_impl->expr_builder().set_container(options.container); + + auto& type_registry = runtime_impl->type_registry(); + auto& function_registry = runtime_impl->function_registry(); + + return RuntimeBuilder(type_registry, function_registry, + std::move(runtime_impl)); +} + +} // namespace cel diff --git a/runtime/runtime_builder_factory.h b/runtime/runtime_builder_factory.h new file mode 100644 index 000000000..0cb35d62a --- /dev/null +++ b/runtime/runtime_builder_factory.h @@ -0,0 +1,65 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Create an unconfigured builder using the default Runtime implementation. +// +// The provided descriptor pool is used when dealing with `google.protobuf.Any` +// messages, as well as for implementing struct creation syntax +// `foo.Bar{my_field: 1}`. The descriptor pool must outlive the resulting +// RuntimeBuilder, the `Runtime` it creates, and any `Program` that the +// `Runtime` creates. The descriptor pool must include the minimally necessary +// descriptors required by CEL. Those are the following: +// - google.protobuf.NullValue +// - google.protobuf.BoolValue +// - google.protobuf.Int32Value +// - google.protobuf.Int64Value +// - google.protobuf.UInt32Value +// - google.protobuf.UInt64Value +// - google.protobuf.FloatValue +// - google.protobuf.DoubleValue +// - google.protobuf.BytesValue +// - google.protobuf.StringValue +// - google.protobuf.Any +// - google.protobuf.Duration +// - google.protobuf.Timestamp +// +// This is provided for environments that only use a subset of the CEL standard +// builtins. Most users should prefer CreateStandardRuntimeBuilder. +// +// Callers must register appropriate builtins. +absl::StatusOr CreateRuntimeBuilder( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const RuntimeOptions& options); +absl::StatusOr CreateRuntimeBuilder( + absl_nonnull std::shared_ptr descriptor_pool, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ diff --git a/runtime/runtime_issue.h b/runtime/runtime_issue.h new file mode 100644 index 000000000..d18931756 --- /dev/null +++ b/runtime/runtime_issue.h @@ -0,0 +1,88 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_ISSUE_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_ISSUE_H_ + +#include + +#include "absl/status/status.h" + +namespace cel { + +// Represents an issue with a given CEL expression. +// +// The error details are represented as an absl::Status for compatibility +// reasons, but users should not depend on this. +class RuntimeIssue { + public: + // Severity of the RuntimeIssue. + // + // Can be used to determine whether to continue program planning or return + // early. + enum class Severity { + // The issue may lead to runtime errors in evaluation. + kWarning = 0, + // The expression is invalid or unsupported. + kError = 1, + // Arbitrary max value above Error. + kNotForUseWithExhaustiveSwitchStatements = 15 + }; + + // Code for well-known runtime error kinds. + enum class ErrorCode { + // Overload not provided for given function call signature. + kNoMatchingOverload, + // Field access refers to unknown field for given type. + kNoSuchField, + // Other error outside the canonical set. + kOther, + }; + + static RuntimeIssue CreateError(absl::Status status, + ErrorCode error_code = ErrorCode::kOther) { + return RuntimeIssue(std::move(status), Severity::kError, error_code); + } + + static RuntimeIssue CreateWarning(absl::Status status, + ErrorCode error_code = ErrorCode::kOther) { + return RuntimeIssue(std::move(status), Severity::kWarning, error_code); + } + + RuntimeIssue(const RuntimeIssue& other) = default; + RuntimeIssue& operator=(const RuntimeIssue& other) = default; + RuntimeIssue(RuntimeIssue&& other) = default; + RuntimeIssue& operator=(RuntimeIssue&& other) = default; + + Severity severity() const { return severity_; } + + ErrorCode error_code() const { return error_code_; } + + const absl::Status& ToStatus() const& { return status_; } + absl::Status ToStatus() && { return std::move(status_); } + + private: + RuntimeIssue(absl::Status status, Severity severity, ErrorCode error_code) + : status_(std::move(status)), + error_code_(error_code), + severity_(severity) {} + + absl::Status status_; + ErrorCode error_code_; + Severity severity_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_ISSUE_H_ diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h new file mode 100644 index 000000000..7a61208a0 --- /dev/null +++ b/runtime/runtime_options.h @@ -0,0 +1,196 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_OPTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_OPTIONS_H_ + +#include + +#include "absl/base/attributes.h" + +namespace cel { + +// Options for unknown processing. +enum class UnknownProcessingOptions { + // No unknown processing. + kDisabled, + // Only attributes supported. + kAttributeOnly, + // Attributes and functions supported. Function results are dependent on the + // logic for handling unknown_attributes, so clients must opt in to both. + kAttributeAndFunction +}; + +// Options for handling unset wrapper types on field access. +enum class ProtoWrapperTypeOptions { + // Default: legacy behavior following proto semantics (unset behaves as though + // it is set to default value). + kUnsetProtoDefault, + // CEL spec behavior, unset wrapper is treated as a null value when accessed. + kUnsetNull, +}; + +// LINT.IfChange +// Interpreter options for controlling evaluation and builtin functions. +// +// Members should provide simple parameters for configuring core features and +// built-ins. +// +// Optimizations or features that have a heavy footprint should be added via an +// extension API. +struct RuntimeOptions { + // Default container for resolving variables, types, and functions. + // Follows protobuf namespace rules. + std::string container = ""; + + // Level of unknown support enabled. + UnknownProcessingOptions unknown_processing = + UnknownProcessingOptions::kDisabled; + + bool enable_missing_attribute_errors = false; + + // Enable timestamp duration overflow checks. + // + // The CEL-Spec indicates that overflow should occur outside the range of + // string-representable timestamps, and at the limit of durations which can be + // expressed with a single int64 value. + bool enable_timestamp_duration_overflow_errors = false; + + // Enable short-circuiting of the logical operator evaluation. If enabled, + // AND, OR, and TERNARY do not evaluate the entire expression once the the + // resulting value is known from the left-hand side. + bool short_circuiting = true; + + // Enable comprehension expressions (e.g. exists, all) + bool enable_comprehension = true; + + // Set maximum number of iterations in the comprehension expressions if + // comprehensions are enabled. The limit applies globally per an evaluation, + // including the nested loops as well. Use value 0 to disable the upper bound. + int comprehension_max_iterations = 10000; + + // Enable list append within comprehensions. Note, this option is not safe + // with hand-rolled ASTs. + bool enable_comprehension_list_append = false; + + // Enable mutable map construction within comprehensions. Note, this option is + // not safe with hand-rolled ASTs. + bool enable_comprehension_mutable_map = false; + + // Enable RE2 match() overload. + bool enable_regex = true; + + // Set maximum program size for RE2 regex if regex overload is enabled. + // Evaluates to an error if a regex exceeds it. Use value 0 to disable the + // upper bound. + int regex_max_program_size = 0; + + // Enable string() overloads. + bool enable_string_conversion = true; + + // Enable string concatenation overload. + bool enable_string_concat = true; + + // Enable list concatenation overload. + bool enable_list_concat = true; + + // Enable list membership overload. + bool enable_list_contains = true; + + // Treat builder warnings as fatal errors. + bool fail_on_warnings = true; + + // Enable the resolution of qualified type identifiers as type values instead + // of field selections. + // + // This toggle may cause certain identifiers which overlap with CEL built-in + // type or with protobuf message types linked into the binary to be resolved + // as static type values rather than as per-eval variables. + bool enable_qualified_type_identifiers = false; + + // Enable heterogeneous comparisons (e.g. support for cross-type comparisons). + ABSL_DEPRECATED( + "The ability to disable heterogeneous equality is being removed in the " + "near future") + bool enable_heterogeneous_equality = true; + + // Enables unwrapping proto wrapper types to null if unset. e.g. if an + // expression access a field of type google.protobuf.Int64Value that is unset, + // that will result in a Null cel value, as opposed to returning the + // cel representation of the proto defined default int64: 0. + bool enable_empty_wrapper_null_unboxing = false; + + // Enable lazy cel.bind alias initialization. + // + // This is now always enabled. Setting this option has no effect. It will be + // removed in a later update. + bool enable_lazy_bind_initialization = true; + + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. + // + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the + // expression. + // + // This does not account for re-entrant evaluation in a client's extension + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. + // + // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. + int max_recursion_depth = 0; + + // Enable tracing support for recursively planned programs. + // + // Unlike the stack machine implementation, supporting tracing can affect + // performance whether or not tracing is requested for a given evaluation. + bool enable_recursive_tracing = false; + + // Enable fast implementations for some CEL standard functions. + // + // Uses a custom implementation for some functions in the CEL standard, + // bypassing normal dispatching logic and safety checks for functions. + // + // This prevents extending or disabling these functions in most cases. The + // expression planner will make a best effort attempt to check if custom + // overloads have been added for these functions, and will attempt to use them + // if they exist. + // + // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in + bool enable_fast_builtins = true; + + // When enabled, string(double) will format the double with enough precision + // to ensure that the original double value can be recovered exactly. + // + // If available, will use the `std::to_chars` standard library function to + // perform the conversion to generate the shortest representation. + // + // Otherwise, will fall back to formatting with the worst-case required + // precision. + // + // If disabled, will use the legacy behavior of rounding to 6 decimal places. + bool enable_precision_preserving_double_format = true; +}; +// LINT.ThenChange(//depot/google3/eval/public/cel_options.h) + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_OPTIONS_H_ diff --git a/runtime/standard/BUILD b/runtime/standard/BUILD new file mode 100644 index 000000000..02a23ef1b --- /dev/null +++ b/runtime/standard/BUILD @@ -0,0 +1,393 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +# Provides registrars for CEL standard definitions. +# TODO(uncreated-issue/41): CEL users shouldn't need to use these directly, instead they should prefer to +# use RegisterBuiltins when available. +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "comparison_functions", + srcs = [ + "comparison_functions.cc", + ], + hdrs = [ + "comparison_functions.h", + ], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:number", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "comparison_functions_test", + size = "small", + srcs = [ + "comparison_functions_test.cc", + ], + deps = [ + ":comparison_functions", + "//base:builtins", + "//common:kind", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "container_membership_functions", + srcs = [ + "container_membership_functions.cc", + ], + hdrs = [ + "container_membership_functions.h", + ], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:number", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:register_function_helper", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "container_membership_functions_test", + size = "small", + srcs = [ + "container_membership_functions_test.cc", + ], + deps = [ + ":container_membership_functions", + "//base:builtins", + "//common:function_descriptor", + "//common:kind", + "//internal:testing", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "equality_functions", + srcs = ["equality_functions.cc"], + hdrs = ["equality_functions.h"], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//common:value_kind", + "//internal:number", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:register_function_helper", + "//runtime:runtime_options", + "//runtime/internal:errors", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "equality_functions_test", + size = "small", + srcs = [ + "equality_functions_test.cc", + ], + deps = [ + ":equality_functions", + "//base:builtins", + "//common:function_descriptor", + "//common:kind", + "//internal:testing", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_library( + name = "logical_functions", + srcs = [ + "logical_functions.cc", + ], + hdrs = [ + "logical_functions.h", + ], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:register_function_helper", + "//runtime:runtime_options", + "//runtime/internal:errors", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "logical_functions_test", + size = "small", + srcs = [ + "logical_functions_test.cc", + ], + deps = [ + ":logical_functions", + "//base:builtins", + "//common:function_descriptor", + "//common:kind", + "//common:value", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", + "//runtime:function", + "//runtime:function_overload_reference", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "container_functions", + srcs = ["container_functions.cc"], + hdrs = ["container_functions.h"], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "container_functions_test", + size = "small", + srcs = [ + "container_functions_test.cc", + ], + deps = [ + ":container_functions", + "//base:builtins", + "//common:function_descriptor", + "//internal:testing", + ], +) + +cc_library( + name = "type_conversion_functions", + srcs = ["type_conversion_functions.cc"], + hdrs = ["type_conversion_functions.h"], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:overflow", + "//internal:status_macros", + "//internal:time", + "//internal:utf8", + "//runtime:function", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_conversion_functions_test", + size = "small", + srcs = [ + "type_conversion_functions_test.cc", + ], + deps = [ + ":type_conversion_functions", + "//base:builtins", + "//common:function_descriptor", + "//internal:testing", + ], +) + +cc_library( + name = "arithmetic_functions", + srcs = ["arithmetic_functions.cc"], + hdrs = ["arithmetic_functions.h"], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:overflow", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "arithmetic_functions_test", + size = "small", + srcs = [ + "arithmetic_functions_test.cc", + ], + deps = [ + ":arithmetic_functions", + "//base:builtins", + "//common:function_descriptor", + "//internal:testing", + ], +) + +cc_library( + name = "time_functions", + srcs = ["time_functions.cc"], + hdrs = ["time_functions.h"], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:overflow", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "time_functions_test", + size = "small", + srcs = [ + "time_functions_test.cc", + ], + deps = [ + ":time_functions", + "//base:builtins", + "//common:function_descriptor", + "//internal:testing", + ], +) + +cc_library( + name = "string_functions", + srcs = ["string_functions.cc"], + hdrs = ["string_functions.h"], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "string_functions_test", + size = "small", + srcs = [ + "string_functions_test.cc", + ], + deps = [ + ":string_functions", + "//base:builtins", + "//common:function_descriptor", + "//internal:testing", + ], +) + +cc_library( + name = "regex_functions", + srcs = ["regex_functions.cc"], + hdrs = ["regex_functions.h"], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//common:value", + "//internal:re2_options", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "regex_functions_test", + srcs = ["regex_functions_test.cc"], + deps = [ + ":regex_functions", + "//base:builtins", + "//common:function_descriptor", + "//internal:testing", + ], +) diff --git a/runtime/standard/arithmetic_functions.cc b/runtime/standard/arithmetic_functions.cc new file mode 100644 index 000000000..a851ceb39 --- /dev/null +++ b/runtime/standard/arithmetic_functions.cc @@ -0,0 +1,233 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/arithmetic_functions.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/overflow.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace { + +// Template functions providing arithmetic operations +template +Value Add(Type v0, Type v1); + +template <> +Value Add(int64_t v0, int64_t v1) { + auto sum = cel::internal::CheckedAdd(v0, v1); + if (!sum.ok()) { + return ErrorValue(sum.status()); + } + return IntValue(*sum); +} + +template <> +Value Add(uint64_t v0, uint64_t v1) { + auto sum = cel::internal::CheckedAdd(v0, v1); + if (!sum.ok()) { + return ErrorValue(sum.status()); + } + return UintValue(*sum); +} + +template <> +Value Add(double v0, double v1) { + return DoubleValue(v0 + v1); +} + +template +Value Sub(Type v0, Type v1); + +template <> +Value Sub(int64_t v0, int64_t v1) { + auto diff = cel::internal::CheckedSub(v0, v1); + if (!diff.ok()) { + return ErrorValue(diff.status()); + } + return IntValue(*diff); +} + +template <> +Value Sub(uint64_t v0, uint64_t v1) { + auto diff = cel::internal::CheckedSub(v0, v1); + if (!diff.ok()) { + return ErrorValue(diff.status()); + } + return UintValue(*diff); +} + +template <> +Value Sub(double v0, double v1) { + return DoubleValue(v0 - v1); +} + +template +Value Mul(Type v0, Type v1); + +template <> +Value Mul(int64_t v0, int64_t v1) { + auto prod = cel::internal::CheckedMul(v0, v1); + if (!prod.ok()) { + return ErrorValue(prod.status()); + } + return IntValue(*prod); +} + +template <> +Value Mul(uint64_t v0, uint64_t v1) { + auto prod = cel::internal::CheckedMul(v0, v1); + if (!prod.ok()) { + return ErrorValue(prod.status()); + } + return UintValue(*prod); +} + +template <> +Value Mul(double v0, double v1) { + return DoubleValue(v0 * v1); +} + +template +Value Div(Type v0, Type v1); + +// Division operations for integer types should check for +// division by 0 +template <> +Value Div(int64_t v0, int64_t v1) { + auto quot = cel::internal::CheckedDiv(v0, v1); + if (!quot.ok()) { + return ErrorValue(quot.status()); + } + return IntValue(*quot); +} + +// Division operations for integer types should check for +// division by 0 +template <> +Value Div(uint64_t v0, uint64_t v1) { + auto quot = cel::internal::CheckedDiv(v0, v1); + if (!quot.ok()) { + return ErrorValue(quot.status()); + } + return UintValue(*quot); +} + +template <> +Value Div(double v0, double v1) { + static_assert(std::numeric_limits::is_iec559, + "Division by zero for doubles must be supported"); + + // For double, division will result in +/- inf + return DoubleValue(v0 / v1); +} + +// Modulo operation +template +Value Modulo(Type v0, Type v1); + +// Modulo operations for integer types should check for +// division by 0 +template <> +Value Modulo(int64_t v0, int64_t v1) { + auto mod = cel::internal::CheckedMod(v0, v1); + if (!mod.ok()) { + return ErrorValue(mod.status()); + } + return IntValue(*mod); +} + +template <> +Value Modulo(uint64_t v0, uint64_t v1) { + auto mod = cel::internal::CheckedMod(v0, v1); + if (!mod.ok()) { + return ErrorValue(mod.status()); + } + return UintValue(*mod); +} + +// Helper method +// Registers all arithmetic functions for template parameter type. +template +absl::Status RegisterArithmeticFunctionsForType(FunctionRegistry& registry) { + using FunctionAdapter = cel::BinaryFunctionAdapter; + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kAdd, false), + FunctionAdapter::WrapFunction(&Add))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kSubtract, false), + FunctionAdapter::WrapFunction(&Sub))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kMultiply, false), + FunctionAdapter::WrapFunction(&Mul))); + + return registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kDivide, false), + FunctionAdapter::WrapFunction(&Div)); +} + +} // namespace + +absl::Status RegisterArithmeticFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); + + // Modulo + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor( + cel::builtin::kModulo, false), + BinaryFunctionAdapter::WrapFunction( + &Modulo))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor( + cel::builtin::kModulo, false), + BinaryFunctionAdapter::WrapFunction( + &Modulo))); + + // Negation group + CEL_RETURN_IF_ERROR( + registry.Register(UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kNeg, false), + UnaryFunctionAdapter::WrapFunction( + [](int64_t value) -> Value { + auto inv = cel::internal::CheckedNegation(value); + if (!inv.ok()) { + return ErrorValue(inv.status()); + } + return IntValue(*inv); + }))); + + return registry.Register( + UnaryFunctionAdapter::CreateDescriptor(cel::builtin::kNeg, + false), + UnaryFunctionAdapter::WrapFunction( + [](double value) -> double { return -value; })); +} + +} // namespace cel diff --git a/runtime/standard/arithmetic_functions.h b/runtime/standard/arithmetic_functions.h new file mode 100644 index 000000000..c58619dc0 --- /dev/null +++ b/runtime/standard/arithmetic_functions.h @@ -0,0 +1,35 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_ARITHMETIC_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_ARITHMETIC_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register builtin arithmetic operators: +// _+_ (addition), _-_ (subtraction), -_ (negation), _/_ (division), +// _*_ (multiplication), _%_ (modulo) +// +// Most users should use RegisterBuiltinFunctions, which includes these +// definitions. +absl::Status RegisterArithmeticFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_ARITHMETIC_FUNCTIONS_H_ diff --git a/runtime/standard/arithmetic_functions_test.cc b/runtime/standard/arithmetic_functions_test.cc new file mode 100644 index 000000000..f1da74fd2 --- /dev/null +++ b/runtime/standard/arithmetic_functions_test.cc @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/arithmetic_functions.h" + +#include + +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::UnorderedElementsAre; + +MATCHER_P2(MatchesOperatorDescriptor, name, expected_kind, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + std::vector types{expected_kind, expected_kind}; + return descriptor.name() == name && descriptor.receiver_style() == false && + descriptor.types() == types; +} + +MATCHER_P(MatchesNegationDescriptor, expected_kind, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + std::vector types{expected_kind}; + return descriptor.name() == builtin::kNeg && + descriptor.receiver_style() == false && descriptor.types() == types; +} + +TEST(RegisterArithmeticFunctions, Registered) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterArithmeticFunctions(registry, options)); + + EXPECT_THAT(registry.FindStaticOverloads(builtin::kAdd, false, + {Kind::kAny, Kind::kAny}), + UnorderedElementsAre( + MatchesOperatorDescriptor(builtin::kAdd, Kind::kInt), + MatchesOperatorDescriptor(builtin::kAdd, Kind::kDouble), + MatchesOperatorDescriptor(builtin::kAdd, Kind::kUint))); + EXPECT_THAT(registry.FindStaticOverloads(builtin::kSubtract, false, + {Kind::kAny, Kind::kAny}), + UnorderedElementsAre( + MatchesOperatorDescriptor(builtin::kSubtract, Kind::kInt), + MatchesOperatorDescriptor(builtin::kSubtract, Kind::kDouble), + MatchesOperatorDescriptor(builtin::kSubtract, Kind::kUint))); + EXPECT_THAT(registry.FindStaticOverloads(builtin::kDivide, false, + {Kind::kAny, Kind::kAny}), + UnorderedElementsAre( + MatchesOperatorDescriptor(builtin::kDivide, Kind::kInt), + MatchesOperatorDescriptor(builtin::kDivide, Kind::kDouble), + MatchesOperatorDescriptor(builtin::kDivide, Kind::kUint))); + EXPECT_THAT(registry.FindStaticOverloads(builtin::kMultiply, false, + {Kind::kAny, Kind::kAny}), + UnorderedElementsAre( + MatchesOperatorDescriptor(builtin::kMultiply, Kind::kInt), + MatchesOperatorDescriptor(builtin::kMultiply, Kind::kDouble), + MatchesOperatorDescriptor(builtin::kMultiply, Kind::kUint))); + EXPECT_THAT(registry.FindStaticOverloads(builtin::kModulo, false, + {Kind::kAny, Kind::kAny}), + UnorderedElementsAre( + MatchesOperatorDescriptor(builtin::kModulo, Kind::kInt), + MatchesOperatorDescriptor(builtin::kModulo, Kind::kUint))); + EXPECT_THAT(registry.FindStaticOverloads(builtin::kNeg, false, {Kind::kAny}), + UnorderedElementsAre(MatchesNegationDescriptor(Kind::kInt), + MatchesNegationDescriptor(Kind::kDouble))); +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard/comparison_functions.cc b/runtime/standard/comparison_functions.cc new file mode 100644 index 000000000..bddd1efe9 --- /dev/null +++ b/runtime/standard/comparison_functions.cc @@ -0,0 +1,272 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/comparison_functions.h" + +#include + +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +namespace { + +using ::cel::internal::Number; + +// Comparison template functions +template +bool LessThan(Type t1, Type t2) { + return (t1 < t2); +} + +template +bool LessThanOrEqual(Type t1, Type t2) { + return (t1 <= t2); +} + +template +bool GreaterThan(Type t1, Type t2) { + return LessThan(t2, t1); +} + +template +bool GreaterThanOrEqual(Type t1, Type t2) { + return LessThanOrEqual(t2, t1); +} + +// String value comparions specializations +template <> +bool LessThan(const StringValue& t1, const StringValue& t2) { + return t1.Compare(t2) < 0; +} + +template <> +bool LessThanOrEqual(const StringValue& t1, const StringValue& t2) { + return t1.Compare(t2) <= 0; +} + +template <> +bool GreaterThan(const StringValue& t1, const StringValue& t2) { + return t1.Compare(t2) > 0; +} + +template <> +bool GreaterThanOrEqual(const StringValue& t1, const StringValue& t2) { + return t1.Compare(t2) >= 0; +} + +// bytes value comparions specializations +template <> +bool LessThan(const BytesValue& t1, const BytesValue& t2) { + return t1.Compare(t2) < 0; +} + +template <> +bool LessThanOrEqual(const BytesValue& t1, const BytesValue& t2) { + return t1.Compare(t2) <= 0; +} + +template <> +bool GreaterThan(const BytesValue& t1, const BytesValue& t2) { + return t1.Compare(t2) > 0; +} + +template <> +bool GreaterThanOrEqual(const BytesValue& t1, const BytesValue& t2) { + return t1.Compare(t2) >= 0; +} + +// Duration comparison specializations +template <> +bool LessThan(absl::Duration t1, absl::Duration t2) { + return absl::operator<(t1, t2); +} + +template <> +bool LessThanOrEqual(absl::Duration t1, absl::Duration t2) { + return absl::operator<=(t1, t2); +} + +template <> +bool GreaterThan(absl::Duration t1, absl::Duration t2) { + return absl::operator>(t1, t2); +} + +template <> +bool GreaterThanOrEqual(absl::Duration t1, absl::Duration t2) { + return absl::operator>=(t1, t2); +} + +// Timestamp comparison specializations +template <> +bool LessThan(absl::Time t1, absl::Time t2) { + return absl::operator<(t1, t2); +} + +template <> +bool LessThanOrEqual(absl::Time t1, absl::Time t2) { + return absl::operator<=(t1, t2); +} + +template <> +bool GreaterThan(absl::Time t1, absl::Time t2) { + return absl::operator>(t1, t2); +} + +template <> +bool GreaterThanOrEqual(absl::Time t1, absl::Time t2) { + return absl::operator>=(t1, t2); +} + +template +bool CrossNumericLessThan(T t, U u) { + return Number(t) < Number(u); +} + +template +bool CrossNumericGreaterThan(T t, U u) { + return Number(t) > Number(u); +} + +template +bool CrossNumericLessOrEqualTo(T t, U u) { + return Number(t) <= Number(u); +} + +template +bool CrossNumericGreaterOrEqualTo(T t, U u) { + return Number(t) >= Number(u); +} + +template +absl::Status RegisterComparisonFunctionsForType( + cel::FunctionRegistry& registry) { + using FunctionAdapter = BinaryFunctionAdapter; + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLess, false), + FunctionAdapter::WrapFunction(LessThan))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, false), + FunctionAdapter::WrapFunction(LessThanOrEqual))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, false), + FunctionAdapter::WrapFunction(GreaterThan))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, false), + FunctionAdapter::WrapFunction(GreaterThanOrEqual))); + + return absl::OkStatus(); +} + +absl::Status RegisterHomogenousComparisonFunctions( + cel::FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + return absl::OkStatus(); +} + +template +absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry) { + using FunctionAdapter = BinaryFunctionAdapter; + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLess, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericLessThan))); + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericGreaterThan))); + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericGreaterOrEqualTo))); + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericLessOrEqualTo))); + return absl::OkStatus(); +} + +absl::Status RegisterHeterogeneousComparisonFunctions( + cel::FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + return absl::OkStatus(); +} +} // namespace + +absl::Status RegisterComparisonFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + if (options.enable_heterogeneous_equality) { + CEL_RETURN_IF_ERROR(RegisterHeterogeneousComparisonFunctions(registry)); + } else { + CEL_RETURN_IF_ERROR(RegisterHomogenousComparisonFunctions(registry)); + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/runtime/standard/comparison_functions.h b/runtime/standard/comparison_functions.h new file mode 100644 index 000000000..4b19f85ed --- /dev/null +++ b/runtime/standard/comparison_functions.h @@ -0,0 +1,36 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register built in comparison functions (<, <=, >, >=). +// +// Most users should prefer to use RegisterBuiltinFunctions. +// +// This is call is included in RegisterBuiltinFunctions -- calling both +// RegisterBuiltinFunctions and RegisterComparisonFunctions directly on the same +// registry will result in an error. +absl::Status RegisterComparisonFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ diff --git a/runtime/standard/comparison_functions_test.cc b/runtime/standard/comparison_functions_test.cc new file mode 100644 index 000000000..1963b6758 --- /dev/null +++ b/runtime/standard/comparison_functions_test.cc @@ -0,0 +1,82 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/comparison_functions.h" + +#include + +#include "absl/strings/str_cat.h" +#include "base/builtins.h" +#include "common/kind.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +MATCHER_P2(DefinesHomogenousOverload, name, argument_kind, + absl::StrCat(name, " for ", KindToString(argument_kind))) { + const cel::FunctionRegistry& registry = arg; + return !registry + .FindStaticOverloads(name, /*receiver_style=*/false, + {argument_kind, argument_kind}) + .empty(); +} + +constexpr std::array kOrderableTypes = { + Kind::kBool, Kind::kInt64, Kind::kUint64, Kind::kString, + Kind::kDouble, Kind::kBytes, Kind::kDuration, Kind::kTimestamp}; + +TEST(RegisterComparisonFunctionsTest, LessThanDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kLess, kind)); + } +} + +TEST(RegisterComparisonFunctionsTest, LessThanOrEqualDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, + DefinesHomogenousOverload(builtin::kLessOrEqual, kind)); + } +} + +TEST(RegisterComparisonFunctionsTest, GreaterThanDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kGreater, kind)); + } +} + +TEST(RegisterComparisonFunctionsTest, GreaterThanOrEqualDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, + DefinesHomogenousOverload(builtin::kGreaterOrEqual, kind)); + } +} + +// TODO(uncreated-issue/41): move functional tests from wrapper library after top-level +// APIs are available for planning and running an expression. + +} // namespace +} // namespace cel diff --git a/runtime/standard/container_functions.cc b/runtime/standard/container_functions.cc new file mode 100644 index 000000000..c81dc7596 --- /dev/null +++ b/runtime/standard/container_functions.cc @@ -0,0 +1,136 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/container_functions.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "common/values/list_value_builder.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +absl::StatusOr MapSizeImpl(const MapValue& value) { + return value.Size(); +} + +absl::StatusOr ListSizeImpl(const ListValue& value) { + return value.Size(); +} + +// Concatenation for CelList type. +absl::StatusOr ConcatList( + const ListValue& value1, const ListValue& value2, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(auto size1, value1.Size()); + if (size1 == 0) { + return value2; + } + CEL_ASSIGN_OR_RETURN(auto size2, value2.Size()); + if (size2 == 0) { + return value1; + } + + // TODO(uncreated-issue/50): add option for checking lists have homogenous element + // types and use a more specialized list type when possible. + auto list_builder = NewListValueBuilder(arena); + + list_builder->Reserve(size1 + size2); + + for (size_t i = 0; i < size1; i++) { + CEL_ASSIGN_OR_RETURN( + Value elem, value1.Get(i, descriptor_pool, message_factory, arena)); + CEL_RETURN_IF_ERROR(list_builder->Add(std::move(elem))); + } + for (size_t i = 0; i < size2; i++) { + CEL_ASSIGN_OR_RETURN( + Value elem, value2.Get(i, descriptor_pool, message_factory, arena)); + CEL_RETURN_IF_ERROR(list_builder->Add(std::move(elem))); + } + + return std::move(*list_builder).Build(); +} + +// AppendList will append the elements in value2 to value1. +// +// This call will only be invoked within comprehensions where `value1` is an +// intermediate result which cannot be directly assigned or co-mingled with a +// user-provided list. +absl::StatusOr AppendList(ListValue value1, const Value& value2) { + // The `value1` object cannot be directly addressed and is an intermediate + // variable. Once the comprehension completes this value will in effect be + // treated as immutable. + if (auto mutable_list_value = + cel::common_internal::AsMutableListValue(value1); + mutable_list_value) { + CEL_RETURN_IF_ERROR(mutable_list_value->Append(value2)); + return value1; + } + return absl::InvalidArgumentError("Unexpected call to runtime list append."); +} +} // namespace + +absl::Status RegisterContainerFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + // receiver style = true/false + // Support both the global and receiver style size() for lists and maps. + for (bool receiver_style : {true, false}) { + CEL_RETURN_IF_ERROR(registry.Register( + cel::UnaryFunctionAdapter, const ListValue&>:: + CreateDescriptor(cel::builtin::kSize, receiver_style), + UnaryFunctionAdapter, + const ListValue&>::WrapFunction(ListSizeImpl))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, const MapValue&>:: + CreateDescriptor(cel::builtin::kSize, receiver_style), + UnaryFunctionAdapter, + const MapValue&>::WrapFunction(MapSizeImpl))); + } + + if (options.enable_list_concat) { + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter< + absl::StatusOr, const ListValue&, + const ListValue&>::CreateDescriptor(cel::builtin::kAdd, false), + BinaryFunctionAdapter, const ListValue&, + const ListValue&>::WrapFunction(ConcatList))); + } + + return registry.Register( + BinaryFunctionAdapter< + absl::StatusOr, ListValue, + const Value&>::CreateDescriptor(cel::builtin::kRuntimeListAppend, + false), + BinaryFunctionAdapter, ListValue, + const Value&>::WrapFunction(AppendList)); +} + +} // namespace cel diff --git a/runtime/standard/container_functions.h b/runtime/standard/container_functions.h new file mode 100644 index 000000000..7d44986f4 --- /dev/null +++ b/runtime/standard/container_functions.h @@ -0,0 +1,36 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register built in container functions. +// +// Most users should prefer to use RegisterBuiltinFunctions. +// +// This call is included in RegisterBuiltinFunctions -- calling both +// RegisterBuiltinFunctions and RegisterContainerFunctions directly on the same +// registry will result in an error. +absl::Status RegisterContainerFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_FUNCTIONS_H_ diff --git a/runtime/standard/container_functions_test.cc b/runtime/standard/container_functions_test.cc new file mode 100644 index 000000000..3e4838bc2 --- /dev/null +++ b/runtime/standard/container_functions_test.cc @@ -0,0 +1,99 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/container_functions.h" + +#include + +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + const std::vector& types = expected_kinds; + return descriptor.name() == name && descriptor.receiver_style() == receiver && + descriptor.types() == types; +} + +TEST(RegisterContainerFunctions, RegistersSizeFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterContainerFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kSize, false, {Kind::kAny}), + UnorderedElementsAre(MatchesDescriptor(builtin::kSize, false, + std::vector{Kind::kList}), + MatchesDescriptor(builtin::kSize, false, + std::vector{Kind::kMap}))); + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kSize, true, {Kind::kAny}), + UnorderedElementsAre(MatchesDescriptor(builtin::kSize, true, + std::vector{Kind::kList}), + MatchesDescriptor(builtin::kSize, true, + std::vector{Kind::kMap}))); +} + +TEST(RegisterContainerFunctions, RegisterListConcatEnabled) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_list_concat = true; + + ASSERT_OK(RegisterContainerFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kAdd, false, + {Kind::kAny, Kind::kAny}), + UnorderedElementsAre(MatchesDescriptor( + builtin::kAdd, false, std::vector{Kind::kList, Kind::kList}))); +} + +TEST(RegisterContainerFunctions, RegisterListConcateDisabled) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_list_concat = false; + + ASSERT_OK(RegisterContainerFunctions(registry, options)); + + EXPECT_THAT(registry.FindStaticOverloads(builtin::kAdd, false, + {Kind::kAny, Kind::kAny}), + IsEmpty()); +} + +TEST(RegisterContainerFunctions, RegisterRuntimeListAppend) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterContainerFunctions(registry, options)); + + EXPECT_THAT(registry.FindStaticOverloads(builtin::kRuntimeListAppend, false, + {Kind::kAny, Kind::kAny}), + UnorderedElementsAre(MatchesDescriptor( + builtin::kRuntimeListAppend, false, + std::vector{Kind::kList, Kind::kAny}))); +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard/container_membership_functions.cc b/runtime/standard/container_membership_functions.cc new file mode 100644 index 000000000..cc0638429 --- /dev/null +++ b/runtime/standard/container_membership_functions.cc @@ -0,0 +1,331 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/container_membership_functions.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/register_function_helper.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +using ::cel::internal::Number; + +static constexpr std::array in_operators = { + cel::builtin::kIn, // @in for map and list types. + cel::builtin::kInFunction, // deprecated in() -- for backwards compat + cel::builtin::kInDeprecated, // deprecated _in_ -- for backwards compat +}; + +template +bool ValueEquals(const Value& value, T other); + +template <> +bool ValueEquals(const Value& value, bool other) { + if (auto bool_value = As(value); bool_value) { + return bool_value->NativeValue() == other; + } + return false; +} + +template <> +bool ValueEquals(const Value& value, int64_t other) { + if (auto int_value = As(value); int_value) { + return int_value->NativeValue() == other; + } + return false; +} + +template <> +bool ValueEquals(const Value& value, uint64_t other) { + if (auto uint_value = As(value); uint_value) { + return uint_value->NativeValue() == other; + } + return false; +} + +template <> +bool ValueEquals(const Value& value, double other) { + if (auto double_value = As(value); double_value) { + return double_value->NativeValue() == other; + } + return false; +} + +template <> +bool ValueEquals(const Value& value, const StringValue& other) { + if (auto string_value = As(value); string_value) { + return string_value->Equals(other); + } + return false; +} + +template <> +bool ValueEquals(const Value& value, const BytesValue& other) { + if (auto bytes_value = As(value); bytes_value) { + return bytes_value->Equals(other); + } + return false; +} + +// Template function implementing CEL in() function +template +absl::StatusOr In( + T value, const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(auto size, list.Size()); + Value element; + for (int i = 0; i < size; i++) { + CEL_RETURN_IF_ERROR( + list.Get(i, descriptor_pool, message_factory, arena, &element)); + if (ValueEquals(element, value)) { + return true; + } + } + + return false; +} + +// Implementation for @in operator using heterogeneous equality. +absl::StatusOr HeterogeneousEqualityIn( + const Value& value, const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return list.Contains(value, descriptor_pool, message_factory, arena); +} + +absl::Status RegisterListMembershipFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + for (absl::string_view op : in_operators) { + if (options.enable_heterogeneous_equality) { + CEL_RETURN_IF_ERROR( + (RegisterHelper, const Value&, const ListValue&>>:: + RegisterGlobalOverload(op, &HeterogeneousEqualityIn, registry))); + } else { + CEL_RETURN_IF_ERROR( + (RegisterHelper, bool, + const ListValue&>>:: + RegisterGlobalOverload(op, In, registry))); + CEL_RETURN_IF_ERROR( + (RegisterHelper, int64_t, + const ListValue&>>:: + RegisterGlobalOverload(op, In, registry))); + CEL_RETURN_IF_ERROR( + (RegisterHelper, uint64_t, + const ListValue&>>:: + RegisterGlobalOverload(op, In, registry))); + CEL_RETURN_IF_ERROR( + (RegisterHelper, double, + const ListValue&>>:: + RegisterGlobalOverload(op, In, registry))); + CEL_RETURN_IF_ERROR( + (RegisterHelper, const StringValue&, const ListValue&>>:: + RegisterGlobalOverload(op, In, registry))); + CEL_RETURN_IF_ERROR( + (RegisterHelper, const BytesValue&, const ListValue&>>:: + RegisterGlobalOverload(op, In, registry))); + } + } + return absl::OkStatus(); +} + +absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + const bool enable_heterogeneous_equality = + options.enable_heterogeneous_equality; + + auto boolKeyInSet = + [enable_heterogeneous_equality]( + bool key, const MapValue& map_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(BoolValue(key), descriptor_pool, + message_factory, arena, &has)); + if (has.IsTrue()) { + return has; + } + if (enable_heterogeneous_equality) { + return BoolValue(false); + } + return has; + }; + + auto intKeyInSet = + [enable_heterogeneous_equality]( + int64_t key, const MapValue& map_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + Value result; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(key), descriptor_pool, + message_factory, arena, &result)); + if (enable_heterogeneous_equality) { + if (result.IsTrue()) { + return result; + } + Number number = Number::FromInt64(key); + if (number.LosslessConvertibleToUint()) { + Value result_alt; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, + arena, &result_alt)); + if (result_alt.IsTrue()) { + return result_alt; + } + } + return BoolValue(false); + } + return result; + }; + + auto stringKeyInSet = + [enable_heterogeneous_equality]( + const StringValue& key, const MapValue& map_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + Value result; + CEL_RETURN_IF_ERROR( + map_value.Has(key, descriptor_pool, message_factory, arena, &result)); + if (result.IsBool()) { + return result; + } + if (enable_heterogeneous_equality) { + return BoolValue(false); + } + return result; + }; + + auto uintKeyInSet = + [enable_heterogeneous_equality]( + uint64_t key, const MapValue& map_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(key), descriptor_pool, + message_factory, arena, &has)); + if (enable_heterogeneous_equality) { + if (has.IsTrue()) { + return has; + } + Value has_alt; + Number number = Number::FromUint64(key); + if (number.LosslessConvertibleToInt()) { + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, + arena, &has_alt)); + if (has.IsTrue()) { + return has; + } + } + return BoolValue(false); + } + return has; + }; + + auto doubleKeyInSet = + [](double key, const MapValue& map_value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + Number number = Number::FromDouble(key); + if (number.LosslessConvertibleToInt()) { + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; + } + } + if (number.LosslessConvertibleToUint()) { + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; + } + } + return BoolValue(false); + }; + + for (auto op : in_operators) { + auto status = RegisterHelper, const StringValue&, + const MapValue&>>::RegisterGlobalOverload(op, stringKeyInSet, registry); + if (!status.ok()) return status; + + status = RegisterHelper< + BinaryFunctionAdapter, bool, const MapValue&>>:: + RegisterGlobalOverload(op, boolKeyInSet, registry); + if (!status.ok()) return status; + + status = RegisterHelper, + int64_t, const MapValue&>>:: + RegisterGlobalOverload(op, intKeyInSet, registry); + if (!status.ok()) return status; + + status = RegisterHelper, + uint64_t, const MapValue&>>:: + RegisterGlobalOverload(op, uintKeyInSet, registry); + if (!status.ok()) return status; + + if (enable_heterogeneous_equality) { + status = RegisterHelper, + double, const MapValue&>>:: + RegisterGlobalOverload(op, doubleKeyInSet, registry); + if (!status.ok()) return status; + } + } + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterContainerMembershipFunctions( + FunctionRegistry& registry, const RuntimeOptions& options) { + if (options.enable_list_contains) { + CEL_RETURN_IF_ERROR(RegisterListMembershipFunctions(registry, options)); + } + return RegisterMapMembershipFunctions(registry, options); +} + +} // namespace cel diff --git a/runtime/standard/container_membership_functions.h b/runtime/standard/container_membership_functions.h new file mode 100644 index 000000000..fee62b6f4 --- /dev/null +++ b/runtime/standard/container_membership_functions.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_MEMBERSHIP_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_MEMBERSHIP_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register container membership functions +// in and in . +// +// The in operator follows the same behavior as equality, following the +// .enable_heterogeneous_equality option. +absl::Status RegisterContainerMembershipFunctions( + FunctionRegistry& registry, const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_MEMBERSHIP_FUNCTIONS_H_ diff --git a/runtime/standard/container_membership_functions_test.cc b/runtime/standard/container_membership_functions_test.cc new file mode 100644 index 000000000..9c90136d9 --- /dev/null +++ b/runtime/standard/container_membership_functions_test.cc @@ -0,0 +1,138 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/container_membership_functions.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "internal/testing.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace { + +using ::testing::UnorderedElementsAre; + +MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { + const FunctionDescriptor& descriptor = *arg; + const std::vector& types = expected_kinds; + return descriptor.name() == name && descriptor.receiver_style() == receiver && + descriptor.types() == types; +} + +static constexpr std::array kInOperators = { + builtin::kIn, builtin::kInDeprecated, builtin::kInFunction}; + +TEST(RegisterContainerMembershipFunctions, RegistersHomogeneousInOperator) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_heterogeneous_equality = false; + + ASSERT_OK(RegisterContainerMembershipFunctions(registry, options)); + + auto overloads = registry.ListFunctions(); + + for (absl::string_view operator_name : kInOperators) { + EXPECT_THAT( + overloads[operator_name], + UnorderedElementsAre( + MatchesDescriptor(operator_name, false, + std::vector{Kind::kInt, Kind::kList}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kUint, Kind::kList}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kDouble, Kind::kList}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kString, Kind::kList}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kBytes, Kind::kList}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kBool, Kind::kList}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kInt, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kUint, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kString, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kBool, Kind::kMap}))); + } +} + +TEST(RegisterContainerMembershipFunctions, RegistersHeterogeneousInOperation) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_heterogeneous_equality = true; + + ASSERT_OK(RegisterContainerMembershipFunctions(registry, options)); + + auto overloads = registry.ListFunctions(); + + for (absl::string_view operator_name : kInOperators) { + EXPECT_THAT( + overloads[operator_name], + UnorderedElementsAre( + MatchesDescriptor(operator_name, false, + std::vector{Kind::kAny, Kind::kList}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kInt, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kUint, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kDouble, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kString, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kBool, Kind::kMap}))); + } +} + +TEST(RegisterContainerMembershipFunctions, RegistersInOperatorListsDisabled) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_list_contains = false; + + ASSERT_OK(RegisterContainerMembershipFunctions(registry, options)); + + auto overloads = registry.ListFunctions(); + + for (absl::string_view operator_name : kInOperators) { + EXPECT_THAT( + overloads[operator_name], + UnorderedElementsAre( + + MatchesDescriptor(operator_name, false, + std::vector{Kind::kInt, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kUint, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kDouble, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kString, Kind::kMap}), + MatchesDescriptor(operator_name, false, + std::vector{Kind::kBool, Kind::kMap}))); + } +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard/equality_functions.cc b/runtime/standard/equality_functions.cc new file mode 100644 index 000000000..6546db16c --- /dev/null +++ b/runtime/standard/equality_functions.cc @@ -0,0 +1,612 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/equality_functions.h" + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "common/value_kind.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/internal/errors.h" +#include "runtime/register_function_helper.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +using ::cel::builtin::kEqual; +using ::cel::builtin::kInequal; +using ::cel::internal::Number; + +// Declaration for the functors for generic equality operator. +// Equal only defined for same-typed values. +// Nullopt is returned if equality is not defined. +struct HomogenousEqualProvider { + static constexpr bool kIsHeterogeneous = false; + absl::StatusOr> operator()( + const Value& lhs, const Value& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; +}; + +// Equal defined between compatible types. +// Nullopt is returned if equality is not defined. +struct HeterogeneousEqualProvider { + static constexpr bool kIsHeterogeneous = true; + + absl::StatusOr> operator()( + const Value& lhs, const Value& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; +}; + +// Comparison template functions +template +absl::optional Inequal(Type lhs, Type rhs) { + return lhs != rhs; +} + +template <> +absl::optional Inequal(const StringValue& lhs, const StringValue& rhs) { + return !lhs.Equals(rhs); +} + +template <> +absl::optional Inequal(const BytesValue& lhs, const BytesValue& rhs) { + return !lhs.Equals(rhs); +} + +template <> +absl::optional Inequal(const NullValue&, const NullValue&) { + return false; +} + +template <> +absl::optional Inequal(const TypeValue& lhs, const TypeValue& rhs) { + return lhs.name() != rhs.name(); +} + +template +absl::optional Equal(Type lhs, Type rhs) { + return lhs == rhs; +} + +template <> +absl::optional Equal(const StringValue& lhs, const StringValue& rhs) { + return lhs.Equals(rhs); +} + +template <> +absl::optional Equal(const BytesValue& lhs, const BytesValue& rhs) { + return lhs.Equals(rhs); +} + +template <> +absl::optional Equal(const NullValue&, const NullValue&) { + return true; +} + +template <> +absl::optional Equal(const TypeValue& lhs, const TypeValue& rhs) { + return lhs.name() == rhs.name(); +} + +// Equality for lists. Template parameter provides either heterogeneous or +// homogenous equality for comparing members. +template +absl::StatusOr> ListEqual( + const ListValue& lhs, const ListValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (&lhs == &rhs) { + return true; + } + CEL_ASSIGN_OR_RETURN(auto lhs_size, lhs.Size()); + CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); + if (lhs_size != rhs_size) { + return false; + } + + for (int i = 0; i < lhs_size; ++i) { + CEL_ASSIGN_OR_RETURN(auto lhs_i, + lhs.Get(i, descriptor_pool, message_factory, arena)); + CEL_ASSIGN_OR_RETURN(auto rhs_i, + rhs.Get(i, descriptor_pool, message_factory, arena)); + CEL_ASSIGN_OR_RETURN(absl::optional eq, + EqualsProvider()(lhs_i, rhs_i, descriptor_pool, + message_factory, arena)); + if (!eq.has_value() || !*eq) { + return eq; + } + } + return true; +} + +// Opaque types only support heterogeneous equality, and by extension that means +// optionals. Heterogeneous equality being enabled is enforced by +// `EnableOptionalTypes`. +absl::StatusOr> OpaqueEqual( + const OpaqueValue& lhs, const OpaqueValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + Value result; + CEL_RETURN_IF_ERROR( + lhs.Equal(rhs, descriptor_pool, message_factory, arena, &result)); + if (auto bool_value = result.AsBool(); bool_value) { + return bool_value->NativeValue(); + } + return TypeConversionError(result.GetTypeName(), "bool").NativeValue(); +} + +absl::optional NumberFromValue(const Value& value) { + if (value.Is()) { + return Number::FromInt64(value.GetInt().NativeValue()); + } else if (value.Is()) { + return Number::FromUint64(value.GetUint().NativeValue()); + } else if (value.Is()) { + return Number::FromDouble(value.GetDouble().NativeValue()); + } + + return absl::nullopt; +} + +absl::StatusOr> CheckAlternativeNumericType( + const Value& key, const MapValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + absl::optional number = NumberFromValue(key); + + if (!number.has_value()) { + return absl::nullopt; + } + + if (!key.IsInt() && number->LosslessConvertibleToInt()) { + absl::optional entry; + CEL_ASSIGN_OR_RETURN(entry, + rhs.Find(IntValue(number->AsInt()), descriptor_pool, + message_factory, arena)); + if (entry) { + return entry; + } + } + + if (!key.IsUint() && number->LosslessConvertibleToUint()) { + absl::optional entry; + CEL_ASSIGN_OR_RETURN(entry, + rhs.Find(UintValue(number->AsUint()), descriptor_pool, + message_factory, arena)); + if (entry) { + return entry; + } + } + + return absl::nullopt; +} + +// Equality for maps. Template parameter provides either heterogeneous or +// homogenous equality for comparing values. +template +absl::StatusOr> MapEqual( + const MapValue& lhs, const MapValue& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (&lhs == &rhs) { + return true; + } + if (lhs.Size() != rhs.Size()) { + return false; + } + + CEL_ASSIGN_OR_RETURN(auto iter, lhs.NewIterator()); + + while (iter->HasNext()) { + CEL_ASSIGN_OR_RETURN(auto lhs_key, + iter->Next(descriptor_pool, message_factory, arena)); + + absl::optional entry; + CEL_ASSIGN_OR_RETURN( + entry, rhs.Find(lhs_key, descriptor_pool, message_factory, arena)); + + if (!entry && EqualsProvider::kIsHeterogeneous) { + CEL_ASSIGN_OR_RETURN( + entry, CheckAlternativeNumericType(lhs_key, rhs, descriptor_pool, + message_factory, arena)); + } + if (!entry) { + return false; + } + + CEL_ASSIGN_OR_RETURN(auto lhs_value, lhs.Get(lhs_key, descriptor_pool, + message_factory, arena)); + CEL_ASSIGN_OR_RETURN(absl::optional eq, + EqualsProvider()(lhs_value, *entry, descriptor_pool, + message_factory, arena)); + + if (!eq.has_value() || !*eq) { + return eq; + } + } + + return true; +} + +// Helper for wrapping ==/!= implementations. +// Name should point to a static constexpr string so the lambda capture is safe. +template +std::function +WrapComparison(Op op, absl::string_view name) { + return [op = std::move(op), name]( + Type lhs, Type rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> Value { + absl::optional result = op(lhs, rhs); + + if (result.has_value()) { + return BoolValue(*result); + } + + return ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError(name)); + }; +} + +// Helper method +// +// Registers all equality functions for template parameters type. +template +absl::Status RegisterEqualityFunctionsForType(cel::FunctionRegistry& registry) { + using FunctionAdapter = + cel::RegisterHelper>; + // Inequality + CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( + kInequal, WrapComparison(&Inequal, kInequal), registry)); + + // Equality + CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( + kEqual, WrapComparison(&Equal, kEqual), registry)); + + return absl::OkStatus(); +} + +template +auto ComplexEquality(Op&& op) { + return [op = std::forward(op)]( + const Type& t1, const Type& t2, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + CEL_ASSIGN_OR_RETURN(absl::optional result, + op(t1, t2, descriptor_pool, message_factory, arena)); + if (!result.has_value()) { + return ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError(kEqual)); + } + return BoolValue(*result); + }; +} + +template +auto ComplexInequality(Op&& op) { + return [op = std::forward(op)]( + Type t1, Type t2, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { + CEL_ASSIGN_OR_RETURN(absl::optional result, + op(t1, t2, descriptor_pool, message_factory, arena)); + if (!result.has_value()) { + return ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError(kInequal)); + } + return BoolValue(!*result); + }; +} + +template +absl::Status RegisterComplexEqualityFunctionsForType( + absl::FunctionRef>( + Type, Type, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull)> + op, + cel::FunctionRegistry& registry) { + using FunctionAdapter = cel::RegisterHelper< + BinaryFunctionAdapter, Type, Type>>; + // Inequality + CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( + kInequal, ComplexInequality(op), registry)); + + // Equality + CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( + kEqual, ComplexEquality(op), registry)); + + return absl::OkStatus(); +} + +absl::Status RegisterHomogenousEqualityFunctions( + cel::FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterEqualityFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterComplexEqualityFunctionsForType( + &ListEqual, registry)); + + CEL_RETURN_IF_ERROR( + RegisterComplexEqualityFunctionsForType( + &MapEqual, registry)); + + return absl::OkStatus(); +} + +absl::Status RegisterNullMessageEqualityFunctions(FunctionRegistry& registry) { + // equals + CEL_RETURN_IF_ERROR( + (cel::RegisterHelper< + BinaryFunctionAdapter>:: + RegisterGlobalOverload( + kEqual, + [](const StructValue&, const NullValue&) { return false; }, + registry))); + + CEL_RETURN_IF_ERROR( + (cel::RegisterHelper< + BinaryFunctionAdapter>:: + RegisterGlobalOverload( + kEqual, + [](const NullValue&, const StructValue&) { return false; }, + registry))); + + // inequals + CEL_RETURN_IF_ERROR( + (cel::RegisterHelper< + BinaryFunctionAdapter>:: + RegisterGlobalOverload( + kInequal, + [](const StructValue&, const NullValue&) { return true; }, + registry))); + + return cel::RegisterHelper< + BinaryFunctionAdapter>:: + RegisterGlobalOverload( + kInequal, [](const NullValue&, const StructValue&) { return true; }, + registry); +} + +template +absl::StatusOr> HomogenousValueEqual( + const Value& v1, const Value& v2, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (v1.kind() != v2.kind()) { + return absl::nullopt; + } + + static_assert(std::is_lvalue_reference_v, + "unexpected value copy"); + + switch (v1->kind()) { + case ValueKind::kBool: + return Equal(v1.GetBool().NativeValue(), + v2.GetBool().NativeValue()); + case ValueKind::kNull: + return Equal(v1.GetNull(), v2.GetNull()); + case ValueKind::kInt: + return Equal(v1.GetInt().NativeValue(), + v2.GetInt().NativeValue()); + case ValueKind::kUint: + return Equal(v1.GetUint().NativeValue(), + v2.GetUint().NativeValue()); + case ValueKind::kDouble: + return Equal(v1.GetDouble().NativeValue(), + v2.GetDouble().NativeValue()); + case ValueKind::kDuration: + return Equal(v1.GetDuration().NativeValue(), + v2.GetDuration().NativeValue()); + case ValueKind::kTimestamp: + return Equal(v1.GetTimestamp().NativeValue(), + v2.GetTimestamp().NativeValue()); + case ValueKind::kCelType: + return Equal(v1.GetType(), v2.GetType()); + case ValueKind::kString: + return Equal(v1.GetString(), v2.GetString()); + case ValueKind::kBytes: + return Equal(v1.GetBytes(), v2.GetBytes()); + case ValueKind::kList: + return ListEqual(v1.GetList(), v2.GetList(), + descriptor_pool, message_factory, arena); + case ValueKind::kMap: + return MapEqual(v1.GetMap(), v2.GetMap(), descriptor_pool, + message_factory, arena); + case ValueKind::kOpaque: + return OpaqueEqual(v1.GetOpaque(), v2.GetOpaque(), descriptor_pool, + message_factory, arena); + default: + return absl::nullopt; + } +} + +absl::StatusOr EqualOverloadImpl( + const Value& lhs, const Value& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(absl::optional result, + runtime_internal::ValueEqualImpl( + lhs, rhs, descriptor_pool, message_factory, arena)); + if (result.has_value()) { + return BoolValue(*result); + } + return ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError(kEqual)); +} + +absl::StatusOr InequalOverloadImpl( + const Value& lhs, const Value& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + CEL_ASSIGN_OR_RETURN(absl::optional result, + runtime_internal::ValueEqualImpl( + lhs, rhs, descriptor_pool, message_factory, arena)); + if (result.has_value()) { + return BoolValue(!*result); + } + return ErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError(kInequal)); +} + +absl::Status RegisterHeterogeneousEqualityFunctions( + cel::FunctionRegistry& registry) { + using Adapter = cel::RegisterHelper< + BinaryFunctionAdapter, const Value&, const Value&>>; + CEL_RETURN_IF_ERROR( + Adapter::RegisterGlobalOverload(kEqual, &EqualOverloadImpl, registry)); + + CEL_RETURN_IF_ERROR(Adapter::RegisterGlobalOverload( + kInequal, &InequalOverloadImpl, registry)); + + return absl::OkStatus(); +} + +absl::StatusOr> HomogenousEqualProvider::operator()( + const Value& lhs, const Value& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + return HomogenousValueEqual( + lhs, rhs, descriptor_pool, message_factory, arena); +} + +absl::StatusOr> HeterogeneousEqualProvider::operator()( + const Value& lhs, const Value& rhs, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + return runtime_internal::ValueEqualImpl(lhs, rhs, descriptor_pool, + message_factory, arena); +} + +} // namespace + +namespace runtime_internal { + +absl::StatusOr> ValueEqualImpl( + const Value& v1, const Value& v2, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (v1.kind() == v2.kind()) { + if (v1.IsStruct() && v2.IsStruct()) { + CEL_ASSIGN_OR_RETURN( + Value result, + v1.GetStruct().Equal(v2, descriptor_pool, message_factory, arena)); + if (result.IsBool()) { + return result.GetBool().NativeValue(); + } + return false; + } + return HomogenousValueEqual( + v1, v2, descriptor_pool, message_factory, arena); + } + + absl::optional lhs = NumberFromValue(v1); + absl::optional rhs = NumberFromValue(v2); + + if (rhs.has_value() && lhs.has_value()) { + return *lhs == *rhs; + } + + // TODO(uncreated-issue/6): It's currently possible for the interpreter to create a + // map containing an Error. Return no matching overload to propagate an error + // instead of a false result. + if (v1.IsError() || v1.IsUnknown() || v2.IsError() || v2.IsUnknown()) { + return absl::nullopt; + } + + return false; +} + +} // namespace runtime_internal + +absl::Status RegisterEqualityFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + if (options.enable_heterogeneous_equality) { + if (options.enable_fast_builtins) { + // If enabled, the evaluator provides an implementation that works + // directly on the value stack. + return absl::OkStatus(); + } + // Heterogeneous equality uses one generic overload that delegates to the + // right equality implementation at runtime. + CEL_RETURN_IF_ERROR(RegisterHeterogeneousEqualityFunctions(registry)); + } else { + CEL_RETURN_IF_ERROR(RegisterHomogenousEqualityFunctions(registry)); + + CEL_RETURN_IF_ERROR(RegisterNullMessageEqualityFunctions(registry)); + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/runtime/standard/equality_functions.h b/runtime/standard/equality_functions.h new file mode 100644 index 000000000..159423e50 --- /dev/null +++ b/runtime/standard/equality_functions.h @@ -0,0 +1,60 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_EQUALITY_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_EQUALITY_FUNCTIONS_H_ + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/value.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace runtime_internal { +// Exposed implementation for == operator. This is used to implement other +// runtime functions. +// +// Nullopt is returned if the comparison is undefined (e.g. special value types +// error and unknown). +absl::StatusOr> ValueEqualImpl( + const Value& v1, const Value& v2, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena); +} // namespace runtime_internal + +// Register equality functions +// ==, != +// +// options.enable_heterogeneous_equality controls which flavor of equality is +// used. +// +// For legacy equality (.enable_heterogeneous_equality = false), equality is +// defined between same-typed values only. +// +// For the CEL specification's definition of equality +// (.enable_heterogeneous_equality = true), equality is defined between most +// types, with false returned if the two different types are incomparable. +absl::Status RegisterEqualityFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_EQUALITY_FUNCTIONS_H_ diff --git a/runtime/standard/equality_functions_test.cc b/runtime/standard/equality_functions_test.cc new file mode 100644 index 000000000..605c66d82 --- /dev/null +++ b/runtime/standard/equality_functions_test.cc @@ -0,0 +1,160 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/equality_functions.h" + +#include + +#include "absl/status/status_matchers.h" +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "internal/testing.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { + const FunctionDescriptor& descriptor = *arg; + const std::vector& types = expected_kinds; + return descriptor.name() == name && descriptor.receiver_style() == receiver && + descriptor.types() == types; +} + +TEST(RegisterEqualityFunctionsHomogeneous, RegistersEqualOperators) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_heterogeneous_equality = false; + + ASSERT_THAT(RegisterEqualityFunctions(registry, options), IsOk()); + auto overloads = registry.ListFunctions(); + EXPECT_THAT( + overloads[builtin::kEqual], + UnorderedElementsAre( + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kList, Kind::kList}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kMap, Kind::kMap}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kBool, Kind::kBool}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kInt, Kind::kInt}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kUint, Kind::kUint}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kDouble, Kind::kDouble}), + MatchesDescriptor( + builtin::kEqual, false, + std::vector{Kind::kDuration, Kind::kDuration}), + MatchesDescriptor( + builtin::kEqual, false, + std::vector{Kind::kTimestamp, Kind::kTimestamp}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kString, Kind::kString}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kBytes, Kind::kBytes}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kType, Kind::kType}), + // Structs comparable to null, but struct == struct undefined. + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kStruct, Kind::kNullType}), + MatchesDescriptor(builtin::kEqual, false, + std::vector{Kind::kNullType, Kind::kStruct}), + MatchesDescriptor( + builtin::kEqual, false, + std::vector{Kind::kNullType, Kind::kNullType}))); + + EXPECT_THAT( + overloads[builtin::kInequal], + UnorderedElementsAre( + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kList, Kind::kList}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kMap, Kind::kMap}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kBool, Kind::kBool}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kInt, Kind::kInt}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kUint, Kind::kUint}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kDouble, Kind::kDouble}), + MatchesDescriptor( + builtin::kInequal, false, + std::vector{Kind::kDuration, Kind::kDuration}), + MatchesDescriptor( + builtin::kInequal, false, + std::vector{Kind::kTimestamp, Kind::kTimestamp}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kString, Kind::kString}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kBytes, Kind::kBytes}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kType, Kind::kType}), + // Structs comparable to null, but struct != struct undefined. + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kStruct, Kind::kNullType}), + MatchesDescriptor(builtin::kInequal, false, + std::vector{Kind::kNullType, Kind::kStruct}), + MatchesDescriptor( + builtin::kInequal, false, + std::vector{Kind::kNullType, Kind::kNullType}))); +} + +TEST(RegisterEqualityFunctionsHeterogeneous, RegistersEqualOperators) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_heterogeneous_equality = true; + options.enable_fast_builtins = false; + + ASSERT_THAT(RegisterEqualityFunctions(registry, options), IsOk()); + auto overloads = registry.ListFunctions(); + + EXPECT_THAT( + overloads[builtin::kEqual], + UnorderedElementsAre(MatchesDescriptor( + builtin::kEqual, false, std::vector{Kind::kAny, Kind::kAny}))); + + EXPECT_THAT(overloads[builtin::kInequal], + UnorderedElementsAre(MatchesDescriptor( + builtin::kInequal, false, + std::vector{Kind::kAny, Kind::kAny}))); +} + +TEST(RegisterEqualityFunctionsHeterogeneous, + NotRegisteredWhenFastBuiltinsEnabled) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_heterogeneous_equality = true; + options.enable_fast_builtins = true; + + ASSERT_THAT(RegisterEqualityFunctions(registry, options), IsOk()); + auto overloads = registry.ListFunctions(); + + EXPECT_THAT(overloads[builtin::kEqual], IsEmpty()); + + EXPECT_THAT(overloads[builtin::kInequal], IsEmpty()); +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard/logical_functions.cc b/runtime/standard/logical_functions.cc new file mode 100644 index 000000000..cd3dd3cb5 --- /dev/null +++ b/runtime/standard/logical_functions.cc @@ -0,0 +1,66 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/logical_functions.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/internal/errors.h" +#include "runtime/register_function_helper.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace { + +using ::cel::runtime_internal::CreateNoMatchingOverloadError; + +Value NotStrictlyFalseImpl(const Value& value) { + if (value.IsBool()) { + return value; + } + + if (value.IsError() || value.IsUnknown()) { + return TrueValue(); + } + + // Should only accept bool unknown or error. + return ErrorValue(CreateNoMatchingOverloadError(builtin::kNotStrictlyFalse)); +} + +} // namespace + +absl::Status RegisterLogicalFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + // logical NOT + CEL_RETURN_IF_ERROR( + (RegisterHelper>::RegisterGlobalOverload( + builtin::kNot, [](bool value) -> bool { return !value; }, registry))); + + // Strictness + using StrictnessHelper = RegisterHelper>; + CEL_RETURN_IF_ERROR(StrictnessHelper::RegisterNonStrictOverload( + builtin::kNotStrictlyFalse, &NotStrictlyFalseImpl, registry)); + + CEL_RETURN_IF_ERROR(StrictnessHelper::RegisterNonStrictOverload( + builtin::kNotStrictlyFalseDeprecated, &NotStrictlyFalseImpl, registry)); + + return absl::OkStatus(); +} + +} // namespace cel diff --git a/runtime/standard/logical_functions.h b/runtime/standard/logical_functions.h new file mode 100644 index 000000000..5061b6f7f --- /dev/null +++ b/runtime/standard/logical_functions.h @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_LOGICAL_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_LOGICAL_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register logical operators ! and @not_strictly_false. +// +// &&, ||, ?: are special cased by the interpreter (not implemented via the +// function registry.) +// +// Most users should use RegisterBuiltinFunctions, which includes these +// definitions. +absl::Status RegisterLogicalFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_LOGICAL_FUNCTIONS_H_ diff --git a/runtime/standard/logical_functions_test.cc b/runtime/standard/logical_functions_test.cc new file mode 100644 index 000000000..de50f5312 --- /dev/null +++ b/runtime/standard/logical_functions_test.cc @@ -0,0 +1,189 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/logical_functions.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "common/kind.h" +#include "common/value.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "runtime/function.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Matcher; +using ::testing::Truly; + +MATCHER_P3(DescriptorIs, name, arg_kinds, is_receiver, "") { + const FunctionOverloadReference& ref = arg; + const FunctionDescriptor& descriptor = ref.descriptor; + return descriptor.name() == name && + descriptor.ShapeMatches(is_receiver, arg_kinds); +} + +MATCHER_P(IsBool, expected, "") { + const Value& value = arg; + return value->Is() && value.GetBool().NativeValue() == expected; +} + +// TODO(uncreated-issue/48): replace this with a parsed expr when the non-protobuf +// parser is available. +absl::StatusOr TestDispatchToFunction( + const FunctionRegistry& registry, absl::string_view simple_name, + absl::Span args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + std::vector arg_matcher_; + arg_matcher_.reserve(args.size()); + for (const auto& value : args) { + arg_matcher_.push_back(ValueKindToKind(value->kind())); + } + std::vector refs = registry.FindStaticOverloads( + simple_name, /*receiver_style=*/false, arg_matcher_); + + if (refs.size() != 1) { + return absl::InvalidArgumentError("ambiguous overloads"); + } + + return refs[0].implementation.Invoke(args, descriptor_pool, message_factory, + arena); +} + +TEST(RegisterLogicalFunctions, NotStrictlyFalseRegistered) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterLogicalFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kNotStrictlyFalse, + /*receiver_style=*/false, {Kind::kAny}), + ElementsAre(DescriptorIs(builtin::kNotStrictlyFalse, + std::vector{Kind::kBool}, false))); +} + +TEST(RegisterLogicalFunctions, LogicalNotRegistered) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterLogicalFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kNot, + /*receiver_style=*/false, {Kind::kAny}), + ElementsAre( + DescriptorIs(builtin::kNot, std::vector{Kind::kBool}, false))); +} + +struct TestCase { + using ArgumentFactory = std::function()>; + + std::string function; + ArgumentFactory arguments; + absl::StatusOr> result_matcher; +}; + +class LogicalFunctionsTest : public testing::TestWithParam { + protected: + google::protobuf::Arena arena_; +}; + +TEST_P(LogicalFunctionsTest, Runner) { + const TestCase& test_case = GetParam(); + cel::FunctionRegistry registry; + + ASSERT_OK(RegisterLogicalFunctions(registry, RuntimeOptions())); + + std::vector args = test_case.arguments(); + + absl::StatusOr result = TestDispatchToFunction( + registry, test_case.function, args, + cel::internal::GetTestingDescriptorPool(), + cel::internal::GetTestingMessageFactory(), &arena_); + + EXPECT_EQ(result.ok(), test_case.result_matcher.ok()); + + if (!test_case.result_matcher.ok()) { + EXPECT_EQ(result.status().code(), test_case.result_matcher.status().code()); + EXPECT_THAT(result.status().message(), + HasSubstr(test_case.result_matcher.status().message())); + } else { + ASSERT_TRUE(result.ok()) << "unexpected error" << result.status(); + EXPECT_THAT(*result, *test_case.result_matcher); + } +} + +INSTANTIATE_TEST_SUITE_P( + Cases, LogicalFunctionsTest, + testing::ValuesIn(std::vector{ + TestCase{builtin::kNot, + []() -> std::vector { return {BoolValue(true)}; }, + IsBool(false)}, + TestCase{builtin::kNot, + []() -> std::vector { return {BoolValue(false)}; }, + IsBool(true)}, + TestCase{builtin::kNot, + []() -> std::vector { + return {BoolValue(true), BoolValue(false)}; + }, + absl::InvalidArgumentError("")}, + TestCase{builtin::kNotStrictlyFalse, + []() -> std::vector { return {BoolValue(true)}; }, + IsBool(true)}, + TestCase{builtin::kNotStrictlyFalse, + []() -> std::vector { return {BoolValue(false)}; }, + IsBool(false)}, + TestCase{builtin::kNotStrictlyFalse, + []() -> std::vector { + return {ErrorValue(absl::InternalError("test"))}; + }, + IsBool(true)}, + TestCase{builtin::kNotStrictlyFalse, + []() -> std::vector { return {UnknownValue()}; }, + IsBool(true)}, + TestCase{builtin::kNotStrictlyFalse, + []() -> std::vector { return {IntValue(42)}; }, + Truly([](const Value& v) { + return v->Is() && + absl::StrContains( + v.GetError().NativeValue().message(), + "No matching overloads"); + })}, + })); + +} // namespace +} // namespace cel diff --git a/runtime/standard/regex_functions.cc b/runtime/standard/regex_functions.cc new file mode 100644 index 000000000..6833f7804 --- /dev/null +++ b/runtime/standard/regex_functions.cc @@ -0,0 +1,56 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "runtime/standard/regex_functions.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/re2_options.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "re2/re2.h" + +namespace cel { +namespace {} // namespace + +absl::Status RegisterRegexFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + if (options.enable_regex) { + auto regex_matches = [max_size = options.regex_max_program_size]( + const StringValue& target, + const StringValue& regex) -> Value { + RE2 re2(regex.ToString(), cel::internal::MakeRE2Options()); + CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, max_size)) + .With(ErrorValueReturn()); + return BoolValue(RE2::PartialMatch(target.ToString(), re2)); + }; + + // bind str.matches(re) and matches(str, re) + for (bool receiver_style : {true, false}) { + using MatchFnAdapter = + BinaryFunctionAdapter; + CEL_RETURN_IF_ERROR( + registry.Register(MatchFnAdapter::CreateDescriptor( + cel::builtin::kRegexMatch, receiver_style), + MatchFnAdapter::WrapFunction(regex_matches))); + } + } // if options.enable_regex + + return absl::OkStatus(); +} + +} // namespace cel diff --git a/runtime/standard/regex_functions.h b/runtime/standard/regex_functions.h new file mode 100644 index 000000000..2be7568e2 --- /dev/null +++ b/runtime/standard/regex_functions.h @@ -0,0 +1,38 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_REGEX_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_REGEX_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register builtin regex functions: +// +// (string).matches(re:string) -> bool +// matches(string, re:string) -> bool +// +// These are implemented with RE2. +// +// Most users should use RegisterBuiltinFunctions, which includes these +// definitions. +absl::Status RegisterRegexFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_REGEX_FUNCTIONS_H_ diff --git a/runtime/standard/regex_functions_test.cc b/runtime/standard/regex_functions_test.cc new file mode 100644 index 000000000..59bbe9abf --- /dev/null +++ b/runtime/standard/regex_functions_test.cc @@ -0,0 +1,77 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "runtime/standard/regex_functions.h" + +#include + +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +enum class CallStyle { kFree, kReceiver }; + +MATCHER_P2(MatchesDescriptor, name, call_style, "") { + bool receiver_style; + switch (call_style) { + case CallStyle::kReceiver: + receiver_style = true; + break; + case CallStyle::kFree: + receiver_style = false; + break; + } + const FunctionDescriptor& descriptor = *arg; + std::vector types{Kind::kString, Kind::kString}; + return descriptor.name() == name && + descriptor.receiver_style() == receiver_style && + descriptor.types() == types; +} + +TEST(RegisterRegexFunctions, Registered) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterRegexFunctions(registry, options)); + + auto overloads = registry.ListFunctions(); + + EXPECT_THAT(overloads[builtin::kRegexMatch], + UnorderedElementsAre( + MatchesDescriptor(builtin::kRegexMatch, CallStyle::kReceiver), + MatchesDescriptor(builtin::kRegexMatch, CallStyle::kFree))); +} + +TEST(RegisterRegexFunctions, NotRegisteredIfDisabled) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_regex = false; + + ASSERT_OK(RegisterRegexFunctions(registry, options)); + + auto overloads = registry.ListFunctions(); + + EXPECT_THAT(overloads[builtin::kRegexMatch], IsEmpty()); +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard/string_functions.cc b/runtime/standard/string_functions.cc new file mode 100644 index 000000000..2bcfe185c --- /dev/null +++ b/runtime/standard/string_functions.cc @@ -0,0 +1,140 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/string_functions.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { +namespace { + +// Concatenation for string type. +absl::StatusOr ConcatString( + const StringValue& value1, const StringValue& value2, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { + return StringValue::Concat(value1, value2, arena); +} + +// Concatenation for bytes type. +absl::StatusOr ConcatBytes( + const BytesValue& value1, const BytesValue& value2, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { + return BytesValue::Concat(value1, value2, arena); +} + +bool StringContains(const StringValue& value, const StringValue& substr) { + return value.Contains(substr); +} + +bool StringEndsWith(const StringValue& value, const StringValue& suffix) { + return value.EndsWith(suffix); +} + +bool StringStartsWith(const StringValue& value, const StringValue& prefix) { + return value.StartsWith(prefix); +} + +absl::Status RegisterSizeFunctions(FunctionRegistry& registry) { + // String size + auto size_func = [](const StringValue& value) -> int64_t { + return value.Size(); + }; + + // Support global and receiver style size() operations on strings. + using StrSizeFnAdapter = UnaryFunctionAdapter; + CEL_RETURN_IF_ERROR(StrSizeFnAdapter::RegisterGlobalOverload( + cel::builtin::kSize, size_func, registry)); + + CEL_RETURN_IF_ERROR(StrSizeFnAdapter::RegisterMemberOverload( + cel::builtin::kSize, size_func, registry)); + + // Bytes size + auto bytes_size_func = [](const BytesValue& value) -> int64_t { + return value.Size(); + }; + + // Support global and receiver style size() operations on bytes. + using BytesSizeFnAdapter = UnaryFunctionAdapter; + CEL_RETURN_IF_ERROR(BytesSizeFnAdapter::RegisterGlobalOverload( + cel::builtin::kSize, bytes_size_func, registry)); + + return BytesSizeFnAdapter::RegisterMemberOverload(cel::builtin::kSize, + bytes_size_func, registry); +} + +absl::Status RegisterConcatFunctions(FunctionRegistry& registry) { + using StrCatFnAdapter = + BinaryFunctionAdapter, const StringValue&, + const StringValue&>; + CEL_RETURN_IF_ERROR(StrCatFnAdapter::RegisterGlobalOverload( + cel::builtin::kAdd, &ConcatString, registry)); + + using BytesCatFnAdapter = + BinaryFunctionAdapter, const BytesValue&, + const BytesValue&>; + return BytesCatFnAdapter::RegisterGlobalOverload(cel::builtin::kAdd, + &ConcatBytes, registry); +} + +} // namespace + +absl::Status RegisterStringFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + // Basic substring tests (contains, startsWith, endsWith) + for (bool receiver_style : {true, false}) { + auto status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringContains, receiver_style, + StringContains, registry); + CEL_RETURN_IF_ERROR(status); + + status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringEndsWith, receiver_style, + StringEndsWith, registry); + CEL_RETURN_IF_ERROR(status); + + status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringStartsWith, receiver_style, + StringStartsWith, registry); + CEL_RETURN_IF_ERROR(status); + } + + // string concatenation if enabled + if (options.enable_string_concat) { + CEL_RETURN_IF_ERROR(RegisterConcatFunctions(registry)); + } + + return RegisterSizeFunctions(registry); +} + +} // namespace cel diff --git a/runtime/standard/string_functions.h b/runtime/standard/string_functions.h new file mode 100644 index 000000000..aa7fb7b6e --- /dev/null +++ b/runtime/standard/string_functions.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_STRING_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_STRING_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register builtin string and bytes functions: +// _+_ (concatenation), size, contains, startsWith, endsWith + +// Most users should use RegisterBuiltinFunctions, which includes these +// definitions. +absl::Status RegisterStringFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_STRING_FUNCTIONS_H_ diff --git a/runtime/standard/string_functions_test.cc b/runtime/standard/string_functions_test.cc new file mode 100644 index 000000000..d520b3577 --- /dev/null +++ b/runtime/standard/string_functions_test.cc @@ -0,0 +1,114 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "runtime/standard/string_functions.h" + +#include + +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +enum class CallStyle { kFree, kReceiver }; + +MATCHER_P3(MatchesDescriptor, name, call_style, expected_kinds, "") { + bool receiver_style; + switch (call_style) { + case CallStyle::kFree: + receiver_style = false; + break; + case CallStyle::kReceiver: + receiver_style = true; + break; + } + const FunctionDescriptor& descriptor = *arg; + const std::vector& types = expected_kinds; + return descriptor.name() == name && + descriptor.receiver_style() == receiver_style && + descriptor.types() == types; +} + +TEST(RegisterStringFunctions, FunctionsRegistered) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterStringFunctions(registry, options)); + auto overloads = registry.ListFunctions(); + + EXPECT_THAT( + overloads[builtin::kAdd], + UnorderedElementsAre( + MatchesDescriptor(builtin::kAdd, CallStyle::kFree, + std::vector{Kind::kString, Kind::kString}), + MatchesDescriptor(builtin::kAdd, CallStyle::kFree, + std::vector{Kind::kBytes, Kind::kBytes}))); + + EXPECT_THAT(overloads[builtin::kSize], + UnorderedElementsAre( + MatchesDescriptor(builtin::kSize, CallStyle::kFree, + std::vector{Kind::kString}), + MatchesDescriptor(builtin::kSize, CallStyle::kFree, + std::vector{Kind::kBytes}), + MatchesDescriptor(builtin::kSize, CallStyle::kReceiver, + std::vector{Kind::kString}), + MatchesDescriptor(builtin::kSize, CallStyle::kReceiver, + std::vector{Kind::kBytes}))); + + EXPECT_THAT( + overloads[builtin::kStringContains], + UnorderedElementsAre( + MatchesDescriptor(builtin::kStringContains, CallStyle::kFree, + std::vector{Kind::kString, Kind::kString}), + + MatchesDescriptor(builtin::kStringContains, CallStyle::kReceiver, + std::vector{Kind::kString, Kind::kString}))); + EXPECT_THAT( + overloads[builtin::kStringStartsWith], + UnorderedElementsAre( + MatchesDescriptor(builtin::kStringStartsWith, CallStyle::kFree, + std::vector{Kind::kString, Kind::kString}), + + MatchesDescriptor(builtin::kStringStartsWith, CallStyle::kReceiver, + std::vector{Kind::kString, Kind::kString}))); + EXPECT_THAT( + overloads[builtin::kStringEndsWith], + UnorderedElementsAre( + MatchesDescriptor(builtin::kStringEndsWith, CallStyle::kFree, + std::vector{Kind::kString, Kind::kString}), + + MatchesDescriptor(builtin::kStringEndsWith, CallStyle::kReceiver, + std::vector{Kind::kString, Kind::kString}))); +} + +TEST(RegisterStringFunctions, ConcatSkippedWhenDisabled) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_string_concat = false; + + ASSERT_OK(RegisterStringFunctions(registry, options)); + auto overloads = registry.ListFunctions(); + + EXPECT_THAT(overloads[builtin::kAdd], IsEmpty()); +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard/time_functions.cc b/runtime/standard/time_functions.cc new file mode 100644 index 000000000..a0ec5377c --- /dev/null +++ b/runtime/standard/time_functions.cc @@ -0,0 +1,499 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/time_functions.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/time/civil_time.h" +#include "absl/time/time.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/overflow.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace { + +// Timestamp +absl::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, + absl::TimeZone::CivilInfo* breakdown) { + absl::TimeZone time_zone; + + // Early return if there is no timezone. + if (tz.empty()) { + *breakdown = time_zone.At(timestamp); + return absl::OkStatus(); + } + + // Check to see whether the timezone is an IANA timezone. + if (absl::LoadTimeZone(tz, &time_zone)) { + *breakdown = time_zone.At(timestamp); + return absl::OkStatus(); + } + + // Check for times of the format: [+-]HH:MM and convert them into durations + // specified as [+-]HHhMMm. + if (absl::StrContains(tz, ":")) { + std::string dur = absl::StrCat(tz, "m"); + absl::StrReplaceAll({{":", "h"}}, &dur); + absl::Duration d; + if (absl::ParseDuration(dur, &d)) { + timestamp += d; + *breakdown = time_zone.At(timestamp); + return absl::OkStatus(); + } + } + + // Otherwise, error. + return absl::InvalidArgumentError("Invalid timezone"); +} + +Value GetTimeBreakdownPart( + absl::Time timestamp, absl::string_view tz, + const std::function& + extractor_func) { + absl::TimeZone::CivilInfo breakdown; + auto status = FindTimeBreakdown(timestamp, tz, &breakdown); + + if (!status.ok()) { + return ErrorValue(status); + } + + return IntValue(extractor_func(breakdown)); +} + +Value GetFullYear(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart(timestamp, tz, + [](const absl::TimeZone::CivilInfo& breakdown) { + return breakdown.cs.year(); + }); +} + +Value GetMonth(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart(timestamp, tz, + [](const absl::TimeZone::CivilInfo& breakdown) { + return breakdown.cs.month() - 1; + }); +} + +Value GetDayOfYear(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart( + timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { + return absl::GetYearDay(absl::CivilDay(breakdown.cs)) - 1; + }); +} + +Value GetDayOfMonth(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart(timestamp, tz, + [](const absl::TimeZone::CivilInfo& breakdown) { + return breakdown.cs.day() - 1; + }); +} + +Value GetDate(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart(timestamp, tz, + [](const absl::TimeZone::CivilInfo& breakdown) { + return breakdown.cs.day(); + }); +} + +Value GetDayOfWeek(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart( + timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { + absl::Weekday weekday = absl::GetWeekday(breakdown.cs); + + // get day of week from the date in UTC, zero-based, zero for Sunday, + // based on GetDayOfWeek CEL function definition. + int weekday_num = static_cast(weekday); + weekday_num = (weekday_num == 6) ? 0 : weekday_num + 1; + return weekday_num; + }); +} + +Value GetHours(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart(timestamp, tz, + [](const absl::TimeZone::CivilInfo& breakdown) { + return breakdown.cs.hour(); + }); +} + +Value GetMinutes(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart(timestamp, tz, + [](const absl::TimeZone::CivilInfo& breakdown) { + return breakdown.cs.minute(); + }); +} + +Value GetSeconds(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart(timestamp, tz, + [](const absl::TimeZone::CivilInfo& breakdown) { + return breakdown.cs.second(); + }); +} + +Value GetMilliseconds(absl::Time timestamp, absl::string_view tz) { + return GetTimeBreakdownPart( + timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { + return absl::ToInt64Milliseconds(breakdown.subsecond); + }); +} + +absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kFullYear, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetFullYear(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kFullYear, true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetFullYear(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kMonth, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetMonth(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor(builtin::kMonth, + true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetMonth(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kDayOfYear, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetDayOfYear(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kDayOfYear, true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetDayOfYear(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kDayOfMonth, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetDayOfMonth(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kDayOfMonth, true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetDayOfMonth(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kDate, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetDate(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor(builtin::kDate, + true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetDate(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kDayOfWeek, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetDayOfWeek(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kDayOfWeek, true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetDayOfWeek(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kHours, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetHours(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor(builtin::kHours, + true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetHours(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kMinutes, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetMinutes(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kMinutes, true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetMinutes(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kSeconds, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetSeconds(ts, tz.ToString()); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kSeconds, true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetSeconds(ts, ""); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kMilliseconds, true), + BinaryFunctionAdapter:: + WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { + return GetMilliseconds(ts, tz.ToString()); + }))); + + return registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kMilliseconds, true), + UnaryFunctionAdapter::WrapFunction( + [](absl::Time ts) -> Value { return GetMilliseconds(ts, ""); })); +} + +absl::Status RegisterCheckedTimeArithmeticFunctions( + FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, + false), + BinaryFunctionAdapter, absl::Time, absl::Duration>:: + WrapFunction( + [](absl::Time t1, absl::Duration d2) -> absl::StatusOr { + auto sum = cel::internal::CheckedAdd(t1, d2); + if (!sum.ok()) { + return ErrorValue(sum.status()); + } + return TimestampValue(*sum); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, absl::Duration, + absl::Time>::CreateDescriptor(builtin::kAdd, false), + BinaryFunctionAdapter, absl::Duration, absl::Time>:: + WrapFunction( + [](absl::Duration d2, absl::Time t1) -> absl::StatusOr { + auto sum = cel::internal::CheckedAdd(t1, d2); + if (!sum.ok()) { + return ErrorValue(sum.status()); + } + return TimestampValue(*sum); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, absl::Duration, + absl::Duration>::CreateDescriptor(builtin::kAdd, + false), + BinaryFunctionAdapter< + absl::StatusOr, absl::Duration, + absl::Duration>::WrapFunction([](absl::Duration d1, absl::Duration d2) + -> absl::StatusOr { + auto sum = cel::internal::CheckedAdd(d1, d2); + if (!sum.ok()) { + return ErrorValue(sum.status()); + } + return DurationValue(*sum); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, absl::Time, absl::Duration>:: + CreateDescriptor(builtin::kSubtract, false), + BinaryFunctionAdapter, absl::Time, absl::Duration>:: + WrapFunction( + [](absl::Time t1, absl::Duration d2) -> absl::StatusOr { + auto diff = cel::internal::CheckedSub(t1, d2); + if (!diff.ok()) { + return ErrorValue(diff.status()); + } + return TimestampValue(*diff); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, absl::Time, + absl::Time>::CreateDescriptor(builtin::kSubtract, + false), + BinaryFunctionAdapter, absl::Time, absl::Time>:: + WrapFunction( + [](absl::Time t1, absl::Time t2) -> absl::StatusOr { + auto diff = cel::internal::CheckedSub(t1, t2); + if (!diff.ok()) { + return ErrorValue(diff.status()); + } + return DurationValue(*diff); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter< + absl::StatusOr, absl::Duration, + absl::Duration>::CreateDescriptor(builtin::kSubtract, false), + BinaryFunctionAdapter< + absl::StatusOr, absl::Duration, + absl::Duration>::WrapFunction([](absl::Duration d1, absl::Duration d2) + -> absl::StatusOr { + auto diff = cel::internal::CheckedSub(d1, d2); + if (!diff.ok()) { + return ErrorValue(diff.status()); + } + return DurationValue(*diff); + }))); + + return absl::OkStatus(); +} + +absl::Status RegisterUncheckedTimeArithmeticFunctions( + FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, + false), + BinaryFunctionAdapter::WrapFunction( + [](absl::Time t1, absl::Duration d2) -> Value { + return UnsafeTimestampValue(t1 + d2); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, false), + BinaryFunctionAdapter::WrapFunction( + [](absl::Duration d2, absl::Time t1) -> Value { + return UnsafeTimestampValue(t1 + d2); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, + false), + BinaryFunctionAdapter:: + WrapFunction([](absl::Duration d1, absl::Duration d2) -> Value { + return UnsafeDurationValue(d1 + d2); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kSubtract, false), + + BinaryFunctionAdapter::WrapFunction( + + [](absl::Time t1, absl::Duration d2) -> Value { + return UnsafeTimestampValue(t1 - d2); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor( + builtin::kSubtract, false), + BinaryFunctionAdapter::WrapFunction( + + [](absl::Time t1, absl::Time t2) -> Value { + return UnsafeDurationValue(t1 - t2); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kSubtract, false), + BinaryFunctionAdapter:: + WrapFunction([](absl::Duration d1, absl::Duration d2) -> Value { + return UnsafeDurationValue(d1 - d2); + }))); + + return absl::OkStatus(); +} + +absl::Status RegisterDurationFunctions(FunctionRegistry& registry) { + // duration breakdown accessor functions + using DurationAccessorFunction = + UnaryFunctionAdapter; + CEL_RETURN_IF_ERROR(registry.Register( + DurationAccessorFunction::CreateDescriptor(builtin::kHours, true), + DurationAccessorFunction::WrapFunction( + [](absl::Duration d) -> int64_t { return absl::ToInt64Hours(d); }))); + + CEL_RETURN_IF_ERROR(registry.Register( + DurationAccessorFunction::CreateDescriptor(builtin::kMinutes, true), + DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { + return absl::ToInt64Minutes(d); + }))); + + CEL_RETURN_IF_ERROR(registry.Register( + DurationAccessorFunction::CreateDescriptor(builtin::kSeconds, true), + DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { + return absl::ToInt64Seconds(d); + }))); + + return registry.Register( + DurationAccessorFunction::CreateDescriptor(builtin::kMilliseconds, true), + DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { + constexpr int64_t millis_per_second = 1000L; + return absl::ToInt64Milliseconds(d) % millis_per_second; + })); +} + +} // namespace + +absl::Status RegisterTimeFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(RegisterTimestampFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterDurationFunctions(registry)); + + // Special arithmetic operators for Timestamp and Duration + // TODO(uncreated-issue/37): deprecate unchecked time math functions when clients no + // longer depend on them. + if (options.enable_timestamp_duration_overflow_errors) { + return RegisterCheckedTimeArithmeticFunctions(registry); + } + + return RegisterUncheckedTimeArithmeticFunctions(registry); +} + +} // namespace cel diff --git a/runtime/standard/time_functions.h b/runtime/standard/time_functions.h new file mode 100644 index 000000000..d8fc2e875 --- /dev/null +++ b/runtime/standard/time_functions.h @@ -0,0 +1,56 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TIME_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TIME_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register builtin timestamp and duration functions: +// +// (timestamp).getFullYear() -> int +// (timestamp).getMonth() -> int +// (timestamp).getDayOfYear() -> int +// (timestamp).getDayOfMonth() -> int +// (timestamp).getDayOfWeek() -> int +// (timestamp).getDate() -> int +// (timestamp).getHours() -> int +// (timestamp).getMinutes() -> int +// (timestamp).getSeconds() -> int +// (timestamp).getMilliseconds() -> int +// +// (duration).getHours() -> int +// (duration).getMinutes() -> int +// (duration).getSeconds() -> int +// (duration).getMilliseconds() -> int +// +// _+_(timestamp, duration) -> timestamp +// _+_(duration, timestamp) -> timestamp +// _+_(duration, duration) -> duration +// _-_(timestamp, timestamp) -> duration +// _-_(timestamp, duration) -> timestamp +// _-_(duration, duration) -> duration +// +// Most users should use RegisterBuiltinFunctions, which includes these +// definitions. +absl::Status RegisterTimeFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TIME_FUNCTIONS_H_ diff --git a/runtime/standard/time_functions_test.cc b/runtime/standard/time_functions_test.cc new file mode 100644 index 000000000..f578a1023 --- /dev/null +++ b/runtime/standard/time_functions_test.cc @@ -0,0 +1,150 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/time_functions.h" + +#include + +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::UnorderedElementsAre; + +MATCHER_P3(MatchesOperatorDescriptor, name, expected_kind1, expected_kind2, + "") { + const FunctionDescriptor& descriptor = *arg; + std::vector types{expected_kind1, expected_kind2}; + return descriptor.name() == name && descriptor.receiver_style() == false && + descriptor.types() == types; +} + +MATCHER_P2(MatchesTimeAccessor, name, kind, "") { + const FunctionDescriptor& descriptor = *arg; + + std::vector types{kind}; + return descriptor.name() == name && descriptor.receiver_style() == true && + descriptor.types() == types; +} + +MATCHER_P2(MatchesTimezoneTimeAccessor, name, kind, "") { + const FunctionDescriptor& descriptor = *arg; + + std::vector types{kind, Kind::kString}; + return descriptor.name() == name && descriptor.receiver_style() == true && + descriptor.types() == types; +} + +TEST(RegisterTimeFunctions, MathOperatorsRegistered) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTimeFunctions(registry, options)); + + auto registered_functions = registry.ListFunctions(); + + EXPECT_THAT(registered_functions[builtin::kAdd], + UnorderedElementsAre( + MatchesOperatorDescriptor(builtin::kAdd, Kind::kDuration, + Kind::kDuration), + MatchesOperatorDescriptor(builtin::kAdd, Kind::kTimestamp, + Kind::kDuration), + MatchesOperatorDescriptor(builtin::kAdd, Kind::kDuration, + Kind::kTimestamp))); + + EXPECT_THAT(registered_functions[builtin::kSubtract], + UnorderedElementsAre( + MatchesOperatorDescriptor(builtin::kSubtract, Kind::kDuration, + Kind::kDuration), + MatchesOperatorDescriptor(builtin::kSubtract, + Kind::kTimestamp, Kind::kDuration), + MatchesOperatorDescriptor( + builtin::kSubtract, Kind::kTimestamp, Kind::kTimestamp))); +} + +TEST(RegisterTimeFunctions, AccessorsRegistered) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTimeFunctions(registry, options)); + + auto registered_functions = registry.ListFunctions(); + EXPECT_THAT( + registered_functions[builtin::kFullYear], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kFullYear, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kFullYear, Kind::kTimestamp))); + EXPECT_THAT( + registered_functions[builtin::kDate], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kDate, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kDate, Kind::kTimestamp))); + EXPECT_THAT( + registered_functions[builtin::kMonth], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kMonth, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kMonth, Kind::kTimestamp))); + EXPECT_THAT( + registered_functions[builtin::kDayOfYear], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kDayOfYear, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kDayOfYear, Kind::kTimestamp))); + EXPECT_THAT( + registered_functions[builtin::kDayOfMonth], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kDayOfMonth, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kDayOfMonth, Kind::kTimestamp))); + EXPECT_THAT( + registered_functions[builtin::kDayOfWeek], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kDayOfWeek, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kDayOfWeek, Kind::kTimestamp))); + + EXPECT_THAT( + registered_functions[builtin::kHours], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kHours, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kHours, Kind::kTimestamp), + MatchesTimeAccessor(builtin::kHours, Kind::kDuration))); + + EXPECT_THAT( + registered_functions[builtin::kMinutes], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kMinutes, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kMinutes, Kind::kTimestamp), + MatchesTimeAccessor(builtin::kMinutes, Kind::kDuration))); + + EXPECT_THAT( + registered_functions[builtin::kSeconds], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kSeconds, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kSeconds, Kind::kTimestamp), + MatchesTimeAccessor(builtin::kSeconds, Kind::kDuration))); + + EXPECT_THAT( + registered_functions[builtin::kMilliseconds], + UnorderedElementsAre( + MatchesTimeAccessor(builtin::kMilliseconds, Kind::kTimestamp), + MatchesTimezoneTimeAccessor(builtin::kMilliseconds, Kind::kTimestamp), + MatchesTimeAccessor(builtin::kMilliseconds, Kind::kDuration))); +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard/type_conversion_functions.cc b/runtime/standard/type_conversion_functions.cc new file mode 100644 index 000000000..76e95751b --- /dev/null +++ b/runtime/standard/type_conversion_functions.cc @@ -0,0 +1,470 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/type_conversion_functions.h" + +#include +#include +#include // NOLINT (required for std::to_chars_result) +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "common/value.h" +#include "internal/overflow.h" +#include "internal/status_macros.h" +#include "internal/time.h" +#include "internal/utf8.h" +#include "runtime/function.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +#if defined(_LIBCPP_VERSION) && _LIBCPP_VERSION >= 14000 && \ + !defined(__APPLE__) || \ + defined(__GNUC__) && __GNUC__ >= 13 || \ + defined(_MSC_VER) && _MSC_VER >= 1920 +#define _CEL_CHAR_CONV_DOUBLE_TO_CHARS 1 +#endif + +namespace cel { +namespace { + +using ::cel::internal::EncodeDurationToJson; +using ::cel::internal::EncodeTimestampToJson; +using ::cel::internal::MaxTimestamp; +using ::cel::internal::MinTimestamp; + +Value FormatDouble(double v, const Function::InvokeContext& context) { + google::protobuf::Arena* arena = context.arena(); +#if defined(CEL_NO_CHARCONV_DOUBLE_TO_CHARS) || \ + !defined(_CEL_CHAR_CONV_DOUBLE_TO_CHARS) + // Fallback to absl::StrFormat. Slower and handles edge cases around precision + // differently but safe and covers most cases. + return StringValue::From(absl::StrFormat("%.17g", v), arena); +#else + constexpr int kBufSize = 32; + char buf[kBufSize]; + std::to_chars_result result = + std::to_chars(buf, buf + kBufSize, v, std::chars_format::general); + if (result.ec != std::errc()) { + return cel::ErrorValue(absl::InvalidArgumentError(absl::StrCat( + "double format error: ", std::make_error_code(result.ec).message()))); + } + absl::string_view out(buf, result.ptr - buf); + return StringValue::From(out, arena); +#endif +} + +Value LegacyFormatDouble(double v, const Function::InvokeContext& context) { + return StringValue::From(absl::StrCat(v), context.arena()); +} + +absl::Status RegisterBoolConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions&) { + // bool -> bool + absl::Status status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kBool, [](bool v) { return v; }, registry); + CEL_RETURN_IF_ERROR(status); + + // string -> bool + return UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kBool, + [](const StringValue& v) -> Value { + if ((v == "true") || (v == "True") || (v == "TRUE") || (v == "t") || + (v == "1")) { + return TrueValue(); + } else if ((v == "false") || (v == "FALSE") || (v == "False") || + (v == "f") || (v == "0")) { + return FalseValue(); + } else { + return ErrorValue(absl::InvalidArgumentError( + "Type conversion error from 'string' to 'bool'")); + } + }, + registry); +} + +absl::Status RegisterIntConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions&) { + // bool -> int + absl::Status status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kInt, [](bool v) { return static_cast(v); }, + registry); + CEL_RETURN_IF_ERROR(status); + + // double -> int + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kInt, + [](double v) -> Value { + auto conv = cel::internal::CheckedDoubleToInt64(v); + if (!conv.ok()) { + return ErrorValue(conv.status()); + } + return IntValue(*conv); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // int -> int + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kInt, [](int64_t v) { return v; }, registry); + CEL_RETURN_IF_ERROR(status); + + // string -> int + status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kInt, + [](const StringValue& s) -> Value { + int64_t result; + if (!absl::SimpleAtoi(s.ToString(), &result)) { + return ErrorValue( + absl::InvalidArgumentError("cannot convert string to int")); + } + return IntValue(result); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // time -> int + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kInt, [](absl::Time t) { return absl::ToUnixSeconds(t); }, + registry); + CEL_RETURN_IF_ERROR(status); + + // uint -> int + return UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kInt, + [](uint64_t v) -> Value { + auto conv = cel::internal::CheckedUint64ToInt64(v); + if (!conv.ok()) { + return ErrorValue(conv.status()); + } + return IntValue(*conv); + }, + registry); +} + +absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + // May be optionally disabled to reduce potential allocs. + if (!options.enable_string_conversion) { + return absl::OkStatus(); + } + + absl::Status status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + + [](const BytesValue& value) -> Value { + auto valid = value.NativeValue([](const auto& value) -> bool { + return internal::Utf8IsValid(value); + }); + if (!valid) { + return ErrorValue( + absl::InvalidArgumentError("malformed UTF-8 bytes")); + } + return StringValue(value.ToString()); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // bool -> string + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + [](bool value) -> StringValue { + return StringValue(value ? "true" : "false"); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // double -> string + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + (options.enable_precision_preserving_double_format ? &FormatDouble + : &LegacyFormatDouble), + registry); + CEL_RETURN_IF_ERROR(status); + + // int -> string + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + [](int64_t value) -> StringValue { + return StringValue(absl::StrCat(value)); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // string -> string + status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + [](StringValue value) -> StringValue { return value; }, registry); + CEL_RETURN_IF_ERROR(status); + + // uint -> string + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + [](uint64_t value) -> StringValue { + return StringValue(absl::StrCat(value)); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // duration -> string + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + [](absl::Duration value) -> Value { + auto encode = EncodeDurationToJson(value); + if (!encode.ok()) { + return ErrorValue(encode.status()); + } + return StringValue(*encode); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // timestamp -> string + return UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kString, + [](absl::Time value) -> Value { + auto encode = EncodeTimestampToJson(value); + if (!encode.ok()) { + return ErrorValue(encode.status()); + } + return StringValue(*encode); + }, + registry); +} + +absl::Status RegisterUintConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions&) { + // double -> uint + absl::Status status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kUint, + [](double v) -> Value { + auto conv = cel::internal::CheckedDoubleToUint64(v); + if (!conv.ok()) { + return ErrorValue(conv.status()); + } + return UintValue(*conv); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // int -> uint + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kUint, + [](int64_t v) -> Value { + auto conv = cel::internal::CheckedInt64ToUint64(v); + if (!conv.ok()) { + return ErrorValue(conv.status()); + } + return UintValue(*conv); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // string -> uint + status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kUint, + [](const StringValue& s) -> Value { + uint64_t result; + if (!absl::SimpleAtoi(s.ToString(), &result)) { + return ErrorValue( + absl::InvalidArgumentError("cannot convert string to uint")); + } + return UintValue(result); + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // uint -> uint + return UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kUint, [](uint64_t v) { return v; }, registry); +} + +absl::Status RegisterBytesConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions&) { + // bytes -> bytes + absl::Status status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kBytes, + + [](BytesValue value) -> BytesValue { return value; }, registry); + CEL_RETURN_IF_ERROR(status); + + // string -> bytes + return UnaryFunctionAdapter, const StringValue&>:: + RegisterGlobalOverload( + cel::builtin::kBytes, + [](const StringValue& value) { return BytesValue(value.ToString()); }, + registry); +} + +absl::Status RegisterDoubleConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions&) { + // double -> double + absl::Status status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kDouble, [](double v) { return v; }, registry); + CEL_RETURN_IF_ERROR(status); + + // int -> double + status = UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kDouble, [](int64_t v) { return static_cast(v); }, + registry); + CEL_RETURN_IF_ERROR(status); + + // string -> double + status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kDouble, + [](const StringValue& s) -> Value { + double result; + if (absl::SimpleAtod(s.ToString(), &result)) { + return DoubleValue(result); + } else { + return ErrorValue(absl::InvalidArgumentError( + "cannot convert string to double")); + } + }, + registry); + CEL_RETURN_IF_ERROR(status); + + // uint -> double + return UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kDouble, [](uint64_t v) { return static_cast(v); }, + registry); +} + +Value CreateDurationFromString(const StringValue& dur_str) { + absl::Duration d; + if (!absl::ParseDuration(dur_str.ToString(), &d)) { + return ErrorValue( + absl::InvalidArgumentError("String to Duration conversion failed")); + } + + auto status = internal::ValidateDuration(d); + if (!status.ok()) { + return ErrorValue(std::move(status)); + } + return DurationValue(d); +} + +absl::Status RegisterTimeConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + // duration() conversion from string. + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kDuration, CreateDurationFromString, registry))); + + bool enable_timestamp_duration_overflow_errors = + options.enable_timestamp_duration_overflow_errors; + + // timestamp conversion from int. + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kTimestamp, + [=](int64_t epoch_seconds) -> Value { + absl::Time ts = absl::FromUnixSeconds(epoch_seconds); + if (enable_timestamp_duration_overflow_errors) { + if (ts < MinTimestamp() || ts > MaxTimestamp()) { + return ErrorValue(absl::OutOfRangeError("timestamp overflow")); + } + } + return UnsafeTimestampValue(ts); + }, + registry))); + + // timestamp -> timestamp + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kTimestamp, + [](absl::Time value) -> Value { return TimestampValue(value); }, + registry))); + + // duration -> duration + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kDuration, + [](absl::Duration value) -> Value { return DurationValue(value); }, + registry))); + + // timestamp() conversion from string. + return UnaryFunctionAdapter:: + RegisterGlobalOverload( + cel::builtin::kTimestamp, + [=](const StringValue& time_str) -> Value { + absl::Time ts; + if (!absl::ParseTime(absl::RFC3339_full, time_str.ToString(), &ts, + nullptr)) { + return ErrorValue(absl::InvalidArgumentError( + "String to Timestamp conversion failed")); + } + if (enable_timestamp_duration_overflow_errors) { + if (ts < MinTimestamp() || ts > MaxTimestamp()) { + return ErrorValue(absl::OutOfRangeError("timestamp overflow")); + } + } + return UnsafeTimestampValue(ts); + }, + registry); +} + +} // namespace + +absl::Status RegisterTypeConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(RegisterBoolConversionFunctions(registry, options)); + + CEL_RETURN_IF_ERROR(RegisterBytesConversionFunctions(registry, options)); + + CEL_RETURN_IF_ERROR(RegisterDoubleConversionFunctions(registry, options)); + + CEL_RETURN_IF_ERROR(RegisterIntConversionFunctions(registry, options)); + + CEL_RETURN_IF_ERROR(RegisterStringConversionFunctions(registry, options)); + + CEL_RETURN_IF_ERROR(RegisterUintConversionFunctions(registry, options)); + + CEL_RETURN_IF_ERROR(RegisterTimeConversionFunctions(registry, options)); + + // dyn() identity function. + // TODO(issues/102): strip dyn() function references at type-check time. + absl::Status status = + UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kDyn, [](const Value& value) -> Value { return value; }, + registry); + CEL_RETURN_IF_ERROR(status); + + // type(dyn) -> type + return UnaryFunctionAdapter::RegisterGlobalOverload( + cel::builtin::kType, + [](const Value& value) { return TypeValue(value.GetRuntimeType()); }, + registry); +} + +} // namespace cel diff --git a/runtime/standard/type_conversion_functions.h b/runtime/standard/type_conversion_functions.h new file mode 100644 index 000000000..77b07e4dc --- /dev/null +++ b/runtime/standard/type_conversion_functions.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TYPE_CONVERSION_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TYPE_CONVERSION_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register builtin type conversion functions: +// dyn, int, uint, double, timestamp, duration, string, bytes, type +// +// Most users should use RegisterBuiltinFunctions, which includes these +// definitions. +absl::Status RegisterTypeConversionFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TYPE_CONVERSION_FUNCTIONS_H_ diff --git a/runtime/standard/type_conversion_functions_test.cc b/runtime/standard/type_conversion_functions_test.cc new file mode 100644 index 000000000..ece8d454f --- /dev/null +++ b/runtime/standard/type_conversion_functions_test.cc @@ -0,0 +1,183 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/type_conversion_functions.h" + +#include + +#include "base/builtins.h" +#include "common/function_descriptor.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +MATCHER_P3(MatchesUnaryDescriptor, name, receiver, expected_kind, "") { + const FunctionDescriptor& descriptor = arg.descriptor; + std::vector types{expected_kind}; + return descriptor.name() == name && descriptor.receiver_style() == receiver && + descriptor.types() == types; +} + +TEST(RegisterTypeConversionFunctions, RegisterBoolConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kBool, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kBool, false, Kind::kBool), + MatchesUnaryDescriptor(builtin::kBool, false, Kind::kString))); +} + +TEST(RegisterTypeConversionFunctions, RegisterIntConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kInt, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kInt, false, Kind::kInt), + MatchesUnaryDescriptor(builtin::kInt, false, Kind::kDouble), + MatchesUnaryDescriptor(builtin::kInt, false, Kind::kUint), + MatchesUnaryDescriptor(builtin::kInt, false, Kind::kBool), + MatchesUnaryDescriptor(builtin::kInt, false, Kind::kString), + MatchesUnaryDescriptor(builtin::kInt, false, Kind::kTimestamp))); +} + +TEST(RegisterTypeConversionFunctions, RegisterUintConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kUint, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kUint, false, Kind::kInt), + MatchesUnaryDescriptor(builtin::kUint, false, Kind::kDouble), + MatchesUnaryDescriptor(builtin::kUint, false, Kind::kUint), + MatchesUnaryDescriptor(builtin::kUint, false, Kind::kString))); +} + +TEST(RegisterTypeConversionFunctions, RegisterDoubleConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kDouble, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kInt), + MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kDouble), + MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kUint), + MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kString))); +} + +TEST(RegisterTypeConversionFunctions, RegisterStringConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + options.enable_string_conversion = true; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kString, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kString, false, Kind::kBool), + MatchesUnaryDescriptor(builtin::kString, false, Kind::kInt), + MatchesUnaryDescriptor(builtin::kString, false, Kind::kDouble), + MatchesUnaryDescriptor(builtin::kString, false, Kind::kUint), + MatchesUnaryDescriptor(builtin::kString, false, Kind::kString), + MatchesUnaryDescriptor(builtin::kString, false, Kind::kBytes), + MatchesUnaryDescriptor(builtin::kString, false, Kind::kDuration), + MatchesUnaryDescriptor(builtin::kString, false, Kind::kTimestamp))); +} + +TEST(RegisterTypeConversionFunctions, + RegisterStringConversionFunctionsDisabled) { + FunctionRegistry registry; + RuntimeOptions options; + options.enable_string_conversion = false; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kString, false, {Kind::kAny}), + IsEmpty()); +} + +TEST(RegisterTypeConversionFunctions, RegisterBytesConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kBytes, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kBytes, false, Kind::kBytes), + MatchesUnaryDescriptor(builtin::kBytes, false, Kind::kString))); +} + +TEST(RegisterTypeConversionFunctions, RegisterTimeConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kTimestamp, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kTimestamp, false, Kind::kInt), + MatchesUnaryDescriptor(builtin::kTimestamp, false, Kind::kString), + MatchesUnaryDescriptor(builtin::kTimestamp, false, + Kind::kTimestamp))); + + EXPECT_THAT( + registry.FindStaticOverloads(builtin::kDuration, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kDuration, false, Kind::kString), + MatchesUnaryDescriptor(builtin::kDuration, false, Kind::kDuration))); +} + +TEST(RegisterTypeConversionFunctions, RegisterMetaTypeConversionFunctions) { + FunctionRegistry registry; + RuntimeOptions options; + + ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); + + EXPECT_THAT(registry.FindStaticOverloads(builtin::kDyn, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kDyn, false, Kind::kAny))); + + EXPECT_THAT(registry.FindStaticOverloads(builtin::kType, false, {Kind::kAny}), + UnorderedElementsAre( + MatchesUnaryDescriptor(builtin::kType, false, Kind::kAny))); +} + +// TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for +// evaluator available. + +} // namespace +} // namespace cel diff --git a/runtime/standard_functions.cc b/runtime/standard_functions.cc new file mode 100644 index 000000000..320654ff6 --- /dev/null +++ b/runtime/standard_functions.cc @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard_functions.h" + +#include "absl/status/status.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/standard/arithmetic_functions.h" +#include "runtime/standard/comparison_functions.h" +#include "runtime/standard/container_functions.h" +#include "runtime/standard/container_membership_functions.h" +#include "runtime/standard/equality_functions.h" +#include "runtime/standard/logical_functions.h" +#include "runtime/standard/regex_functions.h" +#include "runtime/standard/string_functions.h" +#include "runtime/standard/time_functions.h" +#include "runtime/standard/type_conversion_functions.h" + +namespace cel { + +absl::Status RegisterStandardFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(RegisterArithmeticFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterContainerFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterContainerMembershipFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterLogicalFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterRegexFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterStringFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterTimeFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterEqualityFunctions(registry, options)); + + return RegisterTypeConversionFunctions(registry, options); +} + +} // namespace cel diff --git a/runtime/standard_functions.h b/runtime/standard_functions.h new file mode 100644 index 000000000..c01c4fb85 --- /dev/null +++ b/runtime/standard_functions.h @@ -0,0 +1,33 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register all CEL standard definitions. +// +// See +// https://github.com/google/cel-spec/blob/master/doc/langdef.md#standard-definitions +absl::Status RegisterStandardFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_FUNCTIONS_H_ diff --git a/runtime/standard_runtime_builder_factory.cc b/runtime/standard_runtime_builder_factory.cc new file mode 100644 index 000000000..65adf2f5a --- /dev/null +++ b/runtime/standard_runtime_builder_factory.cc @@ -0,0 +1,55 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard_runtime_builder_factory.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "internal/noop_delete.h" +#include "internal/status_macros.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr CreateStandardRuntimeBuilder( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const RuntimeOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + return CreateStandardRuntimeBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + options); +} + +absl::StatusOr CreateStandardRuntimeBuilder( + absl_nonnull std::shared_ptr descriptor_pool, + const RuntimeOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + CEL_ASSIGN_OR_RETURN( + auto builder, CreateRuntimeBuilder(std::move(descriptor_pool), options)); + CEL_RETURN_IF_ERROR( + RegisterStandardFunctions(builder.function_registry(), options)); + return builder; +} + +} // namespace cel diff --git a/runtime/standard_runtime_builder_factory.h b/runtime/standard_runtime_builder_factory.h new file mode 100644 index 000000000..b20423e5e --- /dev/null +++ b/runtime/standard_runtime_builder_factory.h @@ -0,0 +1,43 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Create a builder preconfigured with CEL standard definitions. +// +// See `CreateRuntimeBuilder` for a description of the requirements related to +// `descriptor_pool`. +absl::StatusOr CreateStandardRuntimeBuilder( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const RuntimeOptions& options); +absl::StatusOr CreateStandardRuntimeBuilder( + absl_nonnull std::shared_ptr descriptor_pool, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ diff --git a/runtime/standard_runtime_builder_factory_test.cc b/runtime/standard_runtime_builder_factory_test.cc new file mode 100644 index 000000000..029897233 --- /dev/null +++ b/runtime/standard_runtime_builder_factory_test.cc @@ -0,0 +1,872 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard_runtime_builder_factory.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "common/source.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "extensions/bindings_ext.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/testing.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" +#include "parser/standard_macros.h" +#include "runtime/activation.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime.h" +#include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::extensions::ProtobufRuntimeAdapter; +using ::cel::test::BoolValueIs; +using ::cel::test::IntValueIs; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::Truly; + +const cel::MacroRegistry& GetMacros() { + static absl::NoDestructor macros([]() { + MacroRegistry registry; + ABSL_CHECK_OK(cel::RegisterStandardMacros(registry, {})); + for (const auto& macro : extensions::bindings_macros()) { + ABSL_CHECK_OK(registry.RegisterMacro(macro)); + } + return registry; + }()); + return *macros; +} + +absl::StatusOr ParseWithTestMacros(absl::string_view expression) { + auto src = cel::NewSource(expression, ""); + ABSL_CHECK_OK(src.status()); + return Parse(**src, GetMacros()); +} + +TEST(StandardRuntimeTest, RecursionLimitExceeded) { + RuntimeOptions opts; + opts.max_recursion_depth = 1; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 1 exceeded"))); +} + +TEST(StandardRuntimeTest, RecursionUnderLimit) { + RuntimeOptions opts; + opts.max_recursion_depth = 2; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + // Whether the implementation is recursive shouldn't affect observable + // behavior, but it does have performance implications (it will skip + // allocating a value stack). + EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_THAT(result, IntValueIs(3)); +} + +TEST(StandardRuntimeTest, RecursionLimitTracksLazyExpressions) { + RuntimeOptions opts; + opts.max_recursion_depth = 8; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(R"cel( + cel.bind(a, 4 + (3 + (2 + 1)), + cel.bind(b, 7 + (6 + (5 + a)), + 9 + (8 + b) + ) + ))cel")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 8 exceeded"))); +} + +struct EvaluateResultTestCase { + std::string name; + std::string expression; + bool expected_result; + std::function activation_builder; + + template + friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { + sink.Append(tc.name); + } +}; + +class StandardRuntimeTest : public TestWithParam { + public: + const EvaluateResultTestCase& GetTestCase() { return GetParam(); } +}; + +TEST_P(StandardRuntimeTest, Defaults) { + RuntimeOptions opts; + const EvaluateResultTestCase& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + ParseWithTestMacros(test_case.expression)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + EXPECT_FALSE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + if (test_case.activation_builder != nullptr) { + ASSERT_THAT(test_case.activation_builder(activation), IsOk()); + } + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) + << test_case.expression; +} + +TEST_P(StandardRuntimeTest, Recursive) { + RuntimeOptions opts; + opts.max_recursion_depth = -1; + const EvaluateResultTestCase& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + ParseWithTestMacros(test_case.expression)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + // Whether the implementation is recursive shouldn't affect observable + // behavior, but it does have performance implications (it will skip + // allocating a value stack). + EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + if (test_case.activation_builder != nullptr) { + ASSERT_THAT(test_case.activation_builder(activation), IsOk()); + } + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) + << test_case.expression; +} + +TEST_P(StandardRuntimeTest, FastBuiltins) { + RuntimeOptions opts; + opts.enable_fast_builtins = true; + const EvaluateResultTestCase& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + ParseWithTestMacros(test_case.expression)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + EXPECT_FALSE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + if (test_case.activation_builder != nullptr) { + ASSERT_THAT(test_case.activation_builder(activation), IsOk()); + } + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) + << test_case.expression; +} + +TEST_P(StandardRuntimeTest, RecursiveFastBuiltins) { + RuntimeOptions opts; + opts.enable_fast_builtins = true; + opts.max_recursion_depth = -1; + const EvaluateResultTestCase& test_case = GetTestCase(); + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + ParseWithTestMacros(test_case.expression)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + // Whether the implementation is recursive shouldn't affect observable + // behavior, but it does have performance implications (it will skip + // allocating a value stack). + EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + if (test_case.activation_builder != nullptr) { + ASSERT_THAT(test_case.activation_builder(activation), IsOk()); + } + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) + << test_case.expression; +} + +INSTANTIATE_TEST_SUITE_P( + Basic, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"int_identifier", "int_var == 42", true, + [](Activation& activation) { + activation.InsertOrAssignValue("int_var", cel::IntValue(42)); + return absl::OkStatus(); + }}, + {"logic_and_true", "true && 1 < 2", true}, + {"logic_and_false", "true && 1 > 2", false}, + {"logic_or_true", "false || 1 < 2", true}, + {"logic_or_false", "false && 1 > 2", false}, + {"ternary_true_cond", "(1 < 2 ? 'yes' : 'no') == 'yes'", true}, + {"ternary_false_cond", "(1 > 2 ? 'yes' : 'no') == 'no'", true}, + {"list_index", "['a', 'b', 'c', 'd'][1] == 'b'", true}, + {"map_index_bool", "{true: 1, false: 2}[false] == 2", true}, + {"map_index_string", "{'abc': 123}['abc'] == 123", true}, + {"map_index_int", "{1: 2, 2: 4}[2] == 4", true}, + {"map_index_uint", "{1u: 1, 2u: 2}[1u] == 1", true}, + {"map_index_coerced_double", "{1: 2, 2: 4}[2.0] == 4", true}, + })); + +INSTANTIATE_TEST_SUITE_P( + Equality, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"eq_bool_bool_true", "false == false", true}, + {"eq_bool_bool_false", "false == true", false}, + {"eq_int_int_true", "-1 == -1", true}, + {"eq_int_int_false", "-1 == 1", false}, + {"eq_uint_uint_true", "2u == 2u", true}, + {"eq_uint_uint_false", "2u == 3u", false}, + {"eq_double_double_true", "2.4 == 2.4", true}, + {"eq_double_double_false", "2.4 == 3.3", false}, + {"eq_string_string_true", "'abc' == 'abc'", true}, + {"eq_string_string_false", "'abc' == 'def'", false}, + {"eq_bytes_bytes_true", "b'abc' == b'abc'", true}, + {"eq_bytes_bytes_false", "b'abc' == b'def'", false}, + {"eq_duration_duration_true", "duration('15m') == duration('15m')", + true}, + {"eq_duration_duration_false", "duration('15m') == duration('1h')", + false}, + {"eq_timestamp_timestamp_true", + "timestamp('1970-01-01T00:02:00Z') == " + "timestamp('1970-01-01T00:02:00Z')", + true}, + {"eq_timestamp_timestamp_false", + "timestamp('1970-01-01T00:02:00Z') == " + "timestamp('2020-01-01T00:02:00Z')", + false}, + {"eq_null_null_true", "null == null", true}, + {"eq_list_list_true", "[1, 2, 3] == [1, 2, 3]", true}, + {"eq_list_list_false", "[1, 2, 3] == [1, 2, 3, 4]", false}, + {"eq_map_map_true", "{1: 2, 2: 4} == {1: 2, 2: 4}", true}, + {"eq_map_map_false", "{1: 2, 2: 4} == {1: 2, 2: 5}", false}, + + {"neq_bool_bool_true", "false != false", false}, + {"neq_bool_bool_false", "false != true", true}, + {"neq_int_int_true", "-1 != -1", false}, + {"neq_int_int_false", "-1 != 1", true}, + {"neq_uint_uint_true", "2u != 2u", false}, + {"neq_uint_uint_false", "2u != 3u", true}, + {"neq_double_double_true", "2.4 != 2.4", false}, + {"neq_double_double_false", "2.4 != 3.3", true}, + {"neq_string_string_true", "'abc' != 'abc'", false}, + {"neq_string_string_false", "'abc' != 'def'", true}, + {"neq_bytes_bytes_true", "b'abc' != b'abc'", false}, + {"neq_bytes_bytes_false", "b'abc' != b'def'", true}, + {"neq_duration_duration_true", "duration('15m') != duration('15m')", + false}, + {"neq_duration_duration_false", "duration('15m') != duration('1h')", + true}, + {"neq_timestamp_timestamp_true", + "timestamp('1970-01-01T00:02:00Z') != " + "timestamp('1970-01-01T00:02:00Z')", + false}, + {"neq_timestamp_timestamp_false", + "timestamp('1970-01-01T00:02:00Z') != " + "timestamp('2020-01-01T00:02:00Z')", + true}, + {"neq_null_null_true", "null != null", false}, + {"neq_list_list_true", "[1, 2, 3] != [1, 2, 3]", false}, + {"neq_list_list_false", "[1, 2, 3] != [1, 2, 3, 4]", true}, + {"neq_map_map_true", "{1: 2, 2: 4} != {1: 2, 2: 4}", false}, + {"neq_map_map_false", "{1: 2, 2: 4} != {1: 2, 2: 5}", true}})); + +INSTANTIATE_TEST_SUITE_P( + ArithmeticFunctions, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"lt_int_int_true", "-1 < 2", true}, + {"lt_int_int_false", "2 < -1", false}, + {"lt_double_double_true", "-1.1 < 2.2", true}, + {"lt_double_double_false", "2.2 < -1.1", false}, + {"lt_uint_uint_true", "1u < 2u", true}, + {"lt_uint_uint_false", "2u < 1u", false}, + {"lt_string_string_true", "'abc' < 'def'", true}, + {"lt_string_string_false", "'def' < 'abc'", false}, + {"lt_duration_duration_true", "duration('1s') < duration('2s')", true}, + {"lt_duration_duration_false", "duration('2s') < duration('1s')", + false}, + {"lt_timestamp_timestamp_true", "timestamp(1) < timestamp(2)", true}, + {"lt_timestamp_timestamp_false", "timestamp(2) < timestamp(1)", false}, + + {"gt_int_int_false", "-1 > 2", false}, + {"gt_int_int_true", "2 > -1", true}, + {"gt_double_double_false", "-1.1 > 2.2", false}, + {"gt_double_double_true", "2.2 > -1.1", true}, + {"gt_uint_uint_false", "1u > 2u", false}, + {"gt_uint_uint_true", "2u > 1u", true}, + {"gt_string_string_false", "'abc' > 'def'", false}, + {"gt_string_string_true", "'def' > 'abc'", true}, + {"gt_duration_duration_false", "duration('1s') > duration('2s')", + false}, + {"gt_duration_duration_true", "duration('2s') > duration('1s')", true}, + {"gt_timestamp_timestamp_false", "timestamp(1) > timestamp(2)", false}, + {"gt_timestamp_timestamp_true", "timestamp(2) > timestamp(1)", true}, + + {"le_int_int_true", "-1 <= -1", true}, + {"le_int_int_false", "2 <= -1", false}, + {"le_double_double_true", "-1.1 <= -1.1", true}, + {"le_double_double_false", "2.2 <= -1.1", false}, + {"le_uint_uint_true", "1u <= 1u", true}, + {"le_uint_uint_false", "2u <= 1u", false}, + {"le_string_string_true", "'abc' <= 'abc'", true}, + {"le_string_string_false", "'def' <= 'abc'", false}, + {"le_duration_duration_true", "duration('1s') <= duration('1s')", true}, + {"le_duration_duration_false", "duration('2s') <= duration('1s')", + false}, + {"le_timestamp_timestamp_true", "timestamp(1) <= timestamp(1)", true}, + {"le_timestamp_timestamp_false", "timestamp(2) <= timestamp(1)", false}, + + {"ge_int_int_false", "-1 >= 2", false}, + {"ge_int_int_true", "2 >= 2", true}, + {"ge_double_double_false", "-1.1 >= 2.2", false}, + {"ge_double_double_true", "2.2 >= 2.2", true}, + {"ge_uint_uint_false", "1u >= 2u", false}, + {"ge_uint_uint_true", "2u >= 2u", true}, + {"ge_string_string_false", "'abc' >= 'def'", false}, + {"ge_string_string_true", "'abc' >= 'abc'", true}, + {"ge_duration_duration_false", "duration('1s') >= duration('2s')", + false}, + {"ge_duration_duration_true", "duration('1s') >= duration('1s')", true}, + {"ge_timestamp_timestamp_false", "timestamp(1) >= timestamp(2)", false}, + {"ge_timestamp_timestamp_true", "timestamp(1) >= timestamp(1)", true}, + + {"sum_int_int", "1 + 2 == 3", true}, + {"sum_uint_uint", "3u + 4u == 7", true}, + {"sum_double_double", "1.0 + 2.5 == 3.5", true}, + {"sum_duration_duration", + "duration('2m') + duration('30s') == duration('150s')", true}, + {"sum_time_duration", + "timestamp(0) + duration('2m') == " + "timestamp('1970-01-01T00:02:00Z')", + true}, + + {"difference_int_int", "1 - 2 == -1", true}, + {"difference_uint_uint", "4u - 3u == 1u", true}, + {"difference_double_double", "1.0 - 2.5 == -1.5", true}, + {"difference_duration_duration", + "duration('5m') - duration('45s') == duration('4m15s')", true}, + {"difference_time_time", + "timestamp(10) - timestamp(0) == duration('10s')", true}, + {"difference_time_duration", + "timestamp(0) - duration('2m') == " + "timestamp('1969-12-31T23:58:00Z')", + true}, + + {"multiplication_int_int", "2 * 3 == 6", true}, + {"multiplication_uint_uint", "2u * 3u == 6u", true}, + {"multiplication_double_double", "2.5 * 3.0 == 7.5", true}, + + {"division_int_int", "6 / 3 == 2", true}, + {"division_uint_uint", "8u / 4u == 2u", true}, + {"division_double_double", "1.0 / 0.0 == double('inf')", true}, + + {"modulo_int_int", "6 % 4 == 2", true}, + {"modulo_uint_uint", "8u % 5u == 3u", true}, + })); + +INSTANTIATE_TEST_SUITE_P( + Macros, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"map", "[1, 2, 3, 4].map(x, x * x)[3] == 16", true}, + {"filter", "[1, 2, 3, 4].filter(x, x < 4).size() == 3", true}, + {"exists", "[1, 2, 3, 4].exists(x, x < 4)", true}, + {"all", "[1, 2, 3, 4].all(x, x < 5)", true}})); + +INSTANTIATE_TEST_SUITE_P( + StringFunctions, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"string_contains", "'tacocat'.contains('acoca')", true}, + {"string_contains_global", "contains('tacocat', 'dog')", false}, + {"string_ends_with", "'abcdefg'.endsWith('efg')", true}, + {"string_ends_with_global", "endsWith('abcdefg', 'fgh')", false}, + {"string_starts_with", "'abcdefg'.startsWith('abc')", true}, + {"string_starts_with_global", "startsWith('abcd', 'bcd')", false}, + {"string_size", "'Hello World! 😀'.size() == 14", true}, + {"string_size_global", "size('Hello world!') == 12", true}, + {"bytes_size", "b'0123'.size() == 4", true}, + {"bytes_size_global", "size(b'😀') == 4", true}})); + +INSTANTIATE_TEST_SUITE_P( + RegExFunctions, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"matches_string_re", + "'127.0.0.1'.matches(r'127\\.\\d+\\.\\d+\\.\\d+')", true}, + {"matches_string_re_global", + "matches('192.168.0.1', r'127\\.\\d+\\.\\d+\\.\\d+')", false}})); + +INSTANTIATE_TEST_SUITE_P( + TimeFunctions, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"timestamp_get_full_year", + "timestamp('2001-02-03T04:05:06.007Z').getFullYear() == 2001", true}, + {"timestamp_get_date", + "timestamp('2001-02-03T04:05:06.007Z').getDate() == 3", true}, + {"timestamp_get_hours", + "timestamp('2001-02-03T04:05:06.007Z').getHours() == 4", true}, + {"timestamp_get_minutes", + "timestamp('2001-02-03T04:05:06.007Z').getMinutes() == 5", true}, + {"timestamp_get_seconds", + "timestamp('2001-02-03T04:05:06.007Z').getSeconds() == 6", true}, + {"timestamp_get_milliseconds", + "timestamp('2001-02-03T04:05:06.007Z').getMilliseconds() == 7", true}, + // Zero based indexing + {"timestamp_get_month", + "timestamp('2001-02-03T04:05:06.007Z').getMonth() == 1", true}, + {"timestamp_get_day_of_year", + "timestamp('2001-02-03T04:05:06.007Z').getDayOfYear() == 33", true}, + {"timestamp_get_day_of_month", + "timestamp('2001-02-03T04:05:06.007Z').getDayOfMonth() == 2", true}, + {"timestamp_get_day_of_week", + "timestamp('2001-02-03T04:05:06.007Z').getDayOfWeek() == 6", true}, + {"duration_get_hours", "duration('10h20m30s40ms').getHours() == 10", + true}, + {"duration_get_minutes", + "duration('10h20m30s40ms').getMinutes() == 20 + 600", true}, + {"duration_get_seconds", + "duration('10h20m30s40ms').getSeconds() == 30 + 20 * 60 + 10 * 60 " + "* " + "60", + true}, + {"duration_get_milliseconds", + "duration('10h20m30s40ms').getMilliseconds() == 40", true}, + })); + +INSTANTIATE_TEST_SUITE_P( + TypeConversionFunctions, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + {"string_timestamp", "string(timestamp(1)) == '1970-01-01T00:00:01Z'", + true}, + {"string_duration", "string(duration('10m30s')) == '630s'", true}, + {"string_int", "string(-1) == '-1'", true}, + {"string_uint", "string(1u) == '1'", true}, + {"string_double", "string(double('inf')) == 'inf'", true}, + {"string_double_nan", "string(double('nan')) == 'nan'", true}, + {"string_bytes", R"(string(b'\xF0\x9F\x98\x80') == '😀')", true}, + {"string_string", "string('hello!') == 'hello!'", true}, + {"bytes_bytes", "bytes(b'123') == b'123'", true}, + {"bytes_string", "bytes('😀') == b'\xF0\x9F\x98\x80'", true}, + {"timestamp", "timestamp(1) == timestamp('1970-01-01T00:00:01Z')", + true}, + {"duration", "duration('10h') == duration('600m')", true}, + {"double_string", "double('1.0') == 1.0", true}, + {"double_string_precision", + "double('0.14285714285714285') == 1.0 / 7.0", true}, + {"double_string_nan", "double('nan') != double('nan')", true}, + {"double_int", "double(1) == 1.0", true}, + {"double_uint", "double(1u) == 1.0", true}, + {"double_double", "double(1.0) == 1.0", true}, + {"uint_string", "uint('1') == 1u", true}, + {"uint_int", "uint(1) == 1u", true}, + {"uint_uint", "uint(1u) == 1u", true}, + {"uint_double", "uint(1.1) == 1u", true}, + {"int_string", "int('-1') == -1", true}, + {"int_int", "int(-1) == -1", true}, + {"int_uint", "int(1u) == 1", true}, + {"int_double", "int(-1.1) == -1", true}, + {"int_timestamp", "int(timestamp('1969-12-31T23:30:00Z')) == -1800", + true}, + })); + +INSTANTIATE_TEST_SUITE_P( + ContainerFunctions, StandardRuntimeTest, + testing::ValuesIn(std::vector{ + // Containers + {"map_size", "{'abc': 1, 'def': 2}.size() == 2", true}, + {"map_in", "'abc' in {'abc': 1, 'def': 2}", true}, + {"map_in_numeric", "1.0 in {1u: 1, 2u: 2}", true}, + {"list_size", "[1, 2, 3, 4].size() == 4", true}, + {"list_size_global", "size([1, 2, 3]) == 3", true}, + {"list_concat", "[1, 2] + [3, 4] == [1, 2, 3, 4]", true}, + {"list_in", "'a' in ['a', 'b', 'c', 'd']", true}, + {"list_in_numeric", "3u in [1.1, 2.3, 3.0, 4.4]", true}})); + +TEST(StandardRuntimeTest, RuntimeIssueSupport) { + RuntimeOptions options; + options.fail_on_warnings = false; + + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + ParseWithTestMacros("unregistered_function(1)")); + + std::vector issues; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr, {&issues})); + + EXPECT_THAT(issues, ElementsAre(Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kWarning && + issue.error_code() == + RuntimeIssue::ErrorCode::kNoMatchingOverload; + }))); + } + + { + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + ParseWithTestMacros( + "unregistered_function(1) || unregistered_function(2)")); + + std::vector issues; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr, {&issues})); + + EXPECT_THAT( + issues, + ElementsAre( + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kWarning && + issue.error_code() == + RuntimeIssue::ErrorCode::kNoMatchingOverload; + }), + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kWarning && + issue.error_code() == + RuntimeIssue::ErrorCode::kNoMatchingOverload; + }))); + } + + { + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + ParseWithTestMacros( + "unregistered_function(1) || unregistered_function(2) || true")); + + std::vector issues; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr, {&issues})); + + EXPECT_THAT( + issues, + ElementsAre( + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kWarning && + issue.error_code() == + RuntimeIssue::ErrorCode::kNoMatchingOverload; + }), + Truly([](const RuntimeIssue& issue) { + return issue.severity() == RuntimeIssue::Severity::kWarning && + issue.error_code() == + RuntimeIssue::ErrorCode::kNoMatchingOverload; + }))); + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(&arena, activation)); + EXPECT_TRUE(result->Is() && result.GetBool().NativeValue()); + } +} + +enum class EvalStrategy { kIterative, kRecursive }; + +class StandardRuntimeEvalStrategyTest + : public ::testing::TestWithParam {}; + +// Check that calls to specialized builtins are validated. +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinBoolOp) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kOr); + auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); + arg->mutable_const_expr()->set_bool_value(true); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinTernaryOp) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function( + cel::builtin::kTernary); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinIndex) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kIndex); + auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); + arg->mutable_list_expr() + ->add_elements() + ->mutable_const_expr() + ->set_int64_value(1); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinEq) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kEqual); + auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); + arg->mutable_list_expr() + ->add_elements() + ->mutable_const_expr() + ->set_int64_value(1); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinIn) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kIn); + auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); + arg->mutable_list_expr() + ->add_elements() + ->mutable_const_expr() + ->set_int64_value(1); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(StandardRuntimeEvalStrategyTest, PrecisionPreservingDoubleFormat) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + options.enable_precision_preserving_double_format = true; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + // Note: the string format isn't guaranteed to be shortest since we don't have + // to_chars support on all compilers, but it should still be reversible. + const absl::string_view kCases[] = {"double(string(1.0/7.0)) == 1.0/7.0", + "double(string(0.45)) == 0.45"}; + + google::protobuf::Arena arena; + Activation activation; + + for (const auto& test_case : kCases) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(test_case)); + ASSERT_OK_AND_ASSIGN(auto program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(&arena, activation)); + EXPECT_TRUE(result->Is() && result.GetBool().NativeValue()); + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardRuntimeEvalStrategyTest, StandardRuntimeEvalStrategyTest, + testing::Values(EvalStrategy::kIterative, EvalStrategy::kRecursive), + [](const auto& info) -> std::string { + return info.param == EvalStrategy::kIterative ? "Iterative" : "Recursive"; + }); + +} // namespace +} // namespace cel diff --git a/runtime/type_registry.cc b/runtime/type_registry.cc new file mode 100644 index 000000000..a1e8b0328 --- /dev/null +++ b/runtime/type_registry.cc @@ -0,0 +1,84 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/type_registry.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "common/value.h" +#include "runtime/internal/legacy_runtime_type_provider.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +TypeRegistry::TypeRegistry( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nullable message_factory) + : type_provider_(descriptor_pool), + legacy_type_provider_( + std::make_shared( + descriptor_pool, message_factory)) { + RegisterEnum("google.protobuf.NullValue", {{"NULL_VALUE", 0}}); +} + +void TypeRegistry::RegisterEnum(absl::string_view enum_name, + std::vector enumerators) { + { + absl::MutexLock lock(enum_value_table_mutex_); + enum_value_table_.reset(); + } + enum_types_[enum_name] = + Enumeration{std::string(enum_name), std::move(enumerators)}; +} + +std::shared_ptr> +TypeRegistry::GetEnumValueTable() const { + { + absl::ReaderMutexLock lock(enum_value_table_mutex_); + if (enum_value_table_ != nullptr) { + return enum_value_table_; + } + } + + absl::MutexLock lock(enum_value_table_mutex_); + if (enum_value_table_ != nullptr) { + return enum_value_table_; + } + std::shared_ptr> result = + std::make_shared>(); + + auto& enum_value_map = *result; + for (auto iter = enum_types_.begin(); iter != enum_types_.end(); ++iter) { + absl::string_view enum_name = iter->first; + const auto& enum_type = iter->second; + for (const auto& enumerator : enum_type.enumerators) { + auto key = absl::StrCat(enum_name, ".", enumerator.name); + enum_value_map[key] = cel::IntValue(enumerator.number); + } + } + + enum_value_table_ = result; + + return result; +} +} // namespace cel diff --git a/runtime/type_registry.h b/runtime/type_registry.h new file mode 100644 index 000000000..eadd1f1ea --- /dev/null +++ b/runtime/type_registry.h @@ -0,0 +1,155 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_TYPE_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_TYPE_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "base/type_provider.h" +#include "common/type.h" +#include "common/value.h" +#include "runtime/internal/legacy_runtime_type_provider.h" +#include "runtime/internal/runtime_type_provider.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel { + +class TypeRegistry; + +namespace runtime_internal { +const RuntimeTypeProvider& GetRuntimeTypeProvider( + const TypeRegistry& type_registry); +const absl_nonnull std::shared_ptr& +GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry); + +// Returns a memoized table of fully qualified enum values. +// +// This is populated when first requested. +std::shared_ptr> +GetEnumValueTable(const TypeRegistry& type_registry); +} // namespace runtime_internal + +// TypeRegistry manages composing TypeProviders used with a Runtime. +// +// It provides a single effective type provider to be used in a ValueManager. +class TypeRegistry { + public: + // Representation for a custom enum constant. + struct Enumerator { + std::string name; + int64_t number; + }; + + struct Enumeration { + std::string name; + std::vector enumerators; + }; + + TypeRegistry() + : TypeRegistry(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()) {} + + TypeRegistry(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nullable message_factory); + + // Neither moveable nor copyable. + TypeRegistry(const TypeRegistry& other) = delete; + TypeRegistry& operator=(TypeRegistry& other) = delete; + TypeRegistry(TypeRegistry&& other) = delete; + TypeRegistry& operator=(TypeRegistry&& other) = delete; + + // Registers a type such that it can be accessed by name, i.e. `type(foo) == + // my_type`. Where `my_type` is the type being registered. + absl::Status RegisterType(const OpaqueType& type) { + return type_provider_.RegisterType(type); + } + + // Register a custom enum type. + // + // This adds the enum to the set consulted at plan time to identify constant + // enum values. + void RegisterEnum(absl::string_view enum_name, + std::vector enumerators); + + const absl::flat_hash_map& resolveable_enums() + const { + return enum_types_; + } + + // Returns the effective type provider. + const TypeProvider& GetComposedTypeProvider() const { return type_provider_; } + + private: + friend const runtime_internal::RuntimeTypeProvider& + runtime_internal::GetRuntimeTypeProvider(const TypeRegistry& type_registry); + friend const + absl_nonnull std::shared_ptr& + runtime_internal::GetLegacyRuntimeTypeProvider( + const TypeRegistry& type_registry); + + friend std::shared_ptr> + runtime_internal::GetEnumValueTable(const TypeRegistry& type_registry); + + std::shared_ptr> + GetEnumValueTable() const; + + runtime_internal::RuntimeTypeProvider type_provider_; + absl_nonnull std::shared_ptr + legacy_type_provider_; + absl::flat_hash_map enum_types_; + + // memoized fully qualified enumerator names. + // + // populated when requested. + // + // In almost all cases, this is built once and never updated, but we can't + // guarantee that with the current CelExpressionBuilder API. + // + // The cases when invalidation may occur are likely already race conditions, + // but we provide basic thread safety to avoid issues with sanitizers. + mutable std::shared_ptr> + enum_value_table_ ABSL_GUARDED_BY(enum_value_table_mutex_); + mutable absl::Mutex enum_value_table_mutex_; +}; + +namespace runtime_internal { +inline const RuntimeTypeProvider& GetRuntimeTypeProvider( + const TypeRegistry& type_registry) { + return type_registry.type_provider_; +} +inline const absl_nonnull std::shared_ptr& +GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry) { + return type_registry.legacy_type_provider_; +} +inline std::shared_ptr> +GetEnumValueTable(const TypeRegistry& type_registry) { + return type_registry.GetEnumValueTable(); +} + +} // namespace runtime_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_TYPE_REGISTRY_H_ diff --git a/testing/testrunner/BUILD b/testing/testrunner/BUILD new file mode 100644 index 000000000..b80167487 --- /dev/null +++ b/testing/testrunner/BUILD @@ -0,0 +1,224 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package( + default_testonly = True, + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "cel_test_context", + hdrs = ["cel_test_context.h"], + deps = [ + ":cel_expression_source", + "//common:value", + "//compiler", + "//eval/public:cel_expression", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runner_lib", + srcs = ["runner_lib.cc"], + hdrs = ["runner_lib.h"], + deps = [ + ":cel_expression_source", + ":cel_test_context", + ":coverage_index", + ":coverage_reporting", + "//checker:validation_result", + "//common:ast", + "//common:ast_proto", + "//common:value", + "//common/internal:value_conversion", + "//eval/public:activation", + "//eval/public:cel_expression", + "//eval/public:cel_value", + "//eval/public:transform_utility", + "//internal:status_macros", + "//internal:testing_no_main", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "cel_test_factories", + hdrs = ["cel_test_factories.h"], + deps = [ + ":cel_test_context", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + ], +) + +cc_test( + name = "runner_lib_test", + srcs = ["runner_lib_test.cc"], + args = [ + "--test_cel_file_path=$(location //testing/testrunner/resources:test.cel)", + ], + data = [ + "//testing/testrunner/resources:test.cel", + ], + deps = [ + ":cel_expression_source", + ":cel_test_context", + ":coverage_index", + ":runner_lib", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast_proto", + "//common:decl", + "//common:type", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "coverage_reporting", + srcs = ["coverage_reporting.cc"], + hdrs = ["coverage_reporting.h"], + deps = [ + ":coverage_index", + "//internal:testing_no_main", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "runner", + srcs = ["runner_bin.cc"], + deps = [ + ":cel_expression_source", + ":cel_test_context", + ":cel_test_factories", + ":coverage_index", + ":coverage_reporting", + ":runner_lib", + "//eval/public:cel_expression", + "//internal:status_macros", + "//internal:testing_no_main", + "//runtime", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cc_library( + name = "cel_expression_source", + hdrs = ["cel_expression_source.h"], + deps = ["@com_google_cel_spec//proto/cel/expr:checked_cc_proto"], +) + +cc_library( + name = "coverage_index", + srcs = ["coverage_index.cc"], + hdrs = ["coverage_index.h"], + deps = [ + "//common:ast", + "//common:value", + "//eval/compiler:cel_expression_builder_flat_impl", + "//eval/compiler:instrumentation", + "//eval/public:cel_expression", + "//internal:casts", + "//runtime", + "//runtime/internal:runtime_impl", + "//tools:cel_unparser", + "//tools:navigable_ast", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_test( + name = "coverage_index_test", + srcs = ["coverage_index_test.cc"], + deps = [ + ":coverage_index", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:ast_proto", + "//common:decl", + "//common:type", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/testing/testrunner/cel_cc_test.bzl b/testing/testrunner/cel_cc_test.bzl new file mode 100644 index 000000000..3aac134f6 --- /dev/null +++ b/testing/testrunner/cel_cc_test.bzl @@ -0,0 +1,126 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rules for triggering the cc impl of the CEL test runner.""" + +load("@bazel_skylib//lib:paths.bzl", "paths") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +expr_src_type = struct( + RAW = "raw", + FILE = "file", + CHECKED = "checked", +) + +def cel_cc_test( + name, + test_suite = "", + cel_expr = "", + is_raw_expr = False, + filegroup = "", + deps = [], + enable_coverage = False, + test_data_path = "", + data = [], + **kwargs): + """trigger the cc impl of the CEL test runner. + + This rule will generate a cc_test rule. This rule will be used to trigger + the cc impl of the cel_test rule. + + Args: + name: str name for the generated artifact + test_suite: str label of a file containing a test suite. The file should have a + .textproto extension. + cel_expr: The CEL expression source. The meaning of this argument depends on `is_raw_expr`. + is_raw_expr: bool whether the cel_expr is a raw expression string. If False, + cel_expr is treated as a file path. The file type (.cel or .textproto) + is inferred from the extension. + filegroup: str label of a filegroup containing the test suite, the config and the checked + expression. + deps: list of dependencies for the cc_test rule. + data: list of data dependencies for the cc_test rule. + enable_coverage: bool whether to enable coverage collection. + test_data_path: absolute path of the directory containing the test files. This is needed only + if the test files are not located in the same directory as the BUILD file. + **kwargs: additional arguments to pass to the cc_test rule. + """ + data, test_data_path = _update_data_with_test_files( + data, + filegroup, + test_data_path, + test_suite, + cel_expr, + is_raw_expr, + ) + args = kwargs.pop("args", []) + + test_data_path = test_data_path.lstrip("/") + + if test_suite != "": + test_suite = test_data_path + "/" + test_suite + args.append("--test_suite_path=" + test_suite) + + args.append("--collect_coverage=" + str(enable_coverage)) + + if cel_expr != "": + expr_source_type = "" + expr_source = "" + if is_raw_expr: + expr_source_type = expr_src_type.RAW + expr_source = "\"" + cel_expr + "\"" + else: + _, ext = paths.split_extension(cel_expr) + + # The C++ test runner currently only supports parsing expressions from .cel files. + # Support for other CEL source types (e.g., .celpolicy, .yaml) is not yet implemented. + if ext == ".cel": + expr_source_type = expr_src_type.FILE + expr_source = test_data_path + "/" + cel_expr + else: + expr_source_type = expr_src_type.CHECKED + expr_source = "$(location " + cel_expr + ")" + + args.append("--expr_source_type=" + expr_source_type) + args.append("--expr_source=" + expr_source) + + cc_test( + name = name, + data = data, + args = args, + deps = ["//testing/testrunner:runner"] + deps, + **kwargs + ) + +def _update_data_with_test_files(data, filegroup, test_data_path, test_suite, cel_expr, is_raw_expr): + """Updates the data with the test files.""" + + if filegroup != "": + data = data + [filegroup] + elif test_data_path != "" and test_data_path != native.package_name(): + if test_suite != "": + data = data + [test_data_path + ":" + test_suite] + if cel_expr != "" and not is_raw_expr: + _, ext = paths.split_extension(cel_expr) + if ext == ".cel": + data = data + [test_data_path + ":" + cel_expr] + else: + data = data + [cel_expr] + else: + test_data_path = native.package_name() + if test_suite != "": + data = data + [test_suite] + if cel_expr != "" and not is_raw_expr: + data = data + [cel_expr] + return data, test_data_path diff --git a/testing/testrunner/cel_expression_source.h b/testing/testrunner/cel_expression_source.h new file mode 100644 index 000000000..dfdc61c5c --- /dev/null +++ b/testing/testrunner/cel_expression_source.h @@ -0,0 +1,81 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_EXPRESSION_SOURCE_H_ +#define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_EXPRESSION_SOURCE_H_ + +#include +#include +#include + +#include "cel/expr/checked.pb.h" + +namespace cel::test { + +// A wrapper class that holds one of three possible sources for a CEL +// expression using a std::variant for type safety. +class CelExpressionSource { + public: + // Distinct wrapper types are used for string-based sources to disambiguate + // them within the std::variant. + struct RawExpression { + std::string value; + }; + + struct CelFile { + std::string path; + }; + + // The variant holds one of the three possible source types. + using SourceVariant = + std::variant; + + // Creates a CelExpressionSource from a compiled + // cel::expr::CheckedExpr. + static CelExpressionSource FromCheckedExpr( + cel::expr::CheckedExpr checked_expr) { + return CelExpressionSource(std::move(checked_expr)); + } + + // Creates a CelExpressionSource from a raw CEL expression string. + static CelExpressionSource FromRawExpression(std::string raw_expression) { + return CelExpressionSource(RawExpression{std::move(raw_expression)}); + } + + // Creates a CelExpressionSource from a file path pointing to a .cel file. + static CelExpressionSource FromCelFile(std::string cel_file_path) { + return CelExpressionSource(CelFile{std::move(cel_file_path)}); + } + + // Make copyable and movable. + CelExpressionSource(const CelExpressionSource&) = default; + CelExpressionSource& operator=(const CelExpressionSource&) = default; + CelExpressionSource(CelExpressionSource&&) = default; + CelExpressionSource& operator=(CelExpressionSource&&) = default; + + // Returns the underlying variant. The caller is expected to use std::visit + // to interact with the active value in a type-safe manner. + const SourceVariant& source() const { return source_; } + + private: + // A single private constructor enforces creation via the static factories. + explicit CelExpressionSource(SourceVariant source) + : source_(std::move(source)) {} + + // A single std::variant member efficiently stores one of the possible states. + SourceVariant source_; +}; +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_EXPRESSION_SOURCE_H_ diff --git a/testing/testrunner/cel_test_context.h b/testing/testrunner/cel_test_context.h new file mode 100644 index 000000000..0e0f21e28 --- /dev/null +++ b/testing/testrunner/cel_test_context.h @@ -0,0 +1,200 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ + +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/value.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "eval/public/cel_expression.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "testing/testrunner/cel_expression_source.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +namespace cel::test { + +// The context class for a CEL test, holding configurations needed to evaluate +// compiled CEL expressions. +class CelTestContext { + public: + using CelActivationFactoryFn = std::function( + const cel::expr::conformance::test::TestCase& test_case, + google::protobuf::Arena* arena)>; + using AssertFn = std::function; + + // Creates a CelTestContext using a `CelExpressionBuilder`. + // + // The `CelExpressionBuilder` helps in setting up the environment for + // building the CEL expression. + // + // Example usage: + // + // CEL_REGISTER_TEST_CONTEXT_FACTORY( + // []() -> absl::StatusOr> { + // // SAFE: This setup code now runs when the lambda is invoked at + // runtime, + // // long after all static initializations are complete. + // auto cel_expression_builder = + // google::api::expr::runtime::CreateCelExpressionBuilder(); + // CelTestContextOptions options; + // return CelTestContext::CreateFromCelExpressionBuilder( + // std::move(cel_expression_builder), std::move(options)); + // }); + static std::unique_ptr CreateFromCelExpressionBuilder( + std::unique_ptr + cel_expression_builder) { + return absl::WrapUnique( + new CelTestContext(std::move(cel_expression_builder))); + } + + // Creates a CelTestContext using a `cel::Runtime`. + // + // The `cel::Runtime` is used to evaluate the CEL expression by managing + // the state needed to generate Program. + static std::unique_ptr CreateFromRuntime( + std::unique_ptr runtime) { + return absl::WrapUnique(new CelTestContext(std::move(runtime))); + } + + const cel::Runtime* absl_nullable runtime() const { return runtime_.get(); } + + const google::api::expr::runtime::CelExpressionBuilder* absl_nullable + cel_expression_builder() const { + return cel_expression_builder_.get(); + } + + const cel::Compiler* absl_nullable compiler() const { + return compiler_.get(); + } + + const CelExpressionSource* absl_nullable expression_source() const { + return expression_source_.get(); + } + + const absl::flat_hash_map& + custom_bindings() const { + return custom_bindings_; + } + + bool enable_coverage() const { return enable_coverage_; } + + // Allows the runner to inject the expression source + // parsed from command-line flags. + void SetExpressionSource(CelExpressionSource source) { + expression_source_ = + std::make_unique(std::move(source)); + } + + // Allows the runner to inject an optional CEL compiler. + void SetCompiler(std::unique_ptr compiler) { + compiler_ = std::move(compiler); + } + + // Allows the runner to inject custom bindings. + void SetCustomBindings( + absl::flat_hash_map + custom_bindings) { + custom_bindings_ = std::move(custom_bindings); + } + + // Allows the runner to inject a custom activation factory. If not set, an + // empty activation will be used. Custom bindings and test case inputs will + // be added to the activation returned by the factory. + void SetActivationFactory(CelActivationFactoryFn activation_factory) { + activation_factory_ = std::move(activation_factory); + } + + // Allows the runner to enable coverage collection. + void SetEnableCoverage(bool enable) { enable_coverage_ = enable; } + + const CelActivationFactoryFn& activation_factory() const { + return activation_factory_; + } + + // Allows the runner to inject a custom assertion function. If not set, the + // default assertion logic in TestRunner will be used. + void SetAssertFn(AssertFn assert_fn) { assert_fn_ = std::move(assert_fn); } + + const AssertFn& assert_fn() const { return assert_fn_; } + + private: + // Delete copy and move constructors. + CelTestContext(const CelTestContext&) = delete; + CelTestContext& operator=(const CelTestContext&) = delete; + CelTestContext(CelTestContext&&) = delete; + CelTestContext& operator=(CelTestContext&&) = delete; + + // Make the constructors private to enforce the use of the factory methods. + explicit CelTestContext( + std::unique_ptr + cel_expression_builder) + : cel_expression_builder_(std::move(cel_expression_builder)) {} + + explicit CelTestContext(std::unique_ptr runtime) + : runtime_(std::move(runtime)) {} + + // An optional CEL compiler. This is required for test cases where + // input or output values are themselves CEL expressions that need to be + // resolved at runtime or cel expression source is raw string or cel file. + std::unique_ptr compiler_ = nullptr; + + // A map of variable names to values that provides default bindings for the + // evaluation. + // + // These bindings can be considered context-wide defaults. If a variable name + // exists in both these custom bindings and in a specific TestCase's input, + // the value from the TestCase will take precedence and override this one. + // This logic is handled by the test runner when it constructs the final + // activation. + absl::flat_hash_map custom_bindings_; + + // The source for the CEL expression to be evaluated in the test. + std::unique_ptr expression_source_; + + // This helps in setting up the environment for building the CEL + // expression. Users should either provide a runtime, or the + // CelExpressionBuilder. + std::unique_ptr + cel_expression_builder_; + + // The runtime is used to evaluate the CEL expression by managing the state + // needed to generate Program. Users should either provide a runtime, or the + // CelExpressionBuilder. + std::unique_ptr runtime_; + + CelActivationFactoryFn activation_factory_; + AssertFn assert_fn_; + + // Whether to enable coverage collection. + bool enable_coverage_ = false; +}; + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ diff --git a/testing/testrunner/cel_test_factories.h b/testing/testrunner/cel_test_factories.h new file mode 100644 index 000000000..61058be13 --- /dev/null +++ b/testing/testrunner/cel_test_factories.h @@ -0,0 +1,91 @@ +// Copyright 2025 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_TEST_FACTORIES_H_ +#define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_TEST_FACTORIES_H_ + +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "testing/testrunner/cel_test_context.h" +#include "cel/expr/conformance/test/suite.pb.h" +namespace cel::test { +namespace internal { + +using CelTestContextFactoryFn = + std::function>()>; +using CelTestSuiteFactoryFn = + std::function; + +// Returns the factory function for creating a CelTestContext. +inline CelTestContextFactoryFn& GetCelTestContextFactory() { + static absl::NoDestructor factory; + return *factory; +} + +// Sets the factory function for creating a CelTestContext. Only one factory +// function can be set. Usage details can be found in cel_test_context.h. +inline bool SetCelTestContextFactory(CelTestContextFactoryFn factory) { + ABSL_DCHECK(GetCelTestContextFactory() == nullptr) + << "CelTestContextFactory is already set."; + GetCelTestContextFactory() = std::move(factory); + return true; +} + +// Returns the factory function for creating a CelTestSuite. +inline CelTestSuiteFactoryFn& GetCelTestSuiteFactory() { + static absl::NoDestructor factory; + return *factory; +} + +// Sets the factory function for creating a CelTestSuite. Only one factory +// function can be set. +inline bool SetCelTestSuiteFactory(CelTestSuiteFactoryFn factory) { + ABSL_DCHECK(GetCelTestSuiteFactory() == nullptr) + << "CelTestSuiteFactory is already set."; + GetCelTestSuiteFactory() = std::move(factory); + return true; +} +} // namespace internal + +// Register cel test context factories from a function or lambda. +// +// The return value of `factory_fn` should be a +// `absl::StatusOr>>`. +#define CEL_REGISTER_TEST_CONTEXT_FACTORY(factory_fn) \ + namespace { \ + const bool kTestContextFactoryRegistrationResult_##__LINE__ = \ + ::cel::test::internal::SetCelTestContextFactory(factory_fn); \ + } + +// Register cel test suite factory from a function or lambda. This is used to +// provide a custom test suite to the test runner which is useful for cases +// where the test suite is dynamically generated or where the test suite needs +// to be generated from a user provided source. +// +// The return value of `factory_fn` should be a +// `::cel::expr::conformance::test::TestSuite`. +#define CEL_REGISTER_TEST_SUITE_FACTORY(factory_fn) \ + namespace { \ + const bool kTestSuiteFactoryRegistrationResult_##__LINE__ = \ + ::cel::test::internal::SetCelTestSuiteFactory(factory_fn); \ + } + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_TEST_FACTORIES_H_ diff --git a/testing/testrunner/coverage_index.cc b/testing/testrunner/coverage_index.cc new file mode 100644 index 000000000..57baff593 --- /dev/null +++ b/testing/testrunner/coverage_index.cc @@ -0,0 +1,281 @@ +// Copyright 2025 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testing/testrunner/coverage_index.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/value.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" +#include "eval/compiler/instrumentation.h" +#include "eval/public/cel_expression.h" +#include "internal/casts.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime.h" +#include "tools/cel_unparser.h" +#include "tools/navigable_ast.h" + +namespace cel::test { +namespace { + +using ::cel::expr::CheckedExpr; +using ::cel::expr::Type; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::Instrumentation; +using ::google::api::expr::runtime::InstrumentationFactory; + +std::string EscapeSpecialCharacters(absl::string_view expr_text) { + return absl::StrReplaceAll(expr_text, {{"\\\"", "\""}, + {"\"", "\\\""}, + {"\n", "\\n"}, + {"||", " \\| \\| "}, + {"<", "\\<"}, + {">", "\\>"}, + {"{", "\\{"}, + {"}", "\\}"}}); +} + +std::string KindToString(const NavigableProtoAstNode& node) { + if (node.parent_relation() != ChildKind::kUnspecified && + node.parent()->expr()->has_comprehension_expr()) { + const cel::expr::Expr::Comprehension& comp = + node.parent()->expr()->comprehension_expr(); + if (node.expr()->id() == comp.iter_range().id()) return "IterRange"; + if (node.expr()->id() == comp.accu_init().id()) return "AccuInit"; + if (node.expr()->id() == comp.loop_condition().id()) return "LoopCondition"; + if (node.expr()->id() == comp.loop_step().id()) return "LoopStep"; + if (node.expr()->id() == comp.result().id()) return "Result"; + } + + return absl::StrCat(NodeKindName(node.node_kind()), " Node"); +} + +const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, + int64_t expr_id) { + if (auto it = expr.type_map().find(expr_id); it != expr.type_map().end()) { + return &it->second; + } + return nullptr; +} + +bool InferredBooleanNode(const CheckedExpr& checked_expr, + const NavigableProtoAstNode& node) { + int64_t node_id = node.expr()->id(); + const auto* checker_type = FindCheckerType(checked_expr, node_id); + if (checker_type != nullptr) { + return checker_type->has_primitive() && + checker_type->primitive() == Type::BOOL; + } + + return false; +} + +void TraverseAndCalculateCoverage( + const CheckedExpr& checked_expr, const NavigableProtoAstNode& node, + const absl::flat_hash_map& + stats_map, + bool log_unencountered, std::string preceeding_tabs, + CoverageIndex::CoverageReport& report, std::string& dot_graph) { + int64_t node_id = node.expr()->id(); + + const CoverageIndex::NodeCoverageStats& stats = stats_map.at(node_id); + report.nodes++; + + absl::StatusOr unparsed = + google::api::expr::Unparse(*node.expr()); + std::string expr_text = unparsed.ok() ? *unparsed : "unparse_failed"; + + bool is_interesting_bool_node = + stats.is_boolean_node && !node.expr()->has_const_expr() && + (!node.expr()->has_call_expr() || + node.expr()->call_expr().function() != "cel.@block"); + + absl::string_view node_coverage_style = kUncoveredNodeStyle; + if (stats.covered) { + if (is_interesting_bool_node) { + if (stats.has_true_branch && stats.has_false_branch) { + node_coverage_style = kCompletelyCoveredNodeStyle; + } else { + node_coverage_style = kPartiallyCoveredNodeStyle; + } + } else { + node_coverage_style = kCompletelyCoveredNodeStyle; + } + } + std::string escaped_expr_text = EscapeSpecialCharacters(expr_text); + dot_graph += absl::StrFormat( + "%d [shape=record, %s, label=\"{<1> exprID: %d | <2> %s} | <3> %s\"];\n", + node_id, node_coverage_style, node_id, KindToString(node), + escaped_expr_text); + + bool node_covered = stats.covered; + if (node_covered) { + report.covered_nodes++; + } else if (log_unencountered) { + if (is_interesting_bool_node) { + report.unencountered_nodes.push_back( + absl::StrCat("Expression ID ", node_id, " ('", expr_text, "')")); + } + log_unencountered = false; + } + + if (is_interesting_bool_node) { + report.branches += 2; + if (stats.has_true_branch) { + report.covered_boolean_outcomes++; + } else if (log_unencountered) { + report.unencountered_branches.push_back( + absl::StrCat("\n", preceeding_tabs, "Expression ID ", node_id, " ('", + expr_text, "'): Never evaluated to 'true'")); + preceeding_tabs += "\t\t"; + } + if (stats.has_false_branch) { + report.covered_boolean_outcomes++; + } else if (log_unencountered) { + report.unencountered_branches.push_back( + absl::StrCat("\n", preceeding_tabs, "Expression ID ", node_id, " ('", + expr_text, "'): Never evaluated to 'false'")); + preceeding_tabs += "\t\t"; + } + } + + for (const auto* child : node.children()) { + dot_graph += absl::StrFormat("%d -> %d;\n", node_id, child->expr()->id()); + TraverseAndCalculateCoverage(checked_expr, *child, stats_map, + log_unencountered, preceeding_tabs, report, + dot_graph); + } +} + +int32_t GetLineNumber(const cel::expr::SourceInfo& source_info, + int32_t offset) { + auto line_it = std::upper_bound(source_info.line_offsets().begin(), + source_info.line_offsets().end(), offset); + return std::distance(source_info.line_offsets().begin(), line_it) + 1; +} + +} // namespace + +void CoverageIndex::RecordCoverage(int64_t node_id, const cel::Value& value) { + NodeCoverageStats& stats = node_coverage_stats_[node_id]; + stats.covered = true; + if (node_coverage_stats_[node_id].is_boolean_node && value.IsBool()) { + if (value.AsBool()->NativeValue()) { + stats.has_true_branch = true; + } else { + stats.has_false_branch = true; + } + } +} + +void CoverageIndex::Init(const cel::expr::CheckedExpr& checked_expr) { + checked_expr_ = checked_expr; + navigable_ast_ = NavigableProtoAst::Build(checked_expr_.expr()); + for (const auto& node : navigable_ast_.Root().DescendantsPreorder()) { + NodeCoverageStats stats; + stats.is_boolean_node = InferredBooleanNode(checked_expr_, node); + node_coverage_stats_[node.expr()->id()] = stats; + } +} + +CoverageIndex::CoverageReport CoverageIndex::GetCoverageReport() const { + CoverageReport report; + if (node_coverage_stats_.empty()) { + return report; + } + + std::string dot_graph = std::string(kDigraphHeader); + TraverseAndCalculateCoverage(checked_expr_, navigable_ast_.Root(), + node_coverage_stats_, true, "", report, + dot_graph); + dot_graph += "}\n"; + report.dot_graph = dot_graph; + report.cel_expression = + google::api::expr::Unparse(checked_expr_).value_or(""); + return report; +} + +void CoverageIndex::WriteLCOV(absl::string_view path) { + std::ofstream file(std::string(path).c_str()); + if (!file.is_open()) { + return; + } + + // Maps instrumented line numbers to whether they are covered. + std::map lines; + const auto& positions = checked_expr_.source_info().positions(); + for (const auto& [node_id, stats] : node_coverage_stats_) { + auto it = positions.find(node_id); + if (it == positions.end()) continue; + int line_num = GetLineNumber(checked_expr_.source_info(), it->second); + bool& covered = lines[line_num]; + covered = covered || stats.covered; + } + + file << "SF:" << checked_expr_.source_info().location() << "\n"; + for (auto& [line_num, covered] : lines) { + file << "DA:" << line_num << "," << (covered ? 1 : 0) << "\n"; + } + file << "end_of_record\n"; +} + +InstrumentationFactory InstrumentationFactoryForCoverage( + CoverageIndex& coverage_index) { + return [&](const cel::Ast& ast) -> Instrumentation { + return [&](int64_t node_id, const cel::Value& value) -> absl::Status { + coverage_index.RecordCoverage(node_id, value); + return absl::OkStatus(); + }; + }; +} + +absl::Status EnableCoverageInRuntime(cel::Runtime& runtime, + CoverageIndex& coverage_index) { + auto& runtime_impl = + cel::internal::down_cast(runtime); + runtime_impl.expr_builder().AddProgramOptimizer( + google::api::expr::runtime::CreateInstrumentationExtension( + InstrumentationFactoryForCoverage(coverage_index))); + return absl::OkStatus(); +} + +absl::Status EnableCoverageInCelExpressionBuilder( + CelExpressionBuilder& cel_expression_builder, + CoverageIndex& coverage_index) { + auto& cel_expression_builder_impl = cel::internal::down_cast< + google::api::expr::runtime::CelExpressionBuilderFlatImpl&>( + cel_expression_builder); + cel_expression_builder_impl.flat_expr_builder().AddProgramOptimizer( + google::api::expr::runtime::CreateInstrumentationExtension( + InstrumentationFactoryForCoverage(coverage_index))); + return absl::OkStatus(); +} + +} // namespace cel::test diff --git a/testing/testrunner/coverage_index.h b/testing/testrunner/coverage_index.h new file mode 100644 index 000000000..746281494 --- /dev/null +++ b/testing/testrunner/coverage_index.h @@ -0,0 +1,123 @@ +// Copyright 2025 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_INDEX_H_ +#define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_INDEX_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/value.h" +#include "eval/public/cel_expression.h" +#include "runtime/runtime.h" +#include "tools/navigable_ast.h" + +namespace cel::test { +inline constexpr absl::string_view kDigraphHeader = "digraph {\n"; +inline constexpr absl::string_view kUncoveredNodeStyle = + R"(color="indianred2", style=filled)"; +inline constexpr absl::string_view kPartiallyCoveredNodeStyle = + R"(color="lightyellow", style=filled)"; +inline constexpr absl::string_view kCompletelyCoveredNodeStyle = + R"(color="lightgreen", style=filled)"; + +// `CoverageIndex` is a utility for tracking expression coverage based on the +// Abstract Syntax Tree (AST) of a `cel::expr::CheckedExpr`. +// +// To use `CoverageIndex`, it must first be initialized with a +// `cel::expr::CheckedExpr` using the `Init` method. This allows the +// index to build up a representation of all the nodes and potential boolean +// branches within the expression. +// +// The `CoverageIndex` is then integrated with the CEL evaluation process. +// This is done by enabling coverage either in a `cel::Runtime` or a +// `google::api::expr::runtime::CelExpressionBuilder` using the provided helper +// functions (`EnableCoverageInRuntime` or +// `EnableCoverageInCelExpressionBuilder`). When integrated, the CEL evaluation +// engine will call `RecordCoverage` for each visited expression node, allowing +// `CoverageIndex` to track which parts of the expression were executed and, +// for boolean-producing nodes, which branches were taken (true/false). +// +// After evaluation, a `CoverageReport` can be generated, summarizing the +// executed nodes and branches, and highlighting any unencountered parts of +// the expression. +class CoverageIndex { + public: + struct NodeCoverageStats { + bool is_boolean_node = false; + bool covered = false; + bool has_true_branch = false; + bool has_false_branch = false; + }; + + struct CoverageReport { + std::string cel_expression; + int64_t nodes = 0; + int64_t covered_nodes = 0; + int64_t branches = 0; + int64_t covered_boolean_outcomes = 0; + std::vector unencountered_nodes; + std::vector unencountered_branches; + std::string dot_graph; + }; + + // Initializes the coverage index with the given checked expression. + // + // The coverage index will be initialized with an entry for each node in the + // AST. + void Init(const cel::expr::CheckedExpr& checked_expr); + + // Records coverage for the given node. + // + // The coverage index will be updated with the coverage information for the + // given node. + void RecordCoverage(int64_t node_id, const cel::Value& value); + + // Returns a coverage report for the given checked expression. + CoverageReport GetCoverageReport() const; + + // Writes the coverage in LCOV format to the given path. + void WriteLCOV(absl::string_view path); + + private: + absl::flat_hash_map node_coverage_stats_; + NavigableProtoAst navigable_ast_; + cel::expr::CheckedExpr checked_expr_; +}; + +// Enables coverage tracking within the provided `cel::Runtime`. +// Note: This function ties the `runtime` instance to a single expression. +// Do not reuse this `runtime` instance with multiple expressions when coverage +// is enabled, as the `coverage_index` will accumulate results across different +// expressions, leading to incorrect coverage reports. +absl::Status EnableCoverageInCelExpressionBuilder( + google::api::expr::runtime::CelExpressionBuilder& cel_expression_builder, + CoverageIndex& coverage_index); + +// Enables coverage tracking within the provided `CelExpressionBuilder`. +// Note: This function ties the `cel_expression_builder` instance to a single +// expression. Do not reuse this `cel_expression_builder` instance with +// multiple expressions when coverage is enabled, as the `coverage_index` will +// accumulate results across different expressions, leading to incorrect +// coverage reports. +absl::Status EnableCoverageInRuntime(cel::Runtime& runtime, + CoverageIndex& coverage_index); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_INDEX_H_ diff --git a/testing/testrunner/coverage_index_test.cc b/testing/testrunner/coverage_index_test.cc new file mode 100644 index 000000000..6e9e2b0d3 --- /dev/null +++ b/testing/testrunner/coverage_index_test.cc @@ -0,0 +1,160 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "testing/testrunner/coverage_index.h" + +#include +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::test { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::CheckedExpr; + +absl::StatusOr> CreateTestRuntime() { + CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder standard_runtime_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetTestingDescriptorPool(), {})); + return std::move(standard_runtime_builder).Build(); +} + +TEST(CoverageIndexTest, RecordCoverageWithErrorDoesNotCrash) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", cel::IntType())), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(compiler_builder)->Build()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile("1/x > 1")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + IsOk()); + + CoverageIndex coverage_index; + coverage_index.Init(checked_expr); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), + coverage_index), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + cel::CreateAstFromCheckedExpr(checked_expr)); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(0)); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(cel::Value result, + program->Evaluate(&arena, activation)); + EXPECT_TRUE(result.IsError()); +} + +TEST(CoverageIndexTest, WriteLCOV) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", cel::BoolType())), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(compiler_builder)->Build()); + const absl::string_view kSrc = R"(x ? +true : +false +)"; + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile(kSrc)); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + IsOk()); + checked_expr.mutable_source_info()->set_location("test.cel"); + + CoverageIndex coverage_index; + coverage_index.Init(checked_expr); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), + coverage_index), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + cel::CreateAstFromCheckedExpr(checked_expr)); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::BoolValue(true)); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(cel::Value result, + program->Evaluate(&arena, activation)); + EXPECT_TRUE(result.GetBool().NativeValue()); + + std::string temp_file = absl::StrCat(testing::TempDir(), "/coverage.lcov"); + coverage_index.WriteLCOV(temp_file); + + std::ifstream f(temp_file); + std::stringstream buffer; + buffer << f.rdbuf(); + std::string content = buffer.str(); + + // Verify content. + // We expect "test.cel" to be the source file. + EXPECT_THAT(content, testing::HasSubstr("SF:test.cel")); + // Line 1 (x ?) should be covered. + EXPECT_THAT(content, testing::HasSubstr("DA:1,1")); + // Line 2 (true) should be covered. + EXPECT_THAT(content, testing::HasSubstr("DA:2,1")); + // Line 3 (false) should be uncovered. + EXPECT_THAT(content, testing::HasSubstr("DA:3,0")); + // Line 4 (empty) should not be instrumented. + EXPECT_THAT(content, testing::Not(testing::HasSubstr("DA:4,"))); + EXPECT_THAT(content, testing::HasSubstr("end_of_record")); +} + +} // namespace +} // namespace cel::test diff --git a/testing/testrunner/coverage_reporting.cc b/testing/testrunner/coverage_reporting.cc new file mode 100644 index 000000000..d37386cc3 --- /dev/null +++ b/testing/testrunner/coverage_reporting.cc @@ -0,0 +1,124 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testing/testrunner/coverage_reporting.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "internal/testing.h" +#include "testing/testrunner/coverage_index.h" + +namespace cel::test { +void CoverageReportingEnvironment::TearDown() { + CoverageIndex::CoverageReport coverage_report = + coverage_index_.GetCoverageReport(); + testing::Test::RecordProperty("CEL Expression", + coverage_report.cel_expression); + std::cout << "CEL Expression: " << coverage_report.cel_expression; + if (coverage_report.nodes == 0) { + testing::Test::RecordProperty("CEL Coverage", "No coverage stats found"); + std::cout << "CEL Coverage: " << "No coverage stats found"; + return; + } + + // Log Node Coverage results + double node_coverage = static_cast(coverage_report.covered_nodes) / + static_cast(coverage_report.nodes) * 100.0; + std::string node_coverage_string = + absl::StrFormat("%.2f%% (%d out of %d nodes covered)", node_coverage, + coverage_report.covered_nodes, coverage_report.nodes); + testing::Test::RecordProperty("AST Node Coverage", node_coverage_string); + std::cout << "AST Node Coverage: " << node_coverage_string; + if (!coverage_report.unencountered_nodes.empty()) { + testing::Test::RecordProperty( + "Interesting Unencountered Nodes", + absl::StrJoin(coverage_report.unencountered_nodes, "\n")); + std::cout << "Interesting Unencountered Nodes: " + << absl::StrJoin(coverage_report.unencountered_nodes, "\n"); + } + + // Log Branch Coverage results + double branch_coverage = 0.0; + if (coverage_report.branches > 0) { + branch_coverage = + static_cast(coverage_report.covered_boolean_outcomes) / + static_cast(coverage_report.branches) * 100.0; + } + std::string branch_coverage_string = absl::StrFormat( + "%.2f%% (%d out of %d branch outcomes covered)", branch_coverage, + coverage_report.covered_boolean_outcomes, coverage_report.branches); + testing::Test::RecordProperty("AST Branch Coverage", branch_coverage_string); + std::cout << "AST Branch Coverage: " << branch_coverage_string; + if (!coverage_report.unencountered_branches.empty()) { + testing::Test::RecordProperty( + "Interesting Unencountered Branch Paths", + absl::StrJoin(coverage_report.unencountered_branches, "\n")); + std::cout << "Interesting Unencountered Branch Paths: " + << absl::StrJoin(coverage_report.unencountered_branches, + "\n"); + } + if (!coverage_report.dot_graph.empty()) { + WriteDotGraphToArtifact(coverage_report.dot_graph); + } +} + +void CoverageReportingEnvironment::WriteDotGraphToArtifact( + absl::string_view dot_graph) { + // Save DOT graph to file in TEST_UNDECLARED_OUTPUTS_DIR or default dir + const char* outputs_dir_env = std::getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + // For non-Bazel/Blaze users, we write to a subdirectory under the current + // working directory. + // NOMUTANTS --cel_artifacts is for non-Bazel/Blaze users only so not + // needed to test in our case. + std::string outputs_dir = + (outputs_dir_env == nullptr) ? "cel_artifacts" : outputs_dir_env; + std::string coverage_dir = absl::StrCat(outputs_dir, "/cel_test_coverage"); + // Creates the directory to store CEL test coverage artifacts. + // The second argument, `0755`, sets the directory's permissions in octal + // format, which is a standard for file system operations. It grants: + // - Owner: read, write, and execute permissions (7 = 4+2+1). + // - Group: read and execute permissions (5 = 4+1). + // - Others: read and execute permissions (5 = 4+1). + // This gives the owner full control while allowing other users to access + // the generated artifacts. + int mkdir_result = mkdir(coverage_dir.c_str(), 0755); + // If mkdir fails, it sets the global 'errno' variable to an error code + // indicating the reason. We check this code to specifically ignore the + // EEXIST error, which just means the directory already exists (this is not + // a real failure we need to warn about). + if (mkdir_result == 0 || errno == EEXIST) { + std::string graph_path = absl::StrCat(coverage_dir, "/coverage_graph.txt"); + std::ofstream out(graph_path); + if (out.is_open()) { + out << dot_graph; + out.close(); + } else { + ABSL_LOG(WARNING) << "Failed to open file for writing: " << graph_path; + } + } else { + ABSL_LOG(WARNING) << "Failed to create directory: " << coverage_dir + << " (reason: " << strerror(errno) << ")"; + } +} +} // namespace cel::test diff --git a/testing/testrunner/coverage_reporting.h b/testing/testrunner/coverage_reporting.h new file mode 100644 index 000000000..2e1f4ad23 --- /dev/null +++ b/testing/testrunner/coverage_reporting.h @@ -0,0 +1,43 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_REPORTING_H_ +#define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_REPORTING_H_ + +#include "absl/strings/string_view.h" +#include "internal/testing.h" +#include "testing/testrunner/coverage_index.h" + +namespace cel::test { +// A Google Test Environment that reports CEL coverage results in its TearDown +// phase. +// +// This class encapsulates the logic for calculating coverage statistics and +// logging them as test properties. +class CoverageReportingEnvironment : public testing::Environment { + public: + explicit CoverageReportingEnvironment(CoverageIndex& coverage_index) + : coverage_index_(coverage_index) {}; + + // Called by the Google Test framework after all tests have run. + void TearDown() override; + + private: + // Helper function to write the DOT graph to a test artifact file. + void WriteDotGraphToArtifact(absl::string_view dot_graph); + + CoverageIndex& coverage_index_; +}; +} // namespace cel::test +#endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_REPORTING_H_ diff --git a/testing/testrunner/resources/BUILD b/testing/testrunner/resources/BUILD new file mode 100644 index 000000000..241746fd5 --- /dev/null +++ b/testing/testrunner/resources/BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//visibility:public"]) + +exports_files( + [ + "test.cel", + ], +) + +filegroup( + name = "resources", + srcs = glob([ + "*.textproto", + ]), +) diff --git a/testing/testrunner/resources/simple_tests.textproto b/testing/testrunner/resources/simple_tests.textproto new file mode 100644 index 000000000..7add08851 --- /dev/null +++ b/testing/testrunner/resources/simple_tests.textproto @@ -0,0 +1,44 @@ +# proto-file: google3/third_party/cel/spec/proto/cel/expr/conformance/test/suite.proto +# proto-message: cel.expr.conformance.test.TestSuite + +name: "simple_tests" +description: "Simple tests to validate the test runner." +sections: { + name: "simple_map_operations" + description: "Tests for map operations." + tests: { + name: "literal_and_sum" + description: "Test that a map can be created and values can be accessed." + input: { + key: "x" + value { value { int64_value: 1 } } + } + input { + key: "y" + value { value { int64_value: 2 } } + } + output { + result_value { + bool_value: true + } + } + } + tests: { + name: "literal_and_sum_2_5" + description: "Test that a map can be created and values can be accessed." + input: { + key: "x" + value { value { int64_value: 2 } } + } + input { + key: "y" + value { value { int64_value: 5 } } + } + output { + result_value { + bool_value: false + } + } + } +} + diff --git a/testing/testrunner/resources/test.cel b/testing/testrunner/resources/test.cel new file mode 100644 index 000000000..e2a8707df --- /dev/null +++ b/testing/testrunner/resources/test.cel @@ -0,0 +1 @@ +x-y \ No newline at end of file diff --git a/testing/testrunner/resources/test_environment.textproto b/testing/testrunner/resources/test_environment.textproto new file mode 100644 index 000000000..77e3b180f --- /dev/null +++ b/testing/testrunner/resources/test_environment.textproto @@ -0,0 +1,15 @@ +# proto-file: third_party/cel/go/tools/compilecli/compile_input.proto +# proto-message: Environment + +declarations: { + name: "x" + ident: { + type: { primitive: INT64 } + } +} +declarations: { + name: "y" + ident: { + type: { primitive: INT64 } + } +} diff --git a/testing/testrunner/runner_bin.cc b/testing/testrunner/runner_bin.cc new file mode 100644 index 000000000..c11908ca5 --- /dev/null +++ b/testing/testrunner/runner_bin.cc @@ -0,0 +1,295 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This binary is a test runner for CEL tests. It is used to run CEL tests +// written in the CEL test suite format. +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "eval/public/cel_expression.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/runtime.h" +#include "testing/testrunner/cel_expression_source.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/cel_test_factories.h" +#include "testing/testrunner/coverage_index.h" +#include "testing/testrunner/coverage_reporting.h" +#include "testing/testrunner/runner_lib.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/text_format.h" + +ABSL_FLAG(std::string, test_suite_path, "", + "The path to the file containing the test suite to run."); +ABSL_FLAG(std::string, expr_source_type, "", + "The kind of expression source: 'raw', 'file', or 'checked'."); +ABSL_FLAG(std::string, expr_source, "", + "The value of the CEL expression source. For 'raw', it's the " + "expression string. For 'file' and 'checked', it's the file path."); + +ABSL_FLAG(bool, collect_coverage, false, "Whether to collect code coverage."); + +namespace { + +using ::cel::expr::conformance::test::TestCase; +using ::cel::expr::conformance::test::TestSuite; +using ::cel::test::CelExpressionSource; +using ::cel::test::CelTestContext; +using ::cel::test::CoverageIndex; +using ::cel::test::TestRunner; +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::CelExpressionBuilder; + +class CelTest : public testing::Test { + public: + explicit CelTest(std::shared_ptr test_runner, + const TestCase& test_case) + : test_runner_(std::move(test_runner)), test_case_(test_case) {} + + void TestBody() override { test_runner_->RunTest(test_case_); } + + private: + std::shared_ptr test_runner_; + TestCase test_case_; +}; + +absl::Status RegisterTests(const TestSuite& test_suite, + const std::shared_ptr& test_runner) { + for (const auto& section : test_suite.sections()) { + for (const TestCase& test_case : section.tests()) { + testing::RegisterTest( + test_suite.name().c_str(), + absl::StrCat(section.name(), "/", test_case.name()).c_str(), nullptr, + nullptr, __FILE__, __LINE__, [&test_runner, test_case]() -> CelTest* { + return new CelTest(test_runner, test_case); + }); + } + } + return absl::OkStatus(); +} + +absl::StatusOr ReadFileToString(absl::string_view file_path) { + std::ifstream file_stream{std::string(file_path)}; + if (!file_stream.is_open()) { + return absl::NotFoundError( + absl::StrCat("Unable to open file: ", file_path)); + } + std::stringstream buffer; + buffer << file_stream.rdbuf(); + return buffer.str(); +} + +template +absl::StatusOr ReadTextProtoFromFile(absl::string_view file_path) { + CEL_ASSIGN_OR_RETURN(std::string contents, ReadFileToString(file_path)); + T message; + if (!google::protobuf::TextFormat::ParseFromString(contents, &message)) { + return absl::InternalError(absl::StrCat( + "Failed to parse text-format proto from file: ", file_path)); + } + return message; +} + +absl::StatusOr ReadBinaryProtoFromFile( + absl::string_view file_path) { + CheckedExpr message; + std::ifstream file_stream{std::string(file_path), std::ios::binary}; + if (!file_stream.is_open()) { + return absl::NotFoundError( + absl::StrCat("Unable to open file: ", file_path)); + } + if (!message.ParseFromIstream(&file_stream)) { + return absl::InternalError( + absl::StrCat("Failed to parse binary proto from file: ", file_path)); + } + return message; +} + +TestSuite ReadTestSuiteFromPath(absl::string_view test_suite_path) { + absl::StatusOr test_suite_or = + ReadTextProtoFromFile(test_suite_path); + + if (!test_suite_or.ok()) { + ABSL_LOG(FATAL) << "Failed to load test suite from " << test_suite_path + << ": " << test_suite_or.status(); + } + return *std::move(test_suite_or); +} + +absl::StatusOr ReadCheckedExprFromFile( + absl::string_view file_path) { + if (absl::EndsWith(file_path, ".textproto")) { + return ReadTextProtoFromFile(file_path); + } + if (absl::EndsWith(file_path, ".binarypb")) { + return ReadBinaryProtoFromFile(file_path); + } + return absl::InvalidArgumentError(absl::StrCat( + "Unknown file extension for checked expression. ", + "Please use .textproto, .textpb, .pb, or .binarypb: ", file_path)); +} + +TestSuite GetTestSuite() { + std::string test_suite_path = absl::GetFlag(FLAGS_test_suite_path); + if (!test_suite_path.empty()) { + return ReadTestSuiteFromPath(test_suite_path); + } + + // If no test suite path is provided, use the factory function to get the + // test suite after checking if the factory function is empty or not. + std::function test_suite_factory = + cel::test::internal::GetCelTestSuiteFactory(); + if (test_suite_factory == nullptr) { + ABSL_LOG(FATAL) + << "No CEL test suite provided. Please provide a test suite using " + "either the bzl macro or the CEL_REGISTER_TEST_SUITE_FACTORY " + "preprocessor macro."; + } + return test_suite_factory(); +} + +void UpdateWithExpressionFromCommandLineFlags( + CelTestContext& cel_test_context) { + if (absl::GetFlag(FLAGS_expr_source).empty()) { + return; + } + + constexpr absl::string_view kRawExpressionKind = "raw"; + constexpr absl::string_view kFileExpressionKind = "file"; + constexpr absl::string_view kCheckedExpressionKind = "checked"; + + std::string kind = absl::GetFlag(FLAGS_expr_source_type); + std::string value = absl::GetFlag(FLAGS_expr_source); + + std::optional expression_source_from_flags; + if (kind == kRawExpressionKind) { + expression_source_from_flags = + CelExpressionSource::FromRawExpression(value); + } else if (kind == kFileExpressionKind) { + expression_source_from_flags = CelExpressionSource::FromCelFile(value); + } else if (kind == kCheckedExpressionKind) { + absl::StatusOr checked_expr = ReadCheckedExprFromFile(value); + if (!checked_expr.ok()) { + ABSL_LOG(FATAL) << "Failed to read checked expression from file: " + << checked_expr.status(); + } + expression_source_from_flags = + CelExpressionSource::FromCheckedExpr(std::move(*checked_expr)); + } else { + ABSL_LOG(FATAL) << "Unknown expression kind: " << kind; + } + + // Check for conflicting expression sources. + if (cel_test_context.expression_source() != nullptr) { + ABSL_LOG(FATAL) + << "Expression source can only be set once and is currently set via " + "the factory."; + } + + if (expression_source_from_flags.has_value()) { + cel_test_context.SetExpressionSource( + std::move(*expression_source_from_flags)); + } +} + +} // namespace + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Create a test context using the factory function returned by the global + // factory function provider which was initialized by the user. + absl::StatusOr> + cel_test_context_or = cel::test::internal::GetCelTestContextFactory()(); + if (!cel_test_context_or.ok()) { + ABSL_LOG(FATAL) << "Failed to create CEL test context from factory: " + << cel_test_context_or.status(); + } + std::unique_ptr cel_test_context = + std::move(cel_test_context_or.value()); + + // We manually enable coverage here instead of just setting the + // `enable_coverage` flag on the context. This is intentional and necessary + // for this binary's reporting model. + // + // This binary needs a single coverage report for all tests run. + // We create `coverage_index` here, local to the `main` function, so its + // lifetime spans the entire test run. + // + // We must pass this specific instance to the + // `CoverageReportingEnvironment`, which Google Test calls after all + // dynamically registered tests are finished. + // + // If we just set the `enable_coverage` flag, the `TestRunner`'s + // constructor (as used in our `cc_test` files) would create its own + // internal `CoverageIndex`. That internal index would be destroyed + // with the `TestRunner` and would not populate the `coverage_index` + // instance needed by our global reporter. + // + // This manual approach ensures all tests populate the same `coverage_index` + // (the one local to `main`), which is then ready for the final report. + cel::test::CoverageIndex coverage_index; + + if (absl::GetFlag(FLAGS_collect_coverage)) { + if (cel_test_context->runtime() != nullptr) { + ABSL_CHECK_OK(cel::test::EnableCoverageInRuntime( + const_cast(*cel_test_context->runtime()), + coverage_index)); + } else if (cel_test_context->cel_expression_builder() != nullptr) { + ABSL_CHECK_OK(cel::test::EnableCoverageInCelExpressionBuilder( + const_cast( + *cel_test_context->cel_expression_builder()), + coverage_index)); + } + } + + // Update the context with an expression from flags, if provided. + // This will FATAL if an expression is set by both the factory and flags. + UpdateWithExpressionFromCommandLineFlags(*cel_test_context); + + auto test_runner = std::make_shared(std::move(cel_test_context)); + ABSL_CHECK_OK(RegisterTests(GetTestSuite(), test_runner)); + + // Make sure the checked expression exists during the entire test run since + // the ast references it during coverage collection at teardown. + absl::StatusOr checked_expr = + test_runner->GetCheckedExpr(); + if (!checked_expr.ok()) { + ABSL_LOG(FATAL) << "Failed to get checked expression: " + << checked_expr.status(); + } + + if (absl::GetFlag(FLAGS_collect_coverage)) { + coverage_index.Init(*checked_expr); + testing::AddGlobalTestEnvironment( + new cel::test::CoverageReportingEnvironment(coverage_index)); + } + + return RUN_ALL_TESTS(); +} diff --git a/testing/testrunner/runner_lib.cc b/testing/testrunner/runner_lib.cc new file mode 100644 index 000000000..28806cec7 --- /dev/null +++ b/testing/testrunner/runner_lib.cc @@ -0,0 +1,443 @@ +// Copyright 2025 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "testing/testrunner/runner_lib.h" + +#include +#include +#include +#include +#include +#include + +#include "cel/expr/eval.pb.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/internal/value_conversion.h" +#include "common/value.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/public/transform_utility.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "testing/testrunner/cel_expression_source.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/coverage_index.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/field_comparator.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel::test { +namespace { + +using ::cel::expr::conformance::test::InputValue; +using ::cel::expr::conformance::test::TestCase; +using ::cel::expr::conformance::test::TestOutput; +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::CelExpression; +using ::google::api::expr::runtime::ValueToCelValue; +using ::google::api::expr::runtime::Activation; + +using LegacyCelValue = ::google::api::expr::runtime::CelValue; +using ValueProto = ::cel::expr::Value; + +absl::StatusOr ReadFileToString(absl::string_view file_path) { + std::ifstream file_stream{std::string(file_path)}; + if (!file_stream.is_open()) { + return absl::NotFoundError( + absl::StrCat("Unable to open file: ", file_path)); + } + std::stringstream buffer; + buffer << file_stream.rdbuf(); + return buffer.str(); +} + +absl::StatusOr Compile(absl::string_view expression, + const CelTestContext& context) { + const auto* compiler = context.compiler(); + if (compiler == nullptr) { + return absl::InvalidArgumentError( + "A compiler must be provided to compile a raw expression or .cel " + "file."); + } + + CEL_ASSIGN_OR_RETURN(ValidationResult validation_result, + compiler->Compile(expression)); + if (!validation_result.IsValid()) { + return absl::InternalError(validation_result.FormatError()); + } + + CheckedExpr checked_expr; + CEL_RETURN_IF_ERROR( + AstToCheckedExpr(*validation_result.GetAst(), &checked_expr)); + return checked_expr; +} + +absl::StatusOr> Plan( + const CheckedExpr& checked_expr, const cel::Runtime* runtime) { + std::unique_ptr ast; + CEL_ASSIGN_OR_RETURN(ast, cel::CreateAstFromCheckedExpr(checked_expr)); + if (ast == nullptr) { + return absl::InternalError("No expression provided for testing."); + } + return runtime->CreateProgram(std::move(ast)); +} + +const google::protobuf::DescriptorPool* GetDescriptorPool(const CelTestContext& context) { + return context.cel_expression_builder() != nullptr + ? google::protobuf::DescriptorPool::generated_pool() + : context.runtime()->GetDescriptorPool(); +} + +google::protobuf::MessageFactory* GetMessageFactory(const CelTestContext& context) { + return context.cel_expression_builder() != nullptr + ? google::protobuf::MessageFactory::generated_factory() + : context.runtime()->GetMessageFactory(); +} + +absl::StatusOr EvalWithModernBindings( + const CheckedExpr& checked_expr, const CelTestContext& context, + const cel::Activation& activation, google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + Plan(checked_expr, context.runtime())); + return program->Evaluate(arena, activation); +} + +absl::StatusOr EvalWithLegacyBindings( + const CheckedExpr& checked_expr, const CelTestContext& context, + const Activation& activation, google::protobuf::Arena* arena) { + const auto* builder = context.cel_expression_builder(); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr sub_expression, + builder->CreateExpression(&checked_expr)); + + CEL_ASSIGN_OR_RETURN(LegacyCelValue legacy_result, + sub_expression->Evaluate(activation, arena)); + + ValueProto result_proto; + CEL_RETURN_IF_ERROR(CelValueToValue(legacy_result, &result_proto)); + return FromExprValue(result_proto, GetDescriptorPool(context), + GetMessageFactory(context), arena); +} + +absl::StatusOr ResolveValue(const InputValue& input_value, + const CelTestContext& context, + google::protobuf::Arena* arena) { + return FromExprValue(input_value.value(), GetDescriptorPool(context), + GetMessageFactory(context), arena); +} + +absl::StatusOr ResolveExpr(absl::string_view expr, + const CelTestContext& context, + google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, Compile(expr, context)); + if (context.runtime() != nullptr) { + cel::Activation empty_activation; + return EvalWithModernBindings(checked_expr, context, empty_activation, + arena); + } else { + Activation empty_activation; + return EvalWithLegacyBindings(checked_expr, context, empty_activation, + arena); + } +} + +absl::StatusOr ResolveInputValue(const InputValue& input_value, + const CelTestContext& context, + google::protobuf::Arena* arena) { + switch (input_value.kind_case()) { + case InputValue::kValue: { + return ResolveValue(input_value, context, arena); + } + case InputValue::kExpr: { + return ResolveExpr(input_value.expr(), context, arena); + } + default: + return absl::InvalidArgumentError("Unknown InputValue kind."); + } +} + +absl::Status AddCustomBindingsToModernActivation(const CelTestContext& context, + cel::Activation& activation, + google::protobuf::Arena* arena) { + for (const auto& binding : context.custom_bindings()) { + CEL_ASSIGN_OR_RETURN(cel::Value value, + FromExprValue(/*value_proto=*/binding.second, + GetDescriptorPool(context), + GetMessageFactory(context), arena)); + activation.InsertOrAssignValue(/*name=*/binding.first, value); + } + return absl::OkStatus(); +} + +absl::Status AddTestCaseBindingsToModernActivation( + const TestCase& test_case, const CelTestContext& context, + cel::Activation& activation, google::protobuf::Arena* arena) { + for (const auto& binding : test_case.input()) { + CEL_ASSIGN_OR_RETURN( + cel::Value value, + ResolveInputValue(/*input_value=*/binding.second, context, arena)); + activation.InsertOrAssignValue(/*name=*/binding.first, std::move(value)); + } + return absl::OkStatus(); +} + +absl::StatusOr GetActivation(const CelTestContext& context, + const TestCase& test_case, + google::protobuf::Arena* arena) { + if (context.activation_factory() != nullptr) { + return context.activation_factory()(test_case, arena); + } + return cel::Activation(); +} + +absl::StatusOr CreateModernActivationFromBindings( + const TestCase& test_case, const CelTestContext& context, + google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN(cel::Activation activation, + GetActivation(context, test_case, arena)); + CEL_RETURN_IF_ERROR( + AddCustomBindingsToModernActivation(context, activation, arena)); + + CEL_RETURN_IF_ERROR(AddTestCaseBindingsToModernActivation(test_case, context, + activation, arena)); + + return activation; +} + +absl::Status AddCustomBindingsToLegacyActivation(const CelTestContext& context, + Activation& activation, + google::protobuf::Arena* arena) { + for (const auto& binding : context.custom_bindings()) { + CEL_ASSIGN_OR_RETURN( + LegacyCelValue value, + ValueToCelValue(/*value_proto=*/binding.second, arena)); + activation.InsertValue(/*name=*/binding.first, value); + } + return absl::OkStatus(); +} + +absl::Status AddTestCaseBindingsToLegacyActivation( + const TestCase& test_case, const CelTestContext& context, + Activation& activation, google::protobuf::Arena* arena) { + auto* message_factory = GetMessageFactory(context); + auto* descriptor_pool = GetDescriptorPool(context); + for (const auto& binding : test_case.input()) { + CEL_ASSIGN_OR_RETURN( + cel::Value resolved_cel_value, + ResolveInputValue(/*input_value=*/binding.second, context, arena)); + CEL_ASSIGN_OR_RETURN(ValueProto value_proto, + ToExprValue(resolved_cel_value, descriptor_pool, + message_factory, arena)); + CEL_ASSIGN_OR_RETURN(LegacyCelValue value, + ValueToCelValue(value_proto, arena)); + activation.InsertValue(/*name=*/binding.first, value); + } + return absl::OkStatus(); +} + +absl::StatusOr CreateLegacyActivationFromBindings( + const TestCase& test_case, const CelTestContext& context, + google::protobuf::Arena* arena) { + Activation activation; + + CEL_RETURN_IF_ERROR( + AddCustomBindingsToLegacyActivation(context, activation, arena)); + + CEL_RETURN_IF_ERROR(AddTestCaseBindingsToLegacyActivation(test_case, context, + activation, arena)); + + return activation; +} + +bool IsEqual(const ValueProto& expected, const ValueProto& actual) { + static auto* kFieldComparator = []() { + auto* field_comparator = new google::protobuf::util::DefaultFieldComparator(); + field_comparator->set_treat_nan_as_equal(true); + return field_comparator; + }(); + static auto* kDifferencer = []() { + auto* differencer = new google::protobuf::util::MessageDifferencer(); + differencer->set_message_field_comparison( + google::protobuf::util::MessageDifferencer::EQUIVALENT); + differencer->set_field_comparator(kFieldComparator); + const auto* descriptor = cel::expr::MapValue::descriptor(); + const auto* entries_field = descriptor->FindFieldByName("entries"); + const auto* key_field = + entries_field->message_type()->FindFieldByName("key"); + differencer->TreatAsMap(entries_field, key_field); + return differencer; + }(); + return kDifferencer->Compare(expected, actual); +} + +MATCHER_P(MatchesValue, expected, "") { return IsEqual(arg, expected); } +} // namespace + +void TestRunner::AssertValue(const cel::Value& computed, + const TestOutput& output, google::protobuf::Arena* arena) { + if (computed.IsError()) { + ADD_FAILURE() << "Expected value but got error: " << computed.DebugString(); + return; + } + ValueProto expected_value_proto; + const auto* descriptor_pool = GetDescriptorPool(*test_context_); + auto* message_factory = GetMessageFactory(*test_context_); + if (output.has_result_value()) { + expected_value_proto = output.result_value(); + } else if (output.has_result_expr()) { + InputValue input_value; + input_value.set_expr(output.result_expr()); + ASSERT_OK_AND_ASSIGN(cel::Value resolved_cel_value, + ResolveInputValue(input_value, *test_context_, arena)); + ASSERT_OK_AND_ASSIGN(expected_value_proto, + ToExprValue(resolved_cel_value, descriptor_pool, + message_factory, arena)); + } + ValueProto computed_expr_value; + ASSERT_OK_AND_ASSIGN( + computed_expr_value, + ToExprValue(computed, descriptor_pool, message_factory, arena)); + EXPECT_THAT(computed_expr_value, MatchesValue(expected_value_proto)); +} + +void TestRunner::AssertError(const cel::Value& computed, + const TestOutput& output) { + if (!computed.IsError()) { + ADD_FAILURE() << "Expected error but got value: " << computed.DebugString(); + return; + } + absl::Status computed_status = computed.AsError()->ToStatus(); + // We selected the first error in the set for comparison because there is only + // one runtime error that is reported even if there are multiple errors in the + // critical path. + ASSERT_TRUE(output.eval_error().errors_size() == 1) + << "Expected exactly one error but got: " + << output.eval_error().errors_size(); + ASSERT_EQ(computed_status.message(), output.eval_error().errors(0).message()); +} + +void TestRunner::Assert(const cel::Value& computed, const TestCase& test_case, + google::protobuf::Arena* arena) { + if (test_context_->assert_fn()) { + test_context_->assert_fn()(computed, test_case, arena); + return; + } + TestOutput output = test_case.output(); + if (output.has_result_value() || output.has_result_expr()) { + AssertValue(computed, output, arena); + } else if (output.has_eval_error()) { + AssertError(computed, output); + } else if (output.has_unknown()) { + ADD_FAILURE() << "Unknown assertions not implemented yet."; + } else { + ADD_FAILURE() << "Unexpected output kind."; + } +} + +absl::StatusOr TestRunner::EvalWithRuntime( + const CheckedExpr& checked_expr, const TestCase& test_case, + google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN( + cel::Activation activation, + CreateModernActivationFromBindings(test_case, *test_context_, arena)); + return EvalWithModernBindings(checked_expr, *test_context_, activation, + arena); +} + +absl::StatusOr TestRunner::EvalWithCelExpressionBuilder( + const CheckedExpr& checked_expr, const TestCase& test_case, + google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN( + Activation activation, + CreateLegacyActivationFromBindings(test_case, *test_context_, arena)); + return EvalWithLegacyBindings(checked_expr, *test_context_, activation, + arena); +} + +absl::StatusOr TestRunner::GetCheckedExpr() const { + const CelExpressionSource* source_ptr = test_context_->expression_source(); + if (source_ptr == nullptr) { + return absl::InvalidArgumentError("No expression source provided."); + } + return std::visit( + absl::Overload([](const cel::expr::CheckedExpr& v) + -> absl::StatusOr { return v; }, + [this](const CelExpressionSource::RawExpression& v) + -> absl::StatusOr { + return Compile(v.value, *test_context_); + }, + [this](const CelExpressionSource::CelFile& v) + -> absl::StatusOr { + CEL_ASSIGN_OR_RETURN(std::string contents, + ReadFileToString(v.path)); + return Compile(contents, *test_context_); + }), + source_ptr->source()); +} + +absl::Status TestRunner::EnableCoverage() { + if (test_context_ != nullptr && test_context_->enable_coverage()) { + coverage_index_ = std::make_unique(); + + if (test_context_->runtime() != nullptr) { + auto* runtime = const_cast(test_context_->runtime()); + CEL_RETURN_IF_ERROR(EnableCoverageInRuntime(*runtime, *coverage_index_)); + } else if (test_context_->cel_expression_builder() != nullptr) { + auto* builder = + const_cast( + test_context_->cel_expression_builder()); + CEL_RETURN_IF_ERROR( + EnableCoverageInCelExpressionBuilder(*builder, *coverage_index_)); + } + } + return absl::OkStatus(); +} + +void TestRunner::RunTest(const TestCase& test_case) { + // The arena has to be declared in RunTest because cel::Value returned by + // EvalWithRuntime or EvalWithCelExpressionBuilder might contain pointers to + // the arena. The arena has to be alive during the assertion. + google::protobuf::Arena arena; + ASSERT_THAT(EnableCoverage(), absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr, GetCheckedExpr()); + + if (coverage_index_) { + coverage_index_->Init(checked_expr); + } + + if (test_context_->runtime() != nullptr) { + ASSERT_OK_AND_ASSIGN(cel::Value result, + EvalWithRuntime(checked_expr, test_case, &arena)); + ASSERT_NO_FATAL_FAILURE(Assert(result, test_case, &arena)); + } else if (test_context_->cel_expression_builder() != nullptr) { + ASSERT_OK_AND_ASSIGN( + cel::Value result, + EvalWithCelExpressionBuilder(checked_expr, test_case, &arena)); + ASSERT_NO_FATAL_FAILURE(Assert(result, test_case, &arena)); + } +} +} // namespace cel::test diff --git a/testing/testrunner/runner_lib.h b/testing/testrunner/runner_lib.h new file mode 100644 index 000000000..4fcbed13a --- /dev/null +++ b/testing/testrunner/runner_lib.h @@ -0,0 +1,84 @@ +// Copyright 2025 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RUNNER_LIBRARY_H_ +#define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RUNNER_LIBRARY_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/coverage_index.h" +#include "testing/testrunner/coverage_reporting.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" + +namespace cel::test { + +// The test runner class for running CEL tests. +class TestRunner { + public: + explicit TestRunner(std::unique_ptr test_context) + : test_context_(std::move(test_context)) {} + + // Automatically reports coverage results. + ~TestRunner() { + if (coverage_index_) { + CoverageReportingEnvironment reporter(*coverage_index_); + reporter.TearDown(); + } + } + + // Evaluates the checked expression in the test case, performs the + // assertions against the expected result. + void RunTest(const cel::expr::conformance::test::TestCase& test_case); + + // Returns the checked expression for the test case. + absl::StatusOr GetCheckedExpr() const; + + private: + absl::StatusOr EvalWithRuntime( + const cel::expr::CheckedExpr& checked_expr, + const cel::expr::conformance::test::TestCase& test_case, + google::protobuf::Arena* arena); + + absl::StatusOr EvalWithCelExpressionBuilder( + const cel::expr::CheckedExpr& checked_expr, + const cel::expr::conformance::test::TestCase& test_case, + google::protobuf::Arena* arena); + + void Assert(const cel::Value& computed, + const cel::expr::conformance::test::TestCase& test_case, + google::protobuf::Arena* arena); + + void AssertValue(const cel::Value& computed, + const cel::expr::conformance::test::TestOutput& output, + google::protobuf::Arena* arena); + + void AssertError(const cel::Value& computed, + const cel::expr::conformance::test::TestOutput& output); + + absl::Status EnableCoverage(); + + std::unique_ptr test_context_; + + std::unique_ptr coverage_index_; +}; + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RUNNER_LIBRARY_H_ diff --git a/testing/testrunner/runner_lib_test.cc b/testing/testrunner/runner_lib_test.cc new file mode 100644 index 000000000..804826b6c --- /dev/null +++ b/testing/testrunner/runner_lib_test.cc @@ -0,0 +1,989 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "testing/testrunner/runner_lib.h" + +#include +#include +#include +#include + +#include "gtest/gtest-spi.h" +#include "absl/container/flat_hash_map.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast_proto.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "testing/testrunner/cel_expression_source.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/coverage_index.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +ABSL_FLAG(std::string, test_cel_file_path, "", + "Path to the .cel file for testing"); + +namespace cel::test { +namespace { + +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::expr::conformance::test::TestCase; +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::CelExpressionBuilder; +using ValueProto = ::cel::expr::Value; +using ::testing::EndsWith; +using ::testing::HasSubstr; +using ::testing::Not; +using ::testing::StartsWith; + +template +T ParseTextProtoOrDie(absl::string_view text_proto) { + T result; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); + return result; +} + +int CountSubstrings(absl::string_view text, absl::string_view substr) { + int count = 0; + size_t pos = 0; + while ((pos = text.find(substr, pos)) != absl::string_view::npos) { + ++count; + ++pos; + } + return count; +} + +absl::StatusOr> CreateBasicCompiler() { + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); + CEL_RETURN_IF_ERROR( + checker_builder.AddVariable(cel::MakeVariableDecl("x", cel::IntType()))); + CEL_RETURN_IF_ERROR( + checker_builder.AddVariable(cel::MakeVariableDecl("y", cel::IntType()))); + return std::move(builder)->Build(); +} + +absl::StatusOr> CreateTestRuntime() { + CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder standard_runtime_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetTestingDescriptorPool(), {})); + return std::move(standard_runtime_builder).Build(); +} + +absl::StatusOr> +CreateTestCelExpressionBuilder() { + auto builder = google::api::expr::runtime::CreateCelExpressionBuilder(); + CEL_RETURN_IF_ERROR(google::api::expr::runtime::RegisterBuiltinFunctions( + builder->GetRegistry())); + return builder; +} + +// Creates a static, singleton instance of the basic compiler to be shared +// across tests, avoiding repeated setup costs. +const cel::Compiler& DefaultCompiler() { + static const cel::Compiler* instance = []() { + absl::StatusOr> s = CreateBasicCompiler(); + ABSL_QCHECK_OK(s.status()); + return s->release(); + }(); + return *instance; +} + +enum class RuntimeApi { kRuntime, kBuilder }; + +// Parameterized test fixture for tests that are run against both the Runtime +// and the CelExpressionBuilder backends. +class TestRunnerParamTest : public ::testing::TestWithParam { + protected: + // Helper to create the appropriate CelTestContext based on the test + // parameter. + absl::StatusOr> CreateTestContext() { + if (GetParam() == RuntimeApi::kRuntime) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + CreateTestRuntime()); + return CelTestContext::CreateFromRuntime(std::move(runtime)); + } + CEL_ASSIGN_OR_RETURN(std::unique_ptr builder, + CreateTestCelExpressionBuilder()); + return CelTestContext::CreateFromCelExpressionBuilder(std::move(builder)); + } +}; + +TEST_P(TestRunnerParamTest, BasicTestReportsSuccess) { + ASSERT_OK_AND_ASSIGN( + cel::ValidationResult validation_result, + DefaultCompiler().Compile("{'sum': x + y, 'literal': 3}")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 1 } } + } + input { + key: "y" + value { value { int64_value: 2 } } + } + output { + result_value { + map_value { + entries { + key { string_value: "literal" } + value { int64_value: 3 } + } + entries { + key { string_value: "sum" } + value { int64_value: 3 } + } + } + } + } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST_P(TestRunnerParamTest, BasicTestReportsFailure) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y == 3")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 1 } } + } + input { + key: "y" + value { value { int64_value: 2 } } + } + output { result_value { bool_value: false } } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "bool_value: true"); // expected true got false +} + +TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsSuccess) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { expr: "1 + 1" } + } + input { + key: "y" + value { expr: "10 - 7" } + } + output { result_expr: "7 - 2" } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + context->SetCompiler(std::move(compiler)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsFailure) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { expr: "1 + 1" } + } + input { + key: "y" + value { expr: "10 - 7" } + } + output { result_expr: "10" } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + context->SetCompiler(std::move(compiler)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 5"); // expected 5 got 10 +} + +TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsSuccess) { + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + input { + key: "y" + value { value { int64_value: 3 } } + } + output { result_value { int64_value: 7 } } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + context->SetCompiler(std::move(compiler)); + context->SetExpressionSource(CelExpressionSource::FromRawExpression("x - y")); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsFailure) { + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + input { + key: "y" + value { value { int64_value: 3 } } + } + output { result_value { int64_value: 100 } } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + context->SetCompiler(std::move(compiler)); + context->SetExpressionSource(CelExpressionSource::FromRawExpression("x - y")); + TestRunner test_runner(std::move(context)); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 7"); // expected 7 got 100 +} + +TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsSuccess) { + const std::string cel_file_path = absl::GetFlag(FLAGS_test_cel_file_path); + ASSERT_FALSE(cel_file_path.empty()) + << "Flag --test_cel_file_path must be set"; + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + input { + key: "y" + value { value { int64_value: 3 } } + } + output { result_value { int64_value: 7 } } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + context->SetCompiler(std::move(compiler)); + context->SetExpressionSource(CelExpressionSource::FromCelFile(cel_file_path)); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsFailure) { + const std::string cel_file_path = absl::GetFlag(FLAGS_test_cel_file_path); + ASSERT_FALSE(cel_file_path.empty()) + << "Flag --test_cel_file_path must be set"; + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + input { + key: "y" + value { value { int64_value: 3 } } + } + output { result_value { int64_value: 123 } } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + context->SetCompiler(std::move(compiler)); + context->SetExpressionSource(CelExpressionSource::FromCelFile(cel_file_path)); + TestRunner test_runner(std::move(context)); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 7"); // expected 7 got 123 +} + +TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsSucceeds) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + output { result_value { int64_value: 15 } } + )pb"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + absl::flat_hash_map bindings; + bindings["y"] = ParseTextProtoOrDie(R"pb(int64_value: 5)pb"); + context->SetCustomBindings(std::move(bindings)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsReportsFailure) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + output { result_value { int64_value: 999 } } + )pb"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr context, + CreateTestContext()); + absl::flat_hash_map bindings; + bindings["y"] = ParseTextProtoOrDie(R"pb(int64_value: 5)pb"); + context->SetCustomBindings(std::move(bindings)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 15"); // expected 15 got 999. +} + +INSTANTIATE_TEST_SUITE_P(TestRunnerTests, TestRunnerParamTest, + ::testing::Values(RuntimeApi::kRuntime, + RuntimeApi::kBuilder)); + +TEST(TestRunnerStandaloneTest, DynamicInputWithoutCompilerFails) { + const std::string expected_error = + "INVALID_ARGUMENT: A compiler must be provided to compile a raw " + "expression or .cel file."; + + EXPECT_FATAL_FAILURE( + { + // Create a compiler. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT( + cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { expr: "1 + 1" } + } + input { + key: "y" + value { value { int64_value: 2 } } + } + output { result_value { int64_value: 3 } } + )pb"); + + // Create the expression builder. + ASSERT_OK_AND_ASSIGN(auto builder, CreateTestCelExpressionBuilder()); + + // Create the TestRunner without the compiler. + std::unique_ptr context = + CelTestContext::CreateFromCelExpressionBuilder( + /*cel_expression_builder=*/std::move(builder)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + + test_runner.RunTest(test_case); + }, + expected_error); +} + +TEST(TestRunnerStandaloneTest, + RuntimeUsesRuntimePoolToResolveCustomProtoLiteral) { + // Create a custom CompilerBuilder. + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(cel::StandardCompilerLibrary()), + absl_testing::IsOk()); + cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); + ASSERT_THAT(checker_builder.AddVariable(cel::MakeVariableDecl( + "custom_var", cel::MessageType(TestAllTypes::descriptor()))), + absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(builder)->Build()); + + // Compile the expression. + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile("custom_var.single_int32 == 123")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + // Create a runtime configured with the testing descriptor pool. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + + // Define the test case. The important part is the "custom_var" input, + // which forces 'ResolveValue' to run on a custom type. This succeeds because + // the testing descriptor pool (used by CreateTestRuntime()) is configured + // to contain the TestAllTypes descriptor. + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "custom_var" + value { + value { + object_value { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { + single_int32: 123 + } + } + } + } + } + output { result_value { bool_value: true } } + )pb"); + + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST(TestRunnerStandaloneTest, RunTestFailsWhenNoExpressionSourceIsProvided) { + const std::string expected_error = + "INVALID_ARGUMENT: No expression source provided."; + + EXPECT_FATAL_FAILURE( + { + // Create a runtime. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + input { + key: "y" + value { value { int64_value: 3 } } + } + output { result_value { int64_value: 123 } } + )pb"); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateBasicCompiler()); + + // Create a TestRunner but without an expression source. + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetCompiler(std::move(compiler)); + TestRunner test_runner(std::move(context)); + test_runner.RunTest(test_case); + }, + expected_error); +} + +TEST(TestRunnerStandaloneTest, BasicTestWithErrorAssertion) { + // Compile the expression. + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + // Create a runtime. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 1 } } + } + output { + eval_error { + errors { message: "No value with name \"y\" found in Activation" } + } + } + )pb"); + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST(TestRunnerStandaloneTest, BasicTestFailsWhenExpectingErrorButGotValue) { + // Compile the expression. + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("1 + 1")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + // Create a runtime. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + output { + eval_error { + errors { message: "No value with name \"y\" found in Activation" } + } + } + )pb"); + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "Expected error but got value"); +} + +TEST(TestRunnerStandaloneTest, BasicTestWithActivationFactorySucceeds) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetActivationFactory( + [](const TestCase& test_case, + google::protobuf::Arena* arena) -> absl::StatusOr { + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(10)); + activation.InsertOrAssignValue("y", cel::IntValue(5)); + return activation; + }); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + output { result_value { int64_value: 15 } } + )pb"); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); + + // Input bindings should override values set by the activation factory. + test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 4 } } + } + output { result_value { int64_value: 9 } } + )pb"); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST(TestRunnerStandaloneTest, CustomAssertFnIsUsed) { + // Compile the expression. + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("1 + 1")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + // Create a runtime. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + // Set the output to a value that would fail the default assertion. + TestCase test_case = ParseTextProtoOrDie(R"pb( + output { result_value { int64_value: 102 } } + )pb"); + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + + context->SetAssertFn([&](const cel::Value& computed, + const TestCase& test_case, google::protobuf::Arena* arena) { + ASSERT_TRUE(computed.Is()); + EXPECT_EQ(computed.As().value(), 2); + }); + + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST(CoverageTest, RuntimeCoverage) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", cel::IntType())), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", cel::IntType())), + absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(compiler_builder)->Build()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile("x > 1 && y > 1")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 2 } } + } + input { + key: "y" + value { value { int64_value: 0 } } + } + output { result_value { bool_value: false } } + )pb"); + + CoverageIndex coverage_index; + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), + coverage_index), + absl_testing::IsOk()); + + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(checked_expr)); + TestRunner test_runner(std::move(context)); + coverage_index.Init(checked_expr); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); + + CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); + EXPECT_GT(report.nodes, 0); + EXPECT_GT(report.covered_nodes, 0); + EXPECT_EQ(report.branches, 6); + EXPECT_EQ(report.covered_boolean_outcomes, 3); + EXPECT_THAT( + report.unencountered_branches, + ::testing::ElementsAre( + HasSubstr("\nExpression ID 7 ('x > 1 && y > 1'): Never " + "evaluated to 'true'"), + HasSubstr( + "\n\t\tExpression ID 2 ('x > 1'): Never evaluated to 'false'"), + HasSubstr( + "\n\t\tExpression ID 5 ('y > 1'): Never evaluated to 'true'"))); +} + +TEST(CoverageTest, BuilderCoverage) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", cel::IntType())), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", cel::IntType())), + absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(compiler_builder)->Build()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile("x > 1 && y > 1")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 0 } } + } + input { + key: "y" + value { value { int64_value: 2 } } + } + output { result_value { bool_value: false } } + )pb"); + + CoverageIndex coverage_index; + ASSERT_OK_AND_ASSIGN(std::unique_ptr builder, + CreateTestCelExpressionBuilder()); + ASSERT_THAT(EnableCoverageInCelExpressionBuilder(*builder, coverage_index), + absl_testing::IsOk()); + + std::unique_ptr context = + CelTestContext::CreateFromCelExpressionBuilder(std::move(builder)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(checked_expr)); + TestRunner test_runner(std::move(context)); + coverage_index.Init(checked_expr); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); + + CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); + EXPECT_GT(report.nodes, 0); + EXPECT_GT(report.covered_nodes, 0); + EXPECT_EQ(report.branches, 6); + EXPECT_EQ(report.covered_boolean_outcomes, 2); + EXPECT_THAT(report.unencountered_nodes, + ::testing::UnorderedElementsAre(HasSubstr("y > 1"))); + EXPECT_THAT( + report.unencountered_branches, + ::testing::UnorderedElementsAre(HasSubstr("Never evaluated to 'true'"), + HasSubstr("Never evaluated to 'true'"))); +} + +TEST(CoverageTest, DotGraphIsGeneratedForRuntime) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", cel::IntType())), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", cel::IntType())), + absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(compiler_builder)->Build()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile("x > 1 && y > 1")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 2 } } + } + input { + key: "y" + value { value { int64_value: 0 } } + } + output { result_value { bool_value: false } } + )pb"); + + CoverageIndex coverage_index; + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), + coverage_index), + absl_testing::IsOk()); + + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(checked_expr)); + TestRunner test_runner(std::move(context)); + coverage_index.Init(checked_expr); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); + + CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); + + absl::string_view dot_graph = report.dot_graph; + + // Check for graph structure + EXPECT_THAT(dot_graph, StartsWith(kDigraphHeader)); + EXPECT_THAT(dot_graph, EndsWith("}\n")); + EXPECT_THAT(dot_graph, HasSubstr("->")); + EXPECT_THAT(dot_graph, HasSubstr("shape=record")); + + // Check for the existence of complete labels for key nodes, using the actual + // expression IDs from the build log. + EXPECT_THAT(dot_graph, HasSubstr("label=\"{<1> exprID: 7 | <2> Call Node} | " + "<3> x \\> 1 && y \\> 1\"")); + EXPECT_THAT( + dot_graph, + HasSubstr("label=\"{<1> exprID: 2 | <2> Call Node} | <3> x \\> 1\"")); + EXPECT_THAT( + dot_graph, + HasSubstr("label=\"{<1> exprID: 5 | <2> Call Node} | <3> y \\> 1\"")); + + // Check for coverage styles + EXPECT_THAT(dot_graph, HasSubstr(kCompletelyCoveredNodeStyle)); + EXPECT_THAT(dot_graph, HasSubstr(kPartiallyCoveredNodeStyle)); + EXPECT_THAT(dot_graph, Not(HasSubstr(kUncoveredNodeStyle))); +} + +TEST(CoverageTest, DotGraphIsGeneratedForComprehension) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(compiler_builder)->Build()); + + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + compiler->Compile("[1, 2, 3].all(i, i > 0)")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + // Test case expects 'true' since all elements are > 0. + TestCase test_case = ParseTextProtoOrDie(R"pb( + output { result_value { bool_value: true } } + )pb"); + + CoverageIndex coverage_index; + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), + coverage_index), + absl_testing::IsOk()); + + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(checked_expr)); + TestRunner test_runner(std::move(context)); + coverage_index.Init(checked_expr); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); + + CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); + absl::string_view dot_graph = report.dot_graph; + + // Assert that the specific kinds for comprehension nodes are present in the + // generated graph. + EXPECT_THAT(dot_graph, HasSubstr("IterRange")); + EXPECT_THAT(dot_graph, HasSubstr("AccuInit")); + EXPECT_THAT(dot_graph, HasSubstr("LoopCondition")); + EXPECT_THAT(dot_graph, HasSubstr("LoopStep")); + EXPECT_THAT(dot_graph, HasSubstr("Result")); + + // The expression is fully evaluated, so no nodes should be uncovered. + EXPECT_THAT(dot_graph, Not(HasSubstr(kUncoveredNodeStyle))); +} + +TEST(CoverageTest, PartiallyCoveredBooleanNodeIsStyledCorrectly) { + // This test is designed to kill a mutant that incorrectly styles partially + // covered boolean nodes as completely covered. It uses a short-circuiting + // expression to ensure that some boolean nodes are only evaluated one way + // (e.g., only to 'true'), making them partially covered. + ASSERT_OK_AND_ASSIGN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", cel::IntType())), + absl_testing::IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", cel::IntType())), + absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(compiler_builder)->Build()); + ASSERT_OK_AND_ASSIGN( + cel::ValidationResult validation_result, + compiler->Compile("{'sum': x + y, 'literal': 3}.sum == 3 || x == y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 1 } } + } + input { + key: "y" + value { value { int64_value: 2 } } + } + output { result_value { bool_value: true } } + )pb"); + + CoverageIndex coverage_index; + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), + coverage_index), + absl_testing::IsOk()); + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(checked_expr)); + TestRunner test_runner(std::move(context)); + coverage_index.Init(checked_expr); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); + + CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); + + // With x=1, y=2, the left side of '||' is true, so the right side ('x == y') + // is short-circuited and never evaluated. + // - The '||' node and the '==' node are partially covered (only 'true'). + // - The 'x == y' branch (and its children) are uncovered. + // - All other evaluated nodes are fully covered. + EXPECT_EQ(CountSubstrings(report.dot_graph, kPartiallyCoveredNodeStyle), 2); + EXPECT_EQ(CountSubstrings(report.dot_graph, kUncoveredNodeStyle), 3); + EXPECT_EQ(CountSubstrings(report.dot_graph, kCompletelyCoveredNodeStyle), 9); +} +} // namespace +} // namespace cel::test diff --git a/testing/testrunner/user_tests/BUILD b/testing/testrunner/user_tests/BUILD new file mode 100644 index 000000000..53cd8f716 --- /dev/null +++ b/testing/testrunner/user_tests/BUILD @@ -0,0 +1,160 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("//testing/testrunner:cel_cc_test.bzl", "cel_cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "simple_user_test", + testonly = True, + srcs = ["simple.cc"], + deps = [ + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast_proto", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:runtime_builder", + "//runtime:standard_runtime_builder_factory", + "//testing/testrunner:cel_expression_source", + "//testing/testrunner:cel_test_context", + "//testing/testrunner:cel_test_factories", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cc_library( + name = "raw_expression_user_test", + testonly = True, + srcs = ["raw_expression_test.cc"], + deps = [ + "//checker:type_checker_builder", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:runtime_builder", + "//runtime:standard_runtime_builder_factory", + "//testing/testrunner:cel_expression_source", + "//testing/testrunner:cel_test_context", + "//testing/testrunner:cel_test_factories", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cc_library( + name = "raw_expr_and_cel_file_test", + testonly = True, + srcs = ["raw_expr_and_cel_file_test.cc"], + deps = [ + "//checker:type_checker_builder", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:runtime_builder", + "//runtime:standard_runtime_builder_factory", + "//testing/testrunner:cel_test_context", + "//testing/testrunner:cel_test_factories", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cc_library( + name = "checked_expr_user_test", + testonly = True, + srcs = ["checked_expr_test.cc"], + deps = [ + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:runtime_builder", + "//runtime:standard_runtime_builder_factory", + "//testing/testrunner:cel_test_context", + "//testing/testrunner:cel_test_factories", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cel_cc_test( + name = "simple_test", + enable_coverage = True, + filegroup = "//testing/testrunner/resources", + test_data_path = "//testing/testrunner/resources", + test_suite = "simple_tests.textproto", + deps = [ + ":simple_user_test", + ], +) + +cel_cc_test( + name = "simple_test_with_custom_test_suite", + enable_coverage = True, + filegroup = "//testing/testrunner/resources", + test_data_path = "//testing/testrunner/resources", + deps = [ + ":simple_user_test", + ], +) + +cel_cc_test( + name = "raw_expression_test_with_custom_test_suite", + enable_coverage = True, + deps = [ + ":raw_expression_user_test", + ], +) + +cel_cc_test( + name = "subtraction_raw_expr_test", + cel_expr = "x - y", + is_raw_expr = True, + deps = [ + ":raw_expr_and_cel_file_test", + ], +) + +cel_cc_test( + name = "subtraction_cel_file_test", + cel_expr = "test.cel", + test_data_path = "//testing/testrunner/resources", + deps = [ + ":raw_expr_and_cel_file_test", + ], +) diff --git a/testing/testrunner/user_tests/checked_expr_test.cc b/testing/testrunner/user_tests/checked_expr_test.cc new file mode 100644 index 000000000..44e4b46ae --- /dev/null +++ b/testing/testrunner/user_tests/checked_expr_test.cc @@ -0,0 +1,82 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "internal/status_macros.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/cel_test_factories.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/text_format.h" + +namespace cel::testing { + +using ::cel::test::CelTestContext; + +template +T ParseTextProtoOrDie(absl::string_view text_proto) { + T result; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); + return result; +} + +CEL_REGISTER_TEST_SUITE_FACTORY([]() { + return ParseTextProtoOrDie(R"pb( + name: "cli_expression_tests" + description: "Tests designed for expressions passed via CLI flags." + sections: { + name: "subtraction_test" + description: "Tests subtraction of two variables." + tests: { + name: "variable_subtraction" + description: "Test that subtraction of two variables works." + input: { + key: "x" + value { value { int64_value: 10 } } + } + input { + key: "y" + value { value { int64_value: 5 } } + } + output { result_value { int64_value: 5 } } + } + } + )pb"); +}); + +CEL_REGISTER_TEST_CONTEXT_FACTORY( + []() -> absl::StatusOr> { + ABSL_LOG(INFO) << "Creating runtime-only test context for CheckedExpr"; + + // Create a runtime. + CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetTestingDescriptorPool(), {})); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + + // Create the context with the runtime, but no compiler. + // The test runner will inject the CheckedExpr source later. + return CelTestContext::CreateFromRuntime(std::move(runtime)); + }); +} // namespace cel::testing diff --git a/testing/testrunner/user_tests/raw_expr_and_cel_file_test.cc b/testing/testrunner/user_tests/raw_expr_and_cel_file_test.cc new file mode 100644 index 000000000..b5fd59396 --- /dev/null +++ b/testing/testrunner/user_tests/raw_expr_and_cel_file_test.cc @@ -0,0 +1,103 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/cel_test_factories.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/text_format.h" + +namespace cel::testing { + +using ::cel::test::CelTestContext; + +template +T ParseTextProtoOrDie(absl::string_view text_proto) { + T result; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); + return result; +} + +CEL_REGISTER_TEST_SUITE_FACTORY([]() { + return ParseTextProtoOrDie(R"pb( + name: "cli_expression_tests" + description: "Tests designed for expressions passed via CLI flags." + sections: { + name: "subtraction_test" + description: "Tests subtraction of two variables." + tests: { + name: "variable_subtraction" + description: "Test that subtraction of two variables works." + input: { + key: "x" + value { value { int64_value: 10 } } + } + input { + key: "y" + value { value { int64_value: 5 } } + } + output { result_value { int64_value: 5 } } + } + } + )pb"); +}); + +CEL_REGISTER_TEST_CONTEXT_FACTORY( + []() -> absl::StatusOr> { + ABSL_LOG(INFO) << "Creating test context for raw_expr and cel_file"; + + // Create a compiler. + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); + CEL_RETURN_IF_ERROR(checker_builder.AddVariable( + cel::MakeVariableDecl("x", cel::IntType()))); + CEL_RETURN_IF_ERROR(checker_builder.AddVariable( + cel::MakeVariableDecl("y", cel::IntType()))); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + std::move(builder)->Build()); + + // Create a runtime. + CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetTestingDescriptorPool(), {})); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetCompiler(std::move(compiler)); + return context; + }); +} // namespace cel::testing diff --git a/testing/testrunner/user_tests/raw_expression_test.cc b/testing/testrunner/user_tests/raw_expression_test.cc new file mode 100644 index 000000000..e52cc39dc --- /dev/null +++ b/testing/testrunner/user_tests/raw_expression_test.cc @@ -0,0 +1,104 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "testing/testrunner/cel_expression_source.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/cel_test_factories.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/text_format.h" + +namespace cel::testing { + +using ::cel::test::CelTestContext; + +template +T ParseTextProtoOrDie(absl::string_view text_proto) { + T result; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); + return result; +} + +CEL_REGISTER_TEST_SUITE_FACTORY([]() { + return ParseTextProtoOrDie(R"pb( + name: "raw_expression_tests" + description: "Tests for validating support for raw CEL expressions in test inputs and outputs." + sections: { + name: "raw_expression_io" + description: "A section for tests with raw CEL expressions in inputs and outputs." + tests: { + name: "eval_input_and_output" + description: "Test that a raw CEL expression can be provided as both an input and an expected output." + input: { + key: "x" + value { expr: "1 + 1" } + } + input: { + key: "y" + value { value { int64_value: 8 } } + } + output { result_expr: "5 * 2" } + } + } + )pb"); +}); + +CEL_REGISTER_TEST_CONTEXT_FACTORY( + []() -> absl::StatusOr> { + // Create a compiler. + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); + CEL_RETURN_IF_ERROR(checker_builder.AddVariable( + cel::MakeVariableDecl("x", cel::IntType()))); + CEL_RETURN_IF_ERROR(checker_builder.AddVariable( + cel::MakeVariableDecl("y", cel::IntType()))); + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + builder->Build()); + + // Create a runtime. + CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetTestingDescriptorPool(), {})); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetCompiler(std::move(compiler)); + context->SetExpressionSource( + test::CelExpressionSource::FromRawExpression("x + y")); + + return context; + }); +} // namespace cel::testing diff --git a/testing/testrunner/user_tests/simple.cc b/testing/testrunner/user_tests/simple.cc new file mode 100644 index 000000000..ba0897d94 --- /dev/null +++ b/testing/testrunner/user_tests/simple.cc @@ -0,0 +1,115 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast_proto.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "testing/testrunner/cel_expression_source.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/cel_test_factories.h" +#include "google/protobuf/text_format.h" + +namespace cel::testing { + +using ::cel::test::CelTestContext; +using ::cel::expr::CheckedExpr; + +template +T ParseTextProtoOrDie(absl::string_view text_proto) { + T result; + ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); + return result; +} + +CEL_REGISTER_TEST_SUITE_FACTORY([]() { + return ParseTextProtoOrDie(R"pb( + name: "custom_test_suite_tests" + description: "Simple tests to validate the test runner." + sections: { + name: "simple_map_operations" + description: "Tests for map operations." + tests: { + name: "literal_and_sum" + description: "Test that a map can be created and values can be accessed." + input: { + key: "x" + value { value { int64_value: 1 } } + } + input { + key: "y" + value { value { int64_value: 2 } } + } + output { result_value { bool_value: true } } + } + } + )pb"); +}); + +CEL_REGISTER_TEST_CONTEXT_FACTORY( + []() -> absl::StatusOr> { + ABSL_LOG(INFO) << "Creating test context"; + + // Create a compiler. + CEL_ASSIGN_OR_RETURN( + std::unique_ptr builder, + cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); + CEL_RETURN_IF_ERROR(checker_builder.AddVariable( + cel::MakeVariableDecl("x", cel::IntType()))); + CEL_RETURN_IF_ERROR(checker_builder.AddVariable( + cel::MakeVariableDecl("y", cel::IntType()))); + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + builder->Build()); + + // Compile the expression. + CEL_ASSIGN_OR_RETURN( + cel::ValidationResult validation_result, + compiler->Compile("{'sum': x + y, 'literal': 3}.sum == 3 || x == y")); + CheckedExpr checked_expr; + CEL_RETURN_IF_ERROR( + cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr)); + + // Create a runtime. + CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetTestingDescriptorPool(), {})); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetExpressionSource( + test::CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + return context; + }); +} // namespace cel::testing diff --git a/testutil/BUILD b/testutil/BUILD index 450474c48..782c95ca6 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -1,120 +1,114 @@ -# Description -# Test utilities for cpp CEL. +# Copyright 2021 Google LLC # -# Uses the namespace google::api::expr::testutil. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) cc_library( name = "expr_printer", srcs = ["expr_printer.cc"], hdrs = ["expr_printer.h"], deps = [ - "//common:escaping", + "//common:ast", + "//common:ast_proto", + "//common:constant", + "//common:expr", + "//internal:strings", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) -cc_library( - name = "test_data_util", - srcs = [ - "test_data_util.cc", - ], - hdrs = [ - "test_data_util.h", - ], +cc_test( + name = "expr_printer_test", + srcs = ["expr_printer_test.cc"], deps = [ - "//common:type", - "//common:value", - "//internal:cel_printer", - "//internal:proto_util", - "//internal:types", - "//protoutil:converters", - "//protoutil:type_registry", - "//v1beta1:converters", + ":expr_printer", + "//common:expr", + "//internal:testing", + "//parser", + "//parser:options", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@com_google_cel_spec//testdata:test_value_cc_proto", - "@com_google_googleapis//google/api/expr/v1beta1:eval_cc_proto", - "@com_google_googleapis//google/api/expr/v1beta1:value_cc_proto", - "@com_google_googleapis//google/rpc:code_cc_proto", - "@com_google_protobuf//:protobuf", ], ) cc_library( - name = "test_data_io", - srcs = [ - "test_data_io.cc", - ], + name = "util", + testonly = True, hdrs = [ - "test_data_io.h", + "util.h", ], + deps = ["//internal:proto_matchers"], +) + +cc_library( + name = "test_macros", + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], deps = [ - "//internal:status_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/memory", + "//common:expr", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_cel_spec//testdata:test_data_cc_proto", - "@com_google_cel_spec//testdata:test_value_cc_proto", - "@com_google_googleapis//google/rpc:code_cc_proto", - "@com_google_googleapis//google/rpc:status_cc_proto", - "@com_google_googletest//:gtest", - "@com_google_protobuf//:protobuf", - "@com_googlesource_code_re2//:re2", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) -# Usage: -# blaze build testutil:test_data_gen && -# blaze-bin/testutil/test_data_gen >> -# third_party/cel/spec/testdata/unique_values.textpb -cc_binary( - name = "test_data_gen", - srcs = [ - "test_data_gen.cc", - ], +cc_library( + name = "baseline_tests", + testonly = True, + srcs = ["baseline_tests.cc"], + hdrs = ["baseline_tests.h"], deps = [ - ":test_data_io", - ":test_data_util", - "//common:type", - "//internal:proto_util", - "@com_google_absl//absl/flags:parse", + ":expr_printer", + "//common:ast", + "//common:expr", + "//extensions/protobuf:ast_converters", "@com_google_absl//absl/strings", - "@com_google_cel_spec//testdata:test_data_cc_proto", - "@com_google_googleapis//google/rpc:code_cc_proto", - "@com_google_googleapis//google/type:money_cc_proto", - "@com_google_protobuf//:protobuf", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", ], ) cc_test( - name = "test_data_test", - srcs = ["test_data_test.cc"], - data = [ - "@com_google_cel_spec//testdata", - ], + name = "baseline_tests_test", + srcs = ["baseline_tests_test.cc"], deps = [ - ":test_data_util", - "//testutil:test_data_io", - "@com_google_cel_spec//testdata:test_data_cc_proto", - "@com_google_googleapis//google/api/expr/v1beta1:value_cc_proto", - "@com_google_googletest//:gtest_main", + ":baseline_tests", + "//common:ast", + "//internal:testing", + "@com_google_protobuf//:protobuf", ], ) -cc_library( - name = "util", - hdrs = [ - "util.h", - ], - deps = [ - "@com_google_googletest//:gtest", - "@com_google_protobuf//:protobuf", - ], +proto_library( + name = "test_json_names_proto", + srcs = ["test_json_names.proto"], ) diff --git a/testutil/baseline_tests.cc b/testutil/baseline_tests.cc new file mode 100644 index 000000000..8ce43e63d --- /dev/null +++ b/testutil/baseline_tests.cc @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/baseline_tests.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "common/ast.h" +#include "common/expr.h" +#include "extensions/protobuf/ast_converters.h" +#include "testutil/expr_printer.h" + +namespace cel::test { +namespace { + +std::string FormatReference(const cel::Reference& r) { + if (r.overload_id().empty()) { + return r.name(); + } + return absl::StrJoin(r.overload_id(), "|"); +} + +class TypeAdorner : public ExpressionAdorner { + public: + explicit TypeAdorner(const Ast& ast) : ast_(ast) {} + + std::string Adorn(const Expr& e) const override { + std::string s; + + auto t = ast_.type_map().find(e.id()); + if (t != ast_.type_map().end()) { + absl::StrAppend(&s, "~", FormatTypeSpec(t->second)); + } + if (const auto r = ast_.reference_map().find(e.id()); + r != ast_.reference_map().end()) { + absl::StrAppend(&s, "^", FormatReference(r->second)); + } + return s; + } + + std::string AdornStructField(const StructExprField& e) const override { + return ""; + } + + std::string AdornMapEntry(const MapExprEntry& e) const override { return ""; } + + private: + const Ast& ast_; +}; + +} // namespace + +std::string FormatBaselineAst(const Ast& ast) { + TypeAdorner adorner(ast); + ExprPrinter printer(adorner); + return printer.Print(ast.root_expr()); +} + +std::string FormatBaselineCheckedExpr( + const cel::expr::CheckedExpr& checked) { + auto ast = cel::extensions::CreateAstFromCheckedExpr(checked); + if (!ast.ok()) { + return ast.status().ToString(); + } + return FormatBaselineAst(**ast); +} + +} // namespace cel::test diff --git a/testutil/baseline_tests.h b/testutil/baseline_tests.h new file mode 100644 index 000000000..35d85de4c --- /dev/null +++ b/testutil/baseline_tests.h @@ -0,0 +1,60 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Utilities for baseline tests. Baseline files are textual reports in a common +// format that can be used to compare the output of each of the libraries. +// +// The protobuf ast format is a bit tricky to compare directly (e.g. +// renumberings do not change the meaning of the expression), so we use a custom +// format that compares well with simple string comparisons. +// +// Example: +// ``` +// Source: Foo(a.b) +// declare a { +// variable map(string,dyn) +// } +// declare Foo { +// function foo_string(string) -> string +// function foo_int(int) -> int +// } +// =========> +// Foo( +// a~map(string,dyn)^a.b~dyn +// )~dyn^foo_string|foo_int +// +// +// ``` +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_BASELINE_TESTS_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_BASELINE_TESTS_H_ + +#include + +#include "cel/expr/checked.pb.h" +#include "common/ast.h" + +namespace cel::test { + +// Returns a string representation of the AST that matches the baseline format +// used in tests across the CEL libraries. +std::string FormatBaselineAst(const Ast& ast); + +// Returns a string representation of the protobuf AST that matches the baseline +// format used in tests across the CEL libraries. +std::string FormatBaselineCheckedExpr( + const cel::expr::CheckedExpr& checked); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_BASELINE_TEST_H_ diff --git a/testutil/baseline_tests_test.cc b/testutil/baseline_tests_test.cc new file mode 100644 index 000000000..f4e89706c --- /dev/null +++ b/testutil/baseline_tests_test.cc @@ -0,0 +1,206 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or astied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/baseline_tests.h" + +#include +#include + +#include "common/ast.h" +#include "internal/testing.h" +#include "google/protobuf/text_format.h" + +namespace cel::test { +namespace { + +using ::cel::expr::CheckedExpr; + +TEST(FormatBaselineAst, Basic) { + Ast ast; + ast.mutable_root_expr().mutable_ident_expr().set_name("foo"); + ast.mutable_root_expr().set_id(1); + ast.mutable_type_map()[1] = TypeSpec(PrimitiveType::kInt64); + ast.mutable_reference_map()[1].set_name("foo"); + + EXPECT_EQ(FormatBaselineAst(ast), "foo~int^foo"); +} + +TEST(FormatBaselineAst, NoType) { + Ast ast; + ast.mutable_root_expr().mutable_ident_expr().set_name("foo"); + ast.mutable_root_expr().set_id(1); + ast.mutable_reference_map()[1].set_name("foo"); + + EXPECT_EQ(FormatBaselineAst(ast), "foo^foo"); +} + +TEST(FormatBaselineAst, NoReference) { + Ast ast; + ast.mutable_root_expr().mutable_ident_expr().set_name("foo"); + ast.mutable_root_expr().set_id(1); + ast.mutable_type_map()[1] = TypeSpec(PrimitiveType::kInt64); + + EXPECT_EQ(FormatBaselineAst(ast), "foo~int"); +} + +TEST(FormatBaselineAst, MutlipleReferences) { + Ast ast; + ast.mutable_root_expr().mutable_call_expr().set_function("_+_"); + ast.mutable_root_expr().set_id(1); + ast.mutable_type_map()[1] = TypeSpec(DynTypeSpec()); + ast.mutable_reference_map()[1].mutable_overload_id().push_back( + "add_timestamp_duration"); + ast.mutable_reference_map()[1].mutable_overload_id().push_back( + "add_duration_duration"); + { + auto& arg1 = ast.mutable_root_expr().mutable_call_expr().add_args(); + arg1.mutable_ident_expr().set_name("a"); + arg1.set_id(2); + ast.mutable_type_map()[2] = TypeSpec(DynTypeSpec()); + ast.mutable_reference_map()[2].set_name("a"); + } + { + auto& arg2 = ast.mutable_root_expr().mutable_call_expr().add_args(); + arg2.mutable_ident_expr().set_name("b"); + arg2.set_id(3); + ast.mutable_type_map()[3] = TypeSpec(WellKnownTypeSpec::kDuration); + ast.mutable_reference_map()[3].set_name("b"); + } + + EXPECT_EQ(FormatBaselineAst(ast), + "_+_(\n" + " a~dyn^a,\n" + " b~google.protobuf.Duration^b\n" + ")~dyn^add_timestamp_duration|add_duration_duration"); +} + +TEST(FormatBaselineCheckedExpr, MutlipleReferences) { + CheckedExpr checked; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + id: 1 + call_expr { + function: "_+_" + args { + id: 2 + ident_expr { name: "a" } + } + args { + id: 3 + ident_expr { name: "b" } + } + } + } + type_map { + key: 1 + value { dyn {} } + } + type_map { + key: 2 + value { dyn {} } + } + type_map { + key: 3 + value { well_known: DURATION } + } + reference_map { + key: 1 + value { + overload_id: "add_timestamp_duration" + overload_id: "add_duration_duration" + } + } + reference_map { + key: 2 + value { name: "a" } + } + reference_map { + key: 3 + value { name: "b" } + } + )pb", + &checked)); + + EXPECT_EQ(FormatBaselineCheckedExpr(checked), + "_+_(\n" + " a~dyn^a,\n" + " b~google.protobuf.Duration^b\n" + ")~dyn^add_timestamp_duration|add_duration_duration"); +} + +struct TestCase { + TypeSpec type; + std::string expected_string; +}; + +class FormatBaselineTypeSpecTest : public testing::TestWithParam {}; + +TEST_P(FormatBaselineTypeSpecTest, Runner) { + Ast ast; + ast.mutable_root_expr().set_id(1); + ast.mutable_root_expr().mutable_ident_expr().set_name("x"); + ast.mutable_type_map()[1] = GetParam().type; + + EXPECT_EQ(FormatBaselineAst(ast), GetParam().expected_string); +} + +INSTANTIATE_TEST_SUITE_P( + Types, FormatBaselineTypeSpecTest, + ::testing::Values( + TestCase{TypeSpec(PrimitiveType::kBool), "x~bool"}, + TestCase{TypeSpec(PrimitiveType::kInt64), "x~int"}, + TestCase{TypeSpec(PrimitiveType::kUint64), "x~uint"}, + TestCase{TypeSpec(PrimitiveType::kDouble), "x~double"}, + TestCase{TypeSpec(PrimitiveType::kString), "x~string"}, + TestCase{TypeSpec(PrimitiveType::kBytes), "x~bytes"}, + TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + "x~wrapper(bool)"}, + TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + "x~wrapper(int)"}, + TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + "x~wrapper(uint)"}, + TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + "x~wrapper(double)"}, + TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + "x~wrapper(string)"}, + TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + "x~wrapper(bytes)"}, + TestCase{TypeSpec(WellKnownTypeSpec::kAny), "x~google.protobuf.Any"}, + TestCase{TypeSpec(WellKnownTypeSpec::kDuration), + "x~google.protobuf.Duration"}, + TestCase{TypeSpec(WellKnownTypeSpec::kTimestamp), + "x~google.protobuf.Timestamp"}, + TestCase{TypeSpec(DynTypeSpec()), "x~dyn"}, + TestCase{TypeSpec(NullTypeSpec()), "x~null"}, + TestCase{TypeSpec(UnsetTypeSpec()), "x~*error*"}, + TestCase{TypeSpec(MessageTypeSpec("com.example.Type")), + "x~com.example.Type"}, + TestCase{TypeSpec(AbstractType("optional_type", + {TypeSpec(PrimitiveType::kInt64)})), + "x~optional_type(int)"}, + TestCase{TypeSpec(std::make_unique()), "x~type"}, + TestCase{TypeSpec(std::make_unique(PrimitiveType::kInt64)), + "x~type(int)"}, + TestCase{TypeSpec(ParamTypeSpec("T")), "x~T"}, + TestCase{TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kString))), + "x~map(string, string)"}, + TestCase{TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kString))), + "x~list(string)"})); + +} // namespace +} // namespace cel::test diff --git a/testutil/expr_printer.cc b/testutil/expr_printer.cc index a8b5bde9a..40dea3c33 100644 --- a/testutil/expr_printer.cc +++ b/testutil/expr_printer.cc @@ -1,218 +1,251 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "testutil/expr_printer.h" +#include +#include #include +#include "absl/base/no_destructor.h" +#include "absl/log/absl_log.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_format.h" -#include "common/escaping.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/constant.h" +#include "common/expr.h" +#include "internal/strings.h" -namespace google { -namespace api { -namespace expr { -namespace testutil { +namespace cel::test { namespace { -using ::google::api::expr::v1alpha1::Expr; - -class EmptyAdorner : public ExpressionAdorner { +class EmptyAdornerImpl : public ExpressionAdorner { public: - ~EmptyAdorner() {} - - std::string adorn(const Expr& e) const { return ""; } + std::string Adorn(const Expr& e) const override { return ""; } - std::string adorn(const Expr::CreateStruct::Entry& e) const { + std::string AdornStructField(const StructExprField& e) const override { return ""; } -}; -const EmptyAdorner the_empty_adorner; + std::string AdornMapEntry(const MapExprEntry& e) const override { return ""; } +}; -class Writer { +class StringBuilder { public: - Writer(const ExpressionAdorner& adorner) + explicit StringBuilder(const ExpressionAdorner& adorner) : adorner_(adorner), line_start_(true), indent_(0) {} - void appendExpr(const Expr& e) { - switch (e.expr_kind_case()) { - case Expr::kConstExpr: - append(formatLiteral(e.const_expr())); + std::string Print(const Expr& expr) { + AppendExpr(expr); + return s_; + } + + private: + void AppendExpr(const Expr& e) { + switch (e.kind_case()) { + case ExprKindCase::kConstant: + Append(FormatLiteral(e.const_expr())); + break; + case ExprKindCase::kIdentExpr: + Append(e.ident_expr().name()); break; - case Expr::kIdentExpr: - append(e.ident_expr().name()); + case ExprKindCase::kSelectExpr: + AppendSelect(e.select_expr()); break; - case Expr::kSelectExpr: - appendSelect(e.select_expr()); + case ExprKindCase::kCallExpr: + AppendCall(e.call_expr()); break; - case Expr::kCallExpr: - appendCall(e.call_expr()); + case ExprKindCase::kListExpr: + AppendList(e.list_expr()); break; - case Expr::kListExpr: - appendList(e.list_expr()); + case ExprKindCase::kMapExpr: + AppendMap(e.map_expr()); break; - case Expr::kStructExpr: - appendStruct(e.struct_expr()); + case ExprKindCase::kStructExpr: + AppendStruct(e.struct_expr()); break; - case Expr::kComprehensionExpr: - appendComprehension(e.comprehension_expr()); + case ExprKindCase::kComprehensionExpr: + AppendComprehension(e.comprehension_expr()); break; default: break; } - appendAdorn(e); + Append(adorner_.Adorn(e)); } - void appendSelect(const Expr::Select& sel) { - appendExpr(sel.operand()); - append("."); - append(sel.field()); + void AppendSelect(const SelectExpr& sel) { + AppendExpr(sel.operand()); + Append("."); + Append(sel.field()); if (sel.test_only()) { - append("~test-only~"); + Append("~test-only~"); } } - void appendCall(const Expr::Call& call) { + void AppendCall(const CallExpr& call) { if (call.has_target()) { - appendExpr(call.target()); + AppendExpr(call.target()); s_ += "."; } - append(call.function()); - append("("); - if (call.args_size() > 0) { - addIndent(); - appendLine(); - for (int i = 0; i < call.args_size(); ++i) { - const auto& arg = call.args(i); - if (i > 0) { - append(","); - appendLine(); - } - appendExpr(arg); + + Append(call.function()); + if (call.args().empty()) { + Append("()"); + return; + } + + Append("("); + Indent(); + AppendLine(); + for (int i = 0; i < call.args().size(); ++i) { + const auto& arg = call.args()[i]; + if (i > 0) { + Append(","); + AppendLine(); } - removeIndent(); - appendLine(); + AppendExpr(arg); } - append(")"); + AppendLine(); + Unindent(); + Append(")"); } - void appendList(const Expr::CreateList& list) { - append("["); - if (list.elements_size() > 0) { - appendLine(); - addIndent(); - for (int i = 0; i < list.elements_size(); ++i) { - const auto& elem = list.elements(i); - if (i > 0) { - append(","); - appendLine(); - } - appendExpr(elem); + void AppendList(const ListExpr& list) { + if (list.elements().empty()) { + Append("[]"); + return; + } + Append("["); + AppendLine(); + Indent(); + for (int i = 0; i < list.elements().size(); ++i) { + const auto& elem = list.elements()[i]; + if (i > 0) { + Append(","); + AppendLine(); + } + if (elem.optional()) { + Append("?"); } - removeIndent(); - appendLine(); + AppendExpr(elem.expr()); } - append("]"); + AppendLine(); + Unindent(); + Append("]"); } - void appendStruct(const Expr::CreateStruct& obj) { - if (obj.message_name().empty()) { - appendMap(obj); - } else { - appendObject(obj); + void AppendStruct(const StructExpr& obj) { + Append(obj.name()); + + if (obj.fields().empty()) { + Append("{}"); + return; } - } - void appendMap(const Expr::CreateStruct& obj) { - append("{"); - if (obj.entries_size() > 0) { - appendLine(); - addIndent(); - for (int i = 0; i < obj.entries_size(); ++i) { - const auto& entry = obj.entries(i); - if (i > 0) { - append(","); - appendLine(); - } - appendExpr(entry.map_key()); - append(":"); - appendExpr(entry.value()); - appendAdorn(entry); + Append("{"); + AppendLine(); + Indent(); + for (int i = 0; i < obj.fields().size(); ++i) { + const auto& entry = obj.fields()[i]; + if (i > 0) { + Append(","); + AppendLine(); + } + if (entry.optional()) { + Append("?"); } - removeIndent(); - appendLine(); + Append(entry.name()); + Append(":"); + AppendExpr(entry.value()); + Append(adorner_.AdornStructField(entry)); } - append("}"); + AppendLine(); + Unindent(); + Append("}"); } - void appendObject(const Expr::CreateStruct& obj) { - append(obj.message_name()); - append("{"); - if (obj.entries_size() > 0) { - appendLine(); - addIndent(); - for (int i = 0; i < obj.entries_size(); ++i) { - const auto& entry = obj.entries(i); - if (i > 0) { - append(","); - appendLine(); - } - append(entry.field_key()); - append(":"); - appendExpr(entry.value()); - appendAdorn(entry); + void AppendMap(const MapExpr& obj) { + if (obj.entries().empty()) { + Append("{}"); + return; + } + Append("{"); + AppendLine(); + Indent(); + for (int i = 0; i < obj.entries().size(); ++i) { + const auto& entry = obj.entries()[i]; + if (i > 0) { + Append(","); + AppendLine(); } - removeIndent(); - appendLine(); + if (entry.optional()) { + Append("?"); + } + AppendExpr(entry.key()); + Append(":"); + AppendExpr(entry.value()); + Append(adorner_.AdornMapEntry(entry)); } - append("}"); + AppendLine(); + Unindent(); + Append("}"); } - void appendComprehension(const Expr::Comprehension& comprehension) { - append("__comprehension__("); - addIndent(); - appendLine(); - append("// Variable"); - appendLine(); - append(comprehension.iter_var()); - append(","); - appendLine(); - append("// Target"); - appendLine(); - appendExpr(comprehension.iter_range()); - append(","); - appendLine(); - append("// Accumulator"); - appendLine(); - append(comprehension.accu_var()); - append(","); - appendLine(); - append("// Init"); - appendLine(); - appendExpr(comprehension.accu_init()); - append(","); - appendLine(); - append("// LoopCondition"); - appendLine(); - appendExpr(comprehension.loop_condition()); - append(","); - appendLine(); - append("// LoopStep"); - appendLine(); - appendExpr(comprehension.loop_step()); - append(","); - appendLine(); - append("// Result"); - appendLine(); - appendExpr(comprehension.result()); - append(")"); - removeIndent(); + void AppendComprehension(const ComprehensionExpr& comprehension) { + Append("__comprehension__("); + Indent(); + AppendLine(); + Append("// Variable"); + AppendLine(); + Append(comprehension.iter_var()); + Append(","); + AppendLine(); + Append("// Target"); + AppendLine(); + AppendExpr(comprehension.iter_range()); + Append(","); + AppendLine(); + Append("// Accumulator"); + AppendLine(); + Append(comprehension.accu_var()); + Append(","); + AppendLine(); + Append("// Init"); + AppendLine(); + AppendExpr(comprehension.accu_init()); + Append(","); + AppendLine(); + Append("// LoopCondition"); + AppendLine(); + AppendExpr(comprehension.loop_condition()); + Append(","); + AppendLine(); + Append("// LoopStep"); + AppendLine(); + AppendExpr(comprehension.loop_step()); + Append(","); + AppendLine(); + Append("// Result"); + AppendLine(); + AppendExpr(comprehension.result()); + Append(")"); + Unindent(); } - void appendAdorn(const Expr& e) { append(adorner_.adorn(e)); } - - void appendAdorn(const Expr::CreateStruct::Entry& e) { - append(adorner_.adorn(e)); - } - - void append(const std::string& s) { + void Append(const std::string& s) { if (line_start_) { line_start_ = false; for (int i = 0; i < indent_; ++i) { @@ -222,26 +255,27 @@ class Writer { s_ += s; } - void appendLine() { + void AppendLine() { s_ += "\n"; line_start_ = true; } - void addIndent() { indent_ += 1; } - - void removeIndent() { - if (indent_ > 0) { - indent_ -= 1; + void Indent() { ++indent_; } + void Unindent() { + if (indent_ >= 0) { + --indent_; + } else { + ABSL_LOG(ERROR) << "ExprPrinter indent underflow"; } } - std::string formatLiteral(const google::api::expr::v1alpha1::Constant& c) { - switch (c.constant_kind_case()) { - case google::api::expr::v1alpha1::Constant::kBoolValue: + std::string FormatLiteral(const Constant& c) { + switch (c.kind_case()) { + case ConstantKindCase::kBool: return absl::StrFormat("%s", c.bool_value() ? "true" : "false"); - case google::api::expr::v1alpha1::Constant::kBytesValue: - return absl::StrFormat("b\"%s\"", c.bytes_value()); - case google::api::expr::v1alpha1::Constant::kDoubleValue: { + case ConstantKindCase::kBytes: + return cel::internal::FormatDoubleQuotedBytesLiteral(c.bytes_value()); + case ConstantKindCase::kDouble: { std::string s = absl::StrFormat("%f", c.double_value()); // remove trailing zeros, i.e., convert 1.600000 to just 1.6 without // forcing a specific precision. There seems to be no flag to get this @@ -249,27 +283,24 @@ class Writer { auto idx = std::find_if_not(s.rbegin(), s.rend(), [](const char c) { return c == '0'; }); s.erase(idx.base(), s.end()); + if (absl::EndsWith(s, ".")) { + s += '0'; + } return s; } - case google::api::expr::v1alpha1::Constant::kInt64Value: - return absl::StrFormat("%d", c.int64_value()); - case google::api::expr::v1alpha1::Constant::kStringValue: - return parser::escapeAndQuote(c.string_value()); - case google::api::expr::v1alpha1::Constant::kUint64Value: - return absl::StrFormat("%uu", c.uint64_value()); - case google::api::expr::v1alpha1::Constant::kNullValue: + case ConstantKindCase::kInt: + return absl::StrFormat("%d", c.int_value()); + case ConstantKindCase::kString: + return cel::internal::FormatDoubleQuotedStringLiteral(c.string_value()); + case ConstantKindCase::kUint: + return absl::StrFormat("%uu", c.uint_value()); + case ConstantKindCase::kNull: return "null"; default: return "<>"; } } - std::string print(const Expr& expr) { - appendExpr(expr); - return s_; - } - - private: std::string s_; const ExpressionAdorner& adorner_; bool line_start_; @@ -278,16 +309,23 @@ class Writer { } // namespace -const ExpressionAdorner& empty_adorner() { - return the_empty_adorner; +const ExpressionAdorner& EmptyAdorner() { + static absl::NoDestructor kInstance; + return *kInstance; +} + +std::string ExprPrinter::PrintProto(const cel::expr::Expr& expr) const { + StringBuilder w(adorner_); + absl::StatusOr> ast = CreateAstFromParsedExpr(expr); + if (!ast.ok()) { + return std::string(ast.status().message()); + } + return w.Print(ast.value()->root_expr()); } -std::string ExprPrinter::print(const Expr& expr) const { - Writer w(adorner_); - return w.print(expr); +std::string ExprPrinter::Print(const Expr& expr) const { + StringBuilder w(adorner_); + return w.Print(expr); } -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel::test diff --git a/testutil/expr_printer.h b/testutil/expr_printer.h index 0fc9d7bae..6b0a8c161 100644 --- a/testutil/expr_printer.h +++ b/testutil/expr_printer.h @@ -1,39 +1,57 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_EXPR_PRINTER_H_ #define THIRD_PARTY_CEL_CPP_TESTUTIL_EXPR_PRINTER_H_ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace testutil { +#include "cel/expr/syntax.pb.h" +#include "common/expr.h" -using ::google::api::expr::v1alpha1::Expr; +namespace cel::test { +// Interface for adding additional information to an expression during +// printing. class ExpressionAdorner { public: - virtual ~ExpressionAdorner() {} - virtual std::string adorn(const Expr& e) const = 0; - virtual std::string adorn(const Expr::CreateStruct::Entry& e) const = 0; + virtual ~ExpressionAdorner() = default; + virtual std::string Adorn(const Expr& e) const = 0; + virtual std::string AdornStructField(const StructExprField& e) const = 0; + virtual std::string AdornMapEntry(const MapExprEntry& e) const = 0; }; -const ExpressionAdorner& empty_adorner(); +// Default implementation of the ExpressionAdorner which does nothing. +const ExpressionAdorner& EmptyAdorner(); +// Helper class for printing an expression AST to a human readable, but detailed +// and consistently formatted string. +// +// Note: this implementation is recursive and is not suitable for printing +// arbitrarily large expressions. class ExprPrinter { public: - ExprPrinter() : adorner_(empty_adorner()) {} - ExprPrinter(const ExpressionAdorner& adorner) : adorner_(adorner) {} - std::string print(const Expr& expr) const; + ExprPrinter() : adorner_(EmptyAdorner()) {} + explicit ExprPrinter(const ExpressionAdorner& adorner) : adorner_(adorner) {} + + std::string PrintProto(const cel::expr::Expr& expr) const; + std::string Print(const Expr& expr) const; private: const ExpressionAdorner& adorner_; }; -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google +} // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTUTIL_EXPR_PRINTER_H_ diff --git a/testutil/expr_printer_test.cc b/testutil/expr_printer_test.cc new file mode 100644 index 000000000..9b1e7ca37 --- /dev/null +++ b/testutil/expr_printer_test.cc @@ -0,0 +1,342 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/expr_printer.h" + +#include + +#include "absl/base/no_destructor.h" +#include "absl/strings/str_cat.h" +#include "common/expr.h" +#include "internal/testing.h" +#include "parser/options.h" +#include "parser/parser.h" + +namespace cel::test { +namespace { + +using ::google::api::expr::parser::Parse; + +class TestAdorner : public ExpressionAdorner { + public: + static const TestAdorner& Get() { + static absl::NoDestructor kInstance; + return *kInstance; + } + + std::string Adorn(const Expr& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornStructField(const StructExprField& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornMapEntry(const MapExprEntry& e) const override { + return absl::StrCat("#", e.id()); + } +}; + +TEST(ExprPrinterTest, Identifier) { + Expr expr; + expr.mutable_ident_expr().set_name("foo"); + expr.set_id(1); + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), ("foo#1")); +} + +TEST(ExprPrinterTest, ConstantString) { + Expr expr; + expr.mutable_const_expr().set_string_value("foo"); + expr.set_id(1); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"("foo"#1)")); +} + +TEST(ExprPrinterTest, ConstantBytes) { + Expr expr; + expr.mutable_const_expr().set_bytes_value("foo"); + expr.set_id(1); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"(b"foo"#1)")); +} + +TEST(ExprPrinterTest, ConstantInt) { + Expr expr; + expr.mutable_const_expr().set_int_value(1); + expr.set_id(1); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"(1#1)")); +} + +TEST(ExprPrinterTest, ConstantUint) { + Expr expr; + expr.mutable_const_expr().set_uint_value(1); + expr.set_id(1); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"(1u#1)")); +} + +TEST(ExprPrinterTest, ConstantDouble) { + Expr expr; + expr.mutable_const_expr().set_double_value(1.1); + expr.set_id(1); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"(1.1#1)")); +} + +TEST(ExprPrinterTest, ConstantBool) { + Expr expr; + expr.mutable_const_expr().set_bool_value(true); + expr.set_id(1); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"(true#1)")); +} + +TEST(ExprPrinterTest, Call) { + Expr expr; + expr.mutable_call_expr().set_function("foo"); + expr.set_id(1); + { + Expr& arg1 = expr.mutable_call_expr().add_args(); + arg1.mutable_const_expr().set_int_value(1); + arg1.set_id(2); + } + { + Expr& arg2 = expr.mutable_call_expr().add_args(); + arg2.mutable_const_expr().set_int_value(2); + arg2.set_id(3); + } + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"(foo( + 1#2, + 2#3 +)#1)")); +} + +TEST(ExprPrinterTest, ReceiverCall) { + Expr expr; + expr.mutable_call_expr().set_function("foo"); + expr.set_id(1); + { + Expr& target = expr.mutable_call_expr().mutable_target(); + target.mutable_const_expr().set_string_value("bar"); + target.set_id(2); + } + { + Expr& arg2 = expr.mutable_call_expr().add_args(); + arg2.mutable_const_expr().set_int_value(2); + arg2.set_id(3); + } + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"("bar"#2.foo( + 2#3 +)#1)")); +} + +TEST(ExprPrinterTest, List) { + Expr expr; + expr.set_id(1); + { + ListExprElement& arg1 = expr.mutable_list_expr().add_elements(); + arg1.set_optional(true); + arg1.mutable_expr().set_id(2); + arg1.mutable_expr().mutable_const_expr().set_int_value(1); + } + { + ListExprElement& arg2 = expr.mutable_list_expr().add_elements(); + arg2.set_optional(false); + arg2.mutable_expr().set_id(3); + arg2.mutable_expr().mutable_const_expr().set_int_value(2); + } + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"([ + ?1#2, + 2#3 +]#1)")); +} + +TEST(ExprPrinterTest, Map) { + Expr expr; + expr.set_id(1); + { + MapExprEntry& entry = expr.mutable_map_expr().add_entries(); + entry.set_id(2); + entry.set_optional(true); + entry.mutable_key().set_id(3); + entry.mutable_key().mutable_const_expr().set_string_value("k1"); + entry.mutable_value().set_id(4); + entry.mutable_value().mutable_const_expr().set_string_value("v1"); + } + { + MapExprEntry& entry = expr.mutable_map_expr().add_entries(); + entry.set_id(5); + entry.set_optional(false); + entry.mutable_key().set_id(6); + entry.mutable_key().mutable_const_expr().set_string_value("k2"); + entry.mutable_value().set_id(7); + entry.mutable_value().mutable_const_expr().set_string_value("v2"); + } + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"({ + ?"k1"#3:"v1"#4#2, + "k2"#6:"v2"#7#5 +}#1)")); +} + +TEST(ExprPrinterTest, Struct) { + Expr expr; + expr.set_id(1); + auto& struct_expr = expr.mutable_struct_expr(); + struct_expr.set_name("Foo"); + { + StructExprField& field1 = struct_expr.add_fields(); + field1.set_optional(true); + field1.set_id(2); + field1.set_name("field1"); + field1.mutable_value().set_id(3); + field1.mutable_value().mutable_const_expr().set_int_value(1); + } + { + StructExprField& field2 = struct_expr.add_fields(); + field2.set_optional(false); + field2.set_id(4); + field2.set_name("field2"); + field2.mutable_value().set_id(5); + field2.mutable_value().mutable_const_expr().set_int_value(1); + } + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), (R"(Foo{ + ?field1:1#3#2, + field2:1#5#4 +}#1)")); +} + +TEST(ExprPrinterTest, Comprehension) { + Expr expr; + expr.set_id(1); + expr.mutable_comprehension_expr().set_iter_var("x"); + expr.mutable_comprehension_expr().set_accu_var("@result"); + auto& range = expr.mutable_comprehension_expr().mutable_iter_range(); + range.set_id(2); + range.mutable_ident_expr().set_name("range"); + auto& accu_init = expr.mutable_comprehension_expr().mutable_accu_init(); + accu_init.set_id(3); + accu_init.mutable_ident_expr().set_name("accu_init"); + auto& loop_condition = + expr.mutable_comprehension_expr().mutable_loop_condition(); + loop_condition.set_id(4); + loop_condition.mutable_ident_expr().set_name("loop_condition"); + auto& loop_step = expr.mutable_comprehension_expr().mutable_loop_step(); + loop_step.set_id(5); + loop_step.mutable_ident_expr().set_name("loop_step"); + auto& result = expr.mutable_comprehension_expr().mutable_result(); + result.set_id(6); + result.mutable_ident_expr().set_name("result"); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.Print(expr), R"(__comprehension__( + // Variable + x, + // Target + range#2, + // Accumulator + @result, + // Init + accu_init#3, + // LoopCondition + loop_condition#4, + // LoopStep + loop_step#5, + // Result + result#6)#1)"); +} + +TEST(ExprPrinterTest, Proto) { + ParserOptions options; + options.enable_optional_syntax = true; + options.enable_hidden_accumulator_var = true; + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse(R"cel( + "foo".startsWith("bar") || + [1, ?2, 3].exists(x, x in {?"b": "foo"}) || + Foo{ + byte_value: b'bytes', + bool_value: false, + uint_value: 1u, + double_value: 1.1, + }.bar + )cel", + "", options)); + + ExprPrinter printer(TestAdorner::Get()); + EXPECT_EQ(printer.PrintProto(parsed_expr.expr()), + R"ast(_||_( + _||_( + "foo"#1.startsWith( + "bar"#3 + )#2, + __comprehension__( + // Variable + x, + // Target + [ + 1#5, + ?2#6, + 3#7 + ]#4, + // Accumulator + @result, + // Init + false#16, + // LoopCondition + @not_strictly_false( + !_( + @result#17 + )#18 + )#19, + // LoopStep + _||_( + @result#20, + @in( + x#10, + { + ?"b"#14:"foo"#15#13 + }#12 + )#11 + )#21, + // Result + @result#22)#23 + )#24, + Foo{ + byte_value:b"bytes"#27#26, + bool_value:false#29#28, + uint_value:1u#31#30, + double_value:1.1#33#32 + }#25.bar#34 +)#35)ast"); +} + +} // namespace +} // namespace cel::test diff --git a/testutil/test_data_gen.cc b/testutil/test_data_gen.cc deleted file mode 100644 index 1701442e2..000000000 --- a/testutil/test_data_gen.cc +++ /dev/null @@ -1,171 +0,0 @@ -#include "google/protobuf/empty.pb.h" -#include "google/rpc/code.pb.h" -#include "google/type/money.pb.h" -#include "absl/flags/parse.h" -#include "absl/strings/match.h" -#include "common/type.h" -#include "internal/proto_util.h" -#include "testutil/test_data_io.h" -#include "testutil/test_data_util.h" -#include "testdata/test_data.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace testutil { - -using testdata::TestData; - -constexpr int64_t kDoubleIntegerMax = 1L << 53; - -/** - * Constructs test data that enumerate values that are non-equivalent (but may - * be interpreted as equal) - */ -TestData UniqueValues() { - TestData values; - values.set_description("Set of unique (not equivalent) test values."); - auto add_val = ::google::protobuf::RepeatedFieldBackInserter( - values.mutable_test_values()->mutable_values()); - - // Null. - add_val = NewValue(nullptr); - - // Bool. - add_val = NewValue(true); - add_val = NewValue(false); - - // Int64 - add_val = NewValue(std::numeric_limits::min(), "min"); - add_val = NewValue(-kDoubleIntegerMax - 1, "max_double_int-1"); - add_val = NewValue(-kDoubleIntegerMax, "max_double_int"); - add_val = NewValue(-1); - add_val = NewValue(0); - add_val = NewValue(1); - add_val = NewValue(kDoubleIntegerMax, "max_double_int"); - add_val = NewValue(kDoubleIntegerMax + 1, "max_double_int+1"); - add_val = NewValue(std::numeric_limits::max(), "max"); - - // Uint64 - add_val = NewValue(0u); - add_val = NewValue(1u); - add_val = - NewValue(static_cast(kDoubleIntegerMax), "max_double_int"), - add_val = NewValue(static_cast(kDoubleIntegerMax) + 1, - "max_double_int+1"); - add_val = NewValue(std::numeric_limits::max(), "max"); - - // Double - // NAN is 'equivalent' to itself, but not equal to itself. - add_val = NewValue(NAN, "NaN"); - add_val = NewValue(-std::numeric_limits::infinity(), "-inf"), - add_val = NewValue(std::numeric_limits::min(), "min"), - add_val = - NewValue(-static_cast(kDoubleIntegerMax) - 2, "max_double_int-2"); - add_val = NewValue(-static_cast(kDoubleIntegerMax), "max_double_int"); - add_val = NewValue(-1.0); - add_val = NewValue(0.0); - add_val = NewValue(1.0); - add_val = NewValue(static_cast(kDoubleIntegerMax), "max_double_int"); - add_val = - NewValue(static_cast(kDoubleIntegerMax) + 2, "max_double_int+2"); - add_val = NewValue(std::numeric_limits::max(), "max"), - add_val = NewValue(std::numeric_limits::infinity(), "+inf"), - - // String - add_val = NewValue("", "empty"); - add_val = NewValue("hi"); - - // Bytes - add_val = NewBytesValue("", "empty"); - add_val = NewBytesValue("hi"); - - // Duration - add_val = NewValue(expr::internal::MakeGoogleApiDurationMin(), "min"); - add_val = NewValue( - expr::internal::MakeGoogleApiDurationMin() + absl::Nanoseconds(1), - "min+1"); - add_val = NewValue(absl::Nanoseconds(-1), "-1"); - add_val = NewValue(absl::Nanoseconds(0), "0"); - add_val = NewValue(absl::Nanoseconds(1), "1"); - add_val = NewValue( - expr::internal::MakeGoogleApiDurationMax() - absl::Nanoseconds(1), - "max-1"); - add_val = NewValue(expr::internal::MakeGoogleApiDurationMax(), "max"); - - // Timestmap - add_val = NewValue(expr::internal::MakeGoogleApiTimeMin(), "min"); - add_val = NewValue( - expr::internal::MakeGoogleApiTimeMin() + absl::Nanoseconds(1), "min+1"); - add_val = NewValue(absl::FromUnixNanos(-1), "-1"); - add_val = NewValue(absl::FromUnixNanos(0), "0"); - add_val = NewValue(absl::FromUnixNanos(1), "1"); - add_val = NewValue( - expr::internal::MakeGoogleApiTimeMax() - absl::Nanoseconds(1), "max-1"); - add_val = NewValue(expr::internal::MakeGoogleApiTimeMax(), "max"); - - // Message - add_val = NewValue(google::protobuf::Empty(), "empty"); - google::type::Money money; - add_val = NewValue(money, "empty"); - money.set_nanos(100); - add_val = NewValue(money, "100"); - money.set_nanos(5); - add_val = NewValue(money, "5"); - - // List - add_val = WithName(NewListValue(), "list(empty)"); - add_val = WithName(NewListValue(1), "list"); - add_val = WithName(NewListValue(true, true, false), "list"); - add_val = WithName(NewListValue(nullptr, nullptr), "list"); - add_val = WithName(NewListValue(absl::Seconds(1), absl::Seconds(2)), - "list"); - add_val = WithName(NewListValue(money, money), "list"); - add_val = WithName(NewListValue(1u), "list"); - add_val = WithName(NewListValue(1.0), "list"); - add_val = WithName(NewListValue(1.0, 2.0), "list"); - add_val = WithName(NewListValue(2.0, 1.0), "list"); - add_val = WithName(NewListValue("hi", "bye"), "list"); - add_val = WithName(NewListValue(1.0, "hi"), "list"); - add_val = WithName(NewListValue("hi", 1.0), "list"); - add_val = WithName(NewListValue(1, "hi"), "list"); - add_val = WithName(NewListValue("hi", 1), "list"); - add_val = WithName(NewListValue(1u, "hi"), "list"); - add_val = WithName(NewListValue(true, "hi"), "list"); - - // Map - add_val = WithName(NewMapValue(), "map(empty)"); - add_val = WithName(NewMapValue(1, "hi"), "map"); - add_val = WithName(NewMapValue(1u, "hi"), "map"); - add_val = WithName(NewMapValue(1.0, "hi"), "map"); - add_val = WithName(NewMapValue(true, "hi"), "map"); - add_val = WithName(NewMapValue("hi", 1), "map"); - add_val = WithName(NewMapValue("hi", 1u), "map"); - add_val = WithName(NewMapValue("hi", 1.0), "map"); - add_val = WithName(NewMapValue("hi", true), "map"); - add_val = WithName(NewMapValue("hi", "bye"), "map"); - add_val = Merge({NewMapValue("hi", "bye", "foo", "bar"), - NewMapValue("foo", "bar", "hi", "bye")}, - "multiple_orderings"); - - add_val = - NewValue(expr::common::BasicType(expr::common::BasicTypeValue::kInt)); - add_val = NewValue(expr::common::Type("google.protobuf.Duration")); - add_val = NewValue(expr::common::Type("unknown.type")); - return values; -} - -void WriteData() { - auto status = WriteTestData("unique_values", UniqueValues()); - GOOGLE_CHECK(status.code() == google::rpc::Code::OK) << status.ShortDebugString(); -} - -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google - -int main(int argc, char** argv) { - absl::ParseCommandLine(argc, argv); - google::api::expr::testutil::WriteData(); -} diff --git a/testutil/test_data_io.cc b/testutil/test_data_io.cc deleted file mode 100644 index 9e486c00d..000000000 --- a/testutil/test_data_io.cc +++ /dev/null @@ -1,176 +0,0 @@ -#include "testutil/test_data_io.h" - -#include - -#include - -#include "google/rpc/code.pb.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "google/protobuf/text_format.h" -#include "gtest/gtest.h" -#include "absl/flags/flag.h" -#include "absl/memory/memory.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "internal/status_util.h" -#include "re2/re2.h" - -ABSL_FLAG(std::string, test_data_folder, - "com_google_cel_spec/testdata/", - "The location to read test data from."); - -ABSL_FLAG(std::string, output_dir, "", - "The location to write test data to. Writes to standard out if not " - "specified."); - -ABSL_FLAG(bool, binary, false, "If binary output should be used."); - -namespace google { -namespace api { -namespace expr { -namespace testutil { - -using testdata::TestData; -using testdata::TestValue; - -namespace { - -std::unique_ptr OpenForRead( - const std::string& filename) { - int file_descriptor; - do { - file_descriptor = open(filename.c_str(), O_RDONLY); - } while (file_descriptor < 0 && errno == EINTR); - if (file_descriptor >= 0) { - auto result = - absl::make_unique(file_descriptor); - result->SetCloseOnDelete(true); - return result; - } else { - return nullptr; - } -} - -std::unique_ptr OpenForWrite( - const std::string& filename) { - int file_descriptor; - do { - file_descriptor = open(filename.c_str(), O_WRONLY | O_CREAT); - } while (file_descriptor < 0 && errno == EINTR); - if (file_descriptor >= 0) { - auto result = - absl::make_unique(file_descriptor); - result->SetCloseOnDelete(true); - return result; - } else { - std::cerr << "Could not open file: " << errno << std::endl; - return nullptr; - } -} -std::string GetTestCaseFileName(absl::string_view dir, - absl::string_view test_name, bool binary) { - return absl::StrCat(dir, test_name, binary ? kBinaryPbExt : kTextPbExt); -} - -} // namespace - -google::rpc::Status ReadPbFile(absl::string_view absolute_file_path, - google::protobuf::Message* message) { - message->Clear(); - auto in_stream = OpenForRead(std::string(absolute_file_path)); - if (in_stream == nullptr) { - return internal::NotFoundError( - absl::StrCat("File not found: ", absolute_file_path)); - } - if (absl::EndsWith(absolute_file_path, kTextPbExt)) { - if (google::protobuf::TextFormat::Parse(in_stream.get(), message)) { - return internal::OkStatus(); - } - } else { - if (message->ParseFromZeroCopyStream(in_stream.get())) { - return internal::OkStatus(); - } - } - return internal::InvalidArgumentError( - absl::StrCat("Parsing file contents failed: ", absolute_file_path)); -} - -google::rpc::Status WritePbFile(const google::protobuf::Message& message, - absl::string_view absolute_file_path) { - auto out_stream = OpenForWrite(std::string(absolute_file_path)); - if (out_stream == nullptr) { - return internal::InvalidArgumentError( - absl::StrCat("Could not open file: ", absolute_file_path)); - } - if (absl::EndsWith(absolute_file_path, kTextPbExt)) { - if (!google::protobuf::TextFormat::Print(message, out_stream.get())) { - return internal::InvalidArgumentError( - absl::StrCat("Unable to write to file: ", absolute_file_path)); - } - } else { - if (!message.SerializePartialToZeroCopyStream(out_stream.get())) { - return internal::InvalidArgumentError( - absl::StrCat("Unable to write to file: ", absolute_file_path)); - } - } - std::cout << "Wrote: " << absolute_file_path << std::endl; - return internal::OkStatus(); -} - -google::rpc::Status WriteTestData(absl::string_view test_name, - const TestData& values) { - if (absl::GetFlag(FLAGS_output_dir).empty()) { - google::protobuf::io::OstreamOutputStream os(&std::cout); - google::protobuf::TextFormat::Print(values, &os); - return internal::OkStatus(); - } - return WritePbFile( - values, GetTestCaseFileName(absl::GetFlag(FLAGS_output_dir), test_name, - absl::GetFlag(FLAGS_binary))); -} - -TestData ReadTestData(absl::string_view test_name) { - TestData data; - auto dir = absl::StrCat(std::getenv("TEST_SRCDIR"), "/", - absl::GetFlag(FLAGS_test_data_folder)); - - auto status = ReadPbFile( - GetTestCaseFileName(dir, test_name, absl::GetFlag(FLAGS_binary)), &data); - if (status.code() == google::rpc::Code::OK) { - return data; - } - - // Check for the other file, just to be nice. - if (ReadPbFile( - GetTestCaseFileName(dir, test_name, !absl::GetFlag(FLAGS_binary)), - &data) - .code() == google::rpc::Code::OK) { - return data; - } - - // Die with the error for the first file. - GOOGLE_LOG(FATAL) << status.ShortDebugString(); -} - -std::string TestDataParamName::operator()( - const ::testing::TestParamInfo>& info) - const { - std::string first = info.param.first.name(); - std::string second = info.param.second.name(); - RE2::GlobalReplace(&first, RE2("[^a-zA-Z0-9_]"), "_"); - RE2::GlobalReplace(&second, RE2("[^a-zA-Z0-9_]"), "_"); - - return absl::StrCat(info.index, "_", first, "_v_", second); -} - -std::string TestDataParamName::operator()( - const ::testing::TestParamInfo& info) const { - std::string name = info.param.name(); - RE2::GlobalReplace(&name, RE2("[^a-zA-Z0-9_]"), "_"); - return absl::StrCat(info.index, "_", name); -} - -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google diff --git a/testutil/test_data_io.h b/testutil/test_data_io.h deleted file mode 100644 index 7c81f27e1..000000000 --- a/testutil/test_data_io.h +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_DATA_IO_H_ -#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_DATA_IO_H_ - -#include "google/rpc/status.pb.h" -#include "gtest/gtest.h" -#include "absl/strings/string_view.h" -#include "testdata/test_data.pb.h" -#include "testdata/test_value.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace testutil { - -/** The file extension for a text proto file. */ -constexpr const char kTextPbExt[] = ".textpb"; -/** The file extension for a binary proto file. */ -constexpr const char kBinaryPbExt[] = ".binarypb"; - -/** - * Reads the given proto file. - * - * The file is treated as text if it ends with `kTextPbExt`, and binary - * in all other cases. - */ -google::rpc::Status ReadPbFile(absl::string_view absolute_file_path, - google::protobuf::Message* message); - -/** - * Writes the given proto file. - * - * The file is written as text if it ends with `kTextPbExt`, and binary - * in all other cases. - */ -google::rpc::Status WritePbFile(const google::protobuf::Message& message, - absl::string_view absolute_file_path); - -/** - * Writes test data under the given test_name. - * - * Output is controlled by the following flags: - * - binary: If the output should be in a proto binary format. - * - test_data_folder: The folder to write the files. - */ -google::rpc::Status WriteTestData(absl::string_view test_name, - const testdata::TestData& values); - -/** - * Reads test data for the given test. - * - * Read from the dir specified by the `test_data_folder` flag and CHECK fails - * if the test cannot be found. - */ -testdata::TestData ReadTestData(absl::string_view test_name); - -/** - * A helper class to generate friendly test names. - */ -struct TestDataParamName { - std::string operator()( - const ::testing::TestParamInfo& info) const; - std::string operator()( - const ::testing::TestParamInfo< - std::pair>& info) const; -}; - -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_DATA_IO_H_ diff --git a/testutil/test_data_test.cc b/testutil/test_data_test.cc deleted file mode 100644 index c7f481c32..000000000 --- a/testutil/test_data_test.cc +++ /dev/null @@ -1,95 +0,0 @@ -#include "google/api/expr/v1beta1/value.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "testutil/test_data_io.h" -#include "testutil/test_data_util.h" -#include "testdata/test_data.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace testutil { - -using testdata::TestValue; - -/** - * Test class for all pairs of values in "unique_values" - */ -class UniqueValuesTest - : public ::testing::TestWithParam> { - protected: - UniqueValuesTest() { v1beta1::InitValueDifferencer(&v1beta1_differ_); } - - bool Equivalent(const v1beta1::ExprValue& lhs, - const v1beta1::ExprValue& rhs) { - return v1beta1_differ_.Compare(lhs, rhs); - } - - void ExpectEquivalent(bool expected, const TestValue& lhs, - const TestValue& rhs) { - for (const auto& lhs_value : lhs.v1beta1()) { - SCOPED_TRACE(lhs_value.ShortDebugString()); - for (const auto& rhs_value : rhs.v1beta1()) { - SCOPED_TRACE(rhs_value.ShortDebugString()); - EXPECT_EQ(expected, Equivalent(lhs_value, rhs_value)); - } - } - } - - private: - google::protobuf::util::MessageDifferencer v1beta1_differ_; -}; - -/** - * Tests that the values in "unqiue_values" are not equivalent to each other. - */ -TEST_P(UniqueValuesTest, NotEquivalent) { - ExpectEquivalent(false, GetParam().first, GetParam().second); -} - -INSTANTIATE_TEST_SUITE_P( - UniqueValues, UniqueValuesTest, - ::testing::ValuesIn(AllPairs(ReadTestData("unique_values").test_values())), - TestDataParamName()); - -/** - * Test class for TestValue invariants. - */ -class TestValueTest : public ::testing::TestWithParam { - protected: - TestValueTest() { v1beta1::InitValueDifferencer(&v1beta1_differ_); } - - ::testing::AssertionResult Equivalent(const v1beta1::ExprValue& lhs, - const v1beta1::ExprValue& rhs) { - std::string diff; - v1beta1_differ_.ReportDifferencesToString(&diff); - if (v1beta1_differ_.Compare(lhs, rhs)) { - return ::testing::AssertionSuccess(); - } - return ::testing::AssertionFailure() << diff; - } - - private: - google::protobuf::util::MessageDifferencer v1beta1_differ_; -}; - -/** - * Tests that all values within a given TestValue are equivalent to each other. - */ -TEST_P(TestValueTest, Equivalent) { - const auto& first = GetParam().v1beta1(0); - for (const auto& value : GetParam().v1beta1()) { - SCOPED_TRACE(value.ShortDebugString()); - EXPECT_TRUE(Equivalent(first, value)); - } -} - -INSTANTIATE_TEST_SUITE_P( - UniqueValues, TestValueTest, - ::testing::ValuesIn(ReadTestData("unique_values").test_values().values()), - TestDataParamName()); - -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google diff --git a/testutil/test_data_util.cc b/testutil/test_data_util.cc deleted file mode 100644 index 3a1a10928..000000000 --- a/testutil/test_data_util.cc +++ /dev/null @@ -1,599 +0,0 @@ -#include "testutil/test_data_util.h" - -#include "google/protobuf/any.pb.h" -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/wrappers.pb.h" -#include "google/rpc/code.pb.h" -#include "absl/strings/str_cat.h" -#include "common/value.h" -#include "internal/cel_printer.h" -#include "internal/proto_util.h" -#include "protoutil/converters.h" -#include "protoutil/type_registry.h" -#include "v1beta1/converters.h" - -namespace google { -namespace api { -namespace expr { - -using internal::CelPrinter; -using protoutil::TypeRegistry; -using testdata::TestProtoValue; -using testdata::TestValue; - -namespace v1beta1 { - -void InitValueDifferencer(google::protobuf::util::MessageDifferencer* differencer) { - static google::protobuf::util::DefaultFieldComparator* field_comparator = []() { - auto* result = new google::protobuf::util::DefaultFieldComparator(); - result->set_float_comparison(google::protobuf::util::DefaultFieldComparator::EXACT); - result->set_treat_nan_as_equal(true); - return result; - }(); - - auto map_entry_field = v1beta1::Value::descriptor() - ->FindFieldByName("map_value") - ->message_type() - ->FindFieldByName("entries"); - auto key_field = map_entry_field->message_type()->FindFieldByName("key"); - differencer->TreatAsMap(map_entry_field, key_field); - differencer->TreatAsSet( - v1beta1::ErrorSet::descriptor()->FindFieldByName("errors")); - differencer->TreatAsSet( - google::rpc::Status::descriptor()->FindFieldByName("details")); - differencer->TreatAsSet( - v1beta1::UnknownSet::descriptor()->FindFieldByName("exprs")); - differencer->set_field_comparator(field_comparator); -} - -} // namespace v1beta1 - -namespace testutil { - -namespace { - -const TypeRegistry* GetRegistry() { - static const TypeRegistry* reg = []() { - auto* reg = new TypeRegistry(); - protoutil::RegisterConvertersWith(reg); - return reg; - }(); - return reg; -} - -TestProtoValue* AddProto(TestValue* value, absl::string_view field) { - TestProtoValue* result = value->add_proto(); - result->set_value_field_name(std::string(field)); - return result; -} - -TestProtoValue* FindProto(TestValue* value, absl::string_view field) { - for (auto& proto : *value->mutable_proto()) { - if (proto.value_field_name() == field) { - return &proto; - } - } - return nullptr; -} - -const TestProtoValue* FindProto(const TestValue* value, - absl::string_view field) { - for (auto& proto : value->proto()) { - if (proto.value_field_name() == field) { - return &proto; - } - } - return nullptr; -} - -void InitAll(TestValue* value, absl::string_view prefix) { - for (int i = 0; i < TestProtoValue::descriptor()->field_count(); ++i) { - absl::string_view field_name = - TestProtoValue::descriptor()->field(i)->name(); - if (absl::StartsWith(field_name, prefix)) { - AddProto(value, field_name); - } - } -} - -void RemoveProto(TestValue* value, absl::string_view field) { - for (auto itr = value->mutable_proto()->begin(); - itr != value->mutable_proto()->end();) { - if (itr->value_field_name() == field) { - itr = value->mutable_proto()->erase(itr); - } else { - ++itr; - } - } -} - -template -bool rep_as(F value) { - if (value < std::numeric_limits::min() || - value > std::numeric_limits::max()) { - return false; - } - - return static_cast(static_cast(value)) == value; -} - -template -std::string MakeName(absl::string_view type, T&& value, - absl::string_view name = "") { - if (name.empty()) { - return absl::StrCat(type, "(", value, ")"); - } - return absl::StrCat(type, "(", name, ")"); -} - -template -TestValue MakeValue(absl::string_view type, T&& value, - absl::string_view name = "") { - TestValue result; - result.set_name(MakeName(type, value, name)); - result.add_expr(CelPrinter()(value)); - return result; -} - -} // namespace - -std::vector> AllPairs( - const testdata::TestValues& value_cases) { - std::vector> result; - for (int i = 0; i < value_cases.values().size(); ++i) { - for (int j = i + 1; j < value_cases.values().size(); ++j) { - result.emplace_back(value_cases.values(i), value_cases.values(j)); - } - } - - return result; -} - -TestValue NewValue(std::nullptr_t value) { - TestValue result; - result.set_name("null"); - - result.add_expr("null"); - - // v1beta1 values. - result.add_v1beta1()->mutable_value()->set_null_value( - google::protobuf::NullValue::NULL_VALUE); - - // proto values - // An unset well-known types. - AddProto(&result, "single_any"); - AddProto(&result, "single_duration"); - AddProto(&result, "single_timestamp"); - AddProto(&result, "single_null"); - InitAll(&result, "wrapped_"); - - // json_values - google::protobuf::Value json_null; - json_null.set_null_value(google::protobuf::NullValue::NULL_VALUE); - *AddProto(&result, "single_value")->mutable_single_value() = json_null; - AddProto(&result, "single_any")->mutable_single_any()->PackFrom(json_null); - - return result; -} - -TestValue NewValue(bool value) { - TestValue result = MakeValue("bool", value, CelPrinter()(value)); - - // v1beta1 values. - result.add_v1beta1()->mutable_value()->set_bool_value(value); - - // proto values. - AddProto(&result, "single_bool")->set_single_bool(value); - - // A json values. - google::protobuf::Value json_bool; - json_bool.set_bool_value(value); - *AddProto(&result, "single_value")->mutable_single_value() = json_bool; - AddProto(&result, "single_any")->mutable_single_any()->PackFrom(json_bool); - - // wrapped values. - google::protobuf::BoolValue wrapped_bool; - wrapped_bool.set_value(value); - *AddProto(&result, "wrapped_bool")->mutable_wrapped_bool() = wrapped_bool; - AddProto(&result, "single_any")->mutable_single_any()->PackFrom(wrapped_bool); - return result; -} - -TestValue NewValue(double value, absl::string_view name) { - TestValue result = MakeValue("double", value, name); - - // v1beta1 values. - result.add_v1beta1()->mutable_value()->set_double_value(value); - - // proto values. - AddProto(&result, "single_double")->set_single_double(value); - if (rep_as(value)) { - AddProto(&result, "single_float")->set_single_float(value); - } - - // josn values. - google::protobuf::Value json_double; - json_double.set_number_value(value); - // A json bool - *AddProto(&result, "single_value")->mutable_single_value() = json_double; - // A json bool in an any. - AddProto(&result, "single_any")->mutable_single_any()->PackFrom(json_double); - - // wrapped values. - google::protobuf::DoubleValue wrapped_double; - wrapped_double.set_value(value); - *AddProto(&result, "wrapped_double")->mutable_wrapped_double() = - wrapped_double; - AddProto(&result, "single_any") - ->mutable_single_any() - ->PackFrom(wrapped_double); - if (rep_as(value)) { - google::protobuf::FloatValue wrapped_float; - wrapped_float.set_value(value); - *AddProto(&result, "wrapped_float")->mutable_wrapped_float() = - wrapped_float; - AddProto(&result, "single_any") - ->mutable_single_any() - ->PackFrom(wrapped_double); - } - return result; -} - -TestValue NewValue(int64_t value, absl::string_view name) { - TestValue result = MakeValue("int", value, name); - - // v1beta1 value. - result.add_v1beta1()->mutable_value()->set_int64_value(value); - - // proto values. - AddProto(&result, "single_int64")->set_single_int64(value); - AddProto(&result, "single_sint64")->set_single_sint64(value); - AddProto(&result, "single_sfixed64")->set_single_sfixed64(value); - if (rep_as(value)) { - AddProto(&result, "single_int32")->set_single_int32(value); - AddProto(&result, "single_sint32")->set_single_sint32(value); - AddProto(&result, "single_sfixed32")->set_single_sfixed32(value); - AddProto(&result, "wrapped_int32") - ->mutable_wrapped_int32() - ->set_value(value); - } - - // wrapped values - AddProto(&result, "wrapped_int64")->mutable_wrapped_int64()->set_value(value); - if (rep_as(value)) { - AddProto(&result, "wrapped_int32") - ->mutable_wrapped_int32() - ->set_value(value); - } - - return result; -} - -TestValue NewValue(uint64_t value, absl::string_view name) { - TestValue result = MakeValue("uint", value, name); - result.add_v1beta1()->mutable_value()->set_uint64_value(value); - - // proto values. - AddProto(&result, "single_uint64")->set_single_uint64(value); - AddProto(&result, "single_fixed64")->set_single_fixed64(value); - if (rep_as(value)) { - AddProto(&result, "single_uint32")->set_single_uint32(value); - AddProto(&result, "single_fixed32")->set_single_fixed32(value); - } - - // wrapped values. - google::protobuf::UInt64Value wrapped_uint64; - wrapped_uint64.set_value(value); - *AddProto(&result, "wrapped_uint64")->mutable_wrapped_uint64() = - wrapped_uint64; - AddProto(&result, "single_any") - ->mutable_single_any() - ->PackFrom(wrapped_uint64); - if (rep_as(value)) { - google::protobuf::UInt32Value wrapped_uint32; - wrapped_uint32.set_value(value); - *AddProto(&result, "wrapped_uint32")->mutable_wrapped_uint32() = - wrapped_uint32; - AddProto(&result, "single_any") - ->mutable_single_any() - ->PackFrom(wrapped_uint32); - } - - return result; -} - -TestValue NewValue(absl::string_view value, absl::string_view name) { - TestValue result = MakeValue("string", value, name); - - // v1beta1 values. - result.add_v1beta1()->mutable_value()->set_string_value(std::string(value)); - - // proto values - AddProto(&result, "single_string")->set_single_string(std::string(value)); - - // json values. - google::protobuf::Value json_string; - json_string.set_string_value(std::string(value)); - *AddProto(&result, "single_value")->mutable_single_value() = json_string; - AddProto(&result, "single_any")->mutable_single_any()->PackFrom(json_string); - - // wrapped values. - google::protobuf::StringValue wrapped_string; - wrapped_string.set_value(std::string(value)); - *AddProto(&result, "wrapped_string")->mutable_wrapped_string() = - wrapped_string; - AddProto(&result, "single_any") - ->mutable_single_any() - ->PackFrom(wrapped_string); - - return result; -} - -TestValue NewValue(const char* value, absl::string_view name) { - return NewValue(absl::string_view(value), name); -} - -TestValue NewBytesValue(absl::string_view value, absl::string_view name) { - TestValue result = MakeValue("bytes", value, name); - // v1beta1 values. - result.add_v1beta1()->mutable_value()->set_bytes_value(std::string(value)); - - // proto values - AddProto(&result, "single_bytes")->set_single_bytes(std::string(value)); - - // wrapped values. - google::protobuf::BytesValue wrapped_bytes; - wrapped_bytes.set_value(std::string(value)); - *AddProto(&result, "wrapped_bytes")->mutable_wrapped_bytes() = wrapped_bytes; - AddProto(&result, "single_any") - ->mutable_single_any() - ->PackFrom(wrapped_bytes); - - return result; -} - -TestValue NewValue(const google::protobuf::Message& value, absl::string_view name) { - TestValue result; - result.set_name(MakeName(value.GetTypeName(), name, name)); - - result.add_expr(GetRegistry()->ValueFor(&value).ToString()); - - // v1beta1 values - result.add_v1beta1()->mutable_value()->mutable_object_value()->PackFrom( - value); - - // proto values. - AddProto(&result, "single_any")->mutable_single_any()->PackFrom(value); - - return result; -} - -TestValue NewValue(absl::Duration value, absl::string_view name) { - google::protobuf::Duration duration; - auto status = expr::internal::EncodeDuration(value, &duration); - assert(status.code() == google::rpc::Code::OK); - auto result = NewValue(duration, name); - *AddProto(&result, "single_duration")->mutable_single_duration() = duration; - return result; -} - -TestValue NewValue(absl::Time value, absl::string_view name) { - google::protobuf::Timestamp timestamp; - auto status = expr::internal::EncodeTime(value, ×tamp); - assert(status.code() == google::rpc::Code::OK); - auto result = NewValue(timestamp, name); - *AddProto(&result, "single_timestamp")->mutable_single_timestamp() = - timestamp; - return result; -} - -testdata::TestValue NewValue(const common::Type& type) { - TestValue result; - result.add_expr(type.ToString()); - - v1beta1::ValueTo(common::Value::FromType(type), result.add_v1beta1()); - return result; -} - -namespace internal { - -void AppendToList(TestValue* list_value) {} - -#define SINGLE_TO_REPEATED(name) \ - if (proto->value_field_name() == "repeated_" #name) { \ - if (auto* value = FindProto(&test_value, "single_" #name)) { \ - proto->add_repeated_##name(value->single_##name()); \ - return true; \ - } \ - } - -#define SINGLE_TO_REPEATED_EQ(name) \ - if (proto->value_field_name() == "repeated_" #name) { \ - if (auto* value = FindProto(&test_value, "single_" #name)) { \ - if (value->has_single_##name()) { \ - *proto->add_repeated_##name() = value->single_##name(); \ - return true; \ - } \ - } \ - } - -bool AppendToProtoList(const TestValue& test_value, TestProtoValue* proto) { - if (proto->value_field_name() == "single_list") { - if (auto* value = FindProto(&test_value, "single_value")) { - *proto->mutable_single_list()->add_values() = value->single_value(); - return true; - } - } - SINGLE_TO_REPEATED(bool); - SINGLE_TO_REPEATED(bytes); - SINGLE_TO_REPEATED(string); // no transform - SINGLE_TO_REPEATED(double); - SINGLE_TO_REPEATED(float); - SINGLE_TO_REPEATED(int32); // no transform - SINGLE_TO_REPEATED(int64); // no transform - SINGLE_TO_REPEATED(uint32); // no transform - SINGLE_TO_REPEATED(uint64); // no transform - SINGLE_TO_REPEATED(sint32); - SINGLE_TO_REPEATED(sint64); - SINGLE_TO_REPEATED(fixed32); - SINGLE_TO_REPEATED(fixed64); - SINGLE_TO_REPEATED(sfixed32); - SINGLE_TO_REPEATED(sfixed64); - SINGLE_TO_REPEATED(nested_enum); - SINGLE_TO_REPEATED(null); - SINGLE_TO_REPEATED_EQ(nested_message); - SINGLE_TO_REPEATED_EQ(duration); - SINGLE_TO_REPEATED_EQ(timestamp); - SINGLE_TO_REPEATED_EQ(any); - SINGLE_TO_REPEATED_EQ(struct); - SINGLE_TO_REPEATED_EQ(value); - SINGLE_TO_REPEATED_EQ(list); - return false; -} - -#undef SINGLE_TO_REPEATED -#undef SINGLE_TO_REPEATED_EQ - -void AppendToList(const TestValue& test_value, TestValue* list_value) { - for (auto& v1beta1 : *list_value->mutable_v1beta1()) { - *v1beta1.mutable_value()->mutable_list_value()->add_values() = - test_value.v1beta1(0).value(); - } - - for (auto& expr : *list_value->mutable_expr()) { - absl::StrAppend(&expr, test_value.expr(0), ", "); - } - - for (auto itr = list_value->mutable_proto()->begin(); - itr != list_value->mutable_proto()->end();) { - if (!AppendToProtoList(test_value, &(*itr))) { - itr = list_value->mutable_proto()->erase(itr); - } else { - ++itr; - } - } -} - -#define SINGLES_TO_MAP(key_name, value_name) \ - if (proto->value_field_name() == "map_" #key_name "_" #value_name) { \ - auto* key_value = FindProto(&key, "single_" #key_name); \ - auto* value_value = FindProto(&value, "single_" #value_name); \ - if (key_value && value_value) { \ - (*proto->mutable_map_##key_name##_##value_name()) \ - [key_value->single_##key_name()] = \ - value_value->single_##value_name(); \ - return true; \ - } \ - } - -bool AddToProtoMap(const TestValue& key, const TestValue& value, - TestProtoValue* proto) { - if (proto->value_field_name() == "single_struct") { - auto* string_key = FindProto(&key, "single_string"); - auto* value_value = FindProto(&value, "single_value"); - if (string_key && value_value) { - (*proto->mutable_single_struct() - ->mutable_fields())[string_key->single_string()] = - value_value->single_value(); - return true; - } - } - - SINGLES_TO_MAP(int64, string); // no transform - SINGLES_TO_MAP(uint64, string); // no transform - SINGLES_TO_MAP(bool, string); // no transform - SINGLES_TO_MAP(string, int64); // no transform - SINGLES_TO_MAP(string, uint64); // no transform - SINGLES_TO_MAP(string, bool); // no transform - SINGLES_TO_MAP(string, string); // no transform - return false; -} - -#undef SINGLES_TO_MAP - -void AddToMap(TestValue* map_value) {} - -void AddToMap(const TestValue& key, const TestValue& value, - TestValue* map_value) { - for (auto& v1beta1 : *map_value->mutable_v1beta1()) { - auto& entry = *v1beta1.mutable_value()->mutable_map_value()->add_entries(); - *entry.mutable_key() = key.v1beta1(0).value(); - *entry.mutable_value() = value.v1beta1(0).value(); - } - - for (auto& expr : *map_value->mutable_expr()) { - absl::StrAppend(&expr, key.expr(0), ": ", value.expr(0), ", "); - } - - for (auto itr = map_value->mutable_proto()->begin(); - itr != map_value->mutable_proto()->end();) { - if (!AddToProtoMap(key, value, &(*itr))) { - itr = map_value->mutable_proto()->erase(itr); - } else { - ++itr; - } - } -} - -void StartList(testdata::TestValue* list_value) { - list_value->add_v1beta1()->mutable_value()->mutable_list_value(); - list_value->add_expr("["); - AddProto(list_value, "single_list")->mutable_single_list(); - InitAll(list_value, "repeated_"); -} - -void EndList(testdata::TestValue* list_value) { - for (auto& expr : *list_value->mutable_expr()) { - if (absl::EndsWith(expr, ", ")) { - expr.erase(expr.size() - 2); - } - absl::StrAppend(&expr, "]"); - } -} - -void StartMap(testdata::TestValue* map_value) { - map_value->add_v1beta1()->mutable_value()->mutable_map_value(); - map_value->add_expr("{"); - AddProto(map_value, "single_struct")->mutable_single_struct(); - InitAll(map_value, "map_"); -} - -void EndMap(testdata::TestValue* map_value) { - for (auto& expr : *map_value->mutable_expr()) { - if (absl::EndsWith(expr, ", ")) { - expr.erase(expr.size() - 2); - } - absl::StrAppend(&expr, "}"); - } -} - -} // namespace internal - -TestValue WithName(TestValue value, absl::string_view name) { - value.set_name(std::string(name)); - return value; -} - -TestValue Merge(const absl::Span& values, - absl::string_view name) { - TestValue result; - for (const auto& value : values) { - if (!value.name().empty()) { - result.set_name(std::string(value.name())); - } - result.mutable_v1beta1()->MergeFrom(value.v1beta1()); - result.mutable_expr()->MergeFrom(value.expr()); - result.mutable_proto()->MergeFrom(value.proto()); - } - if (!name.empty()) { - result.set_name(std::string(name)); - } - return result; -} - -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google diff --git a/testutil/test_data_util.h b/testutil/test_data_util.h deleted file mode 100644 index 8b6e20155..000000000 --- a/testutil/test_data_util.h +++ /dev/null @@ -1,165 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_DATA_UTIL_H_ -#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_DATA_UTIL_H_ - -#include "google/api/expr/v1beta1/eval.pb.h" -#include "google/api/expr/v1beta1/value.pb.h" -#include "google/protobuf/message.h" -#include "google/protobuf/util/message_differencer.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "common/type.h" -#include "internal/types.h" -#include "testutil/test_data_util.h" -#include "testdata/test_value.pb.h" - -namespace google { -namespace api { -namespace expr { - -namespace v1beta1 { - -/** - * Initializes a message differencer to support value equivelance checks - * for v1beta1::Value. - */ -void InitValueDifferencer(google::protobuf::util::MessageDifferencer* differencer); - -} // namespace v1beta1 - -namespace testutil { - -/** - * Returns the set of all pairs of different test values. - */ -std::vector> AllPairs( - const testdata::TestValues& value_cases); - -/** Converts a nullptr to a TestValue. */ -testdata::TestValue NewValue(std::nullptr_t value); -/** Converts a bool to a TestValue. */ -testdata::TestValue NewValue(bool value); -/** Converts a double to a TestValue. */ -testdata::TestValue NewValue(double value, absl::string_view name = ""); - -/** Converts a int64_t to a TestValue. */ -testdata::TestValue NewValue(int64_t value, absl::string_view name = ""); - -/** Converts a unit64_t to a TestValue. */ -testdata::TestValue NewValue(uint64_t value, absl::string_view name = ""); - -/** Converts a string_view to a TestValue. */ -testdata::TestValue NewValue(absl::string_view value, - absl::string_view name = ""); -/** Converts a string literal to a TestValue. */ -testdata::TestValue NewValue(const char* value, absl::string_view name = ""); -/** Converts a nullptr to a TestValue. */ -testdata::TestValue NewValue(const google::protobuf::Message& value, - absl::string_view name = ""); -/** Converts a duration to a TestValue. */ -testdata::TestValue NewValue(absl::Duration value, absl::string_view name = ""); -/** Converts a time to a TestValue. */ -testdata::TestValue NewValue(absl::Time value, absl::string_view name = ""); - -/** Converts a string_view of bytes to a TestValue. */ -testdata::TestValue NewBytesValue(absl::string_view value, - absl::string_view name = ""); - -testdata::TestValue NewValue(const expr::common::Type& type); - -// Helpers to remove overload ambiguity. -template -expr::internal::specialize_ift, testdata::TestValue> -NewValue(T&& value, absl::string_view name = "") { - return NewValue(static_cast(value), name); -} -template -expr::internal::specialize_ift, testdata::TestValue> -NewValue(T&& value, absl::string_view name = "") { - return NewValue(static_cast(value), name); -} -template -expr::internal::specialize_ift, testdata::TestValue> -NewValue(T&& value, absl::string_view name = "") { - return NewValue(static_cast(value), name); -} - -namespace internal { - -void AppendToList(const testdata::TestValue& test_value, - testdata::TestValue* list_value); - -void AppendToList(testdata::TestValue* list_value); -template -void AppendToList(testdata::TestValue* list_value, T&& next, - Args&&... remaining) { - AppendToList(NewValue(std::forward(next)), list_value); - AppendToList(list_value, std::forward(remaining)...); -} - -void AddToMap(const testdata::TestValue& key, const testdata::TestValue& value, - testdata::TestValue* map_value); - -void AddToMap(testdata::TestValue* map_value); - -template -void AddToMap(testdata::TestValue* map_value, K&& next_key, V&& next_value, - Args&&... remaining) { - AddToMap(NewValue(std::forward(next_key)), - NewValue(std::forward(next_value)), map_value); - AddToMap(map_value, std::forward(remaining)...); -} - -void StartList(testdata::TestValue* list_value); -void EndList(testdata::TestValue* list_value); -void StartMap(testdata::TestValue* map_value); -void EndMap(testdata::TestValue* map_value); - -} // namespace internal - -/** - * Converts arguments into a list TestValue. For example: - * - * TestValue list = NewListValue("elem1", "elem2", 3, 4.0, ...); - */ -template -testdata::TestValue NewListValue(Args&&... values) { - testdata::TestValue value; - value.set_name("list"); - - internal::StartList(&value); - internal::AppendToList(&value, std::forward(values)...); - internal::EndList(&value); - return value; -} - -/** - * Converts arguments into a map TestValue. For example: - * - * TestValue map = NewMapValue("key1", "value1", 2, 2.0, ...); - */ -template -testdata::TestValue NewMapValue(Args&&... values) { - testdata::TestValue value; - value.set_name("map"); - internal::StartMap(&value); - internal::AddToMap(&value, std::forward(values)...); - internal::EndMap(&value); - return value; -} - -/** Creates a new test value with the given name. */ -testdata::TestValue WithName(testdata::TestValue value, absl::string_view name); - -/** Merges to equivalent test values into a single value. */ -testdata::TestValue Merge(const absl::Span& values, - absl::string_view name = ""); - -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_DATA_UTIL_H_ diff --git a/testutil/test_json_names.proto b/testutil/test_json_names.proto new file mode 100644 index 000000000..a9551085b --- /dev/null +++ b/testutil/test_json_names.proto @@ -0,0 +1,31 @@ +edition = "2024"; + +package cel.cpp.testutil; + +option features.enforce_naming_style = STYLE_LEGACY; + +// This proto tests json_name options +message TestJsonNames { + int32 int32_snake_case_json_name = 1 + [json_name = "int32_snake_case_json_name"]; + int64 int64_camel_case_json_name = 2 [json_name = "int64CamelCaseJsonName"]; + uint32 uint32_default_json_name = 3; + uint64 uint64_custom_json_name = 4 [json_name = "uint64-custom-json-name"]; + + // Collides with normal field name. + string string_json_name_shadows = 5 [json_name = "single_string"]; + string single_string = 6; + + // protoc should fail on cases like these + // double double_json_shadow_default = 7 [json_name = "doubleJsonDefault"] + // double double_json_default = 8; + // double double_json_swapped_a = 7 [json_name = "double_json_swapped_b"]; + // double double_json_swapped_b = 8 [json_name = "double_json_swapped_a"]; + + extensions 100 to 199; +} + +extend TestJsonNames { + int32 int32_snake_case_ext = 100; + int64 int64CamelCaseExt = 101; +} diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc new file mode 100644 index 000000000..19e9a4844 --- /dev/null +++ b/testutil/test_macros.cc @@ -0,0 +1,173 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/test_macros.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +namespace { + +bool IsCelNamespace(const Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == "cel"; +} + +std::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& bindings_arg = args[0]; + if (!bindings_arg.has_list_expr()) { + return factory.ReportErrorAt( + bindings_arg, "cel.block requires the first arg to be a list literal"); + } + return factory.NewCall("cel.@block", args); +} + +std::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@index", index)); +} + +std::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.iterVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.iterVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +std::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.accuVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.accuVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +Macro MakeCelBlockMacro() { + auto macro_or_status = Macro::Receiver("block", 2, CelBlockMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIndexMacro() { + auto macro_or_status = Macro::Receiver("index", 1, CelIndexMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIterVarMacro() { + auto macro_or_status = Macro::Receiver("iterVar", 2, CelIterVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelAccuVarMacro() { + auto macro_or_status = Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +} // namespace + +const Macro& CelBlockMacro() { + static const absl::NoDestructor macro(MakeCelBlockMacro()); + return *macro; +} + +const Macro& CelIndexMacro() { + static const absl::NoDestructor macro(MakeCelIndexMacro()); + return *macro; +} + +const Macro& CelIterVarMacro() { + static const absl::NoDestructor macro(MakeCelIterVarMacro()); + return *macro; +} + +const Macro& CelAccuVarMacro() { + static const absl::NoDestructor macro(MakeCelAccuVarMacro()); + return *macro; +} + +absl::Status RegisterTestMacros(MacroRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelBlockMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIndexMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIterVarMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelAccuVarMacro())); + return absl::OkStatus(); +} + +} // namespace cel::test diff --git a/testutil/test_macros.h b/testutil/test_macros.h new file mode 100644 index 000000000..cad897999 --- /dev/null +++ b/testutil/test_macros.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +const Macro& CelBlockMacro(); +const Macro& CelIndexMacro(); +const Macro& CelIterVarMacro(); +const Macro& CelAccuVarMacro(); + +absl::Status RegisterTestMacros(MacroRegistry& registry); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ diff --git a/testutil/util.h b/testutil/util.h index 7eb62ea85..26c47ebe4 100644 --- a/testutil/util.h +++ b/testutil/util.h @@ -1,102 +1,28 @@ -#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_EXPECT_SAME_TYPE_H_ -#define THIRD_PARTY_CEL_CPP_TESTUTIL_EXPECT_SAME_TYPE_H_ - -#include - -#include "google/protobuf/message.h" -#include "google/protobuf/text_format.h" -#include "gmock/gmock.h" - -namespace google { -namespace api { -namespace expr { -namespace testutil { - -// A helper class that causes the compiler to print a helpful error when -// they template args don't match. -template -struct ExpectSameType; - -template -struct ExpectSameType {}; - -// Creates a proto message of type T from a textual representation. -template -T CreateProto(const std::string& textual_proto); - -/** - * Simple implementation of a proto matcher comparing string representations. - * - * IMPORTANT: Only use this for protos whose textual representation is - * deterministic (that may not be the case for the map collection type). - */ -class ProtoStringMatcher { - public: - explicit inline ProtoStringMatcher(const std::string& expected) - : expected_(expected) {} - - explicit inline ProtoStringMatcher(const google::protobuf::Message& expected) - : expected_(expected.DebugString()) {} - - template - bool MatchAndExplain(const Message& p, - ::testing::MatchResultListener* /* listener */) const; - - template - bool MatchAndExplain(const Message* p, - ::testing::MatchResultListener* /* listener */) const; - - inline void DescribeTo(::std::ostream* os) const { *os << expected_; } - inline void DescribeNegationTo(::std::ostream* os) const { - *os << "not equal to expected message: " << expected_; - } - - private: - const std::string expected_; -}; - -// Polymorphic matcher to compare any two protos. -inline ::testing::PolymorphicMatcher EqualsProto( - const std::string& x) { - return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); -} - -// Polymorphic matcher to compare any two protos. -inline ::testing::PolymorphicMatcher EqualsProto( - const google::protobuf::Message& x) { - return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); -} - -template -T CreateProto(const std::string& textual_proto) { - T proto; - google::protobuf::TextFormat::ParseFromString(textual_proto, &proto); - return proto; -} - -template -bool ProtoStringMatcher::MatchAndExplain( - const Message& p, ::testing::MatchResultListener* /* listener */) const { - // Need to CreateProto and then print as std::string so that the formatting - // matches exactly. - return p.SerializeAsString() == - CreateProto(expected_).SerializeAsString(); -} - -template -bool ProtoStringMatcher::MatchAndExplain( - const Message* p, ::testing::MatchResultListener* /* listener */) const { - // Need to CreateProto and then print as std::string so that the formatting - // matches exactly. - std::unique_ptr value; - value.reset(p->New()); - google::protobuf::TextFormat::ParseFromString(expected_, value.get()); - return p->SerializeAsString() == value->SerializeAsString(); -} - -} // namespace testutil -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_EXPECT_SAME_TYPE_H_ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_UTIL_H_ + +#include "internal/proto_matchers.h" + +namespace google::api::expr::testutil { + +// alias for old namespace +// prefer using cel::internal::test::EqualsProto. +using ::cel::internal::test::EqualsProto; + +} // namespace google::api::expr::testutil + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_UTIL_H_ diff --git a/tools/BUILD b/tools/BUILD index 7729b01b5..af006a67b 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -1,41 +1,95 @@ -load( - "@com_github_google_flatbuffers//:build_defs.bzl", - "flatbuffer_library_public", -) +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) cc_library( - name = "flatbuffers_backed_impl", + name = "cel_field_extractor", + srcs = ["cel_field_extractor.cc"], + hdrs = ["cel_field_extractor.h"], + deps = [ + ":navigable_ast", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_test( + name = "cel_field_extractor_test", + srcs = ["cel_field_extractor_test.cc"], + deps = [ + ":cel_field_extractor", + "//internal:testing", + "//parser", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_library( + name = "cel_unparser", srcs = [ - "flatbuffers_backed_impl.cc", + "cel_unparser.cc", ], hdrs = [ - "flatbuffers_backed_impl.h", + "cel_unparser.h", ], deps = [ - "//eval/public:cel_value", - "@com_github_google_flatbuffers//:flatbuffers", + "//common:operators", + "//internal:status_macros", + "//internal:strings", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:duration_cc_proto", + "@com_google_protobuf//:timestamp_cc_proto", + "@com_googlesource_code_re2//:re2", ], ) -flatbuffer_library_public( - name = "flatbuffers_test", - srcs = ["testdata/flatbuffers.fbs"], - outs = ["testdata/flatbuffers_generated.h"], - language_flag = "-c", - reflection_name = "flatbuffers_reflection", +cc_test( + name = "cel_unparser_test", + srcs = ["cel_unparser_test.cc"], + deps = [ + ":cel_unparser", + "//internal:proto_matchers", + "//internal:testing", + "//parser", + "//parser:options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], ) cc_library( - name = "flatbuffers_test_cc", - srcs = [":flatbuffers_test"], - hdrs = [":flatbuffers_test"], - features = ["-parse_headers"], - linkstatic = True, - deps = ["@com_github_google_flatbuffers//:runtime_cc"], + name = "flatbuffers_backed_impl", + srcs = [ + "flatbuffers_backed_impl.cc", + ], + hdrs = [ + "flatbuffers_backed_impl.h", + ], + deps = [ + "//eval/public:cel_value", + "@com_github_google_flatbuffers//:flatbuffers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], ) cc_test( @@ -45,12 +99,170 @@ cc_test( "flatbuffers_backed_impl_test.cc", ], data = [ - ":flatbuffers_reflection_out", + "//tools/testdata:flatbuffers_reflection_out", ], deps = [ ":flatbuffers_backed_impl", - ":flatbuffers_test_cc", + "//internal:status_macros", + "//internal:testing", "@com_github_google_flatbuffers//:flatbuffers", - "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "navigable_ast", + srcs = ["navigable_ast.cc"], + hdrs = ["navigable_ast.h"], + deps = [ + "//common/ast:navigable_ast_internal", + "//eval/public:ast_traverse", + "//eval/public:ast_visitor", + "//eval/public:ast_visitor_base", + "//eval/public:source_position", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/memory", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_test( + name = "navigable_ast_test", + srcs = ["navigable_ast_test.cc"], + deps = [ + ":navigable_ast", + "//base:builtins", + "//internal:testing", + "//parser", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + ], +) + +cc_library( + name = "branch_coverage", + srcs = ["branch_coverage.cc"], + hdrs = ["branch_coverage.h"], + deps = [ + ":navigable_ast", + "//common:value", + "//eval/internal:interop", + "//eval/public:cel_value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:variant", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "branch_coverage_test", + srcs = ["branch_coverage_test.cc"], + data = [ + "//tools/testdata:coverage_testdata", + ], + deps = [ + ":branch_coverage", + ":navigable_ast", + "//base:builtins", + "//common:value", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_value", + "//internal:proto_file_util", + "//internal:testing", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "descriptor_pool_builder", + srcs = ["descriptor_pool_builder.cc"], + hdrs = ["descriptor_pool_builder.h"], + deps = [ + "//common:minimal_descriptor_database", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "proto_to_predicate", + srcs = ["proto_to_predicate.cc"], + hdrs = ["proto_to_predicate.h"], + deps = [ + "//common:ast", + "//common:expr", + "//common:expr_factory", + "//common:operators", + "//internal:status_macros", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_to_predicate_test", + srcs = ["proto_to_predicate_test.cc"], + deps = [ + ":cel_unparser", + ":proto_to_predicate", + "//common:ast", + "//common:ast_proto", + "//common:value", + "//env:config", + "//env:env_runtime", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:value", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//tools/testdata:test_policy_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "descriptor_pool_builder_test", + srcs = ["descriptor_pool_builder_test.cc"], + deps = [ + ":descriptor_pool_builder", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", ], ) diff --git a/tools/branch_coverage.cc b/tools/branch_coverage.cc new file mode 100644 index 000000000..b5bba3ffe --- /dev/null +++ b/tools/branch_coverage.cc @@ -0,0 +1,253 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/branch_coverage.h" + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/variant.h" +#include "common/value.h" +#include "eval/internal/interop.h" +#include "eval/public/cel_value.h" +#include "tools/navigable_ast.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::cel::expr::CheckedExpr; +using ::cel::expr::Type; +using ::google::api::expr::runtime::CelValue; + +const absl::Status& UnsupportedConversionError() { + static absl::NoDestructor kErr( + absl::StatusCode::kInternal, "Conversion to legacy type unsupported."); + + return *kErr; +} + +// Constant literal. +// +// These should be handled separately from variable parts of the AST to not +// inflate / deflate coverage wrt variable inputs. +struct ConstantNode {}; + +// A boolean node. +// +// Branching in CEL is mostly determined by boolean subexpression results, so +// specify intercepted values. +struct BoolNode { + int result_true; + int result_false; + int result_error; +}; + +// Catch all for other nodes. +struct OtherNode { + int result_error; +}; + +// Representation for coverage of an AST node. +struct CoverageNode { + int evaluate_count; + std::variant kind; +}; + +const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, + int64_t expr_id) { + if (auto it = expr.type_map().find(expr_id); it != expr.type_map().end()) { + return &it->second; + } + return nullptr; +} + +class BranchCoverageImpl : public BranchCoverage { + public: + explicit BranchCoverageImpl(const CheckedExpr& expr) : expr_(expr) {} + + // Implement public interface. + void Record(int64_t expr_id, const Value& value) override { + auto value_or = interop_internal::ToLegacyValue(&arena_, value); + + if (!value_or.ok()) { + // TODO(uncreated-issue/65): Use pointer identity for UnsupportedConversionError + // as a sentinel value. The legacy CEL value just wraps the error pointer. + // This can be removed after the value migration is complete. + RecordImpl(expr_id, CelValue::CreateError(&UnsupportedConversionError())); + } else { + return RecordImpl(expr_id, *value_or); + } + } + + void RecordLegacyValue(int64_t expr_id, const CelValue& value) override { + return RecordImpl(expr_id, value); + } + + BranchCoverage::NodeCoverageStats StatsForNode( + int64_t expr_id) const override; + + const NavigableProtoAst& ast() const override; + const CheckedExpr& expr() const override; + + // Initializes the coverage implementation. This should be called by the + // factory function (synchronously). + // + // Other mutation operations must be synchronized since we don't have control + // of when the instrumented expressions get called. + void Init(); + + private: + friend class BranchCoverage; + + void RecordImpl(int64_t expr_id, const CelValue& value); + + // Infer it the node is boolean typed. Check the type map if available. + // Otherwise infer typing based on built-in functions. + bool InferredBoolType(const NavigableProtoAstNode& node) const; + + CheckedExpr expr_; + NavigableProtoAst ast_; + mutable absl::Mutex coverage_nodes_mu_; + absl::flat_hash_map coverage_nodes_ + ABSL_GUARDED_BY(coverage_nodes_mu_); + absl::flat_hash_set unexpected_expr_ids_ + ABSL_GUARDED_BY(coverage_nodes_mu_); + google::protobuf::Arena arena_; +}; + +BranchCoverage::NodeCoverageStats BranchCoverageImpl::StatsForNode( + int64_t expr_id) const { + BranchCoverage::NodeCoverageStats stats{ + /*is_boolean=*/false, + /*evaluation_count=*/0, + /*error_count=*/0, + /*boolean_true_count=*/0, + /*boolean_false_count=*/0, + }; + + absl::MutexLock lock(coverage_nodes_mu_); + auto it = coverage_nodes_.find(expr_id); + if (it != coverage_nodes_.end()) { + const CoverageNode& coverage_node = it->second; + stats.evaluation_count = coverage_node.evaluate_count; + absl::visit(absl::Overload([&](const ConstantNode& cov) {}, + [&](const OtherNode& cov) { + stats.error_count = cov.result_error; + }, + [&](const BoolNode& cov) { + stats.is_boolean = true; + stats.boolean_true_count = cov.result_true; + stats.boolean_false_count = cov.result_false; + stats.error_count = cov.result_error; + }), + coverage_node.kind); + return stats; + } + return stats; +} + +const NavigableProtoAst& BranchCoverageImpl::ast() const { return ast_; } + +const CheckedExpr& BranchCoverageImpl::expr() const { return expr_; } + +bool BranchCoverageImpl::InferredBoolType( + const NavigableProtoAstNode& node) const { + int64_t expr_id = node.expr()->id(); + const auto* checker_type = FindCheckerType(expr_, expr_id); + if (checker_type != nullptr) { + return checker_type->has_primitive() && + checker_type->primitive() == Type::BOOL; + } + + return false; +} + +void BranchCoverageImpl::Init() ABSL_NO_THREAD_SAFETY_ANALYSIS { + ast_ = NavigableProtoAst::Build(expr_.expr()); + for (const NavigableProtoAstNode& node : ast_.Root().DescendantsPreorder()) { + int64_t expr_id = node.expr()->id(); + + CoverageNode& coverage_node = coverage_nodes_[expr_id]; + coverage_node.evaluate_count = 0; + if (node.node_kind() == NodeKind::kConstant) { + coverage_node.kind = ConstantNode{}; + } else if (InferredBoolType(node)) { + coverage_node.kind = BoolNode{0, 0, 0}; + } else { + coverage_node.kind = OtherNode{0}; + } + } +} + +void BranchCoverageImpl::RecordImpl(int64_t expr_id, const CelValue& value) { + absl::MutexLock lock(coverage_nodes_mu_); + auto it = coverage_nodes_.find(expr_id); + if (it == coverage_nodes_.end()) { + unexpected_expr_ids_.insert(expr_id); + it = coverage_nodes_.insert({expr_id, CoverageNode{0, {}}}).first; + if (value.IsBool()) { + it->second.kind = BoolNode{0, 0, 0}; + } + } + + CoverageNode& coverage_node = it->second; + coverage_node.evaluate_count++; + bool is_error = value.IsError() && + // Filter conversion errors for evaluator internal types. + // TODO(uncreated-issue/65): RecordImpl operates on legacy values so + // special case conversion errors. This error is really just a + // sentinel value and doesn't need to round-trip between + // legacy and legacy types. + value.ErrorOrDie() != &UnsupportedConversionError(); + + absl::visit(absl::Overload([&](ConstantNode& node) {}, + [&](OtherNode& cov) { + if (is_error) { + cov.result_error++; + } + }, + [&](BoolNode& cov) { + if (value.IsBool()) { + bool held_value = value.BoolOrDie(); + if (held_value) { + cov.result_true++; + } else { + cov.result_false++; + } + } else if (is_error) { + cov.result_error++; + } + }), + coverage_node.kind); +} + +} // namespace + +std::unique_ptr CreateBranchCoverage(const CheckedExpr& expr) { + auto result = std::make_unique(expr); + result->Init(); + return result; +} + +} // namespace cel diff --git a/tools/branch_coverage.h b/tools/branch_coverage.h new file mode 100644 index 000000000..128faefed --- /dev/null +++ b/tools/branch_coverage.h @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_BRANCH_COVERAGE_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_BRANCH_COVERAGE_H_ + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "absl/base/attributes.h" +#include "common/value.h" +#include "eval/public/cel_value.h" +#include "tools/navigable_ast.h" + +namespace cel { + +// Interface for BranchCoverage collection utility. +// +// This provides a factory for instrumentation that collects coverage +// information over multiple executions of a CEL expression. This does not +// provide any mechanism for de-duplicating multiple CheckedExpr instances +// that represent the same expression within or across processes. +// +// The default implementation is thread safe. +// +// TODO(uncreated-issue/65): add support for interesting aggregate stats. +class BranchCoverage { + public: + struct NodeCoverageStats { + bool is_boolean; + int evaluation_count; + int boolean_true_count; + int boolean_false_count; + int error_count; + }; + + virtual ~BranchCoverage() = default; + + virtual void Record(int64_t expr_id, const Value& value) = 0; + virtual void RecordLegacyValue( + int64_t expr_id, const google::api::expr::runtime::CelValue& value) = 0; + + virtual NodeCoverageStats StatsForNode(int64_t expr_id) const = 0; + + virtual const NavigableProtoAst& ast() const + ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + virtual const cel::expr::CheckedExpr& expr() const + ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; +}; + +std::unique_ptr CreateBranchCoverage( + const cel::expr::CheckedExpr& expr); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_BRANCH_COVERAGE_H_ diff --git a/tools/branch_coverage_test.cc b/tools/branch_coverage_test.cc new file mode 100644 index 000000000..3a7a1c0a2 --- /dev/null +++ b/tools/branch_coverage_test.cc @@ -0,0 +1,418 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/branch_coverage.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/substitute.h" +#include "base/builtins.h" +#include "common/value.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "internal/proto_file_util.h" +#include "internal/testing.h" +#include "tools/navigable_ast.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::cel::internal::test::ReadTextProtoFromFile; +using ::cel::expr::CheckedExpr; +using ::google::api::expr::runtime::Activation; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::CreateCelExpressionBuilder; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; + +// int1 < int2 && +// (43 > 42) && +// !(bool1 || bool2) && +// 4 / int_divisor >= 1 && +// (ternary_c ? ternary_t : ternary_f) +constexpr char kCoverageExamplePath[] = + "tools/testdata/coverage_example.textproto"; + +const CheckedExpr& TestExpression() { + static absl::NoDestructor expression([]() { + CheckedExpr value; + ABSL_CHECK_OK(ReadTextProtoFromFile(kCoverageExamplePath, value)); + return value; + }()); + return *expression; +} + +std::string FormatNodeStats(const BranchCoverage::NodeCoverageStats& stats) { + return absl::Substitute( + "is_bool: $0; evaluated: $1; bool_true: $2; bool_false: $3; error: $4", + stats.is_boolean, stats.evaluation_count, stats.boolean_true_count, + stats.boolean_false_count, stats.error_count); +} + +google::api::expr::runtime::CelEvaluationListener EvaluationListenerForCoverage( + BranchCoverage* coverage) { + return [coverage](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { + coverage->RecordLegacyValue(id, value); + return absl::OkStatus(); + }; +} + +MATCHER_P(MatchesNodeStats, expected, "") { + const BranchCoverage::NodeCoverageStats& actual = arg; + + *result_listener << "\n"; + *result_listener << "Expected: " << FormatNodeStats(expected); + *result_listener << "\n"; + *result_listener << "Got: " << FormatNodeStats(actual); + + return actual.is_boolean == expected.is_boolean && + actual.evaluation_count == expected.evaluation_count && + actual.boolean_true_count == expected.boolean_true_count && + actual.boolean_false_count == expected.boolean_false_count && + actual.error_count == expected.error_count; +} + +MATCHER(NodeStatsIsBool, "") { + const BranchCoverage::NodeCoverageStats& actual = arg; + + *result_listener << "\n"; + *result_listener << "Expected: " << FormatNodeStats({true, 0, 0, 0, 0}); + *result_listener << "\n"; + *result_listener << "Got: " << FormatNodeStats(actual); + + return actual.is_boolean == true; +} + +TEST(BranchCoverage, DefaultsForUntrackedId) { + auto coverage = CreateBranchCoverage(TestExpression()); + + using Stats = BranchCoverage::NodeCoverageStats; + + EXPECT_THAT(coverage->StatsForNode(99), + MatchesNodeStats(Stats{/*is_boolean=*/false, + /*evaluation_count=*/0, + /*boolean_true_count=*/0, + /*boolean_false_count=*/0, + /*error_count=*/0})); +} + +TEST(BranchCoverage, Record) { + auto coverage = CreateBranchCoverage(TestExpression()); + + int64_t root_id = coverage->expr().expr().id(); + + coverage->Record(root_id, cel::BoolValue(false)); + + using Stats = BranchCoverage::NodeCoverageStats; + + EXPECT_THAT(coverage->StatsForNode(root_id), + MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/1, + /*boolean_true_count=*/0, + /*boolean_false_count=*/1, + /*error_count=*/0})); +} + +TEST(BranchCoverage, RecordUnexpectedId) { + auto coverage = CreateBranchCoverage(TestExpression()); + + int64_t unexpected_id = 99; + + coverage->Record(unexpected_id, cel::BoolValue(false)); + + using Stats = BranchCoverage::NodeCoverageStats; + + EXPECT_THAT(coverage->StatsForNode(unexpected_id), + MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/1, + /*boolean_true_count=*/0, + /*boolean_false_count=*/1, + /*error_count=*/0})); +} + +TEST(BranchCoverage, IncrementsCounters) { + auto coverage = CreateBranchCoverage(TestExpression()); + + EXPECT_TRUE(static_cast(coverage->ast())); + + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // int1 < int2 && + // (43 > 42) && + // !(bool1 || bool2) && + // 4 / int_divisor >= 1 && + // (ternary_c ? ternary_t : ternary_f) + ASSERT_OK_AND_ASSIGN(auto program, + builder->CreateExpression(&TestExpression())); + + google::protobuf::Arena arena; + Activation activation; + activation.InsertValue("bool1", CelValue::CreateBool(false)); + activation.InsertValue("bool2", CelValue::CreateBool(false)); + + activation.InsertValue("int1", CelValue::CreateInt64(42)); + activation.InsertValue("int2", CelValue::CreateInt64(43)); + + activation.InsertValue("int_divisor", CelValue::CreateInt64(4)); + + activation.InsertValue("ternary_c", CelValue::CreateBool(true)); + activation.InsertValue("ternary_t", CelValue::CreateBool(true)); + activation.InsertValue("ternary_f", CelValue::CreateBool(false)); + + ASSERT_OK_AND_ASSIGN( + auto result, + program->Trace(activation, &arena, + EvaluationListenerForCoverage(coverage.get()))); + + EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == true); + + using Stats = BranchCoverage::NodeCoverageStats; + const NavigableProtoAst& ast = coverage->ast(); + auto root_node_stats = coverage->StatsForNode(ast.Root().expr()->id()); + + EXPECT_THAT(root_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/1, + /*boolean_true_count=*/1, + /*boolean_false_count=*/0, + /*error_count=*/0})); + + const NavigableProtoAstNode* ternary; + for (const auto& node : ast.Root().DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kCall && + node.expr()->call_expr().function() == cel::builtin::kTernary) { + ternary = &node; + break; + } + } + + ASSERT_NE(ternary, nullptr); + auto ternary_node_stats = coverage->StatsForNode(ternary->expr()->id()); + // Ternary gets optimized to conditional jumps, so it isn't instrumented + // directly in stack machine impl. + EXPECT_THAT(ternary_node_stats, NodeStatsIsBool()); + + const auto* false_node = ternary->children().at(2); + auto false_node_stats = coverage->StatsForNode(false_node->expr()->id()); + EXPECT_THAT(false_node_stats, + MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/0, + /*boolean_true_count=*/0, + /*boolean_false_count=*/0, + /*error_count=*/0})); + + const NavigableProtoAstNode* not_arg_expr; + for (const auto& node : ast.Root().DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kCall && + node.expr()->call_expr().function() == cel::builtin::kNot) { + not_arg_expr = node.children().at(0); + break; + } + } + + ASSERT_NE(not_arg_expr, nullptr); + auto not_expr_node_stats = coverage->StatsForNode(not_arg_expr->expr()->id()); + EXPECT_THAT(not_expr_node_stats, + MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/1, + /*boolean_true_count=*/0, + /*boolean_false_count=*/1, + /*error_count=*/0})); + + const NavigableProtoAstNode* div_expr; + for (const auto& node : ast.Root().DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kCall && + node.expr()->call_expr().function() == cel::builtin::kDivide) { + div_expr = &node; + break; + } + } + + ASSERT_NE(div_expr, nullptr); + auto div_expr_stats = coverage->StatsForNode(div_expr->expr()->id()); + EXPECT_THAT(div_expr_stats, MatchesNodeStats(Stats{/*is_boolean=*/false, + /*evaluation_count=*/1, + /*boolean_true_count=*/0, + /*boolean_false_count=*/0, + /*error_count=*/0})); +} + +TEST(BranchCoverage, AccumulatesAcrossRuns) { + auto coverage = CreateBranchCoverage(TestExpression()); + + EXPECT_TRUE(static_cast(coverage->ast())); + + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // int1 < int2 && + // (43 > 42) && + // !(bool1 || bool2) && + // 4 / int_divisor >= 1 && + // (ternary_c ? ternary_t : ternary_f) + ASSERT_OK_AND_ASSIGN(auto program, + builder->CreateExpression(&TestExpression())); + + google::protobuf::Arena arena; + Activation activation; + activation.InsertValue("bool1", CelValue::CreateBool(false)); + activation.InsertValue("bool2", CelValue::CreateBool(false)); + + activation.InsertValue("int1", CelValue::CreateInt64(42)); + activation.InsertValue("int2", CelValue::CreateInt64(43)); + + activation.InsertValue("int_divisor", CelValue::CreateInt64(4)); + + activation.InsertValue("ternary_c", CelValue::CreateBool(true)); + activation.InsertValue("ternary_t", CelValue::CreateBool(true)); + activation.InsertValue("ternary_f", CelValue::CreateBool(false)); + + ASSERT_OK_AND_ASSIGN( + auto result, + program->Trace(activation, &arena, + EvaluationListenerForCoverage(coverage.get()))); + + EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == true); + + activation.RemoveValueEntry("ternary_c"); + activation.RemoveValueEntry("ternary_f"); + + activation.InsertValue("ternary_c", CelValue::CreateBool(false)); + activation.InsertValue("ternary_f", CelValue::CreateBool(false)); + + ASSERT_OK_AND_ASSIGN( + result, program->Trace(activation, &arena, + EvaluationListenerForCoverage(coverage.get()))); + + EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == false) + << result.DebugString(); + + using Stats = BranchCoverage::NodeCoverageStats; + const NavigableProtoAst& ast = coverage->ast(); + auto root_node_stats = coverage->StatsForNode(ast.Root().expr()->id()); + + EXPECT_THAT(root_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/2, + /*boolean_true_count=*/1, + /*boolean_false_count=*/1, + /*error_count=*/0})); + + const NavigableProtoAstNode* ternary; + for (const auto& node : ast.Root().DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kCall && + node.expr()->call_expr().function() == cel::builtin::kTernary) { + ternary = &node; + break; + } + } + + ASSERT_NE(ternary, nullptr); + auto ternary_node_stats = coverage->StatsForNode(ternary->expr()->id()); + + // Ternary gets optimized into conditional jumps for stack machine plan. + EXPECT_THAT(ternary_node_stats, NodeStatsIsBool()); + + const auto* false_node = ternary->children().at(2); + auto false_node_stats = coverage->StatsForNode(false_node->expr()->id()); + EXPECT_THAT(false_node_stats, + MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/1, + /*boolean_true_count=*/0, + /*boolean_false_count=*/1, + /*error_count=*/0})); +} + +TEST(BranchCoverage, CountsErrors) { + auto coverage = CreateBranchCoverage(TestExpression()); + + EXPECT_TRUE(static_cast(coverage->ast())); + + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + // int1 < int2 && + // (43 > 42) && + // !(bool1 || bool2) && + // 4 / int_divisor >= 1 && + // (ternary_c ? ternary_t : ternary_f) + ASSERT_OK_AND_ASSIGN(auto program, + builder->CreateExpression(&TestExpression())); + + google::protobuf::Arena arena; + Activation activation; + activation.InsertValue("bool1", CelValue::CreateBool(false)); + activation.InsertValue("bool2", CelValue::CreateBool(false)); + + activation.InsertValue("int1", CelValue::CreateInt64(42)); + activation.InsertValue("int2", CelValue::CreateInt64(43)); + + activation.InsertValue("int_divisor", CelValue::CreateInt64(0)); + + activation.InsertValue("ternary_c", CelValue::CreateBool(true)); + activation.InsertValue("ternary_t", CelValue::CreateBool(false)); + activation.InsertValue("ternary_f", CelValue::CreateBool(false)); + + ASSERT_OK_AND_ASSIGN( + auto result, + program->Trace(activation, &arena, + EvaluationListenerForCoverage(coverage.get()))); + + EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == false); + + using Stats = BranchCoverage::NodeCoverageStats; + const NavigableProtoAst& ast = coverage->ast(); + auto root_node_stats = coverage->StatsForNode(ast.Root().expr()->id()); + + EXPECT_THAT(root_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, + /*evaluation_count=*/1, + /*boolean_true_count=*/0, + /*boolean_false_count=*/1, + /*error_count=*/0})); + + const NavigableProtoAstNode* ternary; + for (const auto& node : ast.Root().DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kCall && + node.expr()->call_expr().function() == cel::builtin::kTernary) { + ternary = &node; + break; + } + } + + const NavigableProtoAstNode* div_expr; + for (const auto& node : ast.Root().DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kCall && + node.expr()->call_expr().function() == cel::builtin::kDivide) { + div_expr = &node; + break; + } + } + + ASSERT_NE(div_expr, nullptr); + auto div_expr_stats = coverage->StatsForNode(div_expr->expr()->id()); + EXPECT_THAT(div_expr_stats, MatchesNodeStats(Stats{/*is_boolean=*/false, + /*evaluation_count=*/1, + /*boolean_true_count=*/0, + /*boolean_false_count=*/0, + /*error_count=*/1})); +} + +} // namespace +} // namespace cel diff --git a/tools/cel_field_extractor.cc b/tools/cel_field_extractor.cc new file mode 100644 index 000000000..50207c3cf --- /dev/null +++ b/tools/cel_field_extractor.cc @@ -0,0 +1,87 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/cel_field_extractor.h" + +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_join.h" +#include "tools/navigable_ast.h" + +namespace cel { + +namespace { + +bool IsComprehensionDefinedField(const cel::NavigableProtoAstNode& node) { + const cel::NavigableProtoAstNode* current_node = &node; + + while (current_node->parent() != nullptr) { + current_node = current_node->parent(); + + if (current_node->node_kind() != cel::NodeKind::kComprehension) { + continue; + } + + std::string ident_name = node.expr()->ident_expr().name(); + bool iter_var_match = + ident_name == current_node->expr()->comprehension_expr().iter_var(); + bool iter_var2_match = + ident_name == current_node->expr()->comprehension_expr().iter_var2(); + bool accu_var_match = + ident_name == current_node->expr()->comprehension_expr().accu_var(); + + if (iter_var_match || iter_var2_match || accu_var_match) { + return true; + } + } + + return false; +} + +} // namespace + +absl::flat_hash_set ExtractFieldPaths( + const cel::expr::Expr& expr) { + NavigableProtoAst ast = NavigableProtoAst::Build(expr); + + absl::flat_hash_set field_paths; + std::vector fields_in_scope; + + // Preorder traversal works because the select nodes (in a well-formed + // expression) always have only one operand, so its operand is visited + // next in the loop iteration (which results in the path being extended, + // completed, or discarded if uninteresting). + for (const cel::NavigableProtoAstNode& node : + ast.Root().DescendantsPreorder()) { + if (node.node_kind() == cel::NodeKind::kSelect) { + fields_in_scope.push_back(node.expr()->select_expr().field()); + continue; + } + if (node.node_kind() == cel::NodeKind::kIdent && + !IsComprehensionDefinedField(node)) { + fields_in_scope.push_back(node.expr()->ident_expr().name()); + std::reverse(fields_in_scope.begin(), fields_in_scope.end()); + field_paths.insert(absl::StrJoin(fields_in_scope, ".")); + } + fields_in_scope.clear(); + } + + return field_paths; +} + +} // namespace cel diff --git a/tools/cel_field_extractor.h b/tools/cel_field_extractor.h new file mode 100644 index 000000000..cfbb2370d --- /dev/null +++ b/tools/cel_field_extractor.h @@ -0,0 +1,70 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_CEL_FIELD_EXTRACTOR_H +#define THIRD_PARTY_CEL_CPP_TOOLS_CEL_FIELD_EXTRACTOR_H + +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/container/flat_hash_set.h" + +namespace cel { + +// ExtractExpressionFieldPaths attempts to extract the set of unique field +// selection paths from top level identifiers (e.g. "request.user.id"). +// +// One possible use case for this class is to determine which fields of a +// serialized message are referenced by a CEL query, enabling partial +// deserialization for performance optimization. +// +// Implementation notes: +// The extraction logic focuses on identifying chains of `Select` operations +// that terminate with a primary identifier node (`IdentExpr`). For example, +// in the expression `message.field.subfield == 10`, the path +// "message.field.subfield" would be extracted. +// +// Identifiers defined locally within CEL comprehension expressions (e.g., +// comprehension variables aliases defined by `iter_var`, `iter_var2`, +// `accu_var` in the AST) are NOT included. Example: +// `list.exists(elem, elem.field == 'value')` would return {"list"} only. +// +// Container indexing with the _[_] is not considered, but map indexing with +// the select operator is considered. For example: +// `message.map_field.key || message.map_field['foo']` results in +// {'message.map_field.key', 'message.map_field'} +// +// This implementation does not consider type check metadata, so there is no +// understanding of whether the primary identifiers and field accesses +// necessarily map to proto messages or proto field accesses. The field +// also does not have any understanding of the type of the leaf of the +// select path. +// +// Example: +// Given the CEL expression: +// `(request.user.id == 'test' && request.user.attributes.exists(attr, +// attr.key +// == 'role')) || size(request.items) > 0` +// +// The extracted field paths would be: +// - "request.user.id" +// - "request.user.attributes" (because `attr` is a comprehension variable) +// - "request.items" + +absl::flat_hash_set ExtractFieldPaths( + const cel::expr::Expr& expr); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_CEL_FIELD_EXTRACTOR_H diff --git a/tools/cel_field_extractor_test.cc b/tools/cel_field_extractor_test.cc new file mode 100644 index 000000000..edf31aef9 --- /dev/null +++ b/tools/cel_field_extractor_test.cc @@ -0,0 +1,148 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/cel_field_extractor.h" + +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "internal/testing.h" +#include "parser/parser.h" + +namespace cel { + +namespace { + +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +absl::flat_hash_set GetExtractedFields( + const std::string& cel_query) { + absl::StatusOr parsed_expr_or_status = Parse(cel_query); + ABSL_CHECK_OK(parsed_expr_or_status); + return ExtractFieldPaths(parsed_expr_or_status.value().expr()); +} + +TEST(TestExtractFieldPaths, CelExprWithOneField) { + EXPECT_THAT(GetExtractedFields("field_name"), + UnorderedElementsAre("field_name")); +} + +TEST(TestExtractFieldPaths, CelExprWithNoWithLiteral) { + EXPECT_THAT(GetExtractedFields("'field_name'"), IsEmpty()); +} + +TEST(TestExtractFieldPaths, CelExprWithFunctionCallOnSingleField) { + EXPECT_THAT(GetExtractedFields("!boolean_field"), + UnorderedElementsAre("boolean_field")); +} + +TEST(TestExtractFieldPaths, CelExprWithSizeFuncCallOnSingleField) { + EXPECT_THAT(GetExtractedFields("size(repeated_field)"), + UnorderedElementsAre("repeated_field")); +} + +TEST(TestExtractFieldPaths, CelExprWithNestedField) { + EXPECT_THAT(GetExtractedFields("message_field.nested_field.nested_field2"), + UnorderedElementsAre("message_field.nested_field.nested_field2")); +} + +TEST(TestExtractFieldPaths, CelExprWithNestedFieldAndIndexAccess) { + EXPECT_THAT(GetExtractedFields( + "repeated_message_field.nested_field[0].nested_field2"), + UnorderedElementsAre("repeated_message_field.nested_field")); +} + +TEST(TestExtractFieldPaths, CelExprWithMultipleFunctionCalls) { + EXPECT_THAT(GetExtractedFields( + "(size(repeated_field) > 0 && !boolean_field == true) || " + "request.valid == true && request.count == 0"), + UnorderedElementsAre("boolean_field", "repeated_field", + "request.valid", "request.count")); +} + +TEST(TestExtractFieldPaths, CelExprWithNestedComprehension) { + EXPECT_THAT( + GetExtractedFields("repeated_field_1.exists(e, e.key == 'one') && " + "req.repeated_field_2.exists(x, " + "x.y.z == 'val' &&" + "x.array.exists(y, y == 'val' && req.bool_field == " + "true && x.bool_field == false))"), + UnorderedElementsAre("req.repeated_field_2", "req.bool_field", + "repeated_field_1")); +} + +TEST(TestExtractFieldPaths, CelExprWithMultipleComprehension) { + EXPECT_THAT( + GetExtractedFields( + "repeated_field_1.exists(e, e.key == 'one' && y.field_1 == 'val') && " + "repeated_field_2.exists(y, y.key == 'one' && e.field_2 == 'val')"), + UnorderedElementsAre("repeated_field_1", "repeated_field_2", "e.field_2", + "y.field_1")); +} + +TEST(TestExtractFieldPaths, CelExprWithListLiteral) { + EXPECT_THAT(GetExtractedFields("['a', b, 3].exists(x, x == 1)"), + UnorderedElementsAre("b")); +} + +TEST(TestExtractFieldPaths, CelExprWithFunctionCallsAndRepeatedFields) { + EXPECT_THAT( + GetExtractedFields("data == 'data_1' && field_1 == 'val_1' &&" + "(matches(req.field_2, 'val_1') == true) &&" + "repeated_field[0].priority >= 200"), + UnorderedElementsAre("data", "field_1", "req.field_2", "repeated_field")); +} + +TEST(TestExtractFieldPaths, CelExprWithFunctionOnRepeatedField) { + EXPECT_THAT( + GetExtractedFields("(contains_data == false && " + "data.field_1=='value_1') || " + "size(data.nodes) > 0 && " + "data.nodes[0].field_2=='value_2'"), + UnorderedElementsAre("contains_data", "data.field_1", "data.nodes")); +} + +TEST(TestExtractFieldPaths, CelExprContainingEndsWithFunction) { + EXPECT_THAT(GetExtractedFields("data.repeated_field.exists(f, " + "f.field_1.field_2.endsWith('val_1')) || " + "data.field_3.endsWith('val_3')"), + UnorderedElementsAre("data.repeated_field", "data.field_3")); +} + +TEST(TestExtractFieldPaths, + CelExprWithMatchFunctionInsideComprehensionAndRegexConstants) { + EXPECT_THAT(GetExtractedFields("req.field_1.field_2=='val_1' && " + "data!=null && req.repeated_field.exists(f, " + "f.matches('a100.*|.*h100_80gb.*|.*h200.*'))"), + UnorderedElementsAre("req.field_1.field_2", "req.repeated_field", + "data")); +} + +TEST(TestExtractFieldPaths, CelExprWithMultipleChecksInComprehension) { + EXPECT_THAT( + GetExtractedFields("req.field.repeated_field.exists(f, f.key == 'data_1'" + " && f.str_value == 'val_1') && " + "req.metadata.type == 3"), + UnorderedElementsAre("req.field.repeated_field", "req.metadata.type")); +} + +} // namespace + +} // namespace cel diff --git a/tools/cel_unparser.cc b/tools/cel_unparser.cc new file mode 100644 index 000000000..741d91208 --- /dev/null +++ b/tools/cel_unparser.cc @@ -0,0 +1,592 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/cel_unparser.h" + +#include +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "re2/re2.h" + +namespace google::api::expr { +namespace { + +using ::cel::expr::CheckedExpr; +using ::cel::expr::Constant; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; +using ::google::api::expr::common::CelOperator; +using ::google::api::expr::common::IsOperatorLeftRecursive; +using ::google::api::expr::common::IsOperatorLowerPrecedence; +using ::google::api::expr::common::IsOperatorSamePrecedence; +using ::google::api::expr::common::LookupBinaryOperator; +using ::google::api::expr::common::LookupUnaryOperator; + +constexpr absl::string_view kLeftParen = "("; +constexpr absl::string_view kRightParen = ")"; +constexpr absl::string_view kLeftBracket = "["; +constexpr absl::string_view kRightBracket = "]"; +constexpr absl::string_view kLeftBrace = "{"; +constexpr absl::string_view kRightBrace = "}"; +constexpr absl::string_view kSpace = " "; +constexpr absl::string_view kDot = "."; +constexpr absl::string_view kColon = ":"; +constexpr absl::string_view kComma = ","; +constexpr absl::string_view kBackQuote = "`"; +constexpr absl::string_view kQuestionMark = "?"; + +static const LazyRE2 kSimpleIdentifierPattern = {R"([a-zA-Z_][a-zA-Z0-9_]*)"}; + +const absl::flat_hash_set& ReservedFieldIdentifiers() { + static const absl::NoDestructor> + kReservedFieldIdentifiers( + []() { return absl::flat_hash_set{"in"}; }()); + return *kReservedFieldIdentifiers; +} + +std::string FormatField(absl::string_view field) { + if (ReservedFieldIdentifiers().contains(field) || + !RE2::FullMatch(field, *kSimpleIdentifierPattern)) { + return absl::StrCat(kBackQuote, field, kBackQuote); + } + return std::string(field); +} + +class Unparser { + public: + static absl::StatusOr Unparse(const Expr& expr, + const SourceInfo& source_info) { + Unparser unparser(expr, source_info); + return unparser.DoUnparse(); + } + + private: + const Expr& expr_; + const SourceInfo& source_info_; + std::string output_; + + Unparser(const Expr& expr, const SourceInfo& source_info) + : expr_(expr), source_info_(source_info) {} + + absl::StatusOr DoUnparse() { + CEL_RETURN_IF_ERROR(Visit(expr_)); + absl::StripAsciiWhitespace(&output_); + return std::move(output_); + } + + absl::Status Visit(const Expr& expr); + + absl::Status VisitConst(const Constant& expr); + + absl::Status VisitIdent(const Expr::Ident& expr); + + absl::Status VisitSelect(const Expr::Select& expr); + + absl::Status VisitOptSelect(const Expr::Call& expr); + + absl::Status VisitCall(const Expr::Call& expr); + + absl::Status VisitCreateList(const Expr::CreateList& expr); + + absl::Status VisitCreateStruct(const Expr::CreateStruct& expr); + + absl::Status VisitComprehension(const Expr::Comprehension& expr); + + absl::Status VisitAllMacro(const Expr::Comprehension& expr); + + absl::Status VisitExistsMacro(const Expr::Comprehension& expr); + + absl::Status VisitExistsOneMacro(const Expr::Comprehension& expr); + + absl::Status VisitMapMacro(const Expr::Comprehension& expr); + + absl::Status VisitUnary(const Expr::Call& expr, const std::string& op); + + absl::Status VisitBinary(const Expr::Call& expr, const std::string& op); + + absl::Status VisitMaybeNested(const Expr& expr, bool nested); + + absl::Status VisitIndex(const Expr::Call& expr); + + absl::Status VisitOptIndex(const Expr::Call& expr); + + absl::Status VisitTernary(const Expr::Call& expr); + + bool IsComplexOperatorWithRespectTo(const Expr& expr, const std::string& op); + + bool IsComplexOperator(const Expr& expr); + + // Returns true the given expression is + // - a call expression AND ONE of the following holds: + // - a binary operator + // - a ternary conditional operator + bool IsBinaryOrTernaryOperator(const Expr& expr); + + bool IsLogicalOperator(absl::string_view op); + + template + void Print(Ts&&... args) { + absl::StrAppend(&output_, std::forward(args)...); + } +}; + +absl::Status Unparser::Visit(const Expr& expr) { + auto macro = source_info_.macro_calls().find(expr.id()); + if (macro != source_info_.macro_calls().end()) { + return Visit(macro->second); + } + switch (expr.expr_kind_case()) { + case Expr::kConstExpr: + return VisitConst(expr.const_expr()); + case Expr::kIdentExpr: + return VisitIdent(expr.ident_expr()); + case Expr::kSelectExpr: + return VisitSelect(expr.select_expr()); + case Expr::kCallExpr: + return VisitCall(expr.call_expr()); + case Expr::kListExpr: + return VisitCreateList(expr.list_expr()); + case Expr::kStructExpr: + return VisitCreateStruct(expr.struct_expr()); + case Expr::kComprehensionExpr: + return VisitComprehension(expr.comprehension_expr()); + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported Expr kind: ", expr.expr_kind_case())); + } +} + +absl::Status Unparser::VisitConst(const Constant& expr) { + switch (expr.constant_kind_case()) { + case Constant::kStringValue: + Print( + cel::internal::FormatDoubleQuotedStringLiteral(expr.string_value())); + break; + case Constant::kInt64Value: + Print(expr.int64_value()); + break; + case Constant::kUint64Value: + Print(expr.uint64_value(), "u"); + break; + case Constant::kBoolValue: + Print(expr.bool_value() ? "true" : "false"); + break; + case Constant::kDoubleValue: + Print(expr.double_value()); + break; + case Constant::kNullValue: + Print("null"); + break; + case Constant::kBytesValue: + Print(cel::internal::FormatDoubleQuotedBytesLiteral(expr.bytes_value())); + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported Constant kind: ", expr.constant_kind_case())); + } + return absl::OkStatus(); +} + +absl::Status Unparser::VisitIdent(const Expr::Ident& expr) { + Print(expr.name()); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitSelect(const Expr::Select& expr) { + if (expr.test_only()) { + Print(CelOperator::HAS, kLeftParen); + } + const auto& operand = expr.operand(); + bool nested = !expr.test_only() && IsBinaryOrTernaryOperator(operand); + CEL_RETURN_IF_ERROR(VisitMaybeNested(operand, nested)); + Print(kDot, FormatField(expr.field())); + if (expr.test_only()) { + Print(kRightParen); + } + return absl::OkStatus(); +} + +absl::Status Unparser::VisitOptSelect(const Expr::Call& expr) { + if (expr.args_size() != 2 || !expr.args()[1].has_const_expr() || + !expr.args()[1].const_expr().has_string_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected select: ", expr.ShortDebugString())); + } + const auto& operand = expr.args()[0]; + bool nested = IsBinaryOrTernaryOperator(operand); + CEL_RETURN_IF_ERROR(VisitMaybeNested(operand, nested)); + Print(kDot, kQuestionMark, + FormatField(expr.args()[1].const_expr().string_value())); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitCall(const Expr::Call& expr) { + const auto& fun = expr.function(); + absl::optional op = LookupUnaryOperator(fun); + if (op.has_value()) { + return VisitUnary(expr, *op); + } + + op = LookupBinaryOperator(fun); + if (op.has_value()) { + return VisitBinary(expr, *op); + } + + if (fun == CelOperator::INDEX) { + return VisitIndex(expr); + } + + if (fun == CelOperator::OPT_INDEX) { + return VisitOptIndex(expr); + } + + if (fun == CelOperator::OPT_SELECT) { + return VisitOptSelect(expr); + } + + if (fun == CelOperator::CONDITIONAL) { + return VisitTernary(expr); + } + + if (expr.has_target()) { + bool nested = IsBinaryOrTernaryOperator(expr.target()); + CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.target(), nested)); + Print(kDot); + } + Print(fun, kLeftParen); + for (int i = 0; i < expr.args_size(); i++) { + if (i > 0) { + Print(kComma, kSpace); + } + CEL_RETURN_IF_ERROR(Visit(expr.args(i))); + } + Print(kRightParen); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitCreateList(const Expr::CreateList& expr) { + Print(kLeftBracket); + for (int i = 0; i < expr.elements_size(); i++) { + if (i > 0) { + Print(kComma, kSpace); + } + if (std::find(expr.optional_indices().begin(), + expr.optional_indices().end(), + static_cast(i)) != expr.optional_indices().end()) { + Print(kQuestionMark); + } + CEL_RETURN_IF_ERROR(Visit(expr.elements(i))); + } + Print(kRightBracket); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitCreateStruct(const Expr::CreateStruct& expr) { + if (!expr.message_name().empty()) { + Print(expr.message_name()); + } + Print(kLeftBrace); + for (int i = 0; i < expr.entries_size(); i++) { + if (i > 0) { + Print(kComma, kSpace); + } + + const auto& e = expr.entries(i); + if (e.optional_entry()) { + Print(kQuestionMark); + } + switch (e.key_kind_case()) { + case Expr::CreateStruct::Entry::kFieldKey: + Print(FormatField(e.field_key())); + break; + case Expr::CreateStruct::Entry::kMapKey: + CEL_RETURN_IF_ERROR(Visit(e.map_key())); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unexpected struct: ", expr.ShortDebugString())); + } + Print(kColon, kSpace); + CEL_RETURN_IF_ERROR(Visit(e.value())); + } + Print(kRightBrace); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitComprehension(const Expr::Comprehension& expr) { + bool nested = IsComplexOperator(expr.iter_range()); + CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.iter_range(), nested)); + Print(kDot); + + if (expr.loop_step().call_expr().function() == CelOperator::LOGICAL_AND) { + return VisitAllMacro(expr); + } + + if (expr.loop_step().call_expr().function() == CelOperator::LOGICAL_OR) { + return VisitExistsMacro(expr); + } + + if (expr.result().expr_kind_case() == Expr::kCallExpr) { + return VisitExistsOneMacro(expr); + } + + return VisitMapMacro(expr); +} + +absl::Status Unparser::VisitAllMacro(const Expr::Comprehension& expr) { + if (expr.loop_step().call_expr().args_size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected all macro: ", expr.ShortDebugString())); + } + + Print(CelOperator::ALL, kLeftParen, expr.iter_var(), kComma, kSpace); + CEL_RETURN_IF_ERROR(Visit(expr.loop_step().call_expr().args(1))); + Print(kRightParen); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitExistsMacro(const Expr::Comprehension& expr) { + if (expr.loop_step().call_expr().args_size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected exists macro: ", expr.ShortDebugString())); + } + + Print(CelOperator::EXISTS, kLeftParen, expr.iter_var(), kComma, kSpace); + CEL_RETURN_IF_ERROR(Visit(expr.loop_step().call_expr().args(1))); + Print(kRightParen); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitExistsOneMacro(const Expr::Comprehension& expr) { + if (expr.loop_step().call_expr().args_size() != 3) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected exists one macro: ", expr.ShortDebugString())); + } + + Print(CelOperator::EXISTS_ONE, kLeftParen, expr.iter_var(), kComma, kSpace); + CEL_RETURN_IF_ERROR(Visit(expr.loop_step().call_expr().args(0))); + Print(kRightParen); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitMapMacro(const Expr::Comprehension& expr) { + Print(CelOperator::MAP, kLeftParen, expr.iter_var(), kComma, kSpace); + Expr step = expr.loop_step(); + if (step.call_expr().function() == CelOperator::CONDITIONAL) { + if (step.call_expr().args_size() != 3) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected exists map macro filter step: ", + expr.ShortDebugString())); + } + + CEL_RETURN_IF_ERROR(Visit(step.call_expr().args(0))); + Print(kComma, kSpace); + + auto temp = step.call_expr().args(1); + step = temp; + } + + if (step.call_expr().args_size() != 2 || + step.call_expr().args(1).list_expr().elements_size() != 1) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected exists map macro: ", expr.ShortDebugString())); + } + + CEL_RETURN_IF_ERROR(Visit(step.call_expr().args(1).list_expr().elements(0))); + Print(kRightParen); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitUnary(const Expr::Call& expr, + const std::string& op) { + if (expr.args_size() != 1) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected unary: ", expr.ShortDebugString())); + } + Print(op); + bool nested = IsComplexOperator(expr.args(0)); + return VisitMaybeNested(expr.args(0), nested); +} + +absl::Status Unparser::VisitBinary(const Expr::Call& expr, + const std::string& op) { + if (expr.args_size() < 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); + } + + const auto& fun = expr.function(); + if (IsLogicalOperator(fun)) { + for (int i = 0; i < expr.args_size(); ++i) { + if (i > 0) { + Print(kSpace, op, kSpace); + } + const auto& arg = expr.args(i); + bool arg_paren = IsComplexOperatorWithRespectTo(arg, fun); + CEL_RETURN_IF_ERROR(VisitMaybeNested(arg, arg_paren)); + } + return absl::OkStatus(); + } + + if (expr.args_size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); + } + + const auto& lhs = expr.args(0); + const auto& rhs = expr.args(1); + + // add parens if the current operator is lower precedence than the lhs expr + // operator. + bool lhs_paren = IsComplexOperatorWithRespectTo(lhs, fun); + // add parens if the current operator is lower precedence than the rhs expr + // operator, or the same precedence and the operator is left recursive. + bool rhs_paren = IsComplexOperatorWithRespectTo(rhs, fun); + if (!rhs_paren && IsOperatorLeftRecursive(fun)) { + rhs_paren = IsOperatorSamePrecedence(fun, rhs); + } + + CEL_RETURN_IF_ERROR(VisitMaybeNested(lhs, lhs_paren)); + Print(kSpace, op, kSpace); + return VisitMaybeNested(rhs, rhs_paren); +} + +absl::Status Unparser::VisitMaybeNested(const Expr& expr, bool nested) { + if (nested) { + Print(kLeftParen); + } + CEL_RETURN_IF_ERROR(Visit(expr)); + if (nested) { + Print(kRightParen); + } + return absl::OkStatus(); +} + +absl::Status Unparser::VisitIndex(const Expr::Call& expr) { + if (expr.args_size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected index call: ", expr.ShortDebugString())); + } + bool nested = IsBinaryOrTernaryOperator(expr.args(0)); + CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(0), nested)); + Print(kLeftBracket); + CEL_RETURN_IF_ERROR(Visit(expr.args(1))); + Print(kRightBracket); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitOptIndex(const Expr::Call& expr) { + if (expr.args_size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected index call: ", expr.ShortDebugString())); + } + bool nested = IsBinaryOrTernaryOperator(expr.args(0)); + CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(0), nested)); + Print(kLeftBracket); + Print(kQuestionMark); + CEL_RETURN_IF_ERROR(Visit(expr.args(1))); + Print(kRightBracket); + return absl::OkStatus(); +} + +absl::Status Unparser::VisitTernary(const Expr::Call& expr) { + if (expr.args_size() != 3) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected ternary: ", expr.ShortDebugString())); + } + + bool nested = + IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr.args(0)) || + IsComplexOperator(expr.args(0)); + CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(0), nested)); + + Print(kSpace, kQuestionMark, kSpace); + + nested = IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr.args(1)) || + IsComplexOperator(expr.args(1)); + CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(1), nested)); + + Print(kSpace, kColon, kSpace); + + nested = IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr.args(2)) || + IsComplexOperator(expr.args(2)); + return VisitMaybeNested(expr.args(2), nested); +} + +bool Unparser::IsComplexOperatorWithRespectTo(const Expr& expr, + const std::string& op) { + // If the arg is not a call with more than one arg, return false. + if (!expr.has_call_expr() || expr.call_expr().args_size() < 2) { + return false; + } + // Otherwise, return whether the given op has lower precedence than expr + return IsOperatorLowerPrecedence(op, expr); +} + +bool Unparser::IsComplexOperator(const Expr& expr) { + // If the arg is a call with more than one arg, return true + return expr.has_call_expr() && expr.call_expr().args_size() >= 2; +} + +// Returns true the given expression is +// - a call expression AND ONE of the following holds: +// - a binary operator +// - a ternary conditional operator +bool Unparser::IsBinaryOrTernaryOperator(const Expr& expr) { + if (!IsComplexOperator(expr)) { + return false; + } + return LookupBinaryOperator(expr.call_expr().function()).has_value() || + IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr); +} + +bool Unparser::IsLogicalOperator(absl::string_view op) { + return op == CelOperator::LOGICAL_AND || op == CelOperator::LOGICAL_OR; +} + +} // namespace + +absl::StatusOr Unparse(const Expr& expr, + const SourceInfo* source_info) { + const SourceInfo& info = + source_info == nullptr ? SourceInfo::default_instance() : *source_info; + return Unparser::Unparse(expr, info); +} + +absl::StatusOr Unparse(const ParsedExpr& parsed_expr) { + return Unparse(parsed_expr.expr(), &parsed_expr.source_info()); +} + +absl::StatusOr Unparse(const CheckedExpr& checked_expr) { + return Unparse(checked_expr.expr(), &checked_expr.source_info()); +} + +} // namespace google::api::expr diff --git a/tools/cel_unparser.h b/tools/cel_unparser.h new file mode 100644 index 000000000..754b1013c --- /dev/null +++ b/tools/cel_unparser.h @@ -0,0 +1,60 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Provides an unparsing utility that converts an AST back into +// a human readable format. +// +// Input to the unparser is the proto AST (Expr, CheckedExpr, or ParsedExpr). +// The unparser does not do any checks to see if the ParsedExpr is syntactically +// or semantically correct but does checks enough to prevent its crash and might +// return errors in such cases. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_UNPARSER_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_UNPARSER_H_ + +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" + +namespace google::api::expr { + +// Unparses the given expression into a human readable cel expression. +ABSL_DEPRECATED( + "Use Unparse(ParsedExpr) to ensure proper unparsing of all CEL " + "expressions. Note, ParserOptions.add_macro_calls must be set to true " + "for full fidelity unparsing.") +absl::StatusOr Unparse( + const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info = nullptr); + +// Unparses the ParsedExpr value to a human-readable string. +// +// For the best results ensure that the expression is parsed with +// ParserOptions.add_macro_calls = true. +absl::StatusOr Unparse( + const cel::expr::ParsedExpr& parsed_expr); + +// Unparses the CheckedExpr value to a human-readable string. +// +// For the best results ensure that the expression is parsed with +// ParserOptions.add_macro_calls = true. +absl::StatusOr Unparse( + const cel::expr::CheckedExpr& checked_expr); + +} // namespace google::api::expr + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_UNPARSER_H_ diff --git a/tools/cel_unparser_test.cc b/tools/cel_unparser_test.cc new file mode 100644 index 000000000..aca6e91fd --- /dev/null +++ b/tools/cel_unparser_test.cc @@ -0,0 +1,804 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/cel_unparser.h" + +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "google/protobuf/text_format.h" + +namespace google::api::expr { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::test::EqualsProto; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +struct UnparserTestCaseTextProto { + std::string proto_text; + absl::StatusOr expr; +}; + +class UnparserTestTextProto + : public testing::TestWithParam {}; + +TEST_P(UnparserTestTextProto, Test) { + auto test_case = GetParam(); + Expr expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(test_case.proto_text, &expr)); + absl::StatusOr result = Unparse(expr); + if (result.ok()) { + ASSERT_OK(test_case.expr); + ASSERT_EQ(*(test_case.expr), *result); + } else { + ASSERT_THAT(result.status(), + StatusIs(test_case.expr.status().code(), + HasSubstr(test_case.expr.status().message()))); + } +} + +// these tests make explicit assumptions about specific proto structures +// that are to be observed +INSTANTIATE_TEST_SUITE_P( + UnparseCompProto, UnparserTestTextProto, + ValuesIn( + {// Empty Expr error + {"", absl::InvalidArgumentError("Unsupported Expr")}, + + // Logical operators with too few arguments (single argument) + { + R"pb( + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + { + R"pb( + call_expr { + function: "_||_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + + // Constants + {"const_expr{}", absl::InvalidArgumentError("Unsupported Constant")}, + {"const_expr{bool_value: true}", "true"}, + {"const_expr{int64_value: 4}", "4"}, + {"const_expr{uint64_value: 4}", "4u"}, + + // Sequences + { + R"pb( + struct_expr { + entries { value { const_expr { uint64_value: 2 } } } + })pb", + absl::InvalidArgumentError("Unexpected struct")}, + {R"pb( + list_expr { + elements { const_expr { int64_value: 1 } } + elements { const_expr { uint64_value: 2 } } + } + )pb", + "[1, 2u]"}, + {R"pb( + struct_expr { + entries { + map_key { const_expr { int64_value: 1 } } + value { const_expr { uint64_value: 2 } } + } + entries { + map_key { const_expr { int64_value: 2 } } + value { const_expr { uint64_value: 3 } } + } + })pb", + "{1: 2u, 2: 3u}"}, + + // Messages + {R"pb( + struct_expr { + message_name: 'TestAllTypes' + entries { + field_key: 'single_int32' + value { const_expr { int64_value: 1 } } + } + entries { + field_key: 'single_int64' + value { const_expr { int64_value: 2 } } + } + } + )pb", + "TestAllTypes{single_int32: 1, single_int64: 2}"}, + + // Conditionals + {R"pb( + call_expr { function: '!_' } + )pb", + absl::InvalidArgumentError("Unexpected unary")}, + {R"pb( + call_expr { function: '_||_' } + )pb", + absl::InvalidArgumentError("Unexpected binary")}, + {R"pb( + call_expr { function: '_[_]' } + )pb", + absl::InvalidArgumentError("Unexpected index")}, + {R"pb( + call_expr { function: '_?_:_' } + )pb", + absl::InvalidArgumentError("Unexpected ternary")}, + {R"pb( + call_expr { + function: '_||_' + args { + call_expr { + function: '_&&_' + args { const_expr { bool_value: false } } + args { + call_expr { + function: '!_' + args { const_expr { bool_value: true } } + } + } + } + } + args { const_expr { bool_value: false } } + })pb", + "false && !true || false"}, + {R"pb( + call_expr { + function: '_&&_' + args { const_expr { bool_value: false } } + args { + call_expr { + function: '_||_' + args { + call_expr { + function: '!_' + args { const_expr { bool_value: true } } + } + } + args { const_expr { bool_value: false } } + } + } + })pb", + "false && (!true || false)"}, + {R"pb( + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_||_' + args { + call_expr { + function: '_&&_' + args { const_expr { bool_value: false } } + args { + call_expr { + function: "!_" + args { const_expr { bool_value: true } } + } + } + } + } + args { const_expr { bool_value: false } } + } + } + args { const_expr { int64_value: 2 } } + args { const_expr { int64_value: 3 } } + })pb", + "(false && !true || false) ? 2 : 3"}, + {R"pb( + call_expr { + function: '!_' + args { + call_expr { + function: '!_' + args { const_expr { bool_value: true } } + } + } + })pb", + "!!true"}, + {R"pb( + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_<_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 5 } } + } + } + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 5 } } + })pb", + "(x < 5) ? x : 5"}, + {R"pb( + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_>_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 5 } } + } + } + args { + call_expr { + function: '_-_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 5 } } + } + } + args { const_expr { int64_value: 0 } } + })pb", + "(x > 5) ? (x - 5) : 0"}, + {R"pb( + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_>_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 5 } } + } + } + args { + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_>_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 10 } } + } + } + args { + call_expr { + function: '_-_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 10 } } + } + } + args { const_expr { int64_value: 5 } } + } + } + args { const_expr { int64_value: 0 } } + })pb", + "(x > 5) ? ((x > 10) ? (x - 10) : 5) : 0"}, + {R"pb( + call_expr { + function: '_in_' + args { ident_expr { name: 'a' } } + args { ident_expr { name: 'b' } } + })pb", + "a in b"}, + + // Calculations + {R"pb( + call_expr { + function: '_*_' + args { + call_expr { + function: '_+_' + args { const_expr { int64_value: 1 } } + args { const_expr { int64_value: 2 } } + } + } + args { const_expr { int64_value: 3 } } + })pb", + "(1 + 2) * 3"}, + {R"pb( + call_expr { + function: '_+_' + args { const_expr { int64_value: 1 } } + args { + call_expr { + function: '_*_' + args { const_expr { int64_value: 2 } } + args { const_expr { int64_value: 3 } } + } + } + })pb", + "1 + 2 * 3"}, + {R"pb( + call_expr { + function: '-_' + args { + call_expr { + function: '_*_' + args { const_expr { int64_value: 1 } } + args { const_expr { int64_value: 2 } } + } + } + })pb", + "-(1 * 2)"}, + + // Comprehensions + {R"pb( + comprehension_expr { + iter_var: 'x' + iter_range { + list_expr { + elements { const_expr { int64_value: 1 } } + elements { const_expr { int64_value: 2 } } + elements { const_expr { int64_value: 3 } } + } + } + accu_var: 'accu' + accu_init { const_expr { bool_value: true } } + loop_condition { ident_expr { name: 'accu' } } + loop_step { + call_expr { + function: '_&&_' + args { ident_expr { name: 'x' } } + args { + call_expr { + function: '_>_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 0 } } + } + } + } + } + result { ident_expr { name: 'accu' } } + })pb", + "[1, 2, 3].all(x, x > 0)"}, + {R"pb( + comprehension_expr { + iter_var: 'x' + iter_range { + list_expr { + elements { const_expr { int64_value: 1 } } + elements { const_expr { int64_value: 2 } } + elements { const_expr { int64_value: 3 } } + } + } + accu_var: 'accu' + accu_init { const_expr { bool_value: false } } + loop_condition { + call_expr { + function: '!_' + args { ident_expr { name: 'accu' } } + } + } + loop_step { + call_expr { + function: '_||_' + args { ident_expr { name: 'x' } } + args { + call_expr { + function: '_>_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 0 } } + } + } + } + } + result { ident_expr { name: 'accu' } } + })pb", + "[1, 2, 3].exists(x, x > 0)"}, + {R"pb( + comprehension_expr { + iter_var: 'x' + iter_range { + list_expr { + elements { const_expr { int64_value: 1 } } + elements { const_expr { int64_value: 2 } } + elements { const_expr { int64_value: 3 } } + } + } + accu_var: 'accu' + accu_init { list_expr {} } + loop_condition { const_expr { bool_value: false } } + loop_step { + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_>=_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 2 } } + } + } + args { + call_expr { + function: '_+_' + args { ident_expr { name: 'accu' } } + args { + list_expr { + elements { + call_expr { + function: '_*_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 4 } } + } + } + } + } + } + } + args { ident_expr { name: 'accu' } } + } + } + result { ident_expr { name: 'accu' } } + })pb", + "[1, 2, 3].map(x, x >= 2, x * 4)"}, + {R"pb( + comprehension_expr { + iter_var: 'x' + iter_range { + list_expr { + elements { const_expr { int64_value: 1 } } + elements { const_expr { int64_value: 2 } } + elements { const_expr { int64_value: 3 } } + } + } + accu_var: 'accu' + accu_init { const_expr { int64_value: 0 } } + loop_condition { + call_expr { + function: '_<=_' + args { ident_expr { name: 'accu' } } + args { const_expr { int64_value: 1 } } + } + } + loop_step { + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_>=_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 2 } } + } + } + args { + call_expr { + function: '_+_' + args { ident_expr { name: 'accu' } } + args { const_expr { int64_value: 1 } } + } + } + args { ident_expr { name: 'accu' } } + } + } + result { + call_expr { + function: '_==_' + args { ident_expr { name: 'accu' } } + args { const_expr { int64_value: 1 } } + } + } + })pb", + "[1, 2, 3].exists_one(x, x >= 2)"}, + {R"pb( + select_expr { + operand { + call_expr { + function: '_[_]' + args { ident_expr { name: 'x' } } + args { const_expr { string_value: 'a' } } + } + } + field: 'single_int32' + test_only: true + })pb", + "has(x[\"a\"].single_int32)"}, + + // This is a filter expression but is decompiled back to + // map(x, filter_function, x) for which the evaluation is + // equal to filter(x, filter_function). + {R"pb( + comprehension_expr { + iter_var: 'x' + iter_range { + list_expr { + elements { const_expr { int64_value: 1 } } + elements { const_expr { int64_value: 2 } } + elements { const_expr { int64_value: 3 } } + } + } + accu_var: 'accu' + accu_init { list_expr {} } + loop_condition { const_expr { bool_value: false } } + loop_step { + call_expr { + function: '_?_:_' + args { + call_expr { + function: '_>=_' + args { ident_expr { name: 'x' } } + args { const_expr { int64_value: 2 } } + } + } + args { + call_expr { + function: '_+_' + args { ident_expr { name: 'accu' } } + args { + list_expr { elements { ident_expr { name: 'x' } } } + } + } + } + args { ident_expr { name: 'accu' } } + } + } + result { ident_expr { name: 'accu' } } + })pb", + "[1, 2, 3].map(x, x >= 2, x)"}, + + // Index + {R"pb( + call_expr { + function: '_==_' + args { + select_expr { + operand { + call_expr { + function: '_[_]' + args { ident_expr { name: 'x' } } + args { const_expr { string_value: 'a' } } + } + } + field: 'single_int32' + } + } + args { const_expr { int64_value: 23 } } + })pb", + "x[\"a\"].single_int32 == 23"}, + {R"pb( + call_expr { + function: '_[_]' + args { + call_expr { + function: '_[_]' + args { ident_expr { name: 'a' } } + args { const_expr { int64_value: 1 } } + } + } + args { const_expr { string_value: 'b' } } + })pb", + "a[1][\"b\"]"}, + + // Functions + {R"pb( + call_expr { + function: '_!=_' + args { ident_expr { name: 'x' } } + args { const_expr { string_value: 'a' } } + })pb", + "x != \"a\""}, + {R"pb( + call_expr { + function: '_==_' + args { + call_expr { + function: 'size' + args { ident_expr { name: 'x' } } + } + } + args { + call_expr { + target { ident_expr { name: 'x' } } + function: 'size' + } + } + })pb", + "size(x) == x.size()"}, + + // Long string + {R"pb( + list_expr { + elements { + const_expr { + string_value: 'Loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong' + } + } + })pb", + R"(["Loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong"])"}})); + +struct UnparserTestCaseTextExpr { + std::string expr; + std::string equiv_expected; +}; + +class UnparserTestTextExpr + : public testing::TestWithParam {}; + +TEST_P(UnparserTestTextExpr, Test) { + Expr expr; + + parser::ParserOptions options; + options.add_macro_calls = true; + options.enable_optional_syntax = true; + options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = true; + + ASSERT_OK_AND_ASSIGN(ParsedExpr result, + Parse(GetParam().expr, "unparser", options)); + + ASSERT_OK_AND_ASSIGN(std::string result_expr, Unparse(result)); + + if (!GetParam().equiv_expected.empty()) { + ASSERT_EQ(GetParam().equiv_expected, result_expr); + } else { + ASSERT_EQ(GetParam().expr, result_expr); + } + + if (GetParam().equiv_expected.empty()) { + // parse again, confirm it's the same result + ASSERT_OK_AND_ASSIGN(ParsedExpr result2, + Parse(result_expr, "unparser", options)); + EXPECT_THAT(result, EqualsProto(result2)); + } else { + // We cannot compare the original parsed proto and the equivalent expected + // proto, since the IDs will most likely be different, e.g., due to + // rebalancing logical expressions. + } +} + +// These test cases check that Unparse(Parse(expr)) is idempotent +// (if there is one string in an entry), or equivalent to some other +// form (if there are two strings in an entry). The latter can occur +// especially due to spacing in the expression, or if the logical +// expression balancer modifies an expression. +INSTANTIATE_TEST_SUITE_P( + UnparseCompExpr, UnparserTestTextExpr, + ValuesIn({ + {"a + b - c", ""}, + {"a && b && c && d && e", ""}, + {"a || b && (c || d) && e", ""}, + {"a ? b : c", ""}, + {"a[1][\"b\"]", ""}, + {"x[\"a\"].single_int32 == 23", ""}, + {"a * (b / c) % 0", ""}, + {"a + b * c", ""}, + {"(a + b) * c / (d - e)", ""}, + {"a * b / c % 0", ""}, + {"!true", ""}, + {"-num", ""}, + {"a || b || c || d || e", ""}, + {"-(1 * 2)", ""}, + {"-(1 + 2)", ""}, + {"(x > 5) ? (x - 5) : 0", ""}, + {"size(a ? (b ? c : d) : e)", ""}, + {"a.hello(\"world\")", ""}, + {"zero()", ""}, + {"one(\"a\")", ""}, + {"and(d, 32u)", ""}, + {"max(a, b, 100)", ""}, + {"x != \"a\"", ""}, + {"[]", ""}, + {"[1]", ""}, + {"[\"hello, world\", \"goodbye, world\", \"sure, why not?\"]", ""}, + {"b\"ÿ\"", "b\"\\xc3\\x83\\xc2\\xbf\""}, + {"b'aaa\"bbb'", "b\"aaa\\\"bbb\""}, + {"-42.101", ""}, + {"false", ""}, + {"-405069", ""}, + {"null", ""}, + {"\"hello:\\t'world'\"", ""}, + {"true", ""}, + {"42u", ""}, + {"my_ident", ""}, + {"has(hello.world)", ""}, + {"{}", ""}, + {"{\"a\": a.b.c, b\"b\": bytes(a.b.c)}", ""}, + {"{a: a, b: a.b, c: a.b.c, a ? b : c: false, a || b: true}", ""}, + {"v1alpha1.Expr{}", ""}, + {"v1alpha1.Expr{id: 1, call_expr: v1alpha1.Call_Expr{function: " + "\"name\"}}", + ""}, + {"a.b.c", ""}, + {"a[b][c].name", ""}, + {"(a + b).name", ""}, + {"(a ? b : c).name", ""}, + {"(a ? b : c)[0]", ""}, + {"(a1 && a2) ? b : c", ""}, + {"a ? (b1 || b2) : (c1 && c2)", ""}, + {"(a ? b : c).method(d)", ""}, + + // the following give the expected equivalent representation that + // is to be observed when parsing and decompiling again, note the + // differences in spacing and simplification of logical expressions + {"a+b-c", "a + b - c"}, + {"a ? b : c", "a ? b : c"}, + {"a[ 1 ][\"b\"]", "a[1][\"b\"]"}, + {"(false && !true) || false", "false && !true || false"}, + {"a . b . c", "a.b.c"}, + // here we expect the expression balancer to remove the double negation + {"!!true", "true"}, + + // From protos above + // Constants + {"true", ""}, + {"4", ""}, + {"4u", ""}, + + // Sequences + {"[1, 2u]", ""}, + {"{1: 2u, 2: 3u}", ""}, + + // Messages + {"TestAllTypes{single_int32: 1, single_int64: 2}", ""}, + + // Conditionals + {"false && !true || false", ""}, + {"false && (!true || false)", ""}, + {"(false && !true || false) ? 2 : 3", ""}, + {"(x < 5) ? x : 5", ""}, + {"(x > 5) ? (x - 5) : 0", ""}, + {"(x > 5) ? ((x > 10) ? (x - 10) : 5) : 0", ""}, + {"a in b", ""}, + + // Calculations + {"(1 + 2) * 3", ""}, + {"1 + 2 * 3", ""}, + {"-(1 * 2)", ""}, + + // Comprehensions + {"[1, 2, 3].all(x, x > 0)", ""}, + {"[1, 2, 3].exists(x, x > 0)", ""}, + {"[1, 2, 3].map(x, x >= 2, x * 4)", ""}, + {"[1, 2, 3].exists_one(x, x >= 2)", ""}, + {"[[1], [2], [3]].all(x, x.all(y, y >= 2))", ""}, + {"(has(x.y) ? x.y : []).filter(z, z == \"zed\")", ""}, + + // Macros + {"has(x[\"a\"].single_int32)", ""}, + + // This is a filter expression but is decompiled back to + // map(x, filter_function, x) for which the evaluation is + // equal to filter(x, filter_function). + {"[1, 2, 3].map(x, x >= 2, x)", ""}, + + // Index + {"x[\"a\"].single_int32 == 23", ""}, + {"a[1][\"b\"]", ""}, + + // Functions + {"x != \"a\"", ""}, + {"size(x) == x.size()", ""}, + + // Long string + {R"(["Loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong"])", + ""}, + {"a.?b[?0] && a[?c]", ""}, + {"{?\"key\": value}", ""}, + {"[?a, ?b]", ""}, + {"[?a[?b]]", ""}, + {"Msg{?field: value}", ""}, + {"Msg{`in`: value}", ""}, + {"Msg{?`b.c`: value}", ""}, + {"has(a.`b.c`)", ""}, + {"a.`b/c`", ""}, + {"a.?`b/c`", ""}, + {"a && b && c && d", ""}, + {"a || b || c || d", ""}, + })); + +} // namespace +} // namespace google::api::expr diff --git a/tools/descriptor_pool_builder.cc b/tools/descriptor_pool_builder.cc new file mode 100644 index 000000000..390363435 --- /dev/null +++ b/tools/descriptor_pool_builder.cc @@ -0,0 +1,111 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/descriptor_pool_builder.h" + +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "common/minimal_descriptor_database.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +namespace { + +absl::Status FindDeps( + std::vector& to_resolve, + absl::flat_hash_set& resolved, + DescriptorPoolBuilder& builder) { + while (!to_resolve.empty()) { + const auto* file = to_resolve.back(); + to_resolve.pop_back(); + if (resolved.contains(file)) { + continue; + } + google::protobuf::FileDescriptorProto file_proto; + file->CopyTo(&file_proto); + // Note: order doesn't matter here as long as all the cross references are + // correct in the final database. + CEL_RETURN_IF_ERROR(builder.AddFileDescriptor(file_proto)); + resolved.insert(file); + for (int i = 0; i < file->dependency_count(); ++i) { + to_resolve.push_back(file->dependency(i)); + } + } + return absl::OkStatus(); +} + +} // namespace + +DescriptorPoolBuilder::StateHolder::StateHolder( + google::protobuf::DescriptorDatabase* base) + : base(base), merged(base, &extensions), pool(&merged) {} + +DescriptorPoolBuilder::DescriptorPoolBuilder() + : state_(std::make_shared( + cel::GetMinimalDescriptorDatabase())) {} + +std::shared_ptr +DescriptorPoolBuilder::Build() && { + auto alias = + std::shared_ptr(state_, &state_->pool); + state_.reset(); + return alias; +} + +absl::Status DescriptorPoolBuilder::AddTransitiveDescriptorSet( + const google::protobuf::Descriptor* absl_nonnull desc) { + absl::flat_hash_set resolved; + std::vector to_resolve{desc->file()}; + return FindDeps(to_resolve, resolved, *this); +} + +absl::Status DescriptorPoolBuilder::AddTransitiveDescriptorSet( + absl::Span descs) { + absl::flat_hash_set resolved; + std::vector to_resolve; + to_resolve.reserve(descs.size()); + for (const google::protobuf::Descriptor* desc : descs) { + to_resolve.push_back(desc->file()); + } + + return FindDeps(to_resolve, resolved, *this); +} + +absl::Status DescriptorPoolBuilder::AddFileDescriptor( + const google::protobuf::FileDescriptorProto& file) { + if (!state_->extensions.Add(file)) { + return absl::InvalidArgumentError( + absl::StrCat("proto descriptor conflict: ", file.name())); + } + return absl::OkStatus(); +} + +absl::Status DescriptorPoolBuilder::AddFileDescriptorSet( + const google::protobuf::FileDescriptorSet& file) { + for (const auto& file : file.file()) { + CEL_RETURN_IF_ERROR(AddFileDescriptor(file)); + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/tools/descriptor_pool_builder.h b/tools/descriptor_pool_builder.h new file mode 100644 index 000000000..3a57ec2fd --- /dev/null +++ b/tools/descriptor_pool_builder.h @@ -0,0 +1,93 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ + +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" + +namespace cel { + +// A helper class for building a descriptor pool from a set proto file +// descriptors. Manages lifetime for the descriptor databases backing +// the pool. +// +// Client must ensure that types are not added multiple times. +// +// Note: in the constructed pool, the definitions for the required types for +// CEL will shadow any added to the builder. Clients should not modify types +// from the google.protobuf package in general, but if they do the behavior of +// the constructed descriptor pool will be inconsistent. +class DescriptorPoolBuilder { + public: + DescriptorPoolBuilder(); + + DescriptorPoolBuilder& operator=(const DescriptorPoolBuilder&) = delete; + DescriptorPoolBuilder(const DescriptorPoolBuilder&) = delete; + DescriptorPoolBuilder& operator=(const DescriptorPoolBuilder&&) = delete; + DescriptorPoolBuilder(DescriptorPoolBuilder&&) = delete; + + ~DescriptorPoolBuilder() = default; + + // Returns a shared pointer to the new descriptor pool that manages the + // underlying descriptor databases backing the pool. + // + // Consumes the builder instance. It is unsafe to make any further changes + // to the descriptor databases after accessing the pool. + std::shared_ptr Build() &&; + + // Utility for adding the transitive dependencies of a message with a linked + // descriptor. + absl::Status AddTransitiveDescriptorSet( + const google::protobuf::Descriptor* absl_nonnull desc); + + absl::Status AddTransitiveDescriptorSet( + absl::Span); + + // Adds a file descriptor set to the pool. Client must ensure that all + // dependencies are satisfied and that files are not added multiple times. + absl::Status AddFileDescriptorSet(const google::protobuf::FileDescriptorSet& files); + + // Adds a single proto file descriptor set to the pool. Client must ensure + // that all dependencies are satisfied and that files are not added multiple + // times. + absl::Status AddFileDescriptor(const google::protobuf::FileDescriptorProto& file); + + private: + struct StateHolder { + explicit StateHolder(google::protobuf::DescriptorDatabase* base); + + google::protobuf::DescriptorDatabase* base; + google::protobuf::SimpleDescriptorDatabase extensions; + google::protobuf::MergedDescriptorDatabase merged; + google::protobuf::DescriptorPool pool; + }; + + explicit DescriptorPoolBuilder(std::shared_ptr state) + : state_(std::move(state)) {} + + std::shared_ptr state_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ diff --git a/tools/descriptor_pool_builder_test.cc b/tools/descriptor_pool_builder_test.cc new file mode 100644 index 000000000..82fa8f699 --- /dev/null +++ b/tools/descriptor_pool_builder_test.cc @@ -0,0 +1,177 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/descriptor_pool_builder.h" + +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" +#include "google/protobuf/text_format.h" + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsNull; +using ::testing::NotNull; + +namespace cel { +namespace { + +TEST(DescriptorPoolBuilderTest, IncludesDefaults) { + DescriptorPoolBuilder builder; + + auto pool = std::move(builder).Build(); + EXPECT_THAT( + pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), + IsNull()); + + EXPECT_THAT(pool->FindMessageTypeByName("google.protobuf.Timestamp"), + NotNull()); + EXPECT_THAT(pool->FindMessageTypeByName("google.protobuf.Any"), NotNull()); +} + +TEST(DescriptorPoolBuilderTest, AddTransitiveDescriptorSet) { + DescriptorPoolBuilder builder; + ASSERT_THAT(builder.AddTransitiveDescriptorSet( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + descriptor()), + IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT( + pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), + NotNull()); +} + +TEST(DescriptorPoolBuilderTest, AddTransitiveDescriptorSetSpan) { + DescriptorPoolBuilder builder; + const google::protobuf::Descriptor* descs[] = { + cel::expr::conformance::proto2::TestAllTypes::descriptor(), + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + descriptor()}; + ASSERT_THAT(builder.AddTransitiveDescriptorSet(descs), IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT( + pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), + NotNull()); +} + +TEST(DescriptorPoolBuilderTest, AddFileDescriptorSet) { + DescriptorPoolBuilder builder; + google::protobuf::FileDescriptorSet file_set; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "foo.proto" + package: "cel.test" + dependency: "bar.proto" + message_type { + name: "Foo" + field: { + name: "bar" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".cel.test.Bar" + } + } + )pb", + file_set.add_file())); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "bar.proto" + package: "cel.test" + message_type { + name: "Bar" + field: { + name: "baz" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + )pb", + file_set.add_file())); + ASSERT_THAT(builder.AddFileDescriptorSet(file_set), IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Foo"), NotNull()); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Bar"), NotNull()); +} + +TEST(DescriptorPoolBuilderTest, BadRef) { + DescriptorPoolBuilder builder; + google::protobuf::FileDescriptorSet file_set; + // Unfulfilled dependency. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "foo.proto" + package: "cel.test" + dependency: "bar.proto" + message_type { + name: "Foo" + field: { + name: "bar" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".cel.test.Bar" + } + } + )pb", + file_set.add_file())); + // Note: descriptor pool is initialized lazily so this will not lead to an + // error now, but looking up the message will fail. + ASSERT_THAT(builder.AddFileDescriptorSet(file_set), IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Foo"), IsNull()); +} + +TEST(DescriptorPoolBuilderTest, AddFile) { + DescriptorPoolBuilder builder; + google::protobuf::FileDescriptorProto file; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "bar.proto" + package: "cel.test" + message_type { + name: "Bar" + field: { + name: "baz" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + )pb", + &file)); + + ASSERT_THAT(builder.AddFileDescriptor(file), IsOk()); + // Duplicate file. + ASSERT_THAT(builder.AddFileDescriptor(file), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // In this specific case, we know that the duplicate is the same so + // the pool will still be valid. + auto pool = std::move(builder).Build(); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Bar"), NotNull()); +} + +} // namespace +} // namespace cel diff --git a/tools/flatbuffers_backed_impl.cc b/tools/flatbuffers_backed_impl.cc index 55f3e4852..2ee226859 100644 --- a/tools/flatbuffers_backed_impl.cc +++ b/tools/flatbuffers_backed_impl.cc @@ -2,6 +2,11 @@ #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" #include "flatbuffers/flatbuffers.h" namespace google { @@ -34,7 +39,7 @@ class FlatBuffersListImpl : public CelList { class StringListImpl : public CelList { public: - StringListImpl( + explicit StringListImpl( const flatbuffers::Vector>* list) : list_(list) {} int size() const override { return list_ ? list_->size() : 0; } @@ -80,12 +85,29 @@ class ObjectStringIndexedMapImpl : public CelMap { schema_(schema), object_(object), index_(index) { - keys_.parent_ = this; + keys_.parent = this; } + int size() const override { return list_ ? list_->size() : 0; } + + absl::StatusOr Has(const CelValue& key) const override { + auto lookup_result = (*this)[key]; + if (!lookup_result.has_value()) { + return false; + } + auto result = *lookup_result; + if (result.IsError()) { + return *(result.ErrorOrDie()); + } + return true; + } + absl::optional operator[](CelValue cel_key) const override { if (!cel_key.IsString()) { - return {}; + return CreateErrorValue( + arena_, absl::InvalidArgumentError( + absl::StrCat("Invalid map key type: '", + CelValue::TypeName(cel_key.type()), "'"))); } const absl::string_view key = cel_key.StringOrDie().value(); const auto it = std::lower_bound( @@ -105,23 +127,24 @@ class ObjectStringIndexedMapImpl : public CelMap { arena_, **it, schema_, object_, arena_)); } } - return {}; + return std::nullopt; } - const CelList* ListKeys() const override { return &keys_; } + + absl::StatusOr ListKeys() const override { return &keys_; } private: struct KeyList : public CelList { - int size() const override { return parent_->size(); } + int size() const override { return parent->size(); } CelValue operator[](int index) const override { - auto value = flatbuffers::GetFieldS(*(parent_->list_->Get(index)), - parent_->index_); + auto value = + flatbuffers::GetFieldS(*(parent->list_->Get(index)), parent->index_); if (value == nullptr) { return CelValue::CreateStringView(absl::string_view()); } return CelValue::CreateStringView( absl::string_view(value->c_str(), value->size())); } - ObjectStringIndexedMapImpl* parent_; + ObjectStringIndexedMapImpl* parent; }; google::protobuf::Arena* arena_; const flatbuffers::Vector>* list_; @@ -143,14 +166,29 @@ const reflection::Field* findStringKeyField(const reflection::Object& object) { } // namespace +absl::StatusOr FlatBuffersMapImpl::Has(const CelValue& key) const { + auto lookup_result = (*this)[key]; + if (!lookup_result.has_value()) { + return false; + } + auto result = *lookup_result; + if (result.IsError()) { + return *(result.ErrorOrDie()); + } + return true; +} + absl::optional FlatBuffersMapImpl::operator[]( CelValue cel_key) const { if (!cel_key.IsString()) { - return {}; + return CreateErrorValue( + arena_, absl::InvalidArgumentError( + absl::StrCat("Invalid map key type: '", + CelValue::TypeName(cel_key.type()), "'"))); } - auto field = keys_.fields_->LookupByKey(cel_key.StringOrDie().value().data()); + auto field = keys_.fields->LookupByKey(cel_key.StringOrDie().value().data()); if (field == nullptr) { - return {}; + return std::nullopt; } switch (field->type()->base_type()) { case reflection::Byte: @@ -285,15 +323,15 @@ absl::optional FlatBuffersMapImpl::operator[]( } default: // Unsupported vector base types - return {}; + return std::nullopt; } break; } default: // Unsupported types: enums, unions, arrays - return {}; + return std::nullopt; } - return {}; + return std::nullopt; } const CelMap* CreateFlatBuffersBackedObject(const uint8_t* flatbuf, diff --git a/tools/flatbuffers_backed_impl.h b/tools/flatbuffers_backed_impl.h index 2fe9b9b02..7051ef5d5 100644 --- a/tools/flatbuffers_backed_impl.h +++ b/tools/flatbuffers_backed_impl.h @@ -15,21 +15,28 @@ class FlatBuffersMapImpl : public CelMap { const reflection::Schema& schema, const reflection::Object& object, google::protobuf::Arena* arena) : arena_(arena), table_(table), schema_(schema) { - keys_.fields_ = object.fields(); + keys_.fields = object.fields(); } - int size() const override { return keys_.fields_->size(); } + + int size() const override { return keys_.fields->size(); } + + absl::StatusOr Has(const CelValue& key) const override; + absl::optional operator[](CelValue cel_key) const override; - const CelList* ListKeys() const override { return &keys_; } + + // Import base class signatures to bypass GCC warning/error. + using CelMap::ListKeys; + absl::StatusOr ListKeys() const override { return &keys_; } private: struct FieldList : public CelList { - int size() const override { return fields_->size(); } + int size() const override { return fields->size(); } CelValue operator[](int index) const override { - auto name = fields_->Get(index)->name(); + auto name = fields->Get(index)->name(); return CelValue::CreateStringView( absl::string_view(name->c_str(), name->size())); } - const flatbuffers::Vector>* fields_; + const flatbuffers::Vector>* fields; }; FieldList keys_; google::protobuf::Arena* arena_; diff --git a/tools/flatbuffers_backed_impl_test.cc b/tools/flatbuffers_backed_impl_test.cc index e2328f332..55589bfd5 100644 --- a/tools/flatbuffers_backed_impl_test.cc +++ b/tools/flatbuffers_backed_impl_test.cc @@ -1,7 +1,9 @@ #include "tools/flatbuffers_backed_impl.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include + +#include "internal/status_macros.h" +#include "internal/testing.h" #include "flatbuffers/idl.h" #include "flatbuffers/reflection.h" @@ -12,10 +14,9 @@ namespace runtime { namespace { -using google::protobuf::Arena; - constexpr char kReflectionBufferPath[] = - "tools/flatbuffers.bfbs"; + "tools/testdata/" + "flatbuffers.bfbs"; constexpr absl::string_view kByteField = "f_byte"; constexpr absl::string_view kUbyteField = "f_ubyte"; @@ -70,7 +71,7 @@ class FlatBuffersTest : public testing::Test { parser_.builder_.GetBufferPointer(), *schema_, &arena_); EXPECT_NE(nullptr, value); EXPECT_EQ(kNumFields, value->size()); - const CelList* keys = value->ListKeys(); + const CelList* keys = value->ListKeys().value(); EXPECT_NE(nullptr, keys); EXPECT_EQ(kNumFields, keys->size()); EXPECT_TRUE((*keys)[2].IsString()); @@ -103,86 +104,92 @@ TEST_F(FlatBuffersTest, PrimitiveFields) { { auto f = value[CelValue::CreateStringView(kByteField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsInt64()); - EXPECT_EQ(-1, f.value().Int64OrDie()); + EXPECT_TRUE(f->IsInt64()); + EXPECT_EQ(-1, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUbyteField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsUint64()); - EXPECT_EQ(1, uf.value().Uint64OrDie()); + EXPECT_TRUE(uf->IsUint64()); + EXPECT_EQ(1, uf->Uint64OrDie()); } // short { auto f = value[CelValue::CreateStringView(kShortField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsInt64()); - EXPECT_EQ(-2, f.value().Int64OrDie()); + EXPECT_TRUE(f->IsInt64()); + EXPECT_EQ(-2, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUshortField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsUint64()); - EXPECT_EQ(2, uf.value().Uint64OrDie()); + EXPECT_TRUE(uf->IsUint64()); + EXPECT_EQ(2, uf->Uint64OrDie()); } // int { auto f = value[CelValue::CreateStringView(kIntField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsInt64()); - EXPECT_EQ(-3, f.value().Int64OrDie()); + EXPECT_TRUE(f->IsInt64()); + EXPECT_EQ(-3, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUintField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsUint64()); - EXPECT_EQ(3, uf.value().Uint64OrDie()); + EXPECT_TRUE(uf->IsUint64()); + EXPECT_EQ(3, uf->Uint64OrDie()); } // long { auto f = value[CelValue::CreateStringView(kLongField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsInt64()); - EXPECT_EQ(-4, f.value().Int64OrDie()); + EXPECT_TRUE(f->IsInt64()); + EXPECT_EQ(-4, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUlongField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsUint64()); - EXPECT_EQ(4, uf.value().Uint64OrDie()); + EXPECT_TRUE(uf->IsUint64()); + EXPECT_EQ(4, uf->Uint64OrDie()); } // float and double { auto f = value[CelValue::CreateStringView(kFloatField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsDouble()); - EXPECT_EQ(5.0, f.value().DoubleOrDie()); + EXPECT_TRUE(f->IsDouble()); + EXPECT_EQ(5.0, f->DoubleOrDie()); } { auto f = value[CelValue::CreateStringView(kDoubleField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsDouble()); - EXPECT_EQ(6.0, f.value().DoubleOrDie()); + EXPECT_TRUE(f->IsDouble()); + EXPECT_EQ(6.0, f->DoubleOrDie()); } // bool { auto f = value[CelValue::CreateStringView(kBoolField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsBool()); - EXPECT_EQ(false, f.value().BoolOrDie()); + EXPECT_TRUE(f->IsBool()); + EXPECT_EQ(false, f->BoolOrDie()); } // string { auto f = value[CelValue::CreateStringView(kStringField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsString()); - EXPECT_EQ("test", f.value().StringOrDie().value()); + EXPECT_TRUE(f->IsString()); + EXPECT_EQ("test", f->StringOrDie().value()); } - // missing field + // bad field type { - auto f = value[CelValue::CreateInt64(1)]; - EXPECT_FALSE(f.has_value()); + CelValue bad_field = CelValue::CreateInt64(1); + auto f = value[bad_field]; + EXPECT_TRUE(f.has_value()); + EXPECT_TRUE(f->IsError()); + auto presence = value.Has(bad_field); + EXPECT_FALSE(presence.ok()); + EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); } + // missing field { auto f = value[CelValue::CreateStringView(kUnknownField)]; EXPECT_FALSE(f.has_value()); @@ -195,29 +202,29 @@ TEST_F(FlatBuffersTest, PrimitiveFieldDefaults) { { auto f = value[CelValue::CreateStringView(kByteField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsInt64()); - EXPECT_EQ(0, f.value().Int64OrDie()); + EXPECT_TRUE(f->IsInt64()); + EXPECT_EQ(0, f->Int64OrDie()); } // short { auto f = value[CelValue::CreateStringView(kShortField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsInt64()); - EXPECT_EQ(150, f.value().Int64OrDie()); + EXPECT_TRUE(f->IsInt64()); + EXPECT_EQ(150, f->Int64OrDie()); } // bool { auto f = value[CelValue::CreateStringView(kBoolField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsBool()); - EXPECT_EQ(true, f.value().BoolOrDie()); + EXPECT_TRUE(f->IsBool()); + EXPECT_EQ(true, f->BoolOrDie()); } // string { auto f = value[CelValue::CreateStringView(kStringField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsString()); - EXPECT_EQ("", f.value().StringOrDie().value()); + EXPECT_TRUE(f->IsString()); + EXPECT_EQ("", f->StringOrDie().value()); } } @@ -228,22 +235,47 @@ TEST_F(FlatBuffersTest, ObjectField) { f_int: 16 } })"); - auto f = value[CelValue::CreateStringView(kObjField)]; + CelValue field = CelValue::CreateStringView(kObjField); + auto presence = value.Has(field); + EXPECT_OK(presence); + EXPECT_TRUE(*presence); + auto f = value[field]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsMap()); - const CelMap& m = *f.value().MapOrDie(); + EXPECT_TRUE(f->IsMap()); + const CelMap& m = *f->MapOrDie(); EXPECT_EQ(2, m.size()); { - auto mf = m[CelValue::CreateStringView(kStringField)]; + auto obj_field = CelValue::CreateStringView(kStringField); + auto member_presence = m.Has(obj_field); + EXPECT_OK(member_presence); + EXPECT_TRUE(*member_presence); + auto mf = m[obj_field]; EXPECT_TRUE(mf.has_value()); - EXPECT_TRUE(mf.value().IsString()); - EXPECT_EQ("entry", mf.value().StringOrDie().value()); + EXPECT_TRUE(mf->IsString()); + EXPECT_EQ("entry", mf->StringOrDie().value()); } { - auto mf = m[CelValue::CreateStringView(kIntField)]; + auto obj_field = CelValue::CreateStringView(kIntField); + auto member_presence = m.Has(obj_field); + EXPECT_OK(member_presence); + EXPECT_TRUE(*member_presence); + auto mf = m[obj_field]; EXPECT_TRUE(mf.has_value()); - EXPECT_TRUE(mf.value().IsInt64()); - EXPECT_EQ(16, mf.value().Int64OrDie()); + EXPECT_TRUE(mf->IsInt64()); + EXPECT_EQ(16, mf->Int64OrDie()); + } + { + std::string undefined = "f_undefined"; + CelValue undefined_field = CelValue::CreateStringView(undefined); + auto presence = m.Has(undefined_field); + EXPECT_OK(presence); + EXPECT_FALSE(*presence); + auto v = m[undefined_field]; + EXPECT_FALSE(v.has_value()); + + presence = m.Has(CelValue::CreateBool(false)); + EXPECT_FALSE(presence.ok()); + EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); } } @@ -251,7 +283,7 @@ TEST_F(FlatBuffersTest, ObjectFieldDefault) { const CelMap& value = loadJson("{}"); auto f = value[CelValue::CreateStringView(kObjField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsNull()); + EXPECT_TRUE(f->IsNull()); } TEST_F(FlatBuffersTest, PrimitiveVectorFields) { @@ -273,29 +305,29 @@ TEST_F(FlatBuffersTest, PrimitiveVectorFields) { { auto f = value[CelValue::CreateStringView(kBytesField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsBytes()); - EXPECT_EQ("\x9F", f.value().BytesOrDie().value()); + EXPECT_TRUE(f->IsBytes()); + EXPECT_EQ("\x9F", f->BytesOrDie().value()); } { auto uf = value[CelValue::CreateStringView(kUbytesField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsBytes()); - EXPECT_EQ("abc", uf.value().BytesOrDie().value()); + EXPECT_TRUE(uf->IsBytes()); + EXPECT_EQ("abc", uf->BytesOrDie().value()); } // short { auto f = value[CelValue::CreateStringView(kShortsField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(-2, l[0].Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUshortsField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsList()); - const CelList& l = *uf.value().ListOrDie(); + EXPECT_TRUE(uf->IsList()); + const CelList& l = *uf->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(2, l[0].Uint64OrDie()); } @@ -303,16 +335,16 @@ TEST_F(FlatBuffersTest, PrimitiveVectorFields) { { auto f = value[CelValue::CreateStringView(kIntsField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(-3, l[0].Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUintsField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsList()); - const CelList& l = *uf.value().ListOrDie(); + EXPECT_TRUE(uf->IsList()); + const CelList& l = *uf->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(3, l[0].Uint64OrDie()); } @@ -320,16 +352,16 @@ TEST_F(FlatBuffersTest, PrimitiveVectorFields) { { auto f = value[CelValue::CreateStringView(kLongsField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(-4, l[0].Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUlongsField)]; EXPECT_TRUE(uf.has_value()); - EXPECT_TRUE(uf.value().IsList()); - const CelList& l = *uf.value().ListOrDie(); + EXPECT_TRUE(uf->IsList()); + const CelList& l = *uf->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(4, l[0].Uint64OrDie()); } @@ -337,16 +369,16 @@ TEST_F(FlatBuffersTest, PrimitiveVectorFields) { { auto f = value[CelValue::CreateStringView(kFloatsField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(5.0, l[0].DoubleOrDie()); } { auto f = value[CelValue::CreateStringView(kDoublesField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(6.0, l[0].DoubleOrDie()); } @@ -354,8 +386,8 @@ TEST_F(FlatBuffersTest, PrimitiveVectorFields) { { auto f = value[CelValue::CreateStringView(kBoolsField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(false, l[0].BoolOrDie()); } @@ -363,8 +395,8 @@ TEST_F(FlatBuffersTest, PrimitiveVectorFields) { { auto f = value[CelValue::CreateStringView(kStringsField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ("test", l[0].StringOrDie().value()); } @@ -381,24 +413,32 @@ TEST_F(FlatBuffersTest, ObjectVectorField) { })"); auto f = value[CelValue::CreateStringView(kObjsField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(2, l.size()); { EXPECT_TRUE(l[0].IsMap()); const CelMap& m = *l[0].MapOrDie(); EXPECT_EQ(2, m.size()); { - auto mf = m[CelValue::CreateStringView(kStringField)]; + CelValue field = CelValue::CreateStringView(kStringField); + auto presence = m.Has(field); + EXPECT_OK(presence); + EXPECT_TRUE(*presence); + auto mf = m[field]; EXPECT_TRUE(mf.has_value()); - EXPECT_TRUE(mf.value().IsString()); - EXPECT_EQ("entry", mf.value().StringOrDie().value()); + EXPECT_TRUE(mf->IsString()); + EXPECT_EQ("entry", mf->StringOrDie().value()); } { - auto mf = m[CelValue::CreateStringView(kIntField)]; + CelValue field = CelValue::CreateStringView(kIntField); + auto presence = m.Has(field); + EXPECT_OK(presence); + EXPECT_TRUE(*presence); + auto mf = m[field]; EXPECT_TRUE(mf.has_value()); - EXPECT_TRUE(mf.value().IsInt64()); - EXPECT_EQ(16, mf.value().Int64OrDie()); + EXPECT_TRUE(mf->IsInt64()); + EXPECT_EQ(16, mf->Int64OrDie()); } } { @@ -406,16 +446,35 @@ TEST_F(FlatBuffersTest, ObjectVectorField) { const CelMap& m = *l[1].MapOrDie(); EXPECT_EQ(2, m.size()); { - auto mf = m[CelValue::CreateStringView(kStringField)]; + CelValue field = CelValue::CreateStringView(kStringField); + auto presence = m.Has(field); + EXPECT_OK(presence); + // Note, the presence checks on flat buffers seem to only apply to whether + // the field is defined. + EXPECT_TRUE(*presence); + auto mf = m[field]; EXPECT_TRUE(mf.has_value()); - EXPECT_TRUE(mf.value().IsString()); - EXPECT_EQ("", mf.value().StringOrDie().value()); + EXPECT_TRUE(mf->IsString()); + EXPECT_EQ("", mf->StringOrDie().value()); } { - auto mf = m[CelValue::CreateStringView(kIntField)]; + CelValue field = CelValue::CreateStringView(kIntField); + auto presence = m.Has(field); + EXPECT_OK(presence); + EXPECT_TRUE(*presence); + auto mf = m[field]; EXPECT_TRUE(mf.has_value()); - EXPECT_TRUE(mf.value().IsInt64()); - EXPECT_EQ(32, mf.value().Int64OrDie()); + EXPECT_TRUE(mf->IsInt64()); + EXPECT_EQ(32, mf->Int64OrDie()); + } + { + std::string undefined = "f_undefined"; + CelValue field = CelValue::CreateStringView(undefined); + auto presence = m.Has(field); + EXPECT_OK(presence); + EXPECT_FALSE(*presence); + auto mf = m[field]; + EXPECT_FALSE(mf.has_value()); } } } @@ -426,25 +485,25 @@ TEST_F(FlatBuffersTest, VectorFieldDefaults) { kIntsField, kBoolsField, kStringsField, kObjsField}) { auto f = value[CelValue::CreateStringView(field)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsList()); - const CelList& l = *f.value().ListOrDie(); + EXPECT_TRUE(f->IsList()); + const CelList& l = *f->ListOrDie(); EXPECT_EQ(0, l.size()); } { auto f = value[CelValue::CreateStringView(kIndexedField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsMap()); - const CelMap& m = *f.value().MapOrDie(); + EXPECT_TRUE(f->IsMap()); + const CelMap& m = *f->MapOrDie(); EXPECT_EQ(0, m.size()); - EXPECT_EQ(0, m.ListKeys()->size()); + EXPECT_EQ(0, (*m.ListKeys())->size()); } { auto f = value[CelValue::CreateStringView(kBytesField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsBytes()); - EXPECT_EQ("", f.value().BytesOrDie().value()); + EXPECT_TRUE(f->IsBytes()); + EXPECT_EQ("", f->BytesOrDie().value()); } } @@ -471,10 +530,10 @@ TEST_F(FlatBuffersTest, IndexedObjectVectorField) { })"); auto f = value[CelValue::CreateStringView(kIndexedField)]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsMap()); - const CelMap& m = *f.value().MapOrDie(); + EXPECT_TRUE(f->IsMap()); + const CelMap& m = *f->MapOrDie(); EXPECT_EQ(4, m.size()); - const CelList& l = *m.ListKeys(); + const CelList& l = *m.ListKeys().value(); EXPECT_EQ(4, l.size()); EXPECT_TRUE(l[0].IsString()); EXPECT_TRUE(l[1].IsString()); @@ -492,15 +551,15 @@ TEST_F(FlatBuffersTest, IndexedObjectVectorField) { for (const std::string& key : std::vector{a, b, c, d}) { auto v = m[CelValue::CreateString(&key)]; EXPECT_TRUE(v.has_value()); - const CelMap& vm = *v.value().MapOrDie(); + const CelMap& vm = *v->MapOrDie(); EXPECT_EQ(2, vm.size()); auto vf = vm[CelValue::CreateStringView(kStringField)]; EXPECT_TRUE(vf.has_value()); - EXPECT_TRUE(vf.value().IsString()); - EXPECT_EQ(key, vf.value().StringOrDie().value()); + EXPECT_TRUE(vf->IsString()); + EXPECT_EQ(key, vf->StringOrDie().value()); auto vi = vm[CelValue::CreateStringView(kIntField)]; EXPECT_TRUE(vi.has_value()); - EXPECT_TRUE(vi.value().IsInt64()); + EXPECT_TRUE(vi->IsInt64()); } { @@ -522,17 +581,39 @@ TEST_F(FlatBuffersTest, IndexedObjectVectorFieldDefaults) { } ] })"); - auto f = value[CelValue::CreateStringView(kIndexedField)]; + CelValue field = CelValue::CreateStringView(kIndexedField); + auto presence = value.Has(field); + EXPECT_OK(presence); + EXPECT_TRUE(*presence); + auto f = value[field]; EXPECT_TRUE(f.has_value()); - EXPECT_TRUE(f.value().IsMap()); - const CelMap& m = *f.value().MapOrDie(); + EXPECT_TRUE(f->IsMap()); + const CelMap& m = *f->MapOrDie(); + EXPECT_EQ(1, m.size()); - const CelList& l = *m.ListKeys(); + const CelList& l = *m.ListKeys().value(); EXPECT_EQ(1, l.size()); EXPECT_TRUE(l[0].IsString()); EXPECT_EQ("", l[0].StringOrDie().value()); - auto v = m[CelValue::CreateStringView(absl::string_view())]; + + CelValue map_field = CelValue::CreateStringView(absl::string_view()); + presence = m.Has(map_field); + EXPECT_OK(presence); + EXPECT_TRUE(*presence); + auto v = m[map_field]; EXPECT_TRUE(v.has_value()); + + std::string undefined = "f_undefined"; + CelValue undefined_field = CelValue::CreateStringView(undefined); + presence = m.Has(undefined_field); + EXPECT_OK(presence); + EXPECT_FALSE(*presence); + v = m[undefined_field]; + EXPECT_FALSE(v.has_value()); + + presence = m.Has(CelValue::CreateBool(false)); + EXPECT_FALSE(presence.ok()); + EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); } } // namespace diff --git a/tools/navigable_ast.cc b/tools/navigable_ast.cc new file mode 100644 index 000000000..0de2d86c6 --- /dev/null +++ b/tools/navigable_ast.cc @@ -0,0 +1,205 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/navigable_ast.h" + +#include +#include +#include +#include +#include +#include + +#include "cel/expr/checked.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/memory/memory.h" +#include "common/ast/navigable_ast_internal.h" +#include "eval/public/ast_traverse.h" +#include "eval/public/ast_visitor.h" +#include "eval/public/ast_visitor_base.h" +#include "eval/public/source_position.h" + +namespace cel { + +namespace { + +using ::cel::expr::Expr; +using ::google::api::expr::runtime::AstTraverse; +using ::google::api::expr::runtime::SourcePosition; + +using AstNode = NavigableProtoAstNode; +using NavigableAstNodeData = + common_internal::NavigableAstNodeData; +using NavigableAstMetadata = + common_internal::NavigableAstMetadata; + +NodeKind GetNodeKind(const Expr& expr) { + switch (expr.expr_kind_case()) { + case Expr::kConstExpr: + return NodeKind::kConstant; + case Expr::kIdentExpr: + return NodeKind::kIdent; + case Expr::kSelectExpr: + return NodeKind::kSelect; + case Expr::kCallExpr: + return NodeKind::kCall; + case Expr::kListExpr: + return NodeKind::kList; + case Expr::kStructExpr: + if (!expr.struct_expr().message_name().empty()) { + return NodeKind::kStruct; + } else { + return NodeKind::kMap; + } + case Expr::kComprehensionExpr: + return NodeKind::kComprehension; + case Expr::EXPR_KIND_NOT_SET: + default: + return NodeKind::kUnspecified; + } +} + +// Get the traversal relationship from parent to the given node. +// Note: these depend on the ast_visitor utility's traversal ordering. +ChildKind GetChildKind(const NavigableAstNodeData& parent_node, + size_t child_index) { + constexpr size_t kComprehensionRangeArgIndex = + google::api::expr::runtime::ITER_RANGE; + constexpr size_t kComprehensionInitArgIndex = + google::api::expr::runtime::ACCU_INIT; + constexpr size_t kComprehensionConditionArgIndex = + google::api::expr::runtime::LOOP_CONDITION; + constexpr size_t kComprehensionLoopStepArgIndex = + google::api::expr::runtime::LOOP_STEP; + constexpr size_t kComprehensionResultArgIndex = + google::api::expr::runtime::RESULT; + + switch (parent_node.node_kind) { + case NodeKind::kStruct: + return ChildKind::kStructValue; + case NodeKind::kMap: + if (child_index % 2 == 0) { + return ChildKind::kMapKey; + } + return ChildKind::kMapValue; + case NodeKind::kList: + return ChildKind::kListElem; + case NodeKind::kSelect: + return ChildKind::kSelectOperand; + case NodeKind::kCall: + if (child_index == 0 && parent_node.expr->call_expr().has_target()) { + return ChildKind::kCallReceiver; + } + return ChildKind::kCallArg; + case NodeKind::kComprehension: + switch (child_index) { + case kComprehensionRangeArgIndex: + return ChildKind::kComprehensionRange; + case kComprehensionInitArgIndex: + return ChildKind::kComprehensionInit; + case kComprehensionConditionArgIndex: + return ChildKind::kComprehensionCondition; + case kComprehensionLoopStepArgIndex: + return ChildKind::kComprehensionLoopStep; + case kComprehensionResultArgIndex: + return ChildKind::kComprensionResult; + default: + return ChildKind::kUnspecified; + } + default: + return ChildKind::kUnspecified; + } +} + +class NavigableExprBuilderVisitor + : public google::api::expr::runtime::AstVisitorBase { + public: + NavigableExprBuilderVisitor( + absl::AnyInvocable()> node_factory, + absl::AnyInvocable node_data_accessor) + : node_factory_(std::move(node_factory)), + node_data_accessor_(std::move(node_data_accessor)), + metadata_(std::make_unique()) {} + + NavigableAstNodeData& NodeDataAt(size_t index) { + return node_data_accessor_(*metadata_->nodes[index]); + } + + void PreVisitExpr(const Expr* expr, const SourcePosition* position) override { + NavigableProtoAstNode* parent = + parent_stack_.empty() ? nullptr + : metadata_->nodes[parent_stack_.back()].get(); + size_t index = metadata_->nodes.size(); + metadata_->nodes.push_back(node_factory_()); + NavigableProtoAstNode* node = metadata_->nodes[index].get(); + auto& node_data = NodeDataAt(index); + node_data.parent = parent; + node_data.expr = expr; + node_data.parent_relation = ChildKind::kUnspecified; + node_data.node_kind = GetNodeKind(*expr); + node_data.tree_size = 1; + node_data.height = 1; + node_data.index = index; + node_data.child_index = -1; + node_data.metadata = metadata_.get(); + + metadata_->id_to_node.insert({expr->id(), node}); + metadata_->expr_to_node.insert({expr, node}); + if (!parent_stack_.empty()) { + auto& parent_node_data = NodeDataAt(parent_stack_.back()); + size_t child_index = parent_node_data.children.size(); + parent_node_data.children.push_back(node); + node_data.parent_relation = GetChildKind(parent_node_data, child_index); + node_data.child_index = child_index; + } + parent_stack_.push_back(index); + } + + void PostVisitExpr(const Expr* expr, + const SourcePosition* position) override { + size_t idx = parent_stack_.back(); + parent_stack_.pop_back(); + metadata_->postorder.push_back(metadata_->nodes[idx].get()); + NavigableAstNodeData& node = NodeDataAt(idx); + if (!parent_stack_.empty()) { + auto& parent_node_data = NodeDataAt(parent_stack_.back()); + parent_node_data.tree_size += node.tree_size; + parent_node_data.height = + std::max(parent_node_data.height, node.height + 1); + } + } + + std::unique_ptr Consume() && { + return std::move(metadata_); + } + + private: + absl::AnyInvocable()> node_factory_; + absl::AnyInvocable node_data_accessor_; + std::unique_ptr metadata_; + std::vector parent_stack_; +}; + +} // namespace + +NavigableProtoAst NavigableProtoAst::Build(const Expr& expr) { + NavigableExprBuilderVisitor visitor( + []() { return absl::WrapUnique(new AstNode()); }, + [](AstNode& node) -> NavigableAstNodeData& { return node.data_; }); + AstTraverse(&expr, /*source_info=*/nullptr, &visitor); + return NavigableProtoAst(std::move(visitor).Consume()); +} + +} // namespace cel diff --git a/tools/navigable_ast.h b/tools/navigable_ast.h new file mode 100644 index 000000000..1ebf6883c --- /dev/null +++ b/tools/navigable_ast.h @@ -0,0 +1,169 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_NAVIGABLE_AST_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_NAVIGABLE_AST_H_ + + +#include "cel/expr/syntax.pb.h" +#include "common/ast/navigable_ast_internal.h" +#include "common/ast/navigable_ast_kinds.h" // IWYU pragma: export + +namespace cel { + +class NavigableProtoAst; +class NavigableProtoAstNode; + +namespace common_internal { + +struct ProtoAstTraits { + using ExprType = cel::expr::Expr; + using AstType = NavigableProtoAst; + using NodeType = NavigableProtoAstNode; +}; + +} // namespace common_internal + +// Wrapper around a CEL AST node that exposes traversal information. +class NavigableProtoAstNode : public common_internal::NavigableAstNodeBase< + common_internal::ProtoAstTraits> { + private: + using Base = + common_internal::NavigableAstNodeBase; + + public: + // A const Span like type that provides pre-order traversal for a sub tree. + // provides .begin() and .end() returning bidirectional iterators to + // const AstNode&. + using PreorderRange = Base::PreorderRange; + + // A const Span like type that provides post-order traversal for a sub tree. + // provides .begin() and .end() returning bidirectional iterators to + // const AstNode&. + using PostorderRange = Base::PostorderRange; + + // The parent of this node or nullptr if it is a root. + using Base::parent; + + // The ptr to the backing Expr in the source AST. + // + // This may dangle if the source AST is mutated or destroyed. + using Base::expr; + + // The index of this node in the parent's children. -1 if this is a root. + using Base::child_index; + + // The type of traversal from parent to this node. + using Base::parent_relation; + + // The type of this node, analogous to Expr::ExprKindCase. + using Base::node_kind; + + // The number of nodes in the tree rooted at this node (including self). + using Base::tree_size; + + // The height of this node in the tree (the number of descendants including + // self on the longest path). + using Base::height; + + // The children of this node in their natural order. + using Base::children; + + // Range over the descendants of this node (including self) using preorder + // semantics. Each node is visited immediately before all of its descendants. + // + // example: + // for (const cel::NavigableProtoAstNode& node : + // ast.Root().DescendantsPreorder()) { + // ... + // } + // + // Children are traversed in their natural order: + // - call arguments are traversed in order (receiver if present is first) + // - list elements are traversed in order + // - maps are traversed in order (alternating key, value per entry) + // - comprehensions are traversed in the order: range, accu_init, condition, + // step, result + using Base::DescendantsPreorder; + + // Range over the descendants of this node (including self) using postorder + // semantics. Each node is visited immediately after all of its descendants. + using Base::DescendantsPostorder; + + private: + friend class NavigableProtoAst; + + NavigableProtoAstNode() = default; +}; + +// NavigableExpr provides a view over a CEL AST that allows for generalized +// traversal. The traversal structures are eagerly built on construction, +// requiring a full traversal of the AST. This is intended for use in tools that +// might require random access or multiple passes over the AST, amortizing the +// cost of building the traversal structures. +// +// Pointers to AstNodes are owned by this instance and must not outlive it. +// +// `NavigableAst` and Navigable nodes are independent of the input Expr and may +// outlive it, but may contain dangling pointers if the input Expr is modified +// or destroyed. +class NavigableProtoAst : public common_internal::NavigableAstBase< + common_internal::ProtoAstTraits> { + private: + using Base = + common_internal::NavigableAstBase; + + public: + static NavigableProtoAst Build(const cel::expr::Expr& expr); + + // Default constructor creates an empty instance. + // + // Operations other than equality are undefined on an empty instance. + // + // This is intended for composed object construction, a new NavigableProtoAst + // should be obtained from the Build factory function. + NavigableProtoAst() = default; + + // Move only. + NavigableProtoAst(const NavigableProtoAst&) = delete; + NavigableProtoAst& operator=(const NavigableProtoAst&) = delete; + NavigableProtoAst(NavigableProtoAst&&) = default; + NavigableProtoAst& operator=(NavigableProtoAst&&) = default; + + // Return ptr to the AST node with id if present. Otherwise returns nullptr. + // + // If ids are non-unique, the first pre-order node encountered with id is + // returned. + using Base::FindId; + + // Return ptr to the AST node representing the given Expr node. + using Base::FindExpr; + + // Returns the root of the AST. + using Base::Root; + + // Return whether the source AST used unique IDs for each node. + // + // This is typically the case, but older versions of the parsers didn't + // guarantee uniqueness for nodes generated by some macros and ASTs modified + // outside of CEL's parse/type check may not have unique IDs. + using Base::IdsAreUnique; + + private: + using Base::Base; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_NAVIGABLE_AST_H_ diff --git a/tools/navigable_ast_test.cc b/tools/navigable_ast_test.cc new file mode 100644 index 000000000..a42f1d5fc --- /dev/null +++ b/tools/navigable_ast_test.cc @@ -0,0 +1,396 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/navigable_ast.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "base/builtins.h" +#include "internal/testing.h" +#include "parser/parser.h" + +namespace cel { +namespace { + +using ::cel::expr::Expr; +using ::google::api::expr::parser::Parse; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::SizeIs; + +TEST(NavigableProtoAst, Basic) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr()->set_int64_value(42); + + NavigableProtoAst ast = NavigableProtoAst::Build(const_node); + EXPECT_TRUE(ast.IdsAreUnique()); + + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.expr(), &const_node); + EXPECT_THAT(root.children(), IsEmpty()); + EXPECT_TRUE(root.parent() == nullptr); + EXPECT_EQ(root.child_index(), -1); + EXPECT_EQ(root.node_kind(), NodeKind::kConstant); + EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); +} + +TEST(NavigableProtoAst, DefaultCtorEmpty) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr()->set_int64_value(42); + + NavigableProtoAst ast = NavigableProtoAst::Build(const_node); + EXPECT_EQ(ast, ast); + + NavigableProtoAst empty; + + EXPECT_NE(ast, empty); + EXPECT_EQ(empty, empty); + + EXPECT_TRUE(static_cast(ast)); + EXPECT_FALSE(static_cast(empty)); + + NavigableProtoAst moved = std::move(ast); + EXPECT_EQ(ast, empty); + EXPECT_FALSE(static_cast(ast)); + EXPECT_TRUE(static_cast(moved)); +} + +TEST(NavigableProtoAst, FindById) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr()->set_int64_value(42); + + NavigableProtoAst ast = NavigableProtoAst::Build(const_node); + + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(ast.FindId(const_node.id()), &root); + EXPECT_EQ(ast.FindId(-1), nullptr); +} + +MATCHER_P(AstNodeWrapping, expr, "") { + const NavigableProtoAstNode* ptr = arg; + return ptr != nullptr && ptr->expr() == expr; +} + +TEST(NavigableProtoAst, ToleratesNonUnique) { + Expr call_node; + call_node.set_id(1); + call_node.mutable_call_expr()->set_function(cel::builtin::kNot); + Expr* const_node = call_node.mutable_call_expr()->add_args(); + const_node->mutable_const_expr()->set_bool_value(false); + const_node->set_id(1); + + NavigableProtoAst ast = NavigableProtoAst::Build(call_node); + + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(ast.FindId(1), &root); + EXPECT_EQ(ast.FindExpr(&call_node), &root); + EXPECT_FALSE(ast.IdsAreUnique()); + EXPECT_THAT(ast.FindExpr(const_node), AstNodeWrapping(const_node)); +} + +TEST(NavigableProtoAst, FindByExprPtr) { + Expr const_node; + const_node.set_id(1); + const_node.mutable_const_expr()->set_int64_value(42); + + NavigableProtoAst ast = NavigableProtoAst::Build(const_node); + + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(ast.FindExpr(&const_node), &root); + EXPECT_EQ(ast.FindExpr(&Expr::default_instance()), nullptr); +} + +TEST(NavigableProtoAst, Children) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + 2")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.expr(), &parsed_expr.expr()); + EXPECT_THAT(root.children(), SizeIs(2)); + EXPECT_TRUE(root.parent() == nullptr); + EXPECT_EQ(root.child_index(), -1); + EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); + EXPECT_EQ(root.node_kind(), NodeKind::kCall); + + EXPECT_THAT( + root.children(), + ElementsAre(AstNodeWrapping(&parsed_expr.expr().call_expr().args(0)), + AstNodeWrapping(&parsed_expr.expr().call_expr().args(1)))); + + ASSERT_THAT(root.children(), SizeIs(2)); + const auto* child1 = root.children()[0]; + EXPECT_EQ(child1->child_index(), 0); + EXPECT_EQ(child1->parent(), &root); + EXPECT_EQ(child1->parent_relation(), ChildKind::kCallArg); + EXPECT_EQ(child1->node_kind(), NodeKind::kConstant); + EXPECT_THAT(child1->children(), IsEmpty()); + + const auto* child2 = root.children()[1]; + EXPECT_EQ(child2->child_index(), 1); +} + +TEST(NavigableProtoAst, UnspecifiedExpr) { + Expr expr; + expr.set_id(1); + NavigableProtoAst ast = NavigableProtoAst::Build(expr); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.expr(), &expr); + EXPECT_THAT(root.children(), SizeIs(0)); + EXPECT_TRUE(root.parent() == nullptr); + EXPECT_EQ(root.child_index(), -1); + EXPECT_EQ(root.node_kind(), NodeKind::kUnspecified); +} + +TEST(NavigableProtoAst, ParentRelationSelect) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kSelectOperand); + EXPECT_EQ(child->node_kind(), NodeKind::kIdent); +} + +TEST(NavigableProtoAst, ParentRelationCallReceiver) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b()")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kCallReceiver); + EXPECT_EQ(child->node_kind(), NodeKind::kIdent); +} + +TEST(NavigableProtoAst, ParentRelationCreateStruct) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + Parse("com.example.Type{field: '123'}")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kStruct); + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kStructValue); + EXPECT_EQ(child->node_kind(), NodeKind::kConstant); +} + +TEST(NavigableProtoAst, ParentRelationCreateMap) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'a': 123}")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kMap); + ASSERT_THAT(root.children(), SizeIs(2)); + const auto* key = root.children()[0]; + const auto* value = root.children()[1]; + + EXPECT_EQ(key->parent_relation(), ChildKind::kMapKey); + EXPECT_EQ(key->node_kind(), NodeKind::kConstant); + + EXPECT_EQ(value->parent_relation(), ChildKind::kMapValue); + EXPECT_EQ(value->node_kind(), NodeKind::kConstant); +} + +TEST(NavigableProtoAst, ParentRelationCreateList) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[123]")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kList); + ASSERT_THAT(root.children(), SizeIs(1)); + const auto* child = root.children()[0]; + + EXPECT_EQ(child->parent_relation(), ChildKind::kListElem); + EXPECT_EQ(child->node_kind(), NodeKind::kConstant); +} + +TEST(NavigableProtoAst, ParentRelationComprehension) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1].all(x, x < 2)")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + ASSERT_THAT(root.children(), SizeIs(5)); + const auto* range = root.children()[0]; + const auto* init = root.children()[1]; + const auto* condition = root.children()[2]; + const auto* step = root.children()[3]; + const auto* finish = root.children()[4]; + + EXPECT_EQ(range->parent_relation(), ChildKind::kComprehensionRange); + EXPECT_EQ(init->parent_relation(), ChildKind::kComprehensionInit); + EXPECT_EQ(condition->parent_relation(), ChildKind::kComprehensionCondition); + EXPECT_EQ(step->parent_relation(), ChildKind::kComprehensionLoopStep); + EXPECT_EQ(finish->parent_relation(), ChildKind::kComprensionResult); +} + +TEST(NavigableProtoAst, DescendantsPostorder) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kCall); + + std::vector constants; + std::vector node_kinds; + + for (const NavigableProtoAstNode& node : root.DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kConstant) { + constants.push_back(node.expr()->const_expr().int64_value()); + } + node_kinds.push_back(node.node_kind()); + } + + EXPECT_THAT(node_kinds, ElementsAre(NodeKind::kConstant, NodeKind::kIdent, + NodeKind::kConstant, NodeKind::kCall, + NodeKind::kCall)); + EXPECT_THAT(constants, ElementsAre(1, 3)); +} + +TEST(NavigableProtoAst, DescendantsPreorder) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kCall); + + std::vector constants; + std::vector node_kinds; + + for (const NavigableProtoAstNode& node : root.DescendantsPreorder()) { + if (node.node_kind() == NodeKind::kConstant) { + constants.push_back(node.expr()->const_expr().int64_value()); + } + node_kinds.push_back(node.node_kind()); + } + + EXPECT_THAT(node_kinds, + ElementsAre(NodeKind::kCall, NodeKind::kConstant, NodeKind::kCall, + NodeKind::kIdent, NodeKind::kConstant)); + EXPECT_THAT(constants, ElementsAre(1, 3)); +} + +TEST(NavigableProtoAst, DescendantsPreorderComprehension) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + for (const NavigableProtoAstNode& node : root.DescendantsPreorder()) { + node_kinds.push_back( + std::make_pair(node.node_kind(), node.parent_relation())); + } + + EXPECT_THAT( + node_kinds, + ElementsAre(Pair(NodeKind::kComprehension, ChildKind::kUnspecified), + Pair(NodeKind::kList, ChildKind::kComprehensionRange), + Pair(NodeKind::kConstant, ChildKind::kListElem), + Pair(NodeKind::kConstant, ChildKind::kListElem), + Pair(NodeKind::kConstant, ChildKind::kListElem), + Pair(NodeKind::kList, ChildKind::kComprehensionInit), + Pair(NodeKind::kConstant, ChildKind::kComprehensionCondition), + Pair(NodeKind::kCall, ChildKind::kComprehensionLoopStep), + Pair(NodeKind::kIdent, ChildKind::kCallArg), + Pair(NodeKind::kList, ChildKind::kCallArg), + Pair(NodeKind::kCall, ChildKind::kListElem), + Pair(NodeKind::kIdent, ChildKind::kCallArg), + Pair(NodeKind::kConstant, ChildKind::kCallArg), + Pair(NodeKind::kIdent, ChildKind::kComprensionResult))); +} + +TEST(NavigableProtoAst, TreeSize) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + EXPECT_EQ(root.tree_size(), 14); + auto it = root.DescendantsPostorder().begin(); + EXPECT_EQ(it->tree_size(), 1); +} + +TEST(NavigableProtoAst, Height) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + EXPECT_EQ(root.height(), 5); + auto it = root.DescendantsPostorder().begin(); + EXPECT_EQ(it->height(), 1); +} + +TEST(NavigableProtoAst, DescendantsPreorderCreateMap) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'key1': 1, 'key2': 2}")); + + NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); + const NavigableProtoAstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kMap); + + std::vector> node_kinds; + + for (const NavigableProtoAstNode& node : root.DescendantsPreorder()) { + node_kinds.push_back( + std::make_pair(node.node_kind(), node.parent_relation())); + } + + EXPECT_THAT(node_kinds, + ElementsAre(Pair(NodeKind::kMap, ChildKind::kUnspecified), + Pair(NodeKind::kConstant, ChildKind::kMapKey), + Pair(NodeKind::kConstant, ChildKind::kMapValue), + Pair(NodeKind::kConstant, ChildKind::kMapKey), + Pair(NodeKind::kConstant, ChildKind::kMapValue))); +} + +} // namespace +} // namespace cel diff --git a/tools/proto_to_predicate.cc b/tools/proto_to_predicate.cc new file mode 100644 index 000000000..8c89ee2f0 --- /dev/null +++ b/tools/proto_to_predicate.cc @@ -0,0 +1,459 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::tools { + +using ::google::api::expr::common::CelOperator; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +class ProtoToPredicateBuilder final : private ExprFactory { + public: + ProtoToPredicateBuilder() : id_(1) {} + + absl::StatusOr Build(absl::string_view input_name, + const Message& message) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(message, base_expr, predicates)); + Expr root = LogicalAnd(predicates); + return Ast(std::move(root), std::move(source_info_)); + } + + absl::StatusOr Build(absl::string_view input_name, + absl::Span messages) { + if (messages.empty()) { + return Ast(NewBoolConst(NextId(), true), std::move(source_info_)); + } + + std::vector message_asts; + message_asts.reserve(messages.size()); + for (const auto* message : messages) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(*message, base_expr, predicates)); + message_asts.push_back(LogicalAnd(predicates)); + } + + return Ast(LogicalOr(message_asts), std::move(source_info_)); + } + + private: + // Retrieves the "match_path" string option from the field options if + // defined, returning an empty string otherwise. + std::string GetMatchPath(const ::google::protobuf::FieldDescriptor* field) { + const ::google::protobuf::Message& options = field->options(); + const ::google::protobuf::Reflection* refl = options.GetReflection(); + std::vector fields; + refl->ListFields(options, &fields); + for (const auto* f : fields) { + if (f->name() == "match_path") { + return refl->GetString(options, f); + } + } + return ""; + } + + // Parses a dot-separated string representation of a path (e.g. "dest.region") + // and builds a corresponding select chain AST. + Expr ParseAndBuildPath(absl::string_view path_str) { + std::vector parts = absl::StrSplit(path_str, '.'); + Expr e = NewIdent(NextId(), parts[0]); + for (size_t i = 1; i < parts.size(); ++i) { + e = NewSelect(NextId(), std::move(e), parts[i]); + } + return e; + } + ExprId NextId() { return id_++; } + + // --------------------------------------------------------------------------- + // Field value extraction + // --------------------------------------------------------------------------- + + // Converts a singular field value to a CEL constant expression. + Expr PrimitiveToExpr(ExprId expr_id, const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(expr_id, reflection->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(expr_id, reflection->GetInt64(message, field)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst(expr_id, reflection->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst(expr_id, reflection->GetUInt64(message, field)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst(expr_id, reflection->GetDouble(message, field)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst(expr_id, reflection->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(expr_id, reflection->GetBool(message, field)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst(expr_id, reflection->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = reflection->GetString(message, field); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(expr_id, std::move(str_val)); + } + return NewStringConst(expr_id, std::move(str_val)); + } + default: + // Log a warning as message should be handled by Walk. + ABSL_LOG(WARNING) << "PrimitiveToExpr: Unhandled field type: " + << FieldDescriptor::TypeName(field->type()); + break; + } + return NewNullConst(expr_id); + } + + Expr PrimitiveToExpr(const Message& message, const Reflection* reflection, + const FieldDescriptor* field) { + return PrimitiveToExpr(NextId(), message, reflection, field); + } + + // Converts a repeated field element to a CEL constant expression. + Expr RepeatedPrimitiveToExpr(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, int index) { + const ExprId id = NextId(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(id, + reflection->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(id, + reflection->GetRepeatedInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst( + id, reflection->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst( + id, reflection->GetRepeatedUInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst( + id, reflection->GetRepeatedDouble(message, field, index)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst( + id, reflection->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(id, + reflection->GetRepeatedBool(message, field, index)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst( + id, reflection->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = + reflection->GetRepeatedString(message, field, index); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(id, std::move(str_val)); + } + return NewStringConst(id, std::move(str_val)); + } + default: + break; + } + return NewNullConst(id); + } + + // --------------------------------------------------------------------------- + // Expression construction helpers + // --------------------------------------------------------------------------- + + // Creates a binary operator call: `lhs rhs`. + Expr ConstructBinaryOp(absl::string_view op, Expr lhs, Expr rhs) { + std::vector args = {std::move(lhs), std::move(rhs)}; + return NewCall(NextId(), op, std::move(args)); + } + + Expr ConstructEquality(Expr lhs, Expr rhs) { + return ConstructBinaryOp(CelOperator::EQUALS, std::move(lhs), + std::move(rhs)); + } + + Expr LogicalOr(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_OR, exprs); + } + + Expr LogicalAnd(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_AND, exprs); + } + + // Left-folds a vector of expressions with a binary operator. + // Requires: `exprs` is non-empty. + Expr LogicalOp(absl::string_view op, std::vector& exprs) { + if (exprs.empty()) { + return NewBoolConst(NextId(), true); + } + if (exprs.size() == 1) { + return std::move(exprs[0]); + } + return NewCall(NextId(), op, std::move(exprs)); + } + + // --------------------------------------------------------------------------- + // Map field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds the predicate for a map field to assert that all key-value pairs + // specified in the policy are present in the input map field: + // "key" in input.map && input.map["key"] == value + absl::Status WalkMapField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, const Expr& base_expr, + int size, std::vector& predicates) { + const FieldDescriptor* const key_field = + field->message_type()->FindFieldByName("key"); + const FieldDescriptor* const value_field = + field->message_type()->FindFieldByName("value"); + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + + struct MapEntry { + const Message* message; + }; + std::vector entries; + entries.reserve(size); + for (int i = 0; i < size; ++i) { + entries.push_back({&reflection->GetRepeatedMessage(message, field, i)}); + } + + if (!entries.empty()) { + const Reflection* const entry_ref = entries[0].message->GetReflection(); + std::sort(entries.begin(), entries.end(), + [entry_ref, key_field](const MapEntry& a, const MapEntry& b) { + switch (key_field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return entry_ref->GetInt32(*a.message, key_field) < + entry_ref->GetInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_INT64: + return entry_ref->GetInt64(*a.message, key_field) < + entry_ref->GetInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT32: + return entry_ref->GetUInt32(*a.message, key_field) < + entry_ref->GetUInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT64: + return entry_ref->GetUInt64(*a.message, key_field) < + entry_ref->GetUInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_BOOL: + return !entry_ref->GetBool(*a.message, key_field) && + entry_ref->GetBool(*b.message, key_field); + case FieldDescriptor::CPPTYPE_STRING: + return entry_ref->GetString(*a.message, key_field) < + entry_ref->GetString(*b.message, key_field); + default: + return false; + } + }); + } + + std::vector map_checks; + map_checks.reserve(size); + for (const auto& entry : entries) { + const Message& entry_msg = *entry.message; + const Reflection* const entry_ref = entry_msg.GetReflection(); + + Expr key_expr = PrimitiveToExpr(entry_msg, entry_ref, key_field); + + // Represents `"key" in input.map` to assert the key exists. + Expr in_check = NewCall(NextId(), CelOperator::IN, + std::vector{key_expr, map_path}); + // Represents `input.map["key"]` to lookup the value. + Expr lookup_path = NewCall(NextId(), CelOperator::INDEX, + std::vector{map_path, key_expr}); + + if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& value_msg = + entry_ref->GetMessage(entry_msg, value_field); + std::vector val_predicates; + CEL_RETURN_IF_ERROR(Walk(value_msg, lookup_path, val_predicates)); + + if (!val_predicates.empty()) { + // Represents `"key" in input.map && (nested message fields check...)` + map_checks.push_back(std::move(in_check)); + map_checks.insert(map_checks.end(), + std::make_move_iterator(val_predicates.begin()), + std::make_move_iterator(val_predicates.end())); + } else { + // Represents `"key" in input.map` if nested message is empty. + map_checks.push_back(std::move(in_check)); + } + } else { + Expr value_expr = PrimitiveToExpr(entry_msg, entry_ref, value_field); + // Represents `input.map["key"] == value` + Expr eq_check = + ConstructEquality(std::move(lookup_path), std::move(value_expr)); + + // Represents `"key" in input.map && input.map["key"] == value` + map_checks.push_back(std::move(in_check)); + map_checks.push_back(std::move(eq_check)); + } + } + + predicates.push_back(LogicalAnd(map_checks)); + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Repeated field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds predicates for a repeated field: + // - Repeated Messages are mapped to a logical OR (||) of the generated + // predicates for each message. + // - Repeated Primitives are mapped either to: + // - `lhs in [values]` if a "match_path" option is specified. + // - `value in input.field` conjoined with && for each value otherwise. + absl::Status WalkRepeatedField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, + const Expr& base_expr, int size, + std::vector& predicates) { + if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + std::vector message_asts; + message_asts.reserve(size); + for (int i = 0; i < size; ++i) { + const Message& sub_message = + reflection->GetRepeatedMessage(message, field, i); + std::vector sub_predicates; + Expr sub_base = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, sub_base, sub_predicates)); + message_asts.push_back(LogicalAnd(sub_predicates)); + } + // Represents alternate message predicates conjoined with OR: `msg_1 || + // msg_2 || ...` + predicates.push_back(LogicalOr(message_asts)); + return absl::OkStatus(); + } + + std::vector elements; + elements.reserve(size); + for (int i = 0; i < size; ++i) { + elements.push_back(NewListElement( + RepeatedPrimitiveToExpr(message, reflection, field, i))); + } + Expr literal_list = NewList(NextId(), std::move(elements)); + + std::string match_path_val = GetMatchPath(field); + if (!match_path_val.empty()) { + Expr lhs = ParseAndBuildPath(match_path_val); + // Represents `lhs in [values]` check (e.g. `dest.region in ["us-east", + // "us-west"]`). + predicates.push_back( + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(lhs), std::move(literal_list)})); + return absl::OkStatus(); + } + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + std::vector element_checks; + element_checks.reserve(size); + for (int i = 0; i < size; ++i) { + Expr elem_expr = RepeatedPrimitiveToExpr(message, reflection, field, i); + // Represents `value in input.field` check. + Expr in_check = + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(elem_expr), map_path}); + element_checks.push_back(std::move(in_check)); + } + // Represents `"val1" in input.list && "val2" in input.list && ...` + predicates.push_back(LogicalAnd(element_checks)); + + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Recursive message walk + // --------------------------------------------------------------------------- + + absl::Status Walk(const Message& message, const Expr& base_expr, + std::vector& predicates) { + const Reflection* const reflection = message.GetReflection(); + std::vector fields; + reflection->ListFields(message, &fields); + + for (const auto* field : fields) { + if (field->is_map()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkMapField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkRepeatedField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& sub_message = reflection->GetMessage(message, field); + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, field_path, predicates)); + } else { + // Primitive field: base_expr.field == + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + predicates.push_back( + ConstructEquality(std::move(field_path), + PrimitiveToExpr(message, reflection, field))); + } + } + return absl::OkStatus(); + } + + ExprId id_; + SourceInfo source_info_; +}; + +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, message); +} + +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, messages); +} + +} // namespace cel::tools diff --git a/tools/proto_to_predicate.h b/tools/proto_to_predicate.h new file mode 100644 index 000000000..ed01cb1e8 --- /dev/null +++ b/tools/proto_to_predicate.h @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "google/protobuf/message.h" + +namespace cel::tools { + +// Translates a Protocol Buffer message into a CEL AST representing a predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message); + +// Translates a list of Protocol Buffer messages into a CEL AST representing a +// conjoined or alternate predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages); + +} // namespace cel::tools + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ diff --git a/tools/proto_to_predicate_test.cc b/tools/proto_to_predicate_test.cc new file mode 100644 index 000000000..80ad140c7 --- /dev/null +++ b/tools/proto_to_predicate_test.cc @@ -0,0 +1,593 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/value.h" +#include "env/config.h" +#include "env/env_runtime.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "tools/cel_unparser.h" +#include "tools/testdata/test_policy.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/json/json.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::tools { +namespace { + +using ::absl_testing::IsOk; +using ::google::api::expr::runtime::TestMessage; + +constexpr absl::string_view kEnvYaml = R"( +name: "test" +extensions: + - name: "bindings" + - name: "optional" +variables: + - name: "input" + type: "google.api.expr.runtime.TestMessage" +)"; + +TestMessage ParseTestMessage(absl::string_view textproto) { + TestMessage msg; + google::protobuf::TextFormat::ParseFromString(textproto, &msg); + return msg; +} + +absl::StatusOr EvaluatePredicate(const cel::Ast& ast, + const TestMessage& input) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + + CEL_ASSIGN_OR_RETURN(cel::Config config, + cel::EnvConfigFromYaml(std::string(kEnvYaml))); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::make_unique(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + CEL_ASSIGN_OR_RETURN( + cel::Value val, cel::extensions::ProtoMessageToValue( + input, descriptor_pool.get(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + activation.InsertOrAssignValue("input", val); + + CEL_ASSIGN_OR_RETURN(cel::Value result, + program->Evaluate(&arena, activation)); + if (!result.IsBool()) { + return absl::InvalidArgumentError( + "Predicate evaluate result must be a boolean value."); + } + return result.GetBool(); +} + +struct TestCase { + std::string name; + std::vector input_textprotos; + std::string expected_unparsed; + std::string eval_textproto; + bool expected_eval_result = true; + // If true, skip the eval step of the test. This is useful for tests where + // the expected expression does not share the same type structure as the + // input proto, such as empty messages. + bool skip_eval = false; +}; + +class ProtoToPredicateTest : public ::testing::TestWithParam {}; + +TEST_P(ProtoToPredicateTest, ConformanceTests) { + const TestCase& param = GetParam(); + + std::vector input_messages; + input_messages.reserve(param.input_textprotos.size()); + for (const auto& proto_str : param.input_textprotos) { + input_messages.push_back(ParseTestMessage(proto_str)); + } + + std::vector ptr_messages; + ptr_messages.reserve(input_messages.size()); + for (const auto& msg : input_messages) { + ptr_messages.push_back(&msg); + } + + absl::StatusOr ast_or; + if (input_messages.size() == 1) { + ast_or = ProtoToPredicateAst("input", input_messages[0]); + } else { + ast_or = ProtoToPredicateAst("input", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); + + if (!param.skip_eval) { + TestMessage eval_msg = ParseTestMessage(param.eval_textproto); + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, eval_msg)); + EXPECT_EQ(eval_result, param.expected_eval_result); + } +} + +INSTANTIATE_TEST_SUITE_P( + ProtoToPredicateSubCases, ProtoToPredicateTest, + testing::Values( + TestCase{ + .name = "EmptyMessageTest", + .input_textprotos = {""}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "EmptyMessagesListTest", + .input_textprotos = {}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "PrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 42 string_value: "hello" + )pb", + }, + TestCase{ + .name = "AllPrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.int64_value == 43 && " + "input.uint32_value == 44u && input.uint64_value == 45u && " + "input.float_value == 46.5 && input.double_value == 47.5 && " + "input.string_value == \"hello\" && " + "input.bytes_value == b\"world\" && " + "input.bool_value == true && " + "input.enum_value == 1", + .eval_textproto = R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb", + }, + TestCase{ + .name = "NestedMessageTest", + .input_textprotos = {R"pb( + message_value: { int32_value: 42 } + )pb"}, + .expected_unparsed = "input.message_value.int32_value == 42", + .eval_textproto = R"pb( + message_value: { int32_value: 42 } + )pb", + }, + TestCase{ + .name = "RepeatedFieldTest", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 2 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldSingleElementTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + )pb"}, + .expected_unparsed = "42 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldEmptyTest", + .input_textprotos = {R"pb( + int32_list: [] + )pb"}, + .expected_unparsed = "true", + .eval_textproto = R"pb( + int32_list: [] + )pb", + }, + TestCase{ + .name = "ListFieldEvalNegative", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 3 ] + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "SingleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb"}, + .expected_unparsed = "42 in input.int32_list && " + "43 in input.int64_list && " + "44u in input.uint32_list && " + "45u in input.uint64_list && " + "46.5 in input.float_list && " + "47.5 in input.double_list && " + "\"hello\" in input.string_list && " + "b\"world\" in input.bytes_list && " + "true in input.bool_list && " + "1 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb", + }, + TestCase{ + .name = "MultipleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb"}, + .expected_unparsed = + "42 in input.int32_list && 142 in input.int32_list && " + "43 in input.int64_list && 143 in input.int64_list && " + "44u in input.uint32_list && 144u in input.uint32_list && " + "45u in input.uint64_list && 145u in input.uint64_list && " + "46.5 in input.float_list && 146.5 in input.float_list && " + "47.5 in input.double_list && 147.5 in input.double_list && " + "\"hello\" in input.string_list && \"universe\" in " + "input.string_list && " + "b\"world\" in input.bytes_list && b\"space\" in " + "input.bytes_list && " + "true in input.bool_list && false in input.bool_list && " + "1 in input.enum_list && 2 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb", + }, + TestCase{ + .name = "MapFieldTest", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb", + }, + TestCase{ + .name = "MapFieldEvalNegativeVal", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 3 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldEvalNegativeNoKey", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldIntKeySortingTest", + .input_textprotos = {R"pb( + int32_int32_map: { key: 10 value: 100 } + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + )pb"}, + .expected_unparsed = "5 in input.int32_int32_map && " + "input.int32_int32_map[5] == 50 && " + "8 in input.int32_int32_map && " + "input.int32_int32_map[8] == 80 && " + "10 in input.int32_int32_map && " + "input.int32_int32_map[10] == 100", + .eval_textproto = R"pb( + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + int32_int32_map: { key: 10 value: 100 } + )pb", + }, + TestCase{ + .name = "MultipleMessagesTest", + .input_textprotos = {R"pb( + int32_value: 42 + )pb", + R"pb( + int32_value: 41 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 || input.int32_value == 41 && " + "input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 41 string_value: "hello" + )pb", + }, + TestCase{ + .name = "RepeatedMessageFieldTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 } + , { int32_value: 43 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42 || " + "input.message_list.int32_value == 43", + .skip_eval = true, + }, + TestCase{ + .name = "RepeatedMessageSingleElementTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42", + .skip_eval = true, + })); + +struct PolicyTestCase { + std::string name; + std::string json_input; + std::string expected_unparsed; +}; + +class PolicyJsonTest : public ::testing::TestWithParam {}; + +TEST_P(PolicyJsonTest, Conformance) { + const PolicyTestCase& param = GetParam(); + + cel::cpp::tools::Policy policy; + google::protobuf::json::ParseOptions options; + options.ignore_unknown_fields = true; + auto status = + google::protobuf::json::JsonStringToMessage(param.json_input, &policy, options); + ASSERT_THAT(status, IsOk()) << "Failed to parse JSON: " << param.json_input; + + absl::StatusOr ast_or; + std::vector ptr_messages; + ptr_messages.reserve(policy.destinations_size()); + for (const auto& dest : policy.destinations()) { + ptr_messages.push_back(&dest); + } + + if (ptr_messages.empty()) { + auto parsed_expr_or = google::api::expr::parser::Parse("false"); + ASSERT_THAT(parsed_expr_or, IsOk()); + auto ast_ptr_or = cel::CreateAstFromParsedExpr(*parsed_expr_or); + ASSERT_THAT(ast_ptr_or, IsOk()); + ast_or = std::move(**ast_ptr_or); + } else if (ptr_messages.size() == 1) { + ast_or = ProtoToPredicateAst("dest", *ptr_messages[0]); + } else { + ast_or = ProtoToPredicateAst("dest", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); +} + +INSTANTIATE_TEST_SUITE_P( + PolicyJsonSubCases, PolicyJsonTest, + testing::Values( + PolicyTestCase{ + .name = "SimpleMatch", + .json_input = + R"({ "destinations": [ { "agent": { "id": "agent-007" } } ] })", + .expected_unparsed = "dest.agent.name == \"agent-007\"", + }, + PolicyTestCase{ + .name = "MultipleFields", + .json_input = + R"({ "destinations": [ { + "tool": { + "name": "admin_tool", + "annotations": { + "read_only_hint": false + } + } + } + ] })", + .expected_unparsed = + "dest.tool.name == \"admin_tool\" && " + "dest.tool.annotations.read_only_hint == false", + }, + PolicyTestCase{ + .name = "RepeatedMessages", + .json_input = + R"({ "destinations": [ + { "agent": { "id": "worker-1" } }, + { "agent": { "id": "worker-2" } }, + ] })", + .expected_unparsed = "dest.agent.name == \"worker-1\" || " + "dest.agent.name == \"worker-2\"", + }, + PolicyTestCase{ + .name = "RepeatedPrimitiveArraySingleElement", + .json_input = + R"({ "destinations": [ { + "tool": { + "role_members": { + "admin": { + "principals": ["alice"] + } + } + } + } ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "\"alice\" in dest.tool.role_members[\"admin\"].principals", + }, + PolicyTestCase{ + .name = "RepeatedArrayEmpty", + .json_input = R"({ "destinations": [ { "tool": { } } ] })", + .expected_unparsed = "true", + }, + PolicyTestCase{ + .name = "MapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "name": "shell", + "labels": { + "cluster": "us-central1", + "project": "dev" + } + } + } ] })", + .expected_unparsed = + "dest.tool.name == \"shell\" && \"cluster\" in " + "dest.tool.labels && dest.tool.labels[\"cluster\"] == " + "\"us-central1\" && \"project\" in dest.tool.labels && " + "dest.tool.labels[\"project\"] == \"dev\"", + }, + PolicyTestCase{ + .name = "NestedMapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "role_members": { + "admin": { + "all_users": true + } + } + } } + ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "dest.tool.role_members[\"admin\"].all_users == true", + }, + PolicyTestCase{ + .name = "EmptyPolicy", + .json_input = "{}", + .expected_unparsed = "false", + })); + +} // namespace +} // namespace cel::tools diff --git a/tools/testdata/BUILD b/tools/testdata/BUILD new file mode 100644 index 000000000..c88c9c478 --- /dev/null +++ b/tools/testdata/BUILD @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public") +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +flatbuffer_library_public( + name = "flatbuffers_test", + srcs = ["flatbuffers.fbs"], + outs = ["flatbuffers_generated.h"], + language_flag = "-c", + reflection_name = "flatbuffers_reflection", +) + +filegroup( + name = "coverage_testdata", + srcs = [ + "coverage_example.textproto", + "exists_macro.textproto", + ], +) + +cc_library( + name = "flatbuffers_test_cc", + srcs = [":flatbuffers_test"], + hdrs = [":flatbuffers_test"], + features = ["-parse_headers"], + linkstatic = True, + deps = ["@com_github_google_flatbuffers//:runtime_cc"], +) + +proto_library( + name = "test_policy_proto", + srcs = ["test_policy.proto"], + visibility = ["//tools:__subpackages__"], +) + +cc_proto_library( + name = "test_policy_cc_proto", + visibility = ["//tools:__subpackages__"], + deps = [":test_policy_proto"], +) diff --git a/tools/testdata/checked_expr_and.textproto b/tools/testdata/checked_expr_and.textproto new file mode 100644 index 000000000..317b4419a --- /dev/null +++ b/tools/testdata/checked_expr_and.textproto @@ -0,0 +1,73 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# x && y +reference_map { + key: 1 + value { + name: "x" + } +} +reference_map { + key: 2 + value { + name: "y" + } +} +reference_map { + key: 3 + value { + overload_id: "logical_and" + } +} +type_map { + key: 1 + value { + primitive: BOOL + } +} +type_map { + key: 2 + value { + primitive: BOOL + } +} +type_map { + key: 3 + value { + primitive: BOOL + } +} +expr { + id: 3 + call_expr { + function: "_&&_" + args { + id: 1 + ident_expr { + name: "x" + } + } + args { + id: 2 + ident_expr { + name: "y" + } + } + } +} +source_info { + location: "" + line_offsets: 7 + positions { + key: 1 + value: 0 + } + positions { + key: 2 + value: 5 + } + positions { + key: 3 + value: 2 + } +} diff --git a/tools/testdata/const_str.textproto b/tools/testdata/const_str.textproto new file mode 100644 index 000000000..ca8a8986d --- /dev/null +++ b/tools/testdata/const_str.textproto @@ -0,0 +1,23 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +type_map { + key: 1 + value { + primitive: STRING + } +} +expr { + id: 1 + const_expr { + string_value: "127.0.0.1" + } +} +source_info { + location: "" + line_offsets: 12 + positions { + key: 1 + value: 0 + } +} + diff --git a/tools/testdata/coverage_example.textproto b/tools/testdata/coverage_example.textproto new file mode 100644 index 000000000..39490586a --- /dev/null +++ b/tools/testdata/coverage_example.textproto @@ -0,0 +1,494 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# +# int1 < int2 && +# (43 > 42) && +# !(bool1 || bool2) && +# 4 / int_divisor >= 1 && +# (ternary_c ? ternary_t : ternary_f) +reference_map: { + key: 1 + value: { + name: "int1" + } +} +reference_map: { + key: 2 + value: { + overload_id: "less_int64" + } +} +reference_map: { + key: 3 + value: { + name: "int2" + } +} +reference_map: { + key: 5 + value: { + overload_id: "greater_int64" + } +} +reference_map: { + key: 7 + value: { + overload_id: "logical_and" + } +} +reference_map: { + key: 8 + value: { + overload_id: "logical_not" + } +} +reference_map: { + key: 9 + value: { + name: "bool1" + } +} +reference_map: { + key: 10 + value: { + name: "bool2" + } +} +reference_map: { + key: 11 + value: { + overload_id: "logical_or" + } +} +reference_map: { + key: 12 + value: { + overload_id: "logical_and" + } +} +reference_map: { + key: 14 + value: { + overload_id: "divide_int64" + } +} +reference_map: { + key: 15 + value: { + name: "int_divisor" + } +} +reference_map: { + key: 16 + value: { + overload_id: "greater_equals_int64" + } +} +reference_map: { + key: 18 + value: { + overload_id: "logical_and" + } +} +reference_map: { + key: 19 + value: { + name: "ternary_c" + } +} +reference_map: { + key: 20 + value: { + overload_id: "conditional" + } +} +reference_map: { + key: 21 + value: { + name: "ternary_t" + } +} +reference_map: { + key: 22 + value: { + name: "ternary_f" + } +} +reference_map: { + key: 23 + value: { + overload_id: "logical_and" + } +} +type_map: { + key: 1 + value: { + primitive: INT64 + } +} +type_map: { + key: 2 + value: { + primitive: BOOL + } +} +type_map: { + key: 3 + value: { + primitive: INT64 + } +} +type_map: { + key: 4 + value: { + primitive: INT64 + } +} +type_map: { + key: 5 + value: { + primitive: BOOL + } +} +type_map: { + key: 6 + value: { + primitive: INT64 + } +} +type_map: { + key: 7 + value: { + primitive: BOOL + } +} +type_map: { + key: 8 + value: { + primitive: BOOL + } +} +type_map: { + key: 9 + value: { + primitive: BOOL + } +} +type_map: { + key: 10 + value: { + primitive: BOOL + } +} +type_map: { + key: 11 + value: { + primitive: BOOL + } +} +type_map: { + key: 12 + value: { + primitive: BOOL + } +} +type_map: { + key: 13 + value: { + primitive: INT64 + } +} +type_map: { + key: 14 + value: { + primitive: INT64 + } +} +type_map: { + key: 15 + value: { + primitive: INT64 + } +} +type_map: { + key: 16 + value: { + primitive: BOOL + } +} +type_map: { + key: 17 + value: { + primitive: INT64 + } +} +type_map: { + key: 18 + value: { + primitive: BOOL + } +} +type_map: { + key: 19 + value: { + primitive: BOOL + } +} +type_map: { + key: 20 + value: { + primitive: BOOL + } +} +type_map: { + key: 21 + value: { + primitive: BOOL + } +} +type_map: { + key: 22 + value: { + primitive: BOOL + } +} +type_map: { + key: 23 + value: { + primitive: BOOL + } +} +source_info: { + location: "" + line_offsets: 109 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 5 + } + positions: { + key: 3 + value: 7 + } + positions: { + key: 4 + value: 16 + } + positions: { + key: 5 + value: 19 + } + positions: { + key: 6 + value: 21 + } + positions: { + key: 7 + value: 12 + } + positions: { + key: 8 + value: 28 + } + positions: { + key: 9 + value: 30 + } + positions: { + key: 10 + value: 39 + } + positions: { + key: 11 + value: 36 + } + positions: { + key: 12 + value: 25 + } + positions: { + key: 13 + value: 49 + } + positions: { + key: 14 + value: 51 + } + positions: { + key: 15 + value: 53 + } + positions: { + key: 16 + value: 65 + } + positions: { + key: 17 + value: 68 + } + positions: { + key: 18 + value: 46 + } + positions: { + key: 19 + value: 74 + } + positions: { + key: 20 + value: 84 + } + positions: { + key: 21 + value: 86 + } + positions: { + key: 22 + value: 98 + } + positions: { + key: 23 + value: 70 + } +} +expr: { + id: 18 + call_expr: { + function: "_&&_" + args: { + id: 12 + call_expr: { + function: "_&&_" + args: { + id: 7 + call_expr: { + function: "_&&_" + args: { + id: 2 + call_expr: { + function: "_<_" + args: { + id: 1 + ident_expr: { + name: "int1" + } + } + args: { + id: 3 + ident_expr: { + name: "int2" + } + } + } + } + args: { + id: 5 + call_expr: { + function: "_>_" + args: { + id: 4 + const_expr: { + int64_value: 43 + } + } + args: { + id: 6 + const_expr: { + int64_value: 42 + } + } + } + } + } + } + args: { + id: 8 + call_expr: { + function: "!_" + args: { + id: 11 + call_expr: { + function: "_||_" + args: { + id: 9 + ident_expr: { + name: "bool1" + } + } + args: { + id: 10 + ident_expr: { + name: "bool2" + } + } + } + } + } + } + } + } + args: { + id: 23 + call_expr: { + function: "_&&_" + args: { + id: 16 + call_expr: { + function: "_>=_" + args: { + id: 14 + call_expr: { + function: "_/_" + args: { + id: 13 + const_expr: { + int64_value: 4 + } + } + args: { + id: 15 + ident_expr: { + name: "int_divisor" + } + } + } + } + args: { + id: 17 + const_expr: { + int64_value: 1 + } + } + } + } + args: { + id: 20 + call_expr: { + function: "_?_:_" + args: { + id: 19 + ident_expr: { + name: "ternary_c" + } + } + args: { + id: 21 + ident_expr: { + name: "ternary_t" + } + } + args: { + id: 22 + ident_expr: { + name: "ternary_f" + } + } + } + } + } + } + } +} diff --git a/tools/testdata/exists_macro.textproto b/tools/testdata/exists_macro.textproto new file mode 100644 index 000000000..2cc2043e8 --- /dev/null +++ b/tools/testdata/exists_macro.textproto @@ -0,0 +1,319 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr + +# [1].exists(x, x == 1) +reference_map: { + key: 5 + value: { + name: "x" + } +} +reference_map: { + key: 6 + value: { + overload_id: "equals" + } +} +reference_map: { + key: 9 + value: { + name: "__result__" + } +} +reference_map: { + key: 10 + value: { + overload_id: "logical_not" + } +} +reference_map: { + key: 11 + value: { + overload_id: "not_strictly_false" + } +} +reference_map: { + key: 12 + value: { + name: "__result__" + } +} +reference_map: { + key: 13 + value: { + overload_id: "logical_or" + } +} +reference_map: { + key: 14 + value: { + name: "__result__" + } +} +type_map: { + key: 1 + value: { + list_type: { + elem_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 2 + value: { + primitive: INT64 + } +} +type_map: { + key: 5 + value: { + primitive: INT64 + } +} +type_map: { + key: 6 + value: { + primitive: BOOL + } +} +type_map: { + key: 7 + value: { + primitive: INT64 + } +} +type_map: { + key: 8 + value: { + primitive: BOOL + } +} +type_map: { + key: 9 + value: { + primitive: BOOL + } +} +type_map: { + key: 10 + value: { + primitive: BOOL + } +} +type_map: { + key: 11 + value: { + primitive: BOOL + } +} +type_map: { + key: 12 + value: { + primitive: BOOL + } +} +type_map: { + key: 13 + value: { + primitive: BOOL + } +} +type_map: { + key: 14 + value: { + primitive: BOOL + } +} +type_map: { + key: 15 + value: { + primitive: BOOL + } +} +source_info: { + location: "" + line_offsets: 22 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 1 + } + positions: { + key: 3 + value: 10 + } + positions: { + key: 4 + value: 11 + } + positions: { + key: 5 + value: 14 + } + positions: { + key: 6 + value: 16 + } + positions: { + key: 7 + value: 19 + } + positions: { + key: 8 + value: 10 + } + positions: { + key: 9 + value: 10 + } + positions: { + key: 10 + value: 10 + } + positions: { + key: 11 + value: 10 + } + positions: { + key: 12 + value: 10 + } + positions: { + key: 13 + value: 10 + } + positions: { + key: 14 + value: 10 + } + positions: { + key: 15 + value: 10 + } + macro_calls: { + key: 15 + value: { + call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + const_expr: { + int64_value: 1 + } + } + } + } + function: "exists" + args: { + id: 4 + ident_expr: { + name: "x" + } + } + args: { + id: 6 + call_expr: { + function: "_==_" + args: { + id: 5 + ident_expr: { + name: "x" + } + } + args: { + id: 7 + const_expr: { + int64_value: 1 + } + } + } + } + } + } + } +} +expr: { + id: 15 + comprehension_expr: { + iter_var: "x" + iter_range: { + id: 1 + list_expr: { + elements: { + id: 2 + const_expr: { + int64_value: 1 + } + } + } + } + accu_var: "__result__" + accu_init: { + id: 8 + const_expr: { + bool_value: false + } + } + loop_condition: { + id: 11 + call_expr: { + function: "@not_strictly_false" + args: { + id: 10 + call_expr: { + function: "!_" + args: { + id: 9 + ident_expr: { + name: "__result__" + } + } + } + } + } + } + loop_step: { + id: 13 + call_expr: { + function: "_||_" + args: { + id: 12 + ident_expr: { + name: "__result__" + } + } + args: { + id: 6 + call_expr: { + function: "_==_" + args: { + id: 5 + ident_expr: { + name: "x" + } + } + args: { + id: 7 + const_expr: { + int64_value: 1 + } + } + } + } + } + } + result: { + id: 14 + ident_expr: { + name: "__result__" + } + } + } +} diff --git a/tools/testdata/macro_multiple_references.textproto b/tools/testdata/macro_multiple_references.textproto new file mode 100644 index 000000000..1ad355c5a --- /dev/null +++ b/tools/testdata/macro_multiple_references.textproto @@ -0,0 +1,396 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# has(msg.old_field) || has(msg.old_field) || +# math.least(msg.old_field, msg.old_field) < 0 +reference_map: { + key: 2 + value: { + name: "msg" + } +} +reference_map: { + key: 6 + value: { + name: "msg" + } +} +reference_map: { + key: 9 + value: { + overload_id: "logical_or" + } +} +reference_map: { + key: 12 + value: { + name: "msg" + } +} +reference_map: { + key: 14 + value: { + name: "msg" + } +} +reference_map: { + key: 16 + value: { + overload_id: "math_@min_int_int" + } +} +reference_map: { + key: 17 + value: { + overload_id: "less_int64" + } +} +reference_map: { + key: 19 + value: { + overload_id: "logical_or" + } +} +type_map: { + key: 2 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 4 + value: { + primitive: BOOL + } +} +type_map: { + key: 6 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 8 + value: { + primitive: BOOL + } +} +type_map: { + key: 9 + value: { + primitive: BOOL + } +} +type_map: { + key: 12 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 13 + value: { + primitive: INT64 + } +} +type_map: { + key: 14 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 15 + value: { + primitive: INT64 + } +} +type_map: { + key: 16 + value: { + primitive: INT64 + } +} +type_map: { + key: 17 + value: { + primitive: BOOL + } +} +type_map: { + key: 18 + value: { + primitive: INT64 + } +} +type_map: { + key: 19 + value: { + primitive: BOOL + } +} +source_info: { + location: "" + line_offsets: 89 + positions: { + key: 1 + value: 3 + } + positions: { + key: 2 + value: 4 + } + positions: { + key: 3 + value: 7 + } + positions: { + key: 4 + value: 3 + } + positions: { + key: 5 + value: 25 + } + positions: { + key: 6 + value: 26 + } + positions: { + key: 7 + value: 29 + } + positions: { + key: 8 + value: 25 + } + positions: { + key: 9 + value: 19 + } + positions: { + key: 10 + value: 44 + } + positions: { + key: 11 + value: 54 + } + positions: { + key: 12 + value: 55 + } + positions: { + key: 13 + value: 58 + } + positions: { + key: 14 + value: 70 + } + positions: { + key: 15 + value: 73 + } + positions: { + key: 16 + value: 54 + } + positions: { + key: 17 + value: 85 + } + positions: { + key: 18 + value: 87 + } + positions: { + key: 19 + value: 41 + } + macro_calls: { + key: 4 + value: { + call_expr: { + function: "has" + args: { + id: 3 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } + macro_calls: { + key: 8 + value: { + call_expr: { + function: "has" + args: { + id: 7 + select_expr: { + operand: { + id: 6 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } + macro_calls: { + key: 16 + value: { + call_expr: { + target: { + id: 10 + ident_expr: { + name: "math" + } + } + function: "least" + args: { + id: 13 + select_expr: { + operand: { + id: 12 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 15 + select_expr: { + operand: { + id: 14 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } +} +expr: { + id: 19 + call_expr: { + function: "_||_" + args: { + id: 9 + call_expr: { + function: "_||_" + args: { + id: 4 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } + } + args: { + id: 8 + select_expr: { + operand: { + id: 6 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } + } + } + } + args: { + id: 17 + call_expr: { + function: "_<_" + args: { + id: 16 + call_expr: { + function: "math.@min" + args: { + id: 13 + select_expr: { + operand: { + id: 12 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 15 + select_expr: { + operand: { + id: 14 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + args: { + id: 18 + const_expr: { + int64_value: 0 + } + } + } + } + } +} diff --git a/tools/testdata/macro_nested_macro_call.textproto b/tools/testdata/macro_nested_macro_call.textproto new file mode 100644 index 000000000..11bdf7f6f --- /dev/null +++ b/tools/testdata/macro_nested_macro_call.textproto @@ -0,0 +1,257 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# math.least(has(msg.old_field) ? msg.old_field : 0, 1) +reference_map: { + key: 4 + value: { + name: "msg" + } +} +reference_map: { + key: 7 + value: { + overload_id: "conditional" + } +} +reference_map: { + key: 8 + value: { + name: "msg" + } +} +reference_map: { + key: 12 + value: { + overload_id: "math_@min_int_int" + } +} +type_map: { + key: 4 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 6 + value: { + primitive: BOOL + } +} +type_map: { + key: 7 + value: { + primitive: INT64 + } +} +type_map: { + key: 8 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 9 + value: { + primitive: INT64 + } +} +type_map: { + key: 10 + value: { + primitive: INT64 + } +} +type_map: { + key: 11 + value: { + primitive: INT64 + } +} +type_map: { + key: 12 + value: { + primitive: INT64 + } +} +source_info: { + location: "" + line_offsets: 54 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 10 + } + positions: { + key: 3 + value: 14 + } + positions: { + key: 4 + value: 15 + } + positions: { + key: 5 + value: 18 + } + positions: { + key: 6 + value: 14 + } + positions: { + key: 7 + value: 30 + } + positions: { + key: 8 + value: 32 + } + positions: { + key: 9 + value: 35 + } + positions: { + key: 10 + value: 48 + } + positions: { + key: 11 + value: 51 + } + positions: { + key: 12 + value: 10 + } + macro_calls: { + key: 6 + value: { + call_expr: { + function: "has" + args: { + id: 5 + select_expr: { + operand: { + id: 4 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } + macro_calls: { + key: 12 + value: { + call_expr: { + target: { + id: 1 + ident_expr: { + name: "math" + } + } + function: "least" + args: { + id: 7 + call_expr: { + function: "_?_:_" + args: { + id: 6 + } + args: { + id: 9 + select_expr: { + operand: { + id: 8 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 10 + const_expr: { + int64_value: 0 + } + } + } + } + args: { + id: 11 + const_expr: { + int64_value: 1 + } + } + } + } + } +} +expr: { + id: 12 + call_expr: { + function: "math.@min" + args: { + id: 7 + call_expr: { + function: "_?_:_" + args: { + id: 6 + select_expr: { + operand: { + id: 4 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } + } + args: { + id: 9 + select_expr: { + operand: { + id: 8 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 10 + const_expr: { + int64_value: 0 + } + } + } + } + args: { + id: 11 + const_expr: { + int64_value: 1 + } + } + } +} diff --git a/tools/testdata/macro_single_reference.textproto b/tools/testdata/macro_single_reference.textproto new file mode 100644 index 000000000..f34c21ad9 --- /dev/null +++ b/tools/testdata/macro_single_reference.textproto @@ -0,0 +1,81 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# has(msg.old_field) +reference_map: { + key: 2 + value: { + name: "msg" + } +} +type_map: { + key: 2 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: STRING + } + } + } +} +type_map: { + key: 4 + value: { + primitive: BOOL + } +} +source_info: { + location: "" + line_offsets: 15 + positions: { + key: 1 + value: 3 + } + positions: { + key: 2 + value: 4 + } + positions: { + key: 3 + value: 7 + } + positions: { + key: 4 + value: 3 + } + macro_calls: { + key: 4 + value: { + call_expr: { + function: "has" + args: { + id: 3 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } +} +expr: { + id: 4 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } +} diff --git a/tools/testdata/msg_new_field.textproto b/tools/testdata/msg_new_field.textproto new file mode 100644 index 000000000..3676d03a0 --- /dev/null +++ b/tools/testdata/msg_new_field.textproto @@ -0,0 +1,52 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# msg.new_field +reference_map: { + key: 1 + value: { + name: "msg" + } +} +type_map: { + key: 1 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: STRING + } + } + } +} +type_map: { + key: 2 + value: { + primitive: STRING + } +} +source_info: { + location: "" + line_offsets: 10 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 3 + } +} +expr: { + id: 2 + select_expr: { + operand: { + id: 1 + ident_expr: { + name: "msg" + } + } + field: "new_field" + } +} diff --git a/tools/testdata/msg_new_field_int.textproto b/tools/testdata/msg_new_field_int.textproto new file mode 100644 index 000000000..c7fd9bb43 --- /dev/null +++ b/tools/testdata/msg_new_field_int.textproto @@ -0,0 +1,52 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# msg.new_field +reference_map: { + key: 1 + value: { + name: "msg" + } +} +type_map: { + key: 1 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 2 + value: { + primitive: INT64 + } +} +source_info: { + location: "" + line_offsets: 14 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 3 + } +} +expr: { + id: 2 + select_expr: { + operand: { + id: 1 + ident_expr: { + name: "msg" + } + } + field: "new_field" + } +} diff --git a/tools/testdata/test_policy.proto b/tools/testdata/test_policy.proto new file mode 100644 index 000000000..b5d424c04 --- /dev/null +++ b/tools/testdata/test_policy.proto @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test schema representing client-configured policies. +// It is used by the `proto_to_predicate` tool to translate Protobuf policies +// into CEL predicates. +edition = "2023"; + +package cel.cpp.tools; + +option cc_enable_arenas = true; + +// Represents the targeted client agent. +message Agent { + string name = 1 [json_name = "id"]; +} + +// Specifies additional metadata tool annotations. +message ToolAnnotations { + bool read_only_hint = 1; +} + +// Represents a mapped nested message entry value inside map fields. +message Members { + repeated string principals = 1; + + repeated string regions = 2; + + bool all_users = 3; + + bool all_authenticated_users = 4; +} + +// Represents a metadata tool block. +message Tool { + // The name of the tool. + string name = 1; + + // Additional metadata annotations for the tool. + ToolAnnotations annotations = 2; + + // A string-to-string map, transpiled as conjoined existence and equality + // checks. + map labels = 3; + + // A map with string keys representing roles and Member instances as values. + map role_members = 4; +} + +// Represents a policy mapping destination block. +message Target { + oneof kind { + Agent agent = 1; + Tool tool = 2; + } +} + +// Represents the top-level policy containing multiple alternate destination +// rules. +message Policy { + repeated Target destinations = 1; +} diff --git a/v1beta1/BUILD b/v1beta1/BUILD deleted file mode 100644 index 02cd9fad9..000000000 --- a/v1beta1/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -# Description -# Libraries for working with the v1beta1 API. -# -# Uses the namespace google::api:expr::v1beta1. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -cc_library( - name = "converters", - srcs = [ - "converters.cc", - ], - hdrs = [ - "converters.h", - ], - deps = [ - "//common:converters", - "//common:macros", - "//common:value", - "//internal:holder", - "//internal:map_impl", - "//internal:proto_util", - "//internal:status_util", - "//protoutil:type_registry", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1beta1:eval_cc_proto", - "@com_google_googleapis//google/api/expr/v1beta1:value_cc_proto", - "@com_google_googleapis//google/rpc:code_cc_proto", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "converters_test", - srcs = ["converters_test.cc"], - data = [ - "@com_google_cel_spec//testdata", - ], - deps = [ - ":converters", - "//common:value", - "//internal:status_util", - "//protoutil:converters", - "//protoutil:type_registry", - "//testutil:test_data_io", - "//testutil:test_data_util", - "@com_google_cel_spec//testdata:test_data_cc_proto", - "@com_google_googleapis//google/type:money_cc_proto", - "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", - ], -) diff --git a/v1beta1/converters.cc b/v1beta1/converters.cc deleted file mode 100644 index 97e6b3b8f..000000000 --- a/v1beta1/converters.cc +++ /dev/null @@ -1,379 +0,0 @@ -#include "v1beta1/converters.h" - -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/rpc/code.pb.h" -#include "absl/container/node_hash_map.h" -#include "absl/strings/str_cat.h" -#include "common/converters.h" -#include "common/macros.h" -#include "internal/holder.h" -#include "internal/map_impl.h" -#include "internal/proto_util.h" -#include "internal/status_util.h" - -namespace google { -namespace api { -namespace expr { -namespace v1beta1 { - -using expr::internal::EncodeDuration; -using expr::internal::EncodeTime; -using expr::internal::IsOk; -using expr::internal::OkStatus; -using protoutil::TypeRegistry; - -namespace { - -/** A visitor that encodes ExprValues. */ -struct ToExprValue { - common::Value from; - v1beta1::ExprValue* result; - - google::rpc::Status operator()(std::nullptr_t value) { - result->mutable_value()->set_null_value( - ::google::protobuf::NullValue::NULL_VALUE); - return OkStatus(); - } - - google::rpc::Status operator()(bool value) { - result->mutable_value()->set_bool_value(value); - return OkStatus(); - } - - google::rpc::Status operator()(int64_t value) { - result->mutable_value()->set_int64_value(value); - return OkStatus(); - } - - google::rpc::Status operator()(uint64_t value) { - result->mutable_value()->set_uint64_value(value); - return OkStatus(); - } - - google::rpc::Status operator()(double value) { - result->mutable_value()->set_double_value(value); - return OkStatus(); - } - - google::rpc::Status operator()(const common::EnumValue& value) { - auto enum_value = result->mutable_value()->mutable_enum_value(); - enum_value->set_value(value.value()); - enum_value->set_type(std::string(value.type().full_name())); - return OkStatus(); - } - - google::rpc::Status operator()(absl::string_view value) { - if (from.kind() == common::Value::Kind::kBytes) { - result->mutable_value()->set_bytes_value(std::string(value)); - } else { - result->mutable_value()->set_string_value(std::string(value)); - } - return OkStatus(); - } - - bool CheckAndEncodeIfError(const google::rpc::Status& value) { - if (value.code() != google::rpc::Code::OK) { - *result->mutable_error()->add_errors() = value; - return false; - } - return true; - } - - void EncodeMessage(const google::protobuf::Message& value) { - result->mutable_value()->mutable_object_value()->PackFrom(value); - } - - google::rpc::Status EncodeValue(const common::Value& value, - v1beta1::Value* sub_value) { - ExprValue expr_value; - auto status = ValueTo(value, &expr_value); - if (IsOk(status)) { - sub_value->Swap(expr_value.mutable_value()); - return status; - } - *result = expr_value; - return status; - } - - google::rpc::Status operator()(absl::Duration value) { - google::protobuf::Duration duration; - auto status = EncodeDuration(value, &duration); - if (CheckAndEncodeIfError(status)) { - EncodeMessage(duration); - } - return status; - } - - google::rpc::Status operator()(absl::Time value) { - google::protobuf::Timestamp time; - auto status = EncodeTime(value, &time); - if (CheckAndEncodeIfError(status)) { - EncodeMessage(time); - } - return status; - } - - google::rpc::Status operator()(const common::List& value) { - auto& list_value = *result->mutable_value()->mutable_list_value(); - return value.ForEach([this, &list_value](const common::Value& elem) { - return EncodeValue(elem, list_value.add_values()); - }); - } - - google::rpc::Status operator()(const common::Map& value) { - auto& map_value = *result->mutable_value()->mutable_map_value(); - return value.ForEach([this, &map_value](const common::Value& key, - const common::Value& value) { - auto& entry = *map_value.add_entries(); - RETURN_IF_STATUS_ERROR(EncodeValue(key, entry.mutable_key())); - return EncodeValue(value, entry.mutable_value()); - }); - } - - google::rpc::Status operator()(const common::Object& value) { - value.To(result->mutable_value()->mutable_object_value()); - return OkStatus(); - } - - google::rpc::Status operator()(const common::Type& value) { - result->mutable_value()->set_type_value(std::string(value.full_name())); - return OkStatus(); - } - - google::rpc::Status operator()(const common::Unknown& value) { - auto& unknown = *result->mutable_unknown(); - for (const auto& id : value.ids()) { - unknown.add_exprs()->set_id(id.value()); - } - return OkStatus(); - } - - google::rpc::Status operator()(const common::Error& value) { - auto& error_set = *result->mutable_error(); - for (const auto& error : value.errors()) { - *error_set.add_errors() = error; - } - return OkStatus(); - } -}; - -/** - * Creates a new common::Value potentially with a reference on parent, if not - * null. - */ -common::Value ValueFor(const v1beta1::Value* value, common::ParentRef parent, - const TypeRegistry* registry); - -template -class ListValue final : public common::List { - public: - template - explicit ListValue(const TypeRegistry* registry, Args&&... args) - : registry_(registry), holder_(std::forward(args)...) {} - - std::size_t size() const override { return holder_.value().values_size(); } - - common::Value Get(std::size_t index) const override { - if (index >= static_cast(holder_.value().values_size())) { - return common::Value::FromError( - internal::OutOfRangeError(index, holder_.value().values_size())); - } - return ValueFor(&holder_.value().values(index), SelfRefProvider(), - registry_); - } - - google::rpc::Status ForEach( - const std::function& call) - const override { - auto ref = SelfRefProvider(); - for (const auto& elem : holder_.value().values()) { - RETURN_IF_STATUS_ERROR(call(ValueFor(&elem, ref, registry_))); - } - return OkStatus(); - } - - bool owns_value() const override { return HolderPolicy::kOwnsValue; } - - private: - const TypeRegistry* registry_; - internal::Holder holder_; -}; - -using ListValueCopy = ListValue; -using ListValueOwned = ListValue; - -common::Value BuildMapFor(const v1beta1::MapValue* map_value, - common::ParentRef parent, - const TypeRegistry* registry) { - absl::node_hash_map result; - for (const auto& entry : map_value->entries()) { - result.emplace(ValueFor(&entry.key(), parent, registry), - ValueFor(&entry.value(), parent, registry)); - } - // The keys and values grabbed a ref on parent if needed, so we don't need one - // separately. - return common::Value::MakeMap(std::move(result)); -} - -common::Value BuildMapFrom(v1beta1::MapValue&& map_value, - const TypeRegistry* registry) { - absl::node_hash_map result; - for (v1beta1::MapValue::Entry& entry : *map_value.mutable_entries()) { - result.emplace( - ValueFrom(absl::WrapUnique(entry.release_key()), registry), - ValueFrom(absl::WrapUnique(entry.release_value()), registry)); - } - return common::Value::MakeMap(std::move(result)); -} - -common::Value BuildMapFrom(const v1beta1::MapValue& map_value, - const TypeRegistry* registry) { - absl::node_hash_map result; - for (auto& entry : map_value.entries()) { - result.emplace(ValueFrom(entry.key(), registry), - ValueFrom(entry.value(), registry)); - } - return common::Value::MakeMap(std::move(result)); -} - -} // namespace - -common::Value ValueFrom(const v1beta1::Value& value, - const TypeRegistry* registry) { - switch (value.kind_case()) { - case v1beta1::Value::kNullValue: - return common::Value::NullValue(); - case v1beta1::Value::kBoolValue: - return common::Value::FromBool(value.bool_value()); - case v1beta1::Value::kInt64Value: - return common::Value::FromInt(value.int64_value()); - case v1beta1::Value::kUint64Value: - return common::Value::FromUInt(value.uint64_value()); - case v1beta1::Value::kDoubleValue: - return common::Value::FromDouble(value.double_value()); - case v1beta1::Value::kStringValue: - return common::Value::FromString(value.string_value()); - case v1beta1::Value::kBytesValue: - return common::Value::FromBytes(value.bytes_value()); - case v1beta1::Value::kTypeValue: - return common::Value::FromType(value.type_value()); - case v1beta1::Value::kListValue: - return common::Value::MakeList(registry, - value.list_value()); - case v1beta1::Value::kObjectValue: - return registry->ValueFrom(value.object_value()); - case v1beta1::Value::kMapValue: - return BuildMapFrom(value.map_value(), registry); - default: - return common::Value::FromError(internal::UnimplementedError( - absl::StrCat("Unimplemented value kind: ", value.kind_case()))); - } -} - -common::Value ValueFrom(const v1beta1::ExprValue& value, - const TypeRegistry* registry) { - switch (value.kind_case()) { - case v1beta1::ExprValue::kValue: - return ValueFrom(value.value(), registry); - case v1beta1::ExprValue::kError: - return common::Value::FromError(common::Error(value.error().errors())); - case v1beta1::ExprValue::kUnknown: { - std::vector ids; - ids.reserve(value.unknown().exprs_size()); - for (const auto& id_ref : value.unknown().exprs()) { - ids.emplace_back(id_ref.id()); - } - return common::Value::FromUnknown(common::Unknown(ids)); - } - default: - return common::Value::FromError(internal::UnimplementedError( - absl::StrCat("Unimplemented expr value kind: ", value.kind_case()))); - } -} - -common::Value ValueFrom(v1beta1::Value&& value, const TypeRegistry* registry) { - switch (value.kind_case()) { - case v1beta1::Value::kListValue: - return common::Value::MakeList( - registry, absl::WrapUnique(value.release_list_value())); - case v1beta1::Value::kMapValue: - return BuildMapFrom(std::move(*value.mutable_map_value()), registry); - default: - // All other cases do not take advantage of the rvalue. - return ValueFrom(value, registry); - } -} - -common::Value ValueFrom(v1beta1::ExprValue&& value, - const TypeRegistry* registry) { - switch (value.kind_case()) { - case v1beta1::ExprValue::kValue: - return ValueFrom(absl::WrapUnique(value.release_value()), registry); - default: - // All other cases cannot take advantage of the rvalue. - return ValueFrom(value, registry); - } -} - -common::Value ValueFor(const v1beta1::ExprValue* value, - const TypeRegistry* registry) { - switch (value->kind_case()) { - case v1beta1::ExprValue::kValue: - return ValueFor(&value->value(), registry); - default: - // All others can't take advantage of the unowned value. - return ValueFrom(*value, registry); - } -} - -common::Value ValueFrom(std::unique_ptr value, - const TypeRegistry* registry) { - return ValueFrom(std::move(*value), registry); -} - -common::Value ValueFrom(std::unique_ptr value, - const TypeRegistry* registry) { - return ValueFrom(std::move(*value), registry); -} - -common::Value ValueFor(const v1beta1::Value* value, - const TypeRegistry* registry) { - return ValueFor(value, common::NoParent(), registry); -} - -google::rpc::Status ValueTo(const common::Value& value, - v1beta1::ExprValue* result) { - return value.visit(ToExprValue{value, result}); -} - -namespace { - -common::Value ValueFor(const v1beta1::Value* value, common::ParentRef parent, - const TypeRegistry* registry) { - if (parent == absl::nullopt) { - return ValueFrom(*value, registry); - } - switch (value->kind_case()) { - case v1beta1::Value::kListValue: - if (parent->RequiresReference()) { - return common::Value::MakeList>>( - registry, parent->GetRef(), &value->list_value()); - } else { - return common::Value::MakeList>( - registry, &value->list_value()); - } - case v1beta1::Value::kMapValue: - return BuildMapFor(&value->map_value(), parent, registry); - default: - return ValueFrom(*value, registry); - } -} - -} // namespace -} // namespace v1beta1 -} // namespace expr -} // namespace api -} // namespace google diff --git a/v1beta1/converters.h b/v1beta1/converters.h deleted file mode 100644 index 549e5b3e4..000000000 --- a/v1beta1/converters.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef THIRD_PARTY_CEL_CPP_COMMON_V1_BETA1_CONVERTERS_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_V1_BETA1_CONVERTERS_H_ - -#include "google/api/expr/v1beta1/eval.pb.h" -#include "google/api/expr/v1beta1/value.pb.h" -#include "common/converters.h" -#include "common/value.h" -#include "protoutil/type_registry.h" - -namespace google { -namespace api { -namespace expr { -namespace v1beta1 { - -/** Decode a v1beta1::Value. */ -common::Value ValueFrom(const v1beta1::Value& value, - const protoutil::TypeRegistry* registry); -/** Decode a v1beta1::Value. */ -common::Value ValueFrom(v1beta1::Value&& value, - const protoutil::TypeRegistry* registry); -/** Decode a v1beta1::Value. */ -common::Value ValueFrom(std::unique_ptr value, - const protoutil::TypeRegistry* registry); -/** Decode a v1beta1::Value. */ -common::Value ValueFor(const v1beta1::Value* value, - const protoutil::TypeRegistry* registry); - -/** Decode a v1beta1::ExprValue. */ -common::Value ValueFrom(const v1beta1::ExprValue& value, - const protoutil::TypeRegistry* registry); -/** Decode a v1beta1::ExprValue. */ -common::Value ValueFrom(v1beta1::ExprValue&& value, - const protoutil::TypeRegistry* registry); -/** Decode a v1beta1::ExprValue. */ -common::Value ValueFrom(std::unique_ptr value, - const protoutil::TypeRegistry* registry); -/** Decode a v1beta1::ExprValue. */ -common::Value ValueFor(const v1beta1::ExprValue* value, - const protoutil::TypeRegistry* registry); - -/** Encode a v1beta1::ExprValue. */ -google::rpc::Status ValueTo(const common::Value& value, - v1beta1::ExprValue* result); - -} // namespace v1beta1 -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_COMMON_V1_BETA1_CONVERTERS_H_ diff --git a/v1beta1/converters_test.cc b/v1beta1/converters_test.cc deleted file mode 100644 index 5027171fa..000000000 --- a/v1beta1/converters_test.cc +++ /dev/null @@ -1,212 +0,0 @@ -#include "v1beta1/converters.h" - -#include "google/protobuf/empty.pb.h" -#include "google/type/money.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "common/value.h" -#include "internal/status_util.h" -#include "protoutil/converters.h" -#include "protoutil/type_registry.h" -#include "testutil/test_data_io.h" -#include "testutil/test_data_util.h" -#include "testdata/test_data.pb.h" - -namespace google { -namespace api { -namespace expr { -namespace v1beta1 { - -using protoutil::TypeRegistry; -using testdata::TestValue; - -namespace { - -const TypeRegistry* kReg = []() { - auto* reg = new TypeRegistry; - protoutil::RegisterConvertersWith(reg); - return reg; -}(); - -class ValueTest : public ::testing::TestWithParam { - public: - ValueTest() { v1beta1::InitValueDifferencer(&v1beta1_differ_); } - - ::testing::AssertionResult IsEquiv(const v1beta1::ExprValue& lhs, - const v1beta1::ExprValue& rhs) { - std::string diff; - v1beta1_differ_.ReportDifferencesToString(&diff); - if (v1beta1_differ_.Compare(lhs, rhs)) { - return ::testing::AssertionSuccess(); - } - return ::testing::AssertionFailure() << diff; - } - - private: - google::protobuf::util::MessageDifferencer v1beta1_differ_; -}; - -TEST_P(ValueTest, SelfEqual) { - for (const auto& lhs : GetParam().v1beta1()) { - SCOPED_TRACE(lhs.ShortDebugString()); - common::Value lhs_val = ValueFrom(lhs, kReg); - for (const auto& rhs : GetParam().v1beta1()) { - SCOPED_TRACE(rhs.ShortDebugString()); - common::Value rhs_val = ValueFrom(rhs, kReg); - EXPECT_EQ(lhs_val.hash_code(), rhs_val.hash_code()); - EXPECT_EQ(lhs_val, rhs_val); - } - } -} - -TEST_P(ValueTest, RoundTrip_FromRef) { - for (const auto& expected : GetParam().v1beta1()) { - SCOPED_TRACE(expected.ShortDebugString()); - auto cel_value = ValueFrom(expected, kReg); - EXPECT_TRUE(cel_value.owns_value()); - v1beta1::ExprValue actual; - ValueTo(cel_value, &actual); - EXPECT_TRUE(IsEquiv(expected, actual)); - } -} - -TEST_P(ValueTest, RoundTrip_FromPtr) { - for (const auto& expected : GetParam().v1beta1()) { - SCOPED_TRACE(expected.ShortDebugString()); - auto cel_value = ValueFrom(absl::make_unique(expected), kReg); - EXPECT_TRUE(cel_value.owns_value()); - v1beta1::ExprValue actual; - ValueTo(cel_value, &actual); - EXPECT_TRUE(IsEquiv(expected, actual)); - } -} - -TEST_P(ValueTest, RoundTrip_FromMove) { - for (const auto& expected : GetParam().v1beta1()) { - SCOPED_TRACE(expected.ShortDebugString()); - auto cel_value = ValueFrom(ExprValue(expected), kReg); - EXPECT_TRUE(cel_value.owns_value()); - v1beta1::ExprValue actual; - ValueTo(cel_value, &actual); - EXPECT_TRUE(IsEquiv(expected, actual)); - } -} - -TEST_P(ValueTest, RoundTrip_For) { - for (const auto& expected : GetParam().v1beta1()) { - SCOPED_TRACE(expected.ShortDebugString()); - auto cel_value = v1beta1::ValueFor(&expected, kReg); - v1beta1::ExprValue actual; - ValueTo(cel_value, &actual); - EXPECT_TRUE(IsEquiv(expected, actual)); - } -} - -INSTANTIATE_TEST_SUITE_P( - UniqueValues, ValueTest, - ::testing::ValuesIn( - testutil::ReadTestData("unique_values").test_values().values()), - testutil::TestDataParamName()); - -class UniqueValueTest - : public ::testing::TestWithParam> { - public: -}; - -TEST_P(UniqueValueTest, NotEqual) { - for (const auto& lhs : GetParam().first.v1beta1()) { - SCOPED_TRACE(lhs.ShortDebugString()); - auto lhs_value = ValueFor(&lhs, kReg); - for (const auto& rhs : GetParam().second.v1beta1()) { - SCOPED_TRACE(rhs.ShortDebugString()); - auto rhs_value = ValueFor(&rhs, kReg); - EXPECT_NE(lhs_value, rhs_value); - } - } -} -INSTANTIATE_TEST_SUITE_P( - All, UniqueValueTest, - ::testing::ValuesIn(testutil::AllPairs( - testutil::ReadTestData("unique_values").test_values())), - testutil::TestDataParamName()); - -TEST(ConvertersTest, List) { - auto value = - ValueFrom(testutil::NewListValue(1, 2u, 3.0, "four").v1beta1(0), kReg); - const auto& list = value.list_value(); - - auto error = common::Value::FromError(expr::internal::OutOfRangeError(4, 4)); - EXPECT_EQ(4, list.size()); - EXPECT_EQ(common::Value::FromInt(1), list.Get(0)); - EXPECT_EQ(common::Value::FromUInt(2), list.Get(1)); - EXPECT_EQ(common::Value::FromDouble(3), list.Get(2)); - EXPECT_EQ(common::Value::ForString("four"), list.Get(3)); - EXPECT_EQ(error, list.Get(4)); - - EXPECT_EQ(common::Value::FromBool(true), - list.Contains(common::Value::FromInt(1))); - EXPECT_EQ(common::Value::FromBool(true), - list.Contains(common::Value::FromUInt(2))); - EXPECT_EQ(common::Value::FromBool(true), - list.Contains(common::Value::FromDouble(3))); - EXPECT_EQ(common::Value::FromBool(true), - list.Contains(common::Value::ForString("four"))); - EXPECT_EQ(error, list.Contains(error)); - - EXPECT_EQ(common::Value::FromBool(false), - list.Contains(common::Value::FromUInt(1))); - EXPECT_EQ(common::Value::FromBool(false), - list.Contains(common::Value::FromInt(2))); - EXPECT_EQ(common::Value::FromBool(false), - list.Contains(common::Value::FromInt(3))); - EXPECT_EQ(common::Value::FromBool(false), - list.Contains(common::Value::FromInt(4))); - - int i = 0; - list.ForEach([&i, &list](const common::Value& elem) { - EXPECT_EQ(list.Get(i++), elem); - return internal::OkStatus(); - }); - EXPECT_EQ(i, 4); -} - -TEST(ConvertersTest, Map) { - auto value = - ValueFrom(testutil::NewMapValue(1, 2u, 3.0, "four").v1beta1(0), kReg); - const auto& map = value.map_value(); - - auto error1 = common::Value::FromError(expr::internal::NoSuchKey("2u")); - auto error2 = common::Value::FromError(expr::internal::NoSuchKey("\"four\"")); - EXPECT_EQ(2, map.size()); - EXPECT_EQ(common::Value::FromUInt(2), map.Get(common::Value::FromInt(1))); - EXPECT_EQ(error1, map.Get(common::Value::FromUInt(2))); - EXPECT_EQ(common::Value::ForString("four"), - map.Get(common::Value::FromDouble(3))); - EXPECT_EQ(error2, map.Get(common::Value::ForString("four"))); - - int i = 0; - map.ForEach([&i, &map](const common::Value& key, const common::Value& value) { - i++; - EXPECT_EQ(value, map.Get(key)); - return expr::internal::OkStatus(); - }); - EXPECT_EQ(i, 2); -} - -TEST(ConvertersTest, BadValue) { - v1beta1::ExprValue result; - auto bad_value = common::Value::FromTime(absl::InfiniteFuture()); - auto status = ValueTo(bad_value, &result); - auto expected = common::Value::FromError( - expr::internal::InvalidArgumentError("time above max")); - // Status returns the expected error code. - EXPECT_EQ(common::Value::FromError(status), expected); - // The result also encodes the error. - EXPECT_EQ(expected, ValueFrom(result, kReg)); -} - -} // namespace -} // namespace v1beta1 -} // namespace expr -} // namespace api -} // namespace google diff --git a/validator/BUILD b/validator/BUILD new file mode 100644 index 000000000..9910a6b97 --- /dev/null +++ b/validator/BUILD @@ -0,0 +1,214 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "validator", + srcs = ["validator.cc"], + hdrs = ["validator.h"], + deps = [ + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:navigable_ast", + "//common:source", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "validator_test", + srcs = ["validator_test.cc"], + deps = [ + ":validator", + "//checker:type_check_issue", + "//common:ast", + "//common:expr", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "timestamp_literal_validator_test", + srcs = ["timestamp_literal_validator_test.cc"], + deps = [ + ":timestamp_literal_validator", + ":validator", + "//checker:validation_result", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "timestamp_literal_validator", + srcs = ["timestamp_literal_validator.cc"], + hdrs = ["timestamp_literal_validator.h"], + deps = [ + ":validator", + "//common:constant", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:time", + "//tools:navigable_ast", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "ast_depth_validator", + srcs = ["ast_depth_validator.cc"], + hdrs = ["ast_depth_validator.h"], + deps = [ + ":validator", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "homogeneous_literal_validator", + srcs = ["homogeneous_literal_validator.cc"], + hdrs = ["homogeneous_literal_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "regex_validator", + srcs = ["regex_validator.cc"], + hdrs = ["regex_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:constant", + "//common:expr", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:re2_options", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "homogeneous_literal_validator_test", + srcs = ["homogeneous_literal_validator_test.cc"], + deps = [ + ":homogeneous_literal_validator", + ":validator", + "//checker:validation_result", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:strings", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "ast_depth_validator_test", + srcs = ["ast_depth_validator_test.cc"], + deps = [ + ":ast_depth_validator", + ":validator", + "//checker:type_check_issue", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/log:absl_check", + ], +) + +cc_test( + name = "regex_validator_test", + srcs = ["regex_validator_test.cc"], + deps = [ + ":regex_validator", + ":validator", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "comprehension_nesting_validator", + srcs = ["comprehension_nesting_validator.cc"], + hdrs = ["comprehension_nesting_validator.h"], + deps = [ + ":validator", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "comprehension_nesting_validator_test", + srcs = ["comprehension_nesting_validator_test.cc"], + deps = [ + ":comprehension_nesting_validator", + ":validator", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + +licenses(["notice"]) diff --git a/validator/ast_depth_validator.cc b/validator/ast_depth_validator.cc new file mode 100644 index 000000000..0f6b8d93d --- /dev/null +++ b/validator/ast_depth_validator.cc @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include "absl/strings/str_cat.h" +#include "validator/validator.h" + +namespace cel { + +Validation AstDepthValidator(int max_depth) { + return Validation([max_depth](ValidationContext& context) { + int height = context.navigable_ast().Root().height(); + if (height > max_depth) { + context.ReportError(absl::StrCat("AST depth ", height, + " exceeds maximum of ", max_depth)); + return false; + } + return true; + }); +} + +} // namespace cel diff --git a/validator/ast_depth_validator.h b/validator/ast_depth_validator.h new file mode 100644 index 000000000..a640af12e --- /dev/null +++ b/validator/ast_depth_validator.h @@ -0,0 +1,27 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks the AST depth is less than or equal to +// max_depth. +Validation AstDepthValidator(int max_depth); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ diff --git a/validator/ast_depth_validator_test.cc b/validator/ast_depth_validator_test.cc new file mode 100644 index 000000000..eda59b40d --- /dev/null +++ b/validator/ast_depth_validator_test.cc @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "checker/type_check_issue.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +std::unique_ptr CreateCompiler() { + auto builder = NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()); + ABSL_CHECK_OK(builder); + ABSL_CHECK_OK((*builder)->AddLibrary(StandardCompilerLibrary())); + auto compiler = (*builder)->Build(); + ABSL_CHECK_OK(compiler); + return *std::move(compiler); +} + +TEST(AstDepthValidatorTest, Basic) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("1 + 2 + 3")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(2)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 3 exceeds maximum of 2")))); +} + +TEST(AstDepthValidatorTest, Nested) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile("1 + (2 + (3 + (4 + 5)))")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(4)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 5 exceeds maximum of 4")))); +} + +} // namespace +} // namespace cel diff --git a/validator/comprehension_nesting_validator.cc b/validator/comprehension_nesting_validator.cc new file mode 100644 index 000000000..81c47cbc3 --- /dev/null +++ b/validator/comprehension_nesting_validator.cc @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool IsEmptyRangeComprehension(const NavigableAstNode& node) { + ABSL_DCHECK(node.expr()->has_comprehension_expr()); + const auto& comp = node.expr()->comprehension_expr(); + return comp.has_iter_range() && comp.iter_range().has_list_expr() && + comp.iter_range().list_expr().elements().empty(); +} + +} // namespace + +Validation ComprehensionNestingLimitValidator(int limit) { + return Validation( + [limit](ValidationContext& context) -> bool { + bool is_valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kComprehension) { + continue; + } + if (IsEmptyRangeComprehension(node)) { + continue; + } + + int count = 0; + const NavigableAstNode* current = &node; + while (current != nullptr) { + if (current->node_kind() == NodeKind::kComprehension && + !IsEmptyRangeComprehension(*current)) { + count++; + } + current = current->parent(); + } + if (count > limit) { + context.ReportErrorAt( + node.expr()->id(), + absl::StrCat("comprehension nesting level of ", count, + " exceeds limit of ", limit)); + is_valid = false; + break; + } + } + return is_valid; + }, + "cel.validator.comprehension_nesting_limit"); +} + +} // namespace cel diff --git a/validator/comprehension_nesting_validator.h b/validator/comprehension_nesting_validator.h new file mode 100644 index 000000000..4dab78db0 --- /dev/null +++ b/validator/comprehension_nesting_validator.h @@ -0,0 +1,31 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that comprehensions are not nested beyond +// the specified limit. +// +// Comprehensions with an empty iteration range (e.g. `cel.bind`) do not count +// towards the nesting limit. +Validation ComprehensionNestingLimitValidator(int limit); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ diff --git a/validator/comprehension_nesting_validator_test.cc b/validator/comprehension_nesting_validator_test.cc new file mode 100644 index 000000000..c1b47f82d --- /dev/null +++ b/validator/comprehension_nesting_validator_test.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + return builder->Build(); +} + +struct TestCase { + std::string expression; + int limit; + bool valid; + std::string error_substr = ""; +}; + +using ComprehensionNestingValidatorTest = testing::TestWithParam; + +TEST_P(ComprehensionNestingValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(ComprehensionNestingLimitValidator(test_case.limit)); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + auto result_or = compiler->Compile(test_case.expression); + if (!result_or.ok()) { + GTEST_SKIP() << "Expression failed to compile: " << test_case.expression + << " " << result_or.status().message(); + } + auto result = std::move(result_or).value(); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression + << " Limit: " << test_case.limit; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionNestingValidatorTest, ComprehensionNestingValidatorTest, + testing::Values( + TestCase{"[1, 2].all(x, x > 0)", 1, true}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 2, true}, + // Empty range comprehension (does not count) + TestCase{"[].all(x, [1, 2].all(y, y > 0))", 1, true}, + TestCase{"cel.bind(x, [1, 2].all(y, y > 0), [1, 2].all(z, z > 0))", 1, + true}, + // Nested empty range comprehensions + TestCase{"[].all(x, [].all(y, true))", 0, true}, + // Deeply nested mixed + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 2, true})); + +} // namespace +} // namespace cel diff --git a/validator/homogeneous_literal_validator.cc b/validator/homogeneous_literal_validator.cc new file mode 100644 index 000000000..4a490dea2 --- /dev/null +++ b/validator/homogeneous_literal_validator.cc @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool InExemptFunction(const NavigableAstNode& node, + const std::vector& exempt_functions) { + const NavigableAstNode* parent = node.parent(); + while (parent != nullptr) { + if (parent->node_kind() == NodeKind::kCall) { + absl::string_view fn_name = parent->expr()->call_expr().function(); + for (const auto& exempt : exempt_functions) { + if (exempt == fn_name) { + return true; + } + } + } + parent = parent->parent(); + } + return false; +} + +bool IsOptional(const TypeSpec& t) { + return t.has_abstract_type() && t.abstract_type().name() == "optional_type"; +} + +const TypeSpec& GetOptionalParameter(const TypeSpec& t) { + return t.abstract_type().parameter_types()[0]; +} + +void TypeMismatch(ValidationContext& context, int64_t id, + const TypeSpec& expected, const TypeSpec& actual) { + context.ReportErrorAt( + id, absl::StrCat("expected type '", FormatTypeSpec(expected), + "' but found '", FormatTypeSpec(actual), "'")); +} + +bool TypeEquiv(const TypeSpec& a, const TypeSpec& b) { + if (a == b) { + return true; + } + + if (a.has_error() || b.has_error()) { + // Don't report mismatch if there's an error (type checking failed for the + // expression). + return true; + } + + if (a.has_wrapper() && b.has_primitive()) { + return a.wrapper() == b.primitive(); + } else if (a.has_primitive() && b.has_wrapper()) { + return a.primitive() == b.wrapper(); + } + + if (a.has_list_type() && b.has_list_type()) { + return TypeEquiv(a.list_type().elem_type(), b.list_type().elem_type()); + } + + if (a.has_map_type() && b.has_map_type()) { + return TypeEquiv(a.map_type().key_type(), b.map_type().key_type()) && + TypeEquiv(a.map_type().value_type(), b.map_type().value_type()); + } + + if (a.has_abstract_type() && b.has_abstract_type() && + a.abstract_type().name() == b.abstract_type().name() && + a.abstract_type().parameter_types().size() == + b.abstract_type().parameter_types().size()) { + for (int i = 0; i < a.abstract_type().parameter_types().size(); ++i) { + if (!TypeEquiv(a.abstract_type().parameter_types()[i], + b.abstract_type().parameter_types()[i])) { + return false; + } + } + return true; + } + + return false; +} + +} // namespace + +Validation HomogeneousLiteralValidator( + std::vector exempt_functions) { + return Validation([exempt_functions = std::move(exempt_functions)]( + ValidationContext& context) -> bool { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kList) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& list_expr = node.expr()->list_expr(); + const auto& elements = list_expr.elements(); + const TypeSpec* expected_type = nullptr; + + for (const auto& element : elements) { + int64_t id = element.expr().id(); + const TypeSpec& actual_type = context.ast().GetTypeOrDyn(id); + const TypeSpec* type_to_check = &actual_type; + + if (element.optional() && IsOptional(actual_type)) { + type_to_check = &GetOptionalParameter(actual_type); + } + + if (expected_type == nullptr) { + expected_type = type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_type, *type_to_check))) { + TypeMismatch(context, id, *expected_type, *type_to_check); + valid = false; + break; + } + } + } else if (node.node_kind() == NodeKind::kMap) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& map_expr = node.expr()->map_expr(); + const auto& entries = map_expr.entries(); + const TypeSpec* expected_key_type = nullptr; + const TypeSpec* expected_value_type = nullptr; + + for (const auto& entry : entries) { + int64_t key_id = entry.key().id(); + int64_t val_id = entry.value().id(); + const TypeSpec& actual_key_type = context.ast().GetTypeOrDyn(key_id); + const TypeSpec& actual_val_type = context.ast().GetTypeOrDyn(val_id); + const TypeSpec* key_type_to_check = &actual_key_type; + const TypeSpec* val_type_to_check = &actual_val_type; + + if (entry.optional() && IsOptional(actual_val_type)) { + val_type_to_check = &GetOptionalParameter(actual_val_type); + } + + if (expected_key_type == nullptr) { + expected_key_type = key_type_to_check; + expected_value_type = val_type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_key_type, *key_type_to_check))) { + TypeMismatch(context, key_id, *expected_key_type, + *key_type_to_check); + valid = false; + break; + } + if (!(TypeEquiv(*expected_value_type, *val_type_to_check))) { + TypeMismatch(context, val_id, *expected_value_type, + *val_type_to_check); + valid = false; + break; + } + } + } + } + return valid; + }); +} + +} // namespace cel diff --git a/validator/homogeneous_literal_validator.h b/validator/homogeneous_literal_validator.h new file mode 100644 index 000000000..e37648a25 --- /dev/null +++ b/validator/homogeneous_literal_validator.h @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ + +#include +#include + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that all literals in map or list literals +// are the same type. If the list or map is part of an argument to an exempted +// function, it is not checked. +Validation HomogeneousLiteralValidator( + std::vector exempt_functions); + +inline Validation HomogeneousLiteralValidator() { + // Default to exempting the strings extension "format" function. + return HomogeneousLiteralValidator({"format"}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ diff --git a/validator/homogeneous_literal_validator_test.cc b/validator/homogeneous_literal_validator_test.cc new file mode 100644 index 000000000..b027fa4b0 --- /dev/null +++ b/validator/homogeneous_literal_validator_test.cc @@ -0,0 +1,145 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError(); + builder->AddLibrary(extensions::StringsCompilerLibrary()).IgnoreError(); + cel::Type message_type = cel::Type::Message( + builder->GetCheckerBuilder().descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("msg", message_type))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using HomogeneousLiteralValidatorTest = testing::TestWithParam; + +TEST_P(HomogeneousLiteralValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(HomogeneousLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + HomogeneousLiteralValidatorTest, HomogeneousLiteralValidatorTest, + testing::Values( + // Lists + TestCase{"[1, 2, 3]", true}, TestCase{"['a', 'b', 'c']", true}, + TestCase{"[1, 'a']", false, "expected type 'int' but found 'string'"}, + TestCase{"[1, 2, 'a']", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1], [2]]", true}, + TestCase{"[[1], ['a']]", false, + "expected type 'list(int)' but found 'list(string)'"}, + + // Dyn casts + TestCase{"[dyn(1), dyn('a')]", true, ""}, + TestCase{"[dyn(1), 2]", false, "expected type 'dyn' but found 'int'"}, + + // Maps + TestCase{"{1: 'a', 2: 'b'}", true}, TestCase{"{'a': 1, 'b': 2}", true}, + TestCase{"{1: 'a', 'b': 2}", false, + "expected type 'int' but found 'string'"}, + TestCase{"{1: 'a', 2: 3}", false, + "expected type 'string' but found 'int'"}, + + // Optionals + TestCase{"[optional.of(1), optional.of(2)]", true}, + TestCase{"[optional.of(1), optional.of('b')]", false, + "expected type 'optional_type(int)' but found " + "'optional_type(string)'"}, + + TestCase{"[?optional.of(1), ?optional.of(2)]", true}, + TestCase{"[?optional.of(1), ?optional.of('a')]", false, + "expected type 'int' but found 'string'"}, + TestCase{"{?1: optional.of('a'), ?2: optional.none()}", true}, + TestCase{"{?1: optional.of('a'), ?2: optional.of(1)}", false, + "expected type 'string' but found 'int'"}, + + // Exempted Functions + TestCase{"'%v %v'.format([1, 'a'])", true}, + + // Mixed Primitives and Wrappers + TestCase{"[1, msg.single_int64_wrapper]", true}, + TestCase{"[msg.single_int64_wrapper, 1]", true}, + TestCase{"['foo', msg.single_string_wrapper]", true}, + TestCase{"[msg.single_string_wrapper, 'foo']", true}, + TestCase{"{1: msg.single_int64_wrapper, 2: 3}", true}, + TestCase{"{1: 2, 2: msg.single_int64_wrapper}", true}, + TestCase{"[[1], [msg.single_int64_wrapper]]", true}, + TestCase{"[optional.of(1), optional.of(msg.single_int64_wrapper)]", + true}, + TestCase{"[1, msg.single_string_wrapper]", false, + "expected type 'int' but found 'wrapper(string)'"}, + TestCase{"[msg.single_int64_wrapper, 'foo']", false, + "expected type 'wrapper(int)' but found 'string'"}, + TestCase{"[msg.single_int64_wrapper, msg.single_string_wrapper]", false, + "expected type 'wrapper(int)' but found 'wrapper(string)'"}, + + // Nested + TestCase{"[1, [2, 'a']]", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1, 2], [3, 4]]", true, ""}, + TestCase{"[{1: 2}, {'foo': 3}]", false, + "expected type 'map(int, int)' but found 'map(string, int)'"}, + TestCase{"[{1: 2}, {3: 'foo'}]", false, + "expected type 'map(int, int)' but found 'map(int, string)'"}, + TestCase{"[{1: 2}, {3: 4}]", true, ""})); + +} // namespace +} // namespace cel diff --git a/validator/regex_validator.cc b/validator/regex_validator.cc new file mode 100644 index 000000000..df92bfb1e --- /dev/null +++ b/validator/regex_validator.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "internal/re2_options.h" +#include "validator/validator.h" +#include "re2/re2.h" + +namespace cel { + +namespace { + +bool CheckPattern(ValidationContext& context, const NavigableAstNode& node, + int arg_index) { + ABSL_DCHECK(node.expr()->has_call_expr()); + const auto& call_expr = node.expr()->call_expr(); + + const Expr* pattern_expr = nullptr; + + if (call_expr.has_target()) { + if (arg_index == 0) { + pattern_expr = &call_expr.target(); + } else if (call_expr.args().size() > arg_index - 1) { + pattern_expr = &call_expr.args()[arg_index - 1]; + } + } else if (call_expr.args().size() > arg_index) { + pattern_expr = &call_expr.args()[arg_index]; + } + + if (pattern_expr == nullptr || !pattern_expr->has_const_expr()) { + return true; + } + + const auto& const_expr = pattern_expr->const_expr(); + if (!const_expr.has_string_value()) { + return true; + } + + absl::string_view pattern_string = const_expr.string_value(); + RE2 re(pattern_string, internal::MakeRE2Options()); + if (!re.ok()) { + context.ReportErrorAt( + pattern_expr->id(), + absl::StrCat("invalid regular expression: ", re.error())); + return false; + } + return true; +} + +} // namespace + +Validation RegexPatternValidator( + absl::string_view id, std::vector config) { + return Validation( + [config = std::move(config)](ValidationContext& context) -> bool { + bool result = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kCall) { + for (const auto& config : config) { + if (node.expr()->call_expr().function() == config.function_name) { + if (!CheckPattern(context, node, config.pattern_arg_index)) { + result = false; + } + break; + } + } + } + } + return result; + }, + id); +} + +} // namespace cel diff --git a/validator/regex_validator.h b/validator/regex_validator.h new file mode 100644 index 000000000..15ee1755e --- /dev/null +++ b/validator/regex_validator.h @@ -0,0 +1,53 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "common/standard_definitions.h" +#include "validator/validator.h" + +namespace cel { + +// Configuration for the regex pattern validator. +struct RegexPatternValidatorConfig { + // The resolved function name. + std::string function_name; + // the index of the pattern argument (counting the receiver as arg 0 if + // present). + int pattern_arg_index; +}; + +// Returns a `Validation` that checks all calls to the given regex functions +// It validates that the specified argument is a valid regular expression if it +// is a literal string. +Validation RegexPatternValidator( + absl::string_view id, std::vector config); + +// Returns a `Validation` that checks all calls to the CEL `matches` function. +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +inline Validation MatchesValidator() { + return RegexPatternValidator( + "cel.validator.matches", + {{std::string(StandardFunctions::kRegexMatch), 1}}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ diff --git a/validator/regex_validator_test.cc b/validator/regex_validator_test.cc new file mode 100644 index 000000000..cfab1468d --- /dev/null +++ b/validator/regex_validator_test.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("p", StringType()))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using MatchesValidatorTest = testing::TestWithParam; + +TEST_P(MatchesValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(MatchesValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + MatchesValidatorTest, MatchesValidatorTest, + testing::Values( + // Member calls + TestCase{"'hello'.matches('h.*')", true}, + TestCase{"'hello'.matches('h[')", false, "invalid regular expression"}, + TestCase{"'hello'.matches('h(a|b)')", true}, + TestCase{"'hello'.matches('h(a|b')", false, + "invalid regular expression"}, + // Global calls + TestCase{"matches('hello', 'h.*')", true}, + TestCase{"matches('hello', 'h[')", false, "invalid regular expression"}, + // Non-literal patterns (should not report regex errors) + TestCase{"'hello'.matches(p)", true}, + TestCase{"'hello'.matches('h' + 'ello')", true}, + TestCase{"'hello'.matches(dyn(1))", true}, + + // Empty pattern + TestCase{"'hello'.matches('')", true})); + +} // namespace +} // namespace cel diff --git a/validator/timestamp_literal_validator.cc b/validator/timestamp_literal_validator.cc new file mode 100644 index 000000000..8b9b76ebb --- /dev/null +++ b/validator/timestamp_literal_validator.cc @@ -0,0 +1,134 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "common/navigable_ast.h" +#include "common/standard_definitions.h" +#include "internal/time.h" +#include "tools/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +bool ValidateTimestamps(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kTimestamp) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + absl::Time ts; + const Constant& constant = child.expr()->const_expr(); + if (constant.has_string_value()) { + absl::string_view timestamp_str = + child.expr()->const_expr().string_value(); + if (!absl::ParseTime(absl::RFC3339_full, timestamp_str, &ts, nullptr)) { + context.ReportErrorAt(child.expr()->id(), "invalid timestamp literal"); + valid = false; + continue; + } + } else if (constant.has_int_value()) { + ts = absl::FromUnixSeconds(constant.int_value()); + } else { + // Checker should have already reported an error. + continue; + } + + if (absl::Status status = internal::ValidateTimestamp(ts); !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid timestamp literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +bool ValidateDurations(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kDuration) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + const Constant& constant = child.expr()->const_expr(); + if (!constant.has_string_value()) { + continue; + } + absl::Duration duration; + + absl::string_view duration_str = child.expr()->const_expr().string_value(); + if (!absl::ParseDuration(duration_str, &duration)) { + context.ReportErrorAt(child.expr()->id(), "invalid duration literal"); + valid = false; + continue; + } + + if (absl::Status status = internal::ValidateDuration(duration); + !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid duration literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +} // namespace + +const Validation& TimestampLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateTimestamps, "cel.validator.timestamp"); + return *kInstance; +} + +// Returns a validator that checks duration literals. +const Validation& DurationLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateDurations, "cel.validator.duration"); + return *kInstance; +} + +} // namespace cel diff --git a/validator/timestamp_literal_validator.h b/validator/timestamp_literal_validator.h new file mode 100644 index 000000000..6d2a39318 --- /dev/null +++ b/validator/timestamp_literal_validator.h @@ -0,0 +1,29 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ + +#include "validator/validator.h" +namespace cel { + +// Returns a `Validation` that checks timestamp literals are valid for CEL. +const Validation& TimestampLiteralValidator(); + +// Returns a `Validation` that checks duration literals are valid for CEL. +const Validation& DurationLiteralValidator(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ diff --git a/validator/timestamp_literal_validator_test.cc b/validator/timestamp_literal_validator_test.cc new file mode 100644 index 000000000..136f7d645 --- /dev/null +++ b/validator/timestamp_literal_validator_test.cc @@ -0,0 +1,146 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + auto builder = + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()).value(); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + return builder->Build(); +} + +class TimestampLiteralValidatorTest : public ::testing::Test { + protected: + TimestampLiteralValidatorTest() { + validator_.AddValidation(TimestampLiteralValidator()); + } + + std::unique_ptr compiler_; + Validator validator_; +}; + +TEST(TimestampLiteralValidatorTest, FormatsIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile("timestamp('invalid')")); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_EQ(result.FormatError(), + R"(ERROR: :1:11: invalid timestamp literal + | timestamp('invalid') + | ..........^)"); +} + +TEST(TimestampLiteralValidatorTest, AccumulatesIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + constexpr absl::string_view kExpression = R"cel( + [ timestamp('invalid'), + timestamp('9999-12-31T23:59:59Z'), + timestamp('10000-01-01T00:00:00Z') + ].all(t, + t - timestamp(0) < duration('10000s') && + t - timestamp(0) > duration("invalid") + ))cel"; + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile(kExpression)); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + AllOf(HasSubstr("2:17: invalid timestamp literal"), + HasSubstr("4:17: invalid timestamp literal"), + HasSubstr("7:35: invalid duration literal"))); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using TimestampLiteralValidatorParameterizedTest = + testing::TestWithParam; + +TEST_P(TimestampLiteralValidatorParameterizedTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + TimestampLiteralValidatorParameterizedTest, + TimestampLiteralValidatorParameterizedTest, + ::testing::Values( + TestCase{"timestamp('2023-01-01T00:00:00Z')", true}, + TestCase{"timestamp('9999-12-31T23:59:59Z')", true}, + TestCase{"timestamp('invalid')", false, "invalid timestamp literal"}, + TestCase{"timestamp('10000-01-01T00:00:00Z')", false, + "invalid timestamp literal"}, + TestCase{"timestamp(0)", true}, + TestCase{"timestamp(-62135596801)", false, + "invalid timestamp literal: Timestamp \"0-12-31T23:59:59Z\" " + "below minimum allowed timestamp \"1-01-01T00:00:00Z\""}, + TestCase{"timestamp(253402300800)", false, + "invalid timestamp literal: Timestamp " + "\"10000-01-01T00:00:00Z\" above maximum allowed timestamp " + "\"9999-12-31T23:59:59.999999999Z\""}, + TestCase{"duration('1s')", true}, + TestCase{"duration('invalid')", false, "invalid duration literal"}, + TestCase{"duration('-1000000000000s')", false, + "below minimum allowed duration"}, + TestCase{"duration('1000000000000s')", false, + "above maximum allowed duration"})); + +} // namespace +} // namespace cel diff --git a/validator/validator.cc b/validator/validator.cc new file mode 100644 index 000000000..e000c71e8 --- /dev/null +++ b/validator/validator.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" + +namespace cel { + +void Validator::AddValidation(Validation validation) { + ABSL_DCHECK(validation); + if (!validation) return; + validations_.push_back(std::move(validation)); +} + +Validator::ValidationOutput Validator::Validate(const Ast& ast) const { + ValidationOutput result; + ValidationContext context(ast); + for (const auto& validation : validations_) { + if (!validation(context)) { + result.valid = false; + } + } + result.issues = context.ReleaseIssues(); + return result; +} + +void Validator::UpdateValidationResult(ValidationResult& in) const { + if (!in.IsValid() || in.GetAst() == nullptr) { + // If the result is already decided invalid, just return it. + return; + } + + auto result = Validate(*in.GetAst()); + if (!result.valid) { + in.ReleaseAst().IgnoreError(); + } + for (auto& issue : result.issues) { + in.AddIssue(std::move(issue)); + } +} + +void ValidationContext::ReportWarningAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportErrorAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportWarning(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + SourceLocation{}, std::string(message))); +} + +void ValidationContext::ReportError(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + SourceLocation{}, std::string(message))); +} + +} // namespace cel diff --git a/validator/validator.h b/validator/validator.h new file mode 100644 index 000000000..a278bd44f --- /dev/null +++ b/validator/validator.h @@ -0,0 +1,151 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/navigable_ast.h" +namespace cel { + +// Context for a validation pass. +// +// Assumed to be scoped to a Validator::Validate() call. Instances must not +// outlive the `ast` passed to the constructor. +class ValidationContext { + public: + explicit ValidationContext(const Ast& ast ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ast_(ast) {} + + const Ast& ast() const { return ast_; } + const NavigableAst& navigable_ast() const { + if (!navigable_ast_) { + navigable_ast_ = NavigableAst::Build(ast_.root_expr()); + } + return navigable_ast_; + } + + void ReportWarningAt(int64_t id, absl::string_view message); + void ReportErrorAt(int64_t id, absl::string_view message); + void ReportWarning(absl::string_view message); + void ReportError(absl::string_view message); + + std::vector ReleaseIssues() { + auto out = std::move(issues_); + issues_.clear(); + return out; + } + + private: + const Ast& ast_; + mutable NavigableAst navigable_ast_; + std::vector issues_; +}; + +// A single validation to apply to an AST. +// +// May be empty if default constructed or moved from. +// use operator bool() to check if the validation is empty. +class Validation { + public: + // Tests the AST reports any issues to the context. + // + // Returns false if the AST is invalid. + // + // The same instance is used across Validate() so must be thread safe + // (typically stateless). + using ImplFunction = + absl::AnyInvocable; + + Validation() = default; + explicit Validation(ImplFunction impl); + Validation(ImplFunction impl, absl::string_view id); + + const ImplFunction& impl() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl; + } + + absl::string_view id() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->id; + } + + bool operator()(ValidationContext& context) const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl(context); + } + + explicit operator bool() const { return rep_ != nullptr; } + + private: + struct Rep { + ImplFunction impl; + // Optional id if supported in environment config. + std::string id; + }; + + std::shared_ptr rep_; +}; + +// A validator checks a set of semantic rules for a given AST. +class Validator { + public: + Validator() = default; + + void AddValidation(Validation validation); + absl::Span validations() const { return validations_; } + + struct ValidationOutput { + bool valid = true; + std::vector issues; + }; + + // Validates the given AST by applying all of the validations. + ValidationOutput Validate(const Ast& ast) const; + + // Validates the given AST, updating the validation result in place. + // + // Used to apply validators to the output of the type checker. + void UpdateValidationResult(ValidationResult& in) const; + + private: + std::vector validations_; +}; + +// Implementation details. +inline Validation::Validation(ImplFunction impl) + : rep_(std::make_shared( + Validation::Rep{std::move(impl)})) {} + +inline Validation::Validation(ImplFunction impl, absl::string_view id) + : rep_(std::make_shared( + Validation::Rep{std::move(impl), std::string(id)})) {} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ diff --git a/validator/validator_test.cc b/validator/validator_test.cc new file mode 100644 index 000000000..744475ec1 --- /dev/null +++ b/validator/validator_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; + +TEST(ValidatorTest, AddValidationAndValidate) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportError("error 1"); + return false; + })); + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportWarning("warning 1"); + return true; + })); + + Ast ast; + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + ElementsAre(Property(&TypeCheckIssue::message, Eq("error 1")), + Property(&TypeCheckIssue::message, Eq("warning 1")))); + EXPECT_EQ(output.issues[0].severity(), TypeCheckIssue::Severity::kError); + EXPECT_EQ(output.issues[1].severity(), TypeCheckIssue::Severity::kWarning); +} + +TEST(ValidatorTest, ReportAt) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportErrorAt(1, "error at 1"); + context.ReportWarningAt(2, "warning at 2"); + return false; + })); + + Expr expr; + expr.set_id(1); + SourceInfo source_info; + source_info.mutable_positions()[1] = 10; + source_info.mutable_positions()[2] = 20; + source_info.set_line_offsets({15, 25}); + + Ast ast(std::move(expr), std::move(source_info)); + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + ASSERT_EQ(output.issues.size(), 2); + + EXPECT_EQ(output.issues[0].location().line, 1); + EXPECT_EQ(output.issues[0].location().column, 10); + + EXPECT_EQ(output.issues[1].location().line, 2); + EXPECT_EQ(output.issues[1].location().column, 5); +} + +} // namespace +} // namespace cel